arcis_interface/
json.rs

1use crate::types::{CircuitInterface, Value};
2use serde::{
3    de::{self, MapAccess, Visitor},
4    Deserialize,
5    Deserializer,
6    Serialize,
7};
8use serde_json::{json, Value as JsonValue};
9use std::fmt;
10
11#[derive(Serialize, Deserialize, Debug)]
12pub struct ManticoreInterface {
13    pub inputs: Vec<String>,
14    pub outputs: Vec<String>,
15}
16
17impl ManticoreInterface {
18    pub fn new(inputs: Vec<String>, outputs: Vec<String>) -> Self {
19        Self { inputs, outputs }
20    }
21
22    pub fn serialize(&self) -> Result<String, serde_json::Error> {
23        serde_json::to_string(self)
24    }
25
26    pub fn from_json(input: &str) -> Result<Self, serde_json::Error> {
27        serde_json::from_str(input)
28    }
29}
30
31impl<'de> Deserialize<'de> for CircuitInterface {
32    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
33    where
34        D: Deserializer<'de>,
35    {
36        enum Field {
37            Name,
38            Inputs,
39            Outputs,
40        }
41
42        impl<'de> Deserialize<'de> for Field {
43            fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
44            where
45                D: Deserializer<'de>,
46            {
47                struct FieldVisitor;
48
49                impl Visitor<'_> for FieldVisitor {
50                    type Value = Field;
51
52                    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
53                        formatter.write_str("`name`, `inputs`, or `outputs`")
54                    }
55
56                    fn visit_str<E>(self, value: &str) -> Result<Field, E>
57                    where
58                        E: de::Error,
59                    {
60                        match value {
61                            "name" => Ok(Field::Name),
62                            "inputs" => Ok(Field::Inputs),
63                            "outputs" => Ok(Field::Outputs),
64                            _ => Err(de::Error::unknown_field(value, FIELDS)),
65                        }
66                    }
67                }
68
69                deserializer.deserialize_identifier(FieldVisitor)
70            }
71        }
72
73        struct CircuitInterfaceVisitor;
74
75        impl<'de> Visitor<'de> for CircuitInterfaceVisitor {
76            type Value = CircuitInterface;
77
78            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
79                formatter.write_str("struct CircuitInterface")
80            }
81
82            fn visit_map<V>(self, mut map: V) -> Result<CircuitInterface, V::Error>
83            where
84                V: MapAccess<'de>,
85            {
86                let mut name = None;
87                let mut inputs = None;
88                let mut outputs = None;
89
90                while let Some(key) = map.next_key()? {
91                    match key {
92                        Field::Name => {
93                            if name.is_some() {
94                                return Err(de::Error::duplicate_field("name"));
95                            }
96                            name = Some(map.next_value()?);
97                        }
98                        Field::Inputs => {
99                            if inputs.is_some() {
100                                return Err(de::Error::duplicate_field("inputs"));
101                            }
102                            inputs = Some(map.next_value()?);
103                        }
104                        Field::Outputs => {
105                            if outputs.is_some() {
106                                return Err(de::Error::duplicate_field("outputs"));
107                            }
108                            outputs = Some(map.next_value()?);
109                        }
110                    }
111                }
112
113                let name = name.ok_or_else(|| de::Error::missing_field("name"))?;
114                let inputs = inputs.ok_or_else(|| de::Error::missing_field("inputs"))?;
115                let outputs = outputs.ok_or_else(|| de::Error::missing_field("output"))?;
116
117                Ok(CircuitInterface {
118                    name,
119                    inputs,
120                    outputs,
121                })
122            }
123        }
124
125        const FIELDS: &[&str] = &["name", "inputs", "outputs"];
126        deserializer.deserialize_struct("CircuitInterface", FIELDS, CircuitInterfaceVisitor)
127    }
128}
129
130impl Value {
131    fn get_scalar_type_name(size_in_bits: usize) -> &'static str {
132        match size_in_bits {
133            8 => "u8",
134            16 => "u16",
135            32 => "u32",
136            64 => "u64",
137            128 => "u128",
138            _ => "scalar", // fallback for non-standard sizes
139        }
140    }
141}
142
143impl Serialize for Value {
144    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
145    where
146        S: serde::Serializer,
147    {
148        let json_value = match self {
149            Value::MScalar { size_in_bits } => json!({
150                "type": "mscalar",
151                "size_in_bits": size_in_bits
152            }),
153            Value::MFloat { size_in_bits } => json!({
154                "type": "mfloat",
155                "size_in_bits": size_in_bits
156            }),
157            Value::MBool => json!({
158                "type": "mbool"
159            }),
160            Value::Scalar { size_in_bits } => json!({
161                "type": Value::get_scalar_type_name(*size_in_bits),
162                "size_in_bits": size_in_bits
163            }),
164            Value::Float { size_in_bits } => json!({
165                "type": "float",
166                "size_in_bits": size_in_bits
167            }),
168            Value::Bool => json!({
169                "type": "bool"
170            }),
171            Value::Ciphertext { size_in_bits } => json!({
172                "type": "ciphertext",
173                "size_in_bits": size_in_bits
174            }),
175            Value::PublicKey { size_in_bits } => json!({
176                "type": "public_key",
177                "size_in_bits": size_in_bits
178            }),
179            Value::Array(vec) => json!({
180                "type": "array",
181                "content": vec
182            }),
183            Value::Tuple(vec) => json!({
184                "type": "tuple",
185                "content": vec
186            }),
187            Value::Struct(vec) => json!({
188                "type": "struct",
189                "content": vec
190            }),
191        };
192        json_value.serialize(serializer)
193    }
194}
195
196impl<'de> Deserialize<'de> for Value {
197    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
198    where
199        D: serde::Deserializer<'de>,
200    {
201        let json_value = JsonValue::deserialize(deserializer)?;
202
203        match json_value {
204            JsonValue::Object(map) => {
205                let type_ = map
206                    .get("type")
207                    .and_then(JsonValue::as_str)
208                    .ok_or_else(|| serde::de::Error::missing_field("type"))?;
209
210                match type_ {
211                    "mscalar" => {
212                        let size_in_bits = map
213                            .get("size_in_bits")
214                            .and_then(JsonValue::as_u64)
215                            .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
216                        Ok(Value::MScalar {
217                            size_in_bits: size_in_bits as usize,
218                        })
219                    }
220                    "mfloat" => {
221                        let size_in_bits = map
222                            .get("size_in_bits")
223                            .and_then(JsonValue::as_u64)
224                            .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
225                        Ok(Value::MFloat {
226                            size_in_bits: size_in_bits as usize,
227                        })
228                    }
229                    "mbool" => Ok(Value::MBool),
230                    "u8" => Ok(Value::Scalar { size_in_bits: 8 }),
231                    "u16" => Ok(Value::Scalar { size_in_bits: 16 }),
232                    "u32" => Ok(Value::Scalar { size_in_bits: 32 }),
233                    "u64" => Ok(Value::Scalar { size_in_bits: 64 }),
234                    "u128" => Ok(Value::Scalar { size_in_bits: 128 }),
235                    "scalar" => {
236                        let size_in_bits = map
237                            .get("size_in_bits")
238                            .and_then(JsonValue::as_u64)
239                            .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
240                        Ok(Value::Scalar {
241                            size_in_bits: size_in_bits as usize,
242                        })
243                    }
244                    "ciphertext" => {
245                        let size_in_bits = map
246                            .get("size_in_bits")
247                            .and_then(JsonValue::as_u64)
248                            .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
249                        Ok(Value::Ciphertext {
250                            size_in_bits: size_in_bits as usize,
251                        })
252                    }
253                    "public_key" => {
254                        let size_in_bits = map
255                            .get("size_in_bits")
256                            .and_then(JsonValue::as_u64)
257                            .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
258                        Ok(Value::PublicKey {
259                            size_in_bits: size_in_bits as usize,
260                        })
261                    }
262                    "float" => {
263                        let size_in_bits = map
264                            .get("size_in_bits")
265                            .and_then(JsonValue::as_u64)
266                            .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
267                        Ok(Value::Float {
268                            size_in_bits: size_in_bits as usize,
269                        })
270                    }
271                    "bool" => Ok(Value::Bool),
272                    "array" | "tuple" | "struct" => {
273                        let content = map
274                            .get("content")
275                            .ok_or_else(|| serde::de::Error::missing_field("content"))?;
276                        let vec: Vec<Value> =
277                            serde_json::from_value(content.clone()).map_err(|e| {
278                                serde::de::Error::custom(format!(
279                                    "Failed to deserialize content: {}",
280                                    e
281                                ))
282                            })?;
283                        match type_ {
284                            "array" => Ok(Value::Array(vec)),
285                            "tuple" => Ok(Value::Tuple(vec)),
286                            "struct" => Ok(Value::Struct(vec)),
287                            _ => unreachable!(),
288                        }
289                    }
290                    _ => Err(serde::de::Error::unknown_variant(
291                        type_,
292                        &[
293                            "mscalar",
294                            "mfloat",
295                            "mbool",
296                            "u8",
297                            "u16",
298                            "u32",
299                            "u64",
300                            "u128",
301                            "scalar",
302                            "float",
303                            "bool",
304                            "array",
305                            "tuple",
306                            "struct",
307                            "ciphertext",
308                            "public_key",
309                        ],
310                    )),
311                }
312            }
313            _ => Err(serde::de::Error::invalid_type(
314                serde::de::Unexpected::Other("non-object"),
315                &"object",
316            )),
317        }
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use serde_json::json;
325
326    #[test]
327    fn test_mscalar_serialization() {
328        let value = Value::MScalar { size_in_bits: 32 };
329        let serialized = serde_json::to_value(value).unwrap();
330        assert_eq!(
331            serialized,
332            json!({
333                "type": "mscalar",
334                "size_in_bits": 32
335            })
336        );
337    }
338
339    #[test]
340    fn test_mbool_serialization() {
341        let value = Value::MBool;
342        let serialized = serde_json::to_value(value).unwrap();
343        assert_eq!(
344            serialized,
345            json!({
346                "type": "mbool"
347            })
348        );
349    }
350
351    #[test]
352    fn test_bool_serialization() {
353        let value = Value::Bool;
354        let serialized = serde_json::to_value(value).unwrap();
355        assert_eq!(
356            serialized,
357            json!({
358                "type": "bool"
359            })
360        );
361    }
362
363    #[test]
364    fn test_array_serialization() {
365        let value = Value::Array(vec![Value::Scalar { size_in_bits: 60 }, Value::Bool]);
366        let serialized = serde_json::to_value(value).unwrap();
367        assert_eq!(
368            serialized,
369            json!({
370                "type": "array",
371                "content": [
372                    {
373                        "type": "scalar",
374                        "size_in_bits": 60
375                    },
376                    {
377                        "type": "bool"
378                    }
379                ]
380            })
381        );
382    }
383
384    #[test]
385    fn test_nested_structure_serialization() {
386        let value = Value::Struct(vec![
387            Value::Tuple(vec![Value::MScalar { size_in_bits: 32 }, Value::MBool]),
388            Value::Array(vec![Value::Scalar { size_in_bits: 64 }, Value::Bool]),
389        ]);
390        let serialized = serde_json::to_string(&value).unwrap();
391        let deserialized: Value = serde_json::from_str(&serialized).unwrap();
392        assert_eq!(value, deserialized);
393    }
394
395    #[test]
396    fn test_mscalar_deserialization() {
397        let json = r#"{"type": "mscalar", "size_in_bits": 32}"#;
398        let deserialized: Value = serde_json::from_str(json).unwrap();
399        assert_eq!(deserialized, Value::MScalar { size_in_bits: 32 });
400    }
401
402    #[test]
403    fn test_array_deserialization() {
404        let json = r#"
405        {
406            "type": "array",
407            "content": [
408                {"type": "scalar", "size_in_bits": 64},
409                {"type": "bool"}
410            ]
411        }"#;
412        let deserialized: Value = serde_json::from_str(json).unwrap();
413        assert_eq!(
414            deserialized,
415            Value::Array(vec![Value::Scalar { size_in_bits: 64 }, Value::Bool,])
416        );
417    }
418
419    #[test]
420    fn test_invalid_type_deserialization() {
421        let json = r#"{"type": "invalid_type"}"#;
422        let result: Result<Value, _> = serde_json::from_str(json);
423        assert!(result.is_err());
424    }
425
426    #[test]
427    fn test_missing_size_in_bits_deserialization() {
428        let json = r#"{"type": "mscalar"}"#;
429        let result: Result<Value, _> = serde_json::from_str(json);
430        assert!(result.is_err());
431    }
432
433    #[test]
434    fn test_plaintext_type_serialization() {
435        // Test all standard numeric sizes
436        let test_cases = [
437            (Value::Scalar { size_in_bits: 8 }, "u8"),
438            (Value::Scalar { size_in_bits: 16 }, "u16"),
439            (Value::Scalar { size_in_bits: 32 }, "u32"),
440            (Value::Scalar { size_in_bits: 64 }, "u64"),
441            (Value::Scalar { size_in_bits: 128 }, "u128"),
442            // Test non-standard size falls back to scalar
443            (Value::Scalar { size_in_bits: 24 }, "scalar"),
444        ];
445
446        for (value, expected_type) in test_cases {
447            // Test serialization
448            let serialized = serde_json::to_value(&value).unwrap();
449            let expected = match &value {
450                Value::Scalar { size_in_bits } => json!({
451                    "type": expected_type,
452                    "size_in_bits": size_in_bits
453                }),
454                Value::Bool => json!({
455                    "type": expected_type
456                }),
457                _ => unreachable!(),
458            };
459            assert_eq!(serialized, expected);
460
461            // Test deserialization
462            let json = serde_json::to_string(&expected).unwrap();
463            let deserialized: Value = serde_json::from_str(&json).unwrap();
464            assert_eq!(deserialized, value);
465        }
466    }
467}