1use crate::types::{CircuitInterface, Value};
2use serde::{
3 de::{self, MapAccess, Visitor},
4 Deserialize,
5 Deserializer,
6 Serialize,
7};
8use serde_json::{json, Value as JsonValue};
9use std::fmt;
10
11#[derive(Serialize, Deserialize, Debug)]
12pub struct ManticoreInterface {
13 pub inputs: Vec<String>,
14 pub outputs: Vec<String>,
15}
16
17impl ManticoreInterface {
18 pub fn new(inputs: Vec<String>, outputs: Vec<String>) -> Self {
19 Self { inputs, outputs }
20 }
21
22 pub fn serialize(&self) -> Result<String, serde_json::Error> {
23 serde_json::to_string(self)
24 }
25
26 pub fn from_json(input: &str) -> Result<Self, serde_json::Error> {
27 serde_json::from_str(input)
28 }
29}
30
31impl<'de> Deserialize<'de> for CircuitInterface {
32 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
33 where
34 D: Deserializer<'de>,
35 {
36 enum Field {
37 Name,
38 Inputs,
39 Outputs,
40 }
41
42 impl<'de> Deserialize<'de> for Field {
43 fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
44 where
45 D: Deserializer<'de>,
46 {
47 struct FieldVisitor;
48
49 impl Visitor<'_> for FieldVisitor {
50 type Value = Field;
51
52 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
53 formatter.write_str("`name`, `inputs`, or `outputs`")
54 }
55
56 fn visit_str<E>(self, value: &str) -> Result<Field, E>
57 where
58 E: de::Error,
59 {
60 match value {
61 "name" => Ok(Field::Name),
62 "inputs" => Ok(Field::Inputs),
63 "outputs" => Ok(Field::Outputs),
64 _ => Err(de::Error::unknown_field(value, FIELDS)),
65 }
66 }
67 }
68
69 deserializer.deserialize_identifier(FieldVisitor)
70 }
71 }
72
73 struct CircuitInterfaceVisitor;
74
75 impl<'de> Visitor<'de> for CircuitInterfaceVisitor {
76 type Value = CircuitInterface;
77
78 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
79 formatter.write_str("struct CircuitInterface")
80 }
81
82 fn visit_map<V>(self, mut map: V) -> Result<CircuitInterface, V::Error>
83 where
84 V: MapAccess<'de>,
85 {
86 let mut name = None;
87 let mut inputs = None;
88 let mut outputs = None;
89
90 while let Some(key) = map.next_key()? {
91 match key {
92 Field::Name => {
93 if name.is_some() {
94 return Err(de::Error::duplicate_field("name"));
95 }
96 name = Some(map.next_value()?);
97 }
98 Field::Inputs => {
99 if inputs.is_some() {
100 return Err(de::Error::duplicate_field("inputs"));
101 }
102 inputs = Some(map.next_value()?);
103 }
104 Field::Outputs => {
105 if outputs.is_some() {
106 return Err(de::Error::duplicate_field("outputs"));
107 }
108 outputs = Some(map.next_value()?);
109 }
110 }
111 }
112
113 let name = name.ok_or_else(|| de::Error::missing_field("name"))?;
114 let inputs = inputs.ok_or_else(|| de::Error::missing_field("inputs"))?;
115 let outputs = outputs.ok_or_else(|| de::Error::missing_field("output"))?;
116
117 Ok(CircuitInterface {
118 name,
119 inputs,
120 outputs,
121 })
122 }
123 }
124
125 const FIELDS: &[&str] = &["name", "inputs", "outputs"];
126 deserializer.deserialize_struct("CircuitInterface", FIELDS, CircuitInterfaceVisitor)
127 }
128}
129
130impl Value {
131 fn get_scalar_type_name(size_in_bits: usize) -> &'static str {
132 match size_in_bits {
133 8 => "u8",
134 16 => "u16",
135 32 => "u32",
136 64 => "u64",
137 128 => "u128",
138 _ => "scalar", }
140 }
141}
142
143impl Serialize for Value {
144 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
145 where
146 S: serde::Serializer,
147 {
148 let json_value = match self {
149 Value::MScalar { size_in_bits } => json!({
150 "type": "mscalar",
151 "size_in_bits": size_in_bits
152 }),
153 Value::MFloat { size_in_bits } => json!({
154 "type": "mfloat",
155 "size_in_bits": size_in_bits
156 }),
157 Value::MBool => json!({
158 "type": "mbool"
159 }),
160 Value::Scalar { size_in_bits } => json!({
161 "type": Value::get_scalar_type_name(*size_in_bits),
162 "size_in_bits": size_in_bits
163 }),
164 Value::Float { size_in_bits } => json!({
165 "type": "float",
166 "size_in_bits": size_in_bits
167 }),
168 Value::Bool => json!({
169 "type": "bool"
170 }),
171 Value::Ciphertext { size_in_bits } => json!({
172 "type": "ciphertext",
173 "size_in_bits": size_in_bits
174 }),
175 Value::PublicKey { size_in_bits } => json!({
176 "type": "public_key",
177 "size_in_bits": size_in_bits
178 }),
179 Value::Array(vec) => json!({
180 "type": "array",
181 "content": vec
182 }),
183 Value::Tuple(vec) => json!({
184 "type": "tuple",
185 "content": vec
186 }),
187 Value::Struct(vec) => json!({
188 "type": "struct",
189 "content": vec
190 }),
191 };
192 json_value.serialize(serializer)
193 }
194}
195
196impl<'de> Deserialize<'de> for Value {
197 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
198 where
199 D: serde::Deserializer<'de>,
200 {
201 let json_value = JsonValue::deserialize(deserializer)?;
202
203 match json_value {
204 JsonValue::Object(map) => {
205 let type_ = map
206 .get("type")
207 .and_then(JsonValue::as_str)
208 .ok_or_else(|| serde::de::Error::missing_field("type"))?;
209
210 match type_ {
211 "mscalar" => {
212 let size_in_bits = map
213 .get("size_in_bits")
214 .and_then(JsonValue::as_u64)
215 .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
216 Ok(Value::MScalar {
217 size_in_bits: size_in_bits as usize,
218 })
219 }
220 "mfloat" => {
221 let size_in_bits = map
222 .get("size_in_bits")
223 .and_then(JsonValue::as_u64)
224 .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
225 Ok(Value::MFloat {
226 size_in_bits: size_in_bits as usize,
227 })
228 }
229 "mbool" => Ok(Value::MBool),
230 "u8" => Ok(Value::Scalar { size_in_bits: 8 }),
231 "u16" => Ok(Value::Scalar { size_in_bits: 16 }),
232 "u32" => Ok(Value::Scalar { size_in_bits: 32 }),
233 "u64" => Ok(Value::Scalar { size_in_bits: 64 }),
234 "u128" => Ok(Value::Scalar { size_in_bits: 128 }),
235 "scalar" => {
236 let size_in_bits = map
237 .get("size_in_bits")
238 .and_then(JsonValue::as_u64)
239 .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
240 Ok(Value::Scalar {
241 size_in_bits: size_in_bits as usize,
242 })
243 }
244 "ciphertext" => {
245 let size_in_bits = map
246 .get("size_in_bits")
247 .and_then(JsonValue::as_u64)
248 .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
249 Ok(Value::Ciphertext {
250 size_in_bits: size_in_bits as usize,
251 })
252 }
253 "public_key" => {
254 let size_in_bits = map
255 .get("size_in_bits")
256 .and_then(JsonValue::as_u64)
257 .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
258 Ok(Value::PublicKey {
259 size_in_bits: size_in_bits as usize,
260 })
261 }
262 "float" => {
263 let size_in_bits = map
264 .get("size_in_bits")
265 .and_then(JsonValue::as_u64)
266 .ok_or_else(|| serde::de::Error::missing_field("size_in_bits"))?;
267 Ok(Value::Float {
268 size_in_bits: size_in_bits as usize,
269 })
270 }
271 "bool" => Ok(Value::Bool),
272 "array" | "tuple" | "struct" => {
273 let content = map
274 .get("content")
275 .ok_or_else(|| serde::de::Error::missing_field("content"))?;
276 let vec: Vec<Value> =
277 serde_json::from_value(content.clone()).map_err(|e| {
278 serde::de::Error::custom(format!(
279 "Failed to deserialize content: {}",
280 e
281 ))
282 })?;
283 match type_ {
284 "array" => Ok(Value::Array(vec)),
285 "tuple" => Ok(Value::Tuple(vec)),
286 "struct" => Ok(Value::Struct(vec)),
287 _ => unreachable!(),
288 }
289 }
290 _ => Err(serde::de::Error::unknown_variant(
291 type_,
292 &[
293 "mscalar",
294 "mfloat",
295 "mbool",
296 "u8",
297 "u16",
298 "u32",
299 "u64",
300 "u128",
301 "scalar",
302 "float",
303 "bool",
304 "array",
305 "tuple",
306 "struct",
307 "ciphertext",
308 "public_key",
309 ],
310 )),
311 }
312 }
313 _ => Err(serde::de::Error::invalid_type(
314 serde::de::Unexpected::Other("non-object"),
315 &"object",
316 )),
317 }
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324 use serde_json::json;
325
326 #[test]
327 fn test_mscalar_serialization() {
328 let value = Value::MScalar { size_in_bits: 32 };
329 let serialized = serde_json::to_value(value).unwrap();
330 assert_eq!(
331 serialized,
332 json!({
333 "type": "mscalar",
334 "size_in_bits": 32
335 })
336 );
337 }
338
339 #[test]
340 fn test_mbool_serialization() {
341 let value = Value::MBool;
342 let serialized = serde_json::to_value(value).unwrap();
343 assert_eq!(
344 serialized,
345 json!({
346 "type": "mbool"
347 })
348 );
349 }
350
351 #[test]
352 fn test_bool_serialization() {
353 let value = Value::Bool;
354 let serialized = serde_json::to_value(value).unwrap();
355 assert_eq!(
356 serialized,
357 json!({
358 "type": "bool"
359 })
360 );
361 }
362
363 #[test]
364 fn test_array_serialization() {
365 let value = Value::Array(vec![Value::Scalar { size_in_bits: 60 }, Value::Bool]);
366 let serialized = serde_json::to_value(value).unwrap();
367 assert_eq!(
368 serialized,
369 json!({
370 "type": "array",
371 "content": [
372 {
373 "type": "scalar",
374 "size_in_bits": 60
375 },
376 {
377 "type": "bool"
378 }
379 ]
380 })
381 );
382 }
383
384 #[test]
385 fn test_nested_structure_serialization() {
386 let value = Value::Struct(vec![
387 Value::Tuple(vec![Value::MScalar { size_in_bits: 32 }, Value::MBool]),
388 Value::Array(vec![Value::Scalar { size_in_bits: 64 }, Value::Bool]),
389 ]);
390 let serialized = serde_json::to_string(&value).unwrap();
391 let deserialized: Value = serde_json::from_str(&serialized).unwrap();
392 assert_eq!(value, deserialized);
393 }
394
395 #[test]
396 fn test_mscalar_deserialization() {
397 let json = r#"{"type": "mscalar", "size_in_bits": 32}"#;
398 let deserialized: Value = serde_json::from_str(json).unwrap();
399 assert_eq!(deserialized, Value::MScalar { size_in_bits: 32 });
400 }
401
402 #[test]
403 fn test_array_deserialization() {
404 let json = r#"
405 {
406 "type": "array",
407 "content": [
408 {"type": "scalar", "size_in_bits": 64},
409 {"type": "bool"}
410 ]
411 }"#;
412 let deserialized: Value = serde_json::from_str(json).unwrap();
413 assert_eq!(
414 deserialized,
415 Value::Array(vec![Value::Scalar { size_in_bits: 64 }, Value::Bool,])
416 );
417 }
418
419 #[test]
420 fn test_invalid_type_deserialization() {
421 let json = r#"{"type": "invalid_type"}"#;
422 let result: Result<Value, _> = serde_json::from_str(json);
423 assert!(result.is_err());
424 }
425
426 #[test]
427 fn test_missing_size_in_bits_deserialization() {
428 let json = r#"{"type": "mscalar"}"#;
429 let result: Result<Value, _> = serde_json::from_str(json);
430 assert!(result.is_err());
431 }
432
433 #[test]
434 fn test_plaintext_type_serialization() {
435 let test_cases = [
437 (Value::Scalar { size_in_bits: 8 }, "u8"),
438 (Value::Scalar { size_in_bits: 16 }, "u16"),
439 (Value::Scalar { size_in_bits: 32 }, "u32"),
440 (Value::Scalar { size_in_bits: 64 }, "u64"),
441 (Value::Scalar { size_in_bits: 128 }, "u128"),
442 (Value::Scalar { size_in_bits: 24 }, "scalar"),
444 ];
445
446 for (value, expected_type) in test_cases {
447 let serialized = serde_json::to_value(&value).unwrap();
449 let expected = match &value {
450 Value::Scalar { size_in_bits } => json!({
451 "type": expected_type,
452 "size_in_bits": size_in_bits
453 }),
454 Value::Bool => json!({
455 "type": expected_type
456 }),
457 _ => unreachable!(),
458 };
459 assert_eq!(serialized, expected);
460
461 let json = serde_json::to_string(&expected).unwrap();
463 let deserialized: Value = serde_json::from_str(&json).unwrap();
464 assert_eq!(deserialized, value);
465 }
466 }
467}