Skip to main content

arcis_interface/
json.rs

1use crate::types::{CircuitInterface, ScalarKind, Value};
2use serde::{
3    de::{self, MapAccess, Visitor},
4    Deserialize,
5    Deserializer,
6    Serialize,
7};
8use serde_json::{json, Value as JsonValue};
9use std::fmt;
10
11#[derive(Serialize, Deserialize, Debug)]
12pub struct ManticoreInterface {
13    pub inputs: Vec<String>,
14    pub outputs: Vec<String>,
15}
16
17impl ManticoreInterface {
18    pub fn new(inputs: Vec<String>, outputs: Vec<String>) -> Self {
19        Self { inputs, outputs }
20    }
21
22    pub fn serialize(&self) -> Result<String, serde_json::Error> {
23        serde_json::to_string(self)
24    }
25
26    pub fn from_json(input: &str) -> Result<Self, serde_json::Error> {
27        serde_json::from_str(input)
28    }
29}
30
31impl<'de> Deserialize<'de> for CircuitInterface {
32    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
33    where
34        D: Deserializer<'de>,
35    {
36        enum Field {
37            Name,
38            Inputs,
39            Outputs,
40        }
41
42        impl<'de> Deserialize<'de> for Field {
43            fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
44            where
45                D: Deserializer<'de>,
46            {
47                struct FieldVisitor;
48
49                impl Visitor<'_> for FieldVisitor {
50                    type Value = Field;
51
52                    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
53                        formatter.write_str("`name`, `inputs`, or `outputs`")
54                    }
55
56                    fn visit_str<E>(self, value: &str) -> Result<Field, E>
57                    where
58                        E: de::Error,
59                    {
60                        match value {
61                            "name" => Ok(Field::Name),
62                            "inputs" => Ok(Field::Inputs),
63                            "outputs" => Ok(Field::Outputs),
64                            _ => Err(de::Error::unknown_field(value, FIELDS)),
65                        }
66                    }
67                }
68
69                deserializer.deserialize_identifier(FieldVisitor)
70            }
71        }
72
73        struct CircuitInterfaceVisitor;
74
75        impl<'de> Visitor<'de> for CircuitInterfaceVisitor {
76            type Value = CircuitInterface;
77
78            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
79                formatter.write_str("struct CircuitInterface")
80            }
81
82            fn visit_map<V>(self, mut map: V) -> Result<CircuitInterface, V::Error>
83            where
84                V: MapAccess<'de>,
85            {
86                let mut name = None;
87                let mut inputs = None;
88                let mut outputs = None;
89
90                while let Some(key) = map.next_key()? {
91                    match key {
92                        Field::Name => {
93                            if name.is_some() {
94                                return Err(de::Error::duplicate_field("name"));
95                            }
96                            name = Some(map.next_value()?);
97                        }
98                        Field::Inputs => {
99                            if inputs.is_some() {
100                                return Err(de::Error::duplicate_field("inputs"));
101                            }
102                            inputs = Some(map.next_value()?);
103                        }
104                        Field::Outputs => {
105                            if outputs.is_some() {
106                                return Err(de::Error::duplicate_field("outputs"));
107                            }
108                            outputs = Some(map.next_value()?);
109                        }
110                    }
111                }
112
113                let name = name.ok_or_else(|| de::Error::missing_field("name"))?;
114                let inputs = inputs.ok_or_else(|| de::Error::missing_field("inputs"))?;
115                let outputs = outputs.ok_or_else(|| de::Error::missing_field("output"))?;
116
117                Ok(CircuitInterface {
118                    name,
119                    inputs,
120                    outputs,
121                })
122            }
123        }
124
125        const FIELDS: &[&str] = &["name", "inputs", "outputs"];
126        deserializer.deserialize_struct("CircuitInterface", FIELDS, CircuitInterfaceVisitor)
127    }
128}
129
130impl Value {
131    /// Maps scalar size and kind to JSON type name.
132    ///
133    /// Returns standard Rust type names for common sizes (e.g., "u32", "i64"),
134    /// or generic names for non-standard sizes ("scalar" for unsigned, "signed_integer" for
135    /// signed).
136    fn get_scalar_type_name(size_in_bits: usize, kind: ScalarKind) -> &'static str {
137        match kind {
138            ScalarKind::Unsigned => match size_in_bits {
139                8 => "u8",
140                16 => "u16",
141                32 => "u32",
142                64 => "u64",
143                128 => "u128",
144                _ => "scalar",
145            },
146            ScalarKind::Signed => match size_in_bits {
147                8 => "i8",
148                16 => "i16",
149                32 => "i32",
150                64 => "i64",
151                128 => "i128",
152                _ => "signed_integer",
153            },
154        }
155    }
156}
157
158impl Serialize for Value {
159    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
160    where
161        S: serde::Serializer,
162    {
163        let json_value = match self {
164            Value::MScalar { size_in_bits } => json!({
165                "type": "mscalar",
166                "size_in_bits": size_in_bits
167            }),
168            Value::MFloat { size_in_bits } => json!({
169                "type": "mfloat",
170                "size_in_bits": size_in_bits
171            }),
172            Value::MBool => json!({
173                "type": "mbool"
174            }),
175            Value::Scalar { size_in_bits, kind } => json!({
176                "type": Value::get_scalar_type_name(*size_in_bits, *kind),
177                "size_in_bits": size_in_bits
178            }),
179            Value::Float { size_in_bits } => json!({
180                "type": "float",
181                "size_in_bits": size_in_bits
182            }),
183            Value::Bool => json!({
184                "type": "bool"
185            }),
186            Value::Ciphertext { size_in_bits } => json!({
187                "type": "ciphertext",
188                "size_in_bits": size_in_bits
189            }),
190            Value::ArcisX25519Pubkey => json!({
191                "type": "arcis_x25519_pubkey",
192            }),
193            Value::Point => json!({
194                "type": "point"
195            }),
196            Value::Array(vec) => json!({
197                "type": "array",
198                "content": vec
199            }),
200            Value::Tuple(vec) => json!({
201                "type": "tuple",
202                "content": vec
203            }),
204            Value::Struct(vec) => json!({
205                "type": "struct",
206                "content": vec
207            }),
208        };
209        json_value.serialize(serializer)
210    }
211}
212
213impl<'de> Deserialize<'de> for Value {
214    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
215    where
216        D: serde::Deserializer<'de>,
217    {
218        let json_value = JsonValue::deserialize(deserializer)?;
219
220        match json_value {
221            JsonValue::Object(map) => {
222                let type_ = map
223                    .get("type")
224                    .and_then(JsonValue::as_str)
225                    .ok_or_else(|| serde::de::Error::missing_field("type"))?;
226
227                match type_ {
228                    "mscalar" => {
229                        let size_in_bits = map
230                            .get("size_in_bits")
231                            .and_then(JsonValue::as_u64)
232                            .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
233                        Ok(Value::MScalar {
234                            size_in_bits: size_in_bits as usize,
235                        })
236                    }
237                    "mfloat" => {
238                        let size_in_bits = map
239                            .get("size_in_bits")
240                            .and_then(JsonValue::as_u64)
241                            .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
242                        Ok(Value::MFloat {
243                            size_in_bits: size_in_bits as usize,
244                        })
245                    }
246                    "mbool" => Ok(Value::MBool),
247                    "u8" => Ok(Value::Scalar {
248                        size_in_bits: 8,
249                        kind: ScalarKind::Unsigned,
250                    }),
251                    "u16" => Ok(Value::Scalar {
252                        size_in_bits: 16,
253                        kind: ScalarKind::Unsigned,
254                    }),
255                    "u32" => Ok(Value::Scalar {
256                        size_in_bits: 32,
257                        kind: ScalarKind::Unsigned,
258                    }),
259                    "u64" => Ok(Value::Scalar {
260                        size_in_bits: 64,
261                        kind: ScalarKind::Unsigned,
262                    }),
263                    "u128" => Ok(Value::Scalar {
264                        size_in_bits: 128,
265                        kind: ScalarKind::Unsigned,
266                    }),
267                    "i8" => Ok(Value::Scalar {
268                        size_in_bits: 8,
269                        kind: ScalarKind::Signed,
270                    }),
271                    "i16" => Ok(Value::Scalar {
272                        size_in_bits: 16,
273                        kind: ScalarKind::Signed,
274                    }),
275                    "i32" => Ok(Value::Scalar {
276                        size_in_bits: 32,
277                        kind: ScalarKind::Signed,
278                    }),
279                    "i64" => Ok(Value::Scalar {
280                        size_in_bits: 64,
281                        kind: ScalarKind::Signed,
282                    }),
283                    "i128" => Ok(Value::Scalar {
284                        size_in_bits: 128,
285                        kind: ScalarKind::Signed,
286                    }),
287                    "scalar" => {
288                        let size_in_bits = map
289                            .get("size_in_bits")
290                            .and_then(JsonValue::as_u64)
291                            .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
292                        Ok(Value::Scalar {
293                            size_in_bits: size_in_bits as usize,
294                            kind: ScalarKind::Unsigned,
295                        })
296                    }
297                    "signed_integer" => {
298                        let size_in_bits = map
299                            .get("size_in_bits")
300                            .and_then(JsonValue::as_u64)
301                            .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
302                        Ok(Value::Scalar {
303                            size_in_bits: size_in_bits as usize,
304                            kind: ScalarKind::Signed,
305                        })
306                    }
307                    "ciphertext" => {
308                        let size_in_bits = map
309                            .get("size_in_bits")
310                            .and_then(JsonValue::as_u64)
311                            .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
312                        Ok(Value::Ciphertext {
313                            size_in_bits: size_in_bits as usize,
314                        })
315                    }
316                    "arcis_x25519_pubkey" => Ok(Value::ArcisX25519Pubkey),
317                    "point" => Ok(Value::Point),
318                    "float" => {
319                        let size_in_bits = map
320                            .get("size_in_bits")
321                            .and_then(JsonValue::as_u64)
322                            .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
323                        Ok(Value::Float {
324                            size_in_bits: size_in_bits as usize,
325                        })
326                    }
327                    "bool" => Ok(Value::Bool),
328                    "array" | "tuple" | "struct" => {
329                        let content = map
330                            .get("content")
331                            .ok_or_else(|| serde::de::Error::missing_field("content"))?;
332                        let vec: Vec<Value> =
333                            serde_json::from_value(content.clone()).map_err(|e| {
334                                serde::de::Error::custom(format!(
335                                    "Failed to deserialize content: {}",
336                                    e
337                                ))
338                            })?;
339                        match type_ {
340                            "array" => Ok(Value::Array(vec)),
341                            "tuple" => Ok(Value::Tuple(vec)),
342                            "struct" => Ok(Value::Struct(vec)),
343                            _ => unreachable!(),
344                        }
345                    }
346                    _ => Err(serde::de::Error::unknown_variant(
347                        type_,
348                        &[
349                            "mscalar",
350                            "mfloat",
351                            "mbool",
352                            "u8",
353                            "u16",
354                            "u32",
355                            "u64",
356                            "u128",
357                            "i8",
358                            "i16",
359                            "i32",
360                            "i64",
361                            "i128",
362                            "scalar",
363                            "signed_integer",
364                            "float",
365                            "bool",
366                            "array",
367                            "tuple",
368                            "struct",
369                            "ciphertext",
370                            "arcis_x25519_pubkey",
371                            "point",
372                        ],
373                    )),
374                }
375            }
376            _ => Err(serde::de::Error::invalid_type(
377                serde::de::Unexpected::Other("non-object"),
378                &"object",
379            )),
380        }
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387    use serde_json::json;
388
389    #[test]
390    fn test_mscalar_serialization() {
391        let value = Value::MScalar { size_in_bits: 32 };
392        let serialized = serde_json::to_value(value).unwrap();
393        assert_eq!(
394            serialized,
395            json!({
396                "type": "mscalar",
397                "size_in_bits": 32
398            })
399        );
400    }
401
402    #[test]
403    fn test_mbool_serialization() {
404        let value = Value::MBool;
405        let serialized = serde_json::to_value(value).unwrap();
406        assert_eq!(
407            serialized,
408            json!({
409                "type": "mbool"
410            })
411        );
412    }
413
414    #[test]
415    fn test_bool_serialization() {
416        let value = Value::Bool;
417        let serialized = serde_json::to_value(value).unwrap();
418        assert_eq!(
419            serialized,
420            json!({
421                "type": "bool"
422            })
423        );
424    }
425
426    #[test]
427    fn test_array_serialization() {
428        let value = Value::Array(vec![
429            Value::Scalar {
430                size_in_bits: 60,
431                kind: ScalarKind::Unsigned,
432            },
433            Value::Bool,
434        ]);
435        let serialized = serde_json::to_value(value).unwrap();
436        assert_eq!(
437            serialized,
438            json!({
439                "type": "array",
440                "content": [
441                    {
442                        "type": "scalar",
443                        "size_in_bits": 60
444                    },
445                    {
446                        "type": "bool"
447                    }
448                ]
449            })
450        );
451    }
452
453    #[test]
454    fn test_nested_structure_serialization() {
455        let value = Value::Struct(vec![
456            Value::Tuple(vec![Value::MScalar { size_in_bits: 32 }, Value::MBool]),
457            Value::Array(vec![
458                Value::Scalar {
459                    size_in_bits: 64,
460                    kind: ScalarKind::Unsigned,
461                },
462                Value::Bool,
463            ]),
464        ]);
465        let serialized = serde_json::to_string(&value).unwrap();
466        let deserialized: Value = serde_json::from_str(&serialized).unwrap();
467        assert_eq!(value, deserialized);
468    }
469
470    #[test]
471    fn test_mscalar_deserialization() {
472        let json = r#"{"type": "mscalar", "size_in_bits": 32}"#;
473        let deserialized: Value = serde_json::from_str(json).unwrap();
474        assert_eq!(deserialized, Value::MScalar { size_in_bits: 32 });
475    }
476
477    #[test]
478    fn test_array_deserialization() {
479        let json = r#"
480        {
481            "type": "array",
482            "content": [
483                {"type": "scalar", "size_in_bits": 64},
484                {"type": "bool"}
485            ]
486        }"#;
487        let deserialized: Value = serde_json::from_str(json).unwrap();
488        assert_eq!(
489            deserialized,
490            Value::Array(vec![
491                Value::Scalar {
492                    size_in_bits: 64,
493                    kind: ScalarKind::Unsigned,
494                },
495                Value::Bool,
496            ])
497        );
498    }
499
500    #[test]
501    fn test_invalid_type_deserialization() {
502        let json = r#"{"type": "invalid_type"}"#;
503        let result: Result<Value, _> = serde_json::from_str(json);
504        assert!(result.is_err());
505    }
506
507    #[test]
508    fn test_missing_size_in_bits_deserialization() {
509        let json = r#"{"type": "mscalar"}"#;
510        let result: Result<Value, _> = serde_json::from_str(json);
511        assert!(result.is_err());
512    }
513
514    #[test]
515    fn test_plaintext_type_serialization() {
516        // Test all standard numeric sizes (unsigned)
517        let test_cases = [
518            (
519                Value::Scalar {
520                    size_in_bits: 8,
521                    kind: ScalarKind::Unsigned,
522                },
523                "u8",
524            ),
525            (
526                Value::Scalar {
527                    size_in_bits: 16,
528                    kind: ScalarKind::Unsigned,
529                },
530                "u16",
531            ),
532            (
533                Value::Scalar {
534                    size_in_bits: 32,
535                    kind: ScalarKind::Unsigned,
536                },
537                "u32",
538            ),
539            (
540                Value::Scalar {
541                    size_in_bits: 64,
542                    kind: ScalarKind::Unsigned,
543                },
544                "u64",
545            ),
546            (
547                Value::Scalar {
548                    size_in_bits: 128,
549                    kind: ScalarKind::Unsigned,
550                },
551                "u128",
552            ),
553            // Test non-standard size falls back to scalar
554            (
555                Value::Scalar {
556                    size_in_bits: 24,
557                    kind: ScalarKind::Unsigned,
558                },
559                "scalar",
560            ),
561        ];
562
563        for (value, expected_type) in test_cases {
564            // Test serialization
565            let serialized = serde_json::to_value(&value).unwrap();
566            let expected = match &value {
567                Value::Scalar { size_in_bits, .. } => json!({
568                    "type": expected_type,
569                    "size_in_bits": size_in_bits
570                }),
571                Value::Bool => json!({
572                    "type": expected_type
573                }),
574                _ => unreachable!(),
575            };
576            assert_eq!(serialized, expected);
577
578            // Test deserialization
579            let json = serde_json::to_string(&expected).unwrap();
580            let deserialized: Value = serde_json::from_str(&json).unwrap();
581            assert_eq!(deserialized, value);
582        }
583    }
584
585    #[test]
586    fn test_signed_integer_serialization() {
587        let test_cases = [
588            (
589                Value::Scalar {
590                    size_in_bits: 8,
591                    kind: ScalarKind::Signed,
592                },
593                "i8",
594            ),
595            (
596                Value::Scalar {
597                    size_in_bits: 16,
598                    kind: ScalarKind::Signed,
599                },
600                "i16",
601            ),
602            (
603                Value::Scalar {
604                    size_in_bits: 32,
605                    kind: ScalarKind::Signed,
606                },
607                "i32",
608            ),
609            (
610                Value::Scalar {
611                    size_in_bits: 64,
612                    kind: ScalarKind::Signed,
613                },
614                "i64",
615            ),
616            (
617                Value::Scalar {
618                    size_in_bits: 128,
619                    kind: ScalarKind::Signed,
620                },
621                "i128",
622            ),
623            (
624                Value::Scalar {
625                    size_in_bits: 24,
626                    kind: ScalarKind::Signed,
627                },
628                "signed_integer",
629            ),
630        ];
631
632        for (value, expected_type) in test_cases {
633            // Test serialization
634            let serialized = serde_json::to_value(&value).unwrap();
635            let expected = json!({
636                "type": expected_type,
637                "size_in_bits": match &value {
638                    Value::Scalar { size_in_bits, .. } => size_in_bits,
639                    _ => unreachable!(),
640                }
641            });
642            assert_eq!(serialized, expected);
643
644            // Test deserialization
645            let json = serde_json::to_string(&expected).unwrap();
646            let deserialized: Value = serde_json::from_str(&json).unwrap();
647            assert_eq!(deserialized, value);
648        }
649    }
650
651    #[test]
652    fn test_backward_compatibility() {
653        // Test that old JSON without 'kind' field defaults to unsigned
654        let old_json_cases = [
655            (r#"{"type": "u8", "size_in_bits": 8}"#, 8),
656            (r#"{"type": "u16", "size_in_bits": 16}"#, 16),
657            (r#"{"type": "u32", "size_in_bits": 32}"#, 32),
658            (r#"{"type": "u64", "size_in_bits": 64}"#, 64),
659            (r#"{"type": "u128", "size_in_bits": 128}"#, 128),
660            (r#"{"type": "scalar", "size_in_bits": 24}"#, 24),
661        ];
662
663        for (json, expected_size) in old_json_cases {
664            let deserialized: Value = serde_json::from_str(json).unwrap();
665            assert_eq!(
666                deserialized,
667                Value::Scalar {
668                    size_in_bits: expected_size,
669                    kind: ScalarKind::Unsigned
670                }
671            );
672        }
673    }
674
675    #[test]
676    fn test_mixed_signed_unsigned_serialization() {
677        // Test that signed and unsigned integers serialize differently
678        let unsigned = Value::Scalar {
679            size_in_bits: 32,
680            kind: ScalarKind::Unsigned,
681        };
682        let signed = Value::Scalar {
683            size_in_bits: 32,
684            kind: ScalarKind::Signed,
685        };
686
687        let unsigned_json = serde_json::to_value(&unsigned).unwrap();
688        let signed_json = serde_json::to_value(&signed).unwrap();
689
690        assert_eq!(
691            unsigned_json,
692            json!({
693                "type": "u32",
694                "size_in_bits": 32
695            })
696        );
697        assert_eq!(
698            signed_json,
699            json!({
700                "type": "i32",
701                "size_in_bits": 32
702            })
703        );
704
705        // Ensure they're different
706        assert_ne!(unsigned_json, signed_json);
707    }
708}