arcis_interface/
json.rs

1use crate::types::{CircuitInterface, Value};
2use serde::{
3    de::{self, MapAccess, Visitor},
4    Deserialize,
5    Deserializer,
6    Serialize,
7};
8use serde_json::{json, Value as JsonValue};
9use std::fmt;
10
11#[derive(Serialize, Deserialize, Debug)]
12pub struct ManticoreInterface {
13    pub inputs: Vec<String>,
14    pub outputs: Vec<String>,
15}
16
17impl ManticoreInterface {
18    pub fn new(inputs: Vec<String>, outputs: Vec<String>) -> Self {
19        Self { inputs, outputs }
20    }
21
22    pub fn serialize(&self) -> Result<String, serde_json::Error> {
23        serde_json::to_string(self)
24    }
25
26    pub fn from_json(input: &str) -> Result<Self, serde_json::Error> {
27        serde_json::from_str(input)
28    }
29}
30
31impl<'de> Deserialize<'de> for CircuitInterface {
32    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
33    where
34        D: Deserializer<'de>,
35    {
36        enum Field {
37            Name,
38            Inputs,
39            Outputs,
40        }
41
42        impl<'de> Deserialize<'de> for Field {
43            fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
44            where
45                D: Deserializer<'de>,
46            {
47                struct FieldVisitor;
48
49                impl Visitor<'_> for FieldVisitor {
50                    type Value = Field;
51
52                    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
53                        formatter.write_str("`name`, `inputs`, or `outputs`")
54                    }
55
56                    fn visit_str<E>(self, value: &str) -> Result<Field, E>
57                    where
58                        E: de::Error,
59                    {
60                        match value {
61                            "name" => Ok(Field::Name),
62                            "inputs" => Ok(Field::Inputs),
63                            "outputs" => Ok(Field::Outputs),
64                            _ => Err(de::Error::unknown_field(value, FIELDS)),
65                        }
66                    }
67                }
68
69                deserializer.deserialize_identifier(FieldVisitor)
70            }
71        }
72
73        struct CircuitInterfaceVisitor;
74
75        impl<'de> Visitor<'de> for CircuitInterfaceVisitor {
76            type Value = CircuitInterface;
77
78            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
79                formatter.write_str("struct CircuitInterface")
80            }
81
82            fn visit_map<V>(self, mut map: V) -> Result<CircuitInterface, V::Error>
83            where
84                V: MapAccess<'de>,
85            {
86                let mut name = None;
87                let mut inputs = None;
88                let mut outputs = None;
89
90                while let Some(key) = map.next_key()? {
91                    match key {
92                        Field::Name => {
93                            if name.is_some() {
94                                return Err(de::Error::duplicate_field("name"));
95                            }
96                            name = Some(map.next_value()?);
97                        }
98                        Field::Inputs => {
99                            if inputs.is_some() {
100                                return Err(de::Error::duplicate_field("inputs"));
101                            }
102                            inputs = Some(map.next_value()?);
103                        }
104                        Field::Outputs => {
105                            if outputs.is_some() {
106                                return Err(de::Error::duplicate_field("outputs"));
107                            }
108                            outputs = Some(map.next_value()?);
109                        }
110                    }
111                }
112
113                let name = name.ok_or_else(|| de::Error::missing_field("name"))?;
114                let inputs = inputs.ok_or_else(|| de::Error::missing_field("inputs"))?;
115                let outputs = outputs.ok_or_else(|| de::Error::missing_field("output"))?;
116
117                Ok(CircuitInterface {
118                    name,
119                    inputs,
120                    outputs,
121                })
122            }
123        }
124
125        const FIELDS: &[&str] = &["name", "inputs", "outputs"];
126        deserializer.deserialize_struct("CircuitInterface", FIELDS, CircuitInterfaceVisitor)
127    }
128}
129
130impl Value {
131    fn get_scalar_type_name(size_in_bits: usize) -> &'static str {
132        match size_in_bits {
133            8 => "u8",
134            16 => "u16",
135            32 => "u32",
136            64 => "u64",
137            128 => "u128",
138            _ => "scalar", // fallback for non-standard sizes
139        }
140    }
141}
142
143impl Serialize for Value {
144    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
145    where
146        S: serde::Serializer,
147    {
148        let json_value = match self {
149            Value::MScalar { size_in_bits } => json!({
150                "type": "mscalar",
151                "size_in_bits": size_in_bits
152            }),
153            Value::MFloat { size_in_bits } => json!({
154                "type": "mfloat",
155                "size_in_bits": size_in_bits
156            }),
157            Value::MBool => json!({
158                "type": "mbool"
159            }),
160            Value::Scalar { size_in_bits } => json!({
161                "type": Value::get_scalar_type_name(*size_in_bits),
162                "size_in_bits": size_in_bits
163            }),
164            Value::Float { size_in_bits } => json!({
165                "type": "float",
166                "size_in_bits": size_in_bits
167            }),
168            Value::Bool => json!({
169                "type": "bool"
170            }),
171            Value::Ciphertext { size_in_bits } => json!({
172                "type": "ciphertext",
173                "size_in_bits": size_in_bits
174            }),
175            Value::PublicKey { size_in_bits } => json!({
176                "type": "public_key",
177                "size_in_bits": size_in_bits
178            }),
179            Value::Point => json!({
180                "type": "point"
181            }),
182            Value::Array(vec) => json!({
183                "type": "array",
184                "content": vec
185            }),
186            Value::Tuple(vec) => json!({
187                "type": "tuple",
188                "content": vec
189            }),
190            Value::Struct(vec) => json!({
191                "type": "struct",
192                "content": vec
193            }),
194        };
195        json_value.serialize(serializer)
196    }
197}
198
199impl<'de> Deserialize<'de> for Value {
200    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
201    where
202        D: serde::Deserializer<'de>,
203    {
204        let json_value = JsonValue::deserialize(deserializer)?;
205
206        match json_value {
207            JsonValue::Object(map) => {
208                let type_ = map
209                    .get("type")
210                    .and_then(JsonValue::as_str)
211                    .ok_or_else(|| serde::de::Error::missing_field("type"))?;
212
213                match type_ {
214                    "mscalar" => {
215                        let size_in_bits = map
216                            .get("size_in_bits")
217                            .and_then(JsonValue::as_u64)
218                            .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
219                        Ok(Value::MScalar {
220                            size_in_bits: size_in_bits as usize,
221                        })
222                    }
223                    "mfloat" => {
224                        let size_in_bits = map
225                            .get("size_in_bits")
226                            .and_then(JsonValue::as_u64)
227                            .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
228                        Ok(Value::MFloat {
229                            size_in_bits: size_in_bits as usize,
230                        })
231                    }
232                    "mbool" => Ok(Value::MBool),
233                    "u8" => Ok(Value::Scalar { size_in_bits: 8 }),
234                    "u16" => Ok(Value::Scalar { size_in_bits: 16 }),
235                    "u32" => Ok(Value::Scalar { size_in_bits: 32 }),
236                    "u64" => Ok(Value::Scalar { size_in_bits: 64 }),
237                    "u128" => Ok(Value::Scalar { size_in_bits: 128 }),
238                    "scalar" => {
239                        let size_in_bits = map
240                            .get("size_in_bits")
241                            .and_then(JsonValue::as_u64)
242                            .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
243                        Ok(Value::Scalar {
244                            size_in_bits: size_in_bits as usize,
245                        })
246                    }
247                    "ciphertext" => {
248                        let size_in_bits = map
249                            .get("size_in_bits")
250                            .and_then(JsonValue::as_u64)
251                            .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
252                        Ok(Value::Ciphertext {
253                            size_in_bits: size_in_bits as usize,
254                        })
255                    }
256                    "public_key" => {
257                        let size_in_bits = map
258                            .get("size_in_bits")
259                            .and_then(JsonValue::as_u64)
260                            .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
261                        Ok(Value::PublicKey {
262                            size_in_bits: size_in_bits as usize,
263                        })
264                    }
265                    "float" => {
266                        let size_in_bits = map
267                            .get("size_in_bits")
268                            .and_then(JsonValue::as_u64)
269                            .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
270                        Ok(Value::Float {
271                            size_in_bits: size_in_bits as usize,
272                        })
273                    }
274                    "bool" => Ok(Value::Bool),
275                    "array" | "tuple" | "struct" => {
276                        let content = map
277                            .get("content")
278                            .ok_or_else(|| serde::de::Error::missing_field("content"))?;
279                        let vec: Vec<Value> =
280                            serde_json::from_value(content.clone()).map_err(|e| {
281                                serde::de::Error::custom(format!(
282                                    "Failed to deserialize content: {}",
283                                    e
284                                ))
285                            })?;
286                        match type_ {
287                            "array" => Ok(Value::Array(vec)),
288                            "tuple" => Ok(Value::Tuple(vec)),
289                            "struct" => Ok(Value::Struct(vec)),
290                            _ => unreachable!(),
291                        }
292                    }
293                    _ => Err(serde::de::Error::unknown_variant(
294                        type_,
295                        &[
296                            "mscalar",
297                            "mfloat",
298                            "mbool",
299                            "u8",
300                            "u16",
301                            "u32",
302                            "u64",
303                            "u128",
304                            "scalar",
305                            "float",
306                            "bool",
307                            "array",
308                            "tuple",
309                            "struct",
310                            "ciphertext",
311                            "public_key",
312                        ],
313                    )),
314                }
315            }
316            _ => Err(serde::de::Error::invalid_type(
317                serde::de::Unexpected::Other("non-object"),
318                &"object",
319            )),
320        }
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327    use serde_json::json;
328
329    #[test]
330    fn test_mscalar_serialization() {
331        let value = Value::MScalar { size_in_bits: 32 };
332        let serialized = serde_json::to_value(value).unwrap();
333        assert_eq!(
334            serialized,
335            json!({
336                "type": "mscalar",
337                "size_in_bits": 32
338            })
339        );
340    }
341
342    #[test]
343    fn test_mbool_serialization() {
344        let value = Value::MBool;
345        let serialized = serde_json::to_value(value).unwrap();
346        assert_eq!(
347            serialized,
348            json!({
349                "type": "mbool"
350            })
351        );
352    }
353
354    #[test]
355    fn test_bool_serialization() {
356        let value = Value::Bool;
357        let serialized = serde_json::to_value(value).unwrap();
358        assert_eq!(
359            serialized,
360            json!({
361                "type": "bool"
362            })
363        );
364    }
365
366    #[test]
367    fn test_array_serialization() {
368        let value = Value::Array(vec![Value::Scalar { size_in_bits: 60 }, Value::Bool]);
369        let serialized = serde_json::to_value(value).unwrap();
370        assert_eq!(
371            serialized,
372            json!({
373                "type": "array",
374                "content": [
375                    {
376                        "type": "scalar",
377                        "size_in_bits": 60
378                    },
379                    {
380                        "type": "bool"
381                    }
382                ]
383            })
384        );
385    }
386
387    #[test]
388    fn test_nested_structure_serialization() {
389        let value = Value::Struct(vec![
390            Value::Tuple(vec![Value::MScalar { size_in_bits: 32 }, Value::MBool]),
391            Value::Array(vec![Value::Scalar { size_in_bits: 64 }, Value::Bool]),
392        ]);
393        let serialized = serde_json::to_string(&value).unwrap();
394        let deserialized: Value = serde_json::from_str(&serialized).unwrap();
395        assert_eq!(value, deserialized);
396    }
397
398    #[test]
399    fn test_mscalar_deserialization() {
400        let json = r#"{"type": "mscalar", "size_in_bits": 32}"#;
401        let deserialized: Value = serde_json::from_str(json).unwrap();
402        assert_eq!(deserialized, Value::MScalar { size_in_bits: 32 });
403    }
404
405    #[test]
406    fn test_array_deserialization() {
407        let json = r#"
408        {
409            "type": "array",
410            "content": [
411                {"type": "scalar", "size_in_bits": 64},
412                {"type": "bool"}
413            ]
414        }"#;
415        let deserialized: Value = serde_json::from_str(json).unwrap();
416        assert_eq!(
417            deserialized,
418            Value::Array(vec![Value::Scalar { size_in_bits: 64 }, Value::Bool,])
419        );
420    }
421
422    #[test]
423    fn test_invalid_type_deserialization() {
424        let json = r#"{"type": "invalid_type"}"#;
425        let result: Result<Value, _> = serde_json::from_str(json);
426        assert!(result.is_err());
427    }
428
429    #[test]
430    fn test_missing_size_in_bits_deserialization() {
431        let json = r#"{"type": "mscalar"}"#;
432        let result: Result<Value, _> = serde_json::from_str(json);
433        assert!(result.is_err());
434    }
435
436    #[test]
437    fn test_plaintext_type_serialization() {
438        // Test all standard numeric sizes
439        let test_cases = [
440            (Value::Scalar { size_in_bits: 8 }, "u8"),
441            (Value::Scalar { size_in_bits: 16 }, "u16"),
442            (Value::Scalar { size_in_bits: 32 }, "u32"),
443            (Value::Scalar { size_in_bits: 64 }, "u64"),
444            (Value::Scalar { size_in_bits: 128 }, "u128"),
445            // Test non-standard size falls back to scalar
446            (Value::Scalar { size_in_bits: 24 }, "scalar"),
447        ];
448
449        for (value, expected_type) in test_cases {
450            // Test serialization
451            let serialized = serde_json::to_value(&value).unwrap();
452            let expected = match &value {
453                Value::Scalar { size_in_bits } => json!({
454                    "type": expected_type,
455                    "size_in_bits": size_in_bits
456                }),
457                Value::Bool => json!({
458                    "type": expected_type
459                }),
460                _ => unreachable!(),
461            };
462            assert_eq!(serialized, expected);
463
464            // Test deserialization
465            let json = serde_json::to_string(&expected).unwrap();
466            let deserialized: Value = serde_json::from_str(&json).unwrap();
467            assert_eq!(deserialized, value);
468        }
469    }
470}