1use crate::types::{CircuitInterface, Value};
2use serde::{
3 de::{self, MapAccess, Visitor},
4 Deserialize,
5 Deserializer,
6 Serialize,
7};
8use serde_json::{json, Value as JsonValue};
9use std::fmt;
10
11#[derive(Serialize, Deserialize, Debug)]
12pub struct ManticoreInterface {
13 pub inputs: Vec<String>,
14 pub outputs: Vec<String>,
15}
16
17impl ManticoreInterface {
18 pub fn new(inputs: Vec<String>, outputs: Vec<String>) -> Self {
19 Self { inputs, outputs }
20 }
21
22 pub fn serialize(&self) -> Result<String, serde_json::Error> {
23 serde_json::to_string(self)
24 }
25
26 pub fn from_json(input: &str) -> Result<Self, serde_json::Error> {
27 serde_json::from_str(input)
28 }
29}
30
31impl<'de> Deserialize<'de> for CircuitInterface {
32 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
33 where
34 D: Deserializer<'de>,
35 {
36 enum Field {
37 Name,
38 Inputs,
39 Outputs,
40 }
41
42 impl<'de> Deserialize<'de> for Field {
43 fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
44 where
45 D: Deserializer<'de>,
46 {
47 struct FieldVisitor;
48
49 impl Visitor<'_> for FieldVisitor {
50 type Value = Field;
51
52 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
53 formatter.write_str("`name`, `inputs`, or `outputs`")
54 }
55
56 fn visit_str<E>(self, value: &str) -> Result<Field, E>
57 where
58 E: de::Error,
59 {
60 match value {
61 "name" => Ok(Field::Name),
62 "inputs" => Ok(Field::Inputs),
63 "outputs" => Ok(Field::Outputs),
64 _ => Err(de::Error::unknown_field(value, FIELDS)),
65 }
66 }
67 }
68
69 deserializer.deserialize_identifier(FieldVisitor)
70 }
71 }
72
73 struct CircuitInterfaceVisitor;
74
75 impl<'de> Visitor<'de> for CircuitInterfaceVisitor {
76 type Value = CircuitInterface;
77
78 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
79 formatter.write_str("struct CircuitInterface")
80 }
81
82 fn visit_map<V>(self, mut map: V) -> Result<CircuitInterface, V::Error>
83 where
84 V: MapAccess<'de>,
85 {
86 let mut name = None;
87 let mut inputs = None;
88 let mut outputs = None;
89
90 while let Some(key) = map.next_key()? {
91 match key {
92 Field::Name => {
93 if name.is_some() {
94 return Err(de::Error::duplicate_field("name"));
95 }
96 name = Some(map.next_value()?);
97 }
98 Field::Inputs => {
99 if inputs.is_some() {
100 return Err(de::Error::duplicate_field("inputs"));
101 }
102 inputs = Some(map.next_value()?);
103 }
104 Field::Outputs => {
105 if outputs.is_some() {
106 return Err(de::Error::duplicate_field("outputs"));
107 }
108 outputs = Some(map.next_value()?);
109 }
110 }
111 }
112
113 let name = name.ok_or_else(|| de::Error::missing_field("name"))?;
114 let inputs = inputs.ok_or_else(|| de::Error::missing_field("inputs"))?;
115 let outputs = outputs.ok_or_else(|| de::Error::missing_field("output"))?;
116
117 Ok(CircuitInterface {
118 name,
119 inputs,
120 outputs,
121 })
122 }
123 }
124
125 const FIELDS: &[&str] = &["name", "inputs", "outputs"];
126 deserializer.deserialize_struct("CircuitInterface", FIELDS, CircuitInterfaceVisitor)
127 }
128}
129
130impl Value {
131 fn get_scalar_type_name(size_in_bits: usize) -> &'static str {
132 match size_in_bits {
133 8 => "u8",
134 16 => "u16",
135 32 => "u32",
136 64 => "u64",
137 128 => "u128",
138 _ => "scalar", }
140 }
141}
142
143impl Serialize for Value {
144 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
145 where
146 S: serde::Serializer,
147 {
148 let json_value = match self {
149 Value::MScalar { size_in_bits } => json!({
150 "type": "mscalar",
151 "size_in_bits": size_in_bits
152 }),
153 Value::MFloat { size_in_bits } => json!({
154 "type": "mfloat",
155 "size_in_bits": size_in_bits
156 }),
157 Value::MBool => json!({
158 "type": "mbool"
159 }),
160 Value::Scalar { size_in_bits } => json!({
161 "type": Value::get_scalar_type_name(*size_in_bits),
162 "size_in_bits": size_in_bits
163 }),
164 Value::Float { size_in_bits } => json!({
165 "type": "float",
166 "size_in_bits": size_in_bits
167 }),
168 Value::Bool => json!({
169 "type": "bool"
170 }),
171 Value::Ciphertext { size_in_bits } => json!({
172 "type": "ciphertext",
173 "size_in_bits": size_in_bits
174 }),
175 Value::PublicKey { size_in_bits } => json!({
176 "type": "public_key",
177 "size_in_bits": size_in_bits
178 }),
179 Value::Point => json!({
180 "type": "point"
181 }),
182 Value::Array(vec) => json!({
183 "type": "array",
184 "content": vec
185 }),
186 Value::Tuple(vec) => json!({
187 "type": "tuple",
188 "content": vec
189 }),
190 Value::Struct(vec) => json!({
191 "type": "struct",
192 "content": vec
193 }),
194 };
195 json_value.serialize(serializer)
196 }
197}
198
199impl<'de> Deserialize<'de> for Value {
200 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
201 where
202 D: serde::Deserializer<'de>,
203 {
204 let json_value = JsonValue::deserialize(deserializer)?;
205
206 match json_value {
207 JsonValue::Object(map) => {
208 let type_ = map
209 .get("type")
210 .and_then(JsonValue::as_str)
211 .ok_or_else(|| serde::de::Error::missing_field("type"))?;
212
213 match type_ {
214 "mscalar" => {
215 let size_in_bits = map
216 .get("size_in_bits")
217 .and_then(JsonValue::as_u64)
218 .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
219 Ok(Value::MScalar {
220 size_in_bits: size_in_bits as usize,
221 })
222 }
223 "mfloat" => {
224 let size_in_bits = map
225 .get("size_in_bits")
226 .and_then(JsonValue::as_u64)
227 .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
228 Ok(Value::MFloat {
229 size_in_bits: size_in_bits as usize,
230 })
231 }
232 "mbool" => Ok(Value::MBool),
233 "u8" => Ok(Value::Scalar { size_in_bits: 8 }),
234 "u16" => Ok(Value::Scalar { size_in_bits: 16 }),
235 "u32" => Ok(Value::Scalar { size_in_bits: 32 }),
236 "u64" => Ok(Value::Scalar { size_in_bits: 64 }),
237 "u128" => Ok(Value::Scalar { size_in_bits: 128 }),
238 "scalar" => {
239 let size_in_bits = map
240 .get("size_in_bits")
241 .and_then(JsonValue::as_u64)
242 .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
243 Ok(Value::Scalar {
244 size_in_bits: size_in_bits as usize,
245 })
246 }
247 "ciphertext" => {
248 let size_in_bits = map
249 .get("size_in_bits")
250 .and_then(JsonValue::as_u64)
251 .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
252 Ok(Value::Ciphertext {
253 size_in_bits: size_in_bits as usize,
254 })
255 }
256 "public_key" => {
257 let size_in_bits = map
258 .get("size_in_bits")
259 .and_then(JsonValue::as_u64)
260 .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
261 Ok(Value::PublicKey {
262 size_in_bits: size_in_bits as usize,
263 })
264 }
265 "float" => {
266 let size_in_bits = map
267 .get("size_in_bits")
268 .and_then(JsonValue::as_u64)
269 .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
270 Ok(Value::Float {
271 size_in_bits: size_in_bits as usize,
272 })
273 }
274 "bool" => Ok(Value::Bool),
275 "array" | "tuple" | "struct" => {
276 let content = map
277 .get("content")
278 .ok_or_else(|| serde::de::Error::missing_field("content"))?;
279 let vec: Vec<Value> =
280 serde_json::from_value(content.clone()).map_err(|e| {
281 serde::de::Error::custom(format!(
282 "Failed to deserialize content: {}",
283 e
284 ))
285 })?;
286 match type_ {
287 "array" => Ok(Value::Array(vec)),
288 "tuple" => Ok(Value::Tuple(vec)),
289 "struct" => Ok(Value::Struct(vec)),
290 _ => unreachable!(),
291 }
292 }
293 _ => Err(serde::de::Error::unknown_variant(
294 type_,
295 &[
296 "mscalar",
297 "mfloat",
298 "mbool",
299 "u8",
300 "u16",
301 "u32",
302 "u64",
303 "u128",
304 "scalar",
305 "float",
306 "bool",
307 "array",
308 "tuple",
309 "struct",
310 "ciphertext",
311 "public_key",
312 ],
313 )),
314 }
315 }
316 _ => Err(serde::de::Error::invalid_type(
317 serde::de::Unexpected::Other("non-object"),
318 &"object",
319 )),
320 }
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327 use serde_json::json;
328
329 #[test]
330 fn test_mscalar_serialization() {
331 let value = Value::MScalar { size_in_bits: 32 };
332 let serialized = serde_json::to_value(value).unwrap();
333 assert_eq!(
334 serialized,
335 json!({
336 "type": "mscalar",
337 "size_in_bits": 32
338 })
339 );
340 }
341
342 #[test]
343 fn test_mbool_serialization() {
344 let value = Value::MBool;
345 let serialized = serde_json::to_value(value).unwrap();
346 assert_eq!(
347 serialized,
348 json!({
349 "type": "mbool"
350 })
351 );
352 }
353
354 #[test]
355 fn test_bool_serialization() {
356 let value = Value::Bool;
357 let serialized = serde_json::to_value(value).unwrap();
358 assert_eq!(
359 serialized,
360 json!({
361 "type": "bool"
362 })
363 );
364 }
365
366 #[test]
367 fn test_array_serialization() {
368 let value = Value::Array(vec![Value::Scalar { size_in_bits: 60 }, Value::Bool]);
369 let serialized = serde_json::to_value(value).unwrap();
370 assert_eq!(
371 serialized,
372 json!({
373 "type": "array",
374 "content": [
375 {
376 "type": "scalar",
377 "size_in_bits": 60
378 },
379 {
380 "type": "bool"
381 }
382 ]
383 })
384 );
385 }
386
387 #[test]
388 fn test_nested_structure_serialization() {
389 let value = Value::Struct(vec![
390 Value::Tuple(vec![Value::MScalar { size_in_bits: 32 }, Value::MBool]),
391 Value::Array(vec![Value::Scalar { size_in_bits: 64 }, Value::Bool]),
392 ]);
393 let serialized = serde_json::to_string(&value).unwrap();
394 let deserialized: Value = serde_json::from_str(&serialized).unwrap();
395 assert_eq!(value, deserialized);
396 }
397
398 #[test]
399 fn test_mscalar_deserialization() {
400 let json = r#"{"type": "mscalar", "size_in_bits": 32}"#;
401 let deserialized: Value = serde_json::from_str(json).unwrap();
402 assert_eq!(deserialized, Value::MScalar { size_in_bits: 32 });
403 }
404
405 #[test]
406 fn test_array_deserialization() {
407 let json = r#"
408 {
409 "type": "array",
410 "content": [
411 {"type": "scalar", "size_in_bits": 64},
412 {"type": "bool"}
413 ]
414 }"#;
415 let deserialized: Value = serde_json::from_str(json).unwrap();
416 assert_eq!(
417 deserialized,
418 Value::Array(vec![Value::Scalar { size_in_bits: 64 }, Value::Bool,])
419 );
420 }
421
422 #[test]
423 fn test_invalid_type_deserialization() {
424 let json = r#"{"type": "invalid_type"}"#;
425 let result: Result<Value, _> = serde_json::from_str(json);
426 assert!(result.is_err());
427 }
428
429 #[test]
430 fn test_missing_size_in_bits_deserialization() {
431 let json = r#"{"type": "mscalar"}"#;
432 let result: Result<Value, _> = serde_json::from_str(json);
433 assert!(result.is_err());
434 }
435
436 #[test]
437 fn test_plaintext_type_serialization() {
438 let test_cases = [
440 (Value::Scalar { size_in_bits: 8 }, "u8"),
441 (Value::Scalar { size_in_bits: 16 }, "u16"),
442 (Value::Scalar { size_in_bits: 32 }, "u32"),
443 (Value::Scalar { size_in_bits: 64 }, "u64"),
444 (Value::Scalar { size_in_bits: 128 }, "u128"),
445 (Value::Scalar { size_in_bits: 24 }, "scalar"),
447 ];
448
449 for (value, expected_type) in test_cases {
450 let serialized = serde_json::to_value(&value).unwrap();
452 let expected = match &value {
453 Value::Scalar { size_in_bits } => json!({
454 "type": expected_type,
455 "size_in_bits": size_in_bits
456 }),
457 Value::Bool => json!({
458 "type": expected_type
459 }),
460 _ => unreachable!(),
461 };
462 assert_eq!(serialized, expected);
463
464 let json = serde_json::to_string(&expected).unwrap();
466 let deserialized: Value = serde_json::from_str(&json).unwrap();
467 assert_eq!(deserialized, value);
468 }
469 }
470}