1use num_bigint::{BigInt, BigUint};
2use num_traits::{Num, Zero};
3use std::collections::{BTreeMap, HashSet};
4use thiserror::Error;
5
6use acvm::{AcirField, FieldElement};
7use serde::Serialize;
8
9use crate::errors::InputParserError;
10use crate::{Abi, AbiType};
11
12pub mod json;
13mod toml;
14
15#[derive(Debug, Clone, Serialize, PartialEq)]
19pub enum InputValue {
20 Field(FieldElement),
21 String(String),
22 Vec(Vec<InputValue>),
23 Struct(BTreeMap<String, InputValue>),
24}
25
26#[derive(Debug, Error)]
27pub enum InputTypecheckingError {
28 #[error("Value {value:?} does not fall within range of allowable values for a {typ:?}")]
29 OutsideOfValidRange { path: String, typ: AbiType, value: InputValue },
30 #[error(
31 "Type {typ:?} is expected to have length {expected_length} but value {value:?} has length {actual_length}"
32 )]
33 LengthMismatch {
34 path: String,
35 typ: AbiType,
36 value: InputValue,
37 expected_length: usize,
38 actual_length: usize,
39 },
40 #[error(
41 "Could not find value for required field `{expected_field}`. Found values for fields {found_fields:?}"
42 )]
43 MissingField { path: String, expected_field: String, found_fields: Vec<String> },
44 #[error(
45 "Additional unexpected field was provided for type {typ:?}. Found field named `{extra_field}`"
46 )]
47 UnexpectedField { path: String, typ: AbiType, extra_field: String },
48 #[error("Type {typ:?} and value {value:?} do not match")]
49 IncompatibleTypes { path: String, typ: AbiType, value: InputValue },
50}
51
52impl InputTypecheckingError {
53 pub(crate) fn path(&self) -> &str {
54 match self {
55 InputTypecheckingError::OutsideOfValidRange { path, .. }
56 | InputTypecheckingError::LengthMismatch { path, .. }
57 | InputTypecheckingError::MissingField { path, .. }
58 | InputTypecheckingError::UnexpectedField { path, .. }
59 | InputTypecheckingError::IncompatibleTypes { path, .. } => path,
60 }
61 }
62}
63
64impl InputValue {
65 pub(crate) fn find_type_mismatch(
67 &self,
68 abi_param: &AbiType,
69 path: String,
70 ) -> Result<(), InputTypecheckingError> {
71 match (self, abi_param) {
72 (InputValue::Field(_), AbiType::Field) => Ok(()),
73 (InputValue::Field(field_element), AbiType::Integer { width, .. }) => {
74 if field_element.num_bits() <= *width {
75 Ok(())
76 } else {
77 Err(InputTypecheckingError::OutsideOfValidRange {
78 path,
79 typ: abi_param.clone(),
80 value: self.clone(),
81 })
82 }
83 }
84 (InputValue::Field(field_element), AbiType::Boolean) => {
85 if field_element.is_one() || field_element.is_zero() {
86 Ok(())
87 } else {
88 Err(InputTypecheckingError::OutsideOfValidRange {
89 path,
90 typ: abi_param.clone(),
91 value: self.clone(),
92 })
93 }
94 }
95
96 (InputValue::Vec(array_elements), AbiType::Array { length, typ, .. }) => {
97 if array_elements.len() != *length as usize {
98 return Err(InputTypecheckingError::LengthMismatch {
99 path,
100 typ: abi_param.clone(),
101 value: self.clone(),
102 expected_length: *length as usize,
103 actual_length: array_elements.len(),
104 });
105 }
106 for (i, element) in array_elements.iter().enumerate() {
108 let mut path = path.clone();
109 path.push_str(&format!("[{i}]"));
110
111 element.find_type_mismatch(typ, path)?;
112 }
113 Ok(())
114 }
115
116 (InputValue::String(string), AbiType::String { length }) => {
117 if string.len() == *length as usize {
118 Ok(())
119 } else {
120 Err(InputTypecheckingError::LengthMismatch {
121 path,
122 typ: abi_param.clone(),
123 value: self.clone(),
124 actual_length: string.len(),
125 expected_length: *length as usize,
126 })
127 }
128 }
129
130 (InputValue::Struct(map), AbiType::Struct { fields, .. }) => {
131 for (field_name, field_type) in fields {
132 if let Some(value) = map.get(field_name) {
133 let mut path = path.clone();
134 path.push_str(&format!(".{field_name}"));
135 value.find_type_mismatch(field_type, path)?;
136 } else {
137 return Err(InputTypecheckingError::MissingField {
138 path,
139 expected_field: field_name.to_string(),
140 found_fields: map.keys().cloned().collect(),
141 });
142 }
143 }
144
145 if map.len() > fields.len() {
146 let expected_fields: HashSet<String> =
147 fields.iter().map(|(field, _)| field.to_string()).collect();
148 let extra_field = map.keys().find(|&key| !expected_fields.contains(key)).cloned().expect("`map` is larger than the expected type's `fields` so it must contain an unexpected field");
149 return Err(InputTypecheckingError::UnexpectedField {
150 path,
151 typ: abi_param.clone(),
152 extra_field: extra_field.to_string(),
153 });
154 }
155
156 Ok(())
157 }
158
159 (InputValue::Vec(vec_elements), AbiType::Tuple { fields }) => {
160 if vec_elements.len() != fields.len() {
161 return Err(InputTypecheckingError::LengthMismatch {
162 path,
163 typ: abi_param.clone(),
164 value: self.clone(),
165 actual_length: vec_elements.len(),
166 expected_length: fields.len(),
167 });
168 }
169 for (i, (element, expected_typ)) in vec_elements.iter().zip(fields).enumerate() {
171 let mut path = path.clone();
172 path.push_str(&format!(".{i}"));
173 element.find_type_mismatch(expected_typ, path)?;
174 }
175 Ok(())
176 }
177
178 _ => Err(InputTypecheckingError::IncompatibleTypes {
180 path,
181 typ: abi_param.clone(),
182 value: self.clone(),
183 }),
184 }
185 }
186
187 pub fn matches_abi(&self, abi_param: &AbiType) -> bool {
189 self.find_type_mismatch(abi_param, String::new()).is_ok()
190 }
191}
192
193#[derive(Debug, Clone, PartialEq, Eq)]
196#[cfg_attr(test, derive(strum_macros::EnumIter))]
197pub enum Format {
198 Json,
199 Toml,
200}
201
202impl Format {
203 pub fn ext(&self) -> &'static str {
204 match self {
205 Format::Json => "json",
206 Format::Toml => "toml",
207 }
208 }
209
210 pub fn from_ext(ext: &str) -> Option<Self> {
211 match ext {
212 "json" => Some(Self::Json),
213 "toml" => Some(Self::Toml),
214 _ => None,
215 }
216 }
217}
218
219impl Format {
220 pub fn parse(
221 &self,
222 input_string: &str,
223 abi: &Abi,
224 ) -> Result<BTreeMap<String, InputValue>, InputParserError> {
225 match self {
226 Format::Json => json::parse_json(input_string, abi),
227 Format::Toml => toml::parse_toml(input_string, abi),
228 }
229 }
230
231 pub fn serialize(
232 &self,
233 input_map: &BTreeMap<String, InputValue>,
234 abi: &Abi,
235 ) -> Result<String, InputParserError> {
236 match self {
237 Format::Json => json::serialize_to_json(input_map, abi),
238 Format::Toml => toml::serialize_to_toml(input_map, abi),
239 }
240 }
241}
242
243#[cfg(test)]
244mod serialization_tests {
245 use std::collections::BTreeMap;
246
247 use acvm::{AcirField, FieldElement};
248 use strum::IntoEnumIterator;
249
250 use crate::{
251 Abi, AbiParameter, AbiReturnType, AbiType, AbiVisibility, MAIN_RETURN_NAME, Sign,
252 input_parser::InputValue,
253 };
254
255 use super::Format;
256
257 #[test]
258 fn serialization_round_trip() {
259 let abi = Abi {
260 parameters: vec![
261 AbiParameter {
262 name: "foo".into(),
263 typ: AbiType::Field,
264 visibility: AbiVisibility::Private,
265 },
266 AbiParameter {
267 name: "signed_example".into(),
268 typ: AbiType::Integer { sign: Sign::Signed, width: 8 },
269 visibility: AbiVisibility::Private,
270 },
271 AbiParameter {
272 name: "bar".into(),
273 typ: AbiType::Struct {
274 path: "MyStruct".into(),
275 fields: vec![
276 ("field1".into(), AbiType::Integer { sign: Sign::Unsigned, width: 8 }),
277 (
278 "field2".into(),
279 AbiType::Array { length: 2, typ: Box::new(AbiType::Boolean) },
280 ),
281 ],
282 },
283 visibility: AbiVisibility::Private,
284 },
285 ],
286 return_type: Some(AbiReturnType {
287 abi_type: AbiType::String { length: 5 },
288 visibility: AbiVisibility::Public,
289 }),
290 error_types: Default::default(),
291 };
292
293 let input_map: BTreeMap<String, InputValue> = BTreeMap::from([
294 ("foo".into(), InputValue::Field(FieldElement::one())),
295 ("signed_example".into(), InputValue::Field(FieldElement::from(240u128))),
296 (
297 "bar".into(),
298 InputValue::Struct(BTreeMap::from([
299 ("field1".into(), InputValue::Field(255u128.into())),
300 (
301 "field2".into(),
302 InputValue::Vec(vec![
303 InputValue::Field(true.into()),
304 InputValue::Field(false.into()),
305 ]),
306 ),
307 ])),
308 ),
309 (MAIN_RETURN_NAME.into(), InputValue::String("hello".to_owned())),
310 ]);
311
312 for format in Format::iter() {
313 let serialized_inputs = format.serialize(&input_map, &abi).unwrap();
314
315 let reconstructed_input_map = format.parse(&serialized_inputs, &abi).unwrap();
316
317 assert_eq!(input_map, reconstructed_input_map);
318 }
319 }
320}
321
322fn parse_str_to_field(value: &str, arg_name: &str) -> Result<FieldElement, InputParserError> {
323 let big_num = if let Some(hex) = value.strip_prefix("0x") {
324 BigUint::from_str_radix(hex, 16)
325 } else {
326 BigUint::from_str_radix(value, 10)
327 };
328 big_num
329 .map_err(|err_msg| InputParserError::ParseStr {
330 arg_name: arg_name.into(),
331 value: value.into(),
332 error: err_msg.to_string(),
333 })
334 .and_then(|bigint| {
335 if bigint < FieldElement::modulus() {
336 Ok(field_from_big_uint(bigint))
337 } else {
338 Err(InputParserError::InputExceedsFieldModulus {
339 arg_name: arg_name.into(),
340 value: value.to_string(),
341 })
342 }
343 })
344}
345
346fn parse_str_to_signed(
347 value: &str,
348 width: u32,
349 arg_name: &str,
350) -> Result<FieldElement, InputParserError> {
351 let big_num = if let Some(hex) = value.strip_prefix("-0x") {
352 BigInt::from_str_radix(hex, 16).map(|value| -value)
353 } else if let Some(hex) = value.strip_prefix("0x") {
354 BigInt::from_str_radix(hex, 16)
355 } else {
356 BigInt::from_str_radix(value, 10)
357 };
358
359 big_num
360 .map_err(|err_msg| InputParserError::ParseStr {
361 arg_name: arg_name.into(),
362 value: value.into(),
363 error: err_msg.to_string(),
364 })
365 .and_then(|bigint| {
366 let min = if width == 128 { i128::MIN } else { -(1 << (width - 1)) };
367 let max = if width == 128 { i128::MAX } else { (1 << (width - 1)) - 1 };
368
369 let max = BigInt::from(max);
370 let min = BigInt::from(min);
371
372 if bigint < min {
373 return Err(InputParserError::InputUnderflowsMinimum {
374 arg_name: arg_name.into(),
375 value: bigint.to_string(),
376 min: min.to_string(),
377 });
378 }
379
380 if bigint > max {
381 return Err(InputParserError::InputOverflowsMaximum {
382 arg_name: arg_name.into(),
383 value: bigint.to_string(),
384 max: max.to_string(),
385 });
386 }
387
388 let modulus: BigInt = FieldElement::modulus().into();
389 let bigint = if bigint.sign() == num_bigint::Sign::Minus {
390 BigInt::from(2).pow(width) + bigint
391 } else {
392 bigint
393 };
394 if bigint.is_zero() || (bigint.sign() == num_bigint::Sign::Plus && bigint < modulus) {
395 Ok(field_from_big_int(bigint))
396 } else {
397 Err(InputParserError::InputExceedsFieldModulus {
398 arg_name: arg_name.into(),
399 value: value.to_string(),
400 })
401 }
402 })
403}
404
405fn parse_integer_to_signed(
406 integer: i128,
407 width: u32,
408 arg_name: &str,
409) -> Result<FieldElement, InputParserError> {
410 let min = if width == 128 { i128::MIN } else { -(1 << (width - 1)) };
411 let max = if width == 128 { i128::MAX } else { (1 << (width - 1)) - 1 };
412
413 if integer < min {
414 return Err(InputParserError::InputUnderflowsMinimum {
415 arg_name: arg_name.into(),
416 value: integer.to_string(),
417 min: min.to_string(),
418 });
419 }
420
421 if integer > max {
422 return Err(InputParserError::InputOverflowsMaximum {
423 arg_name: arg_name.into(),
424 value: integer.to_string(),
425 max: max.to_string(),
426 });
427 }
428
429 let integer = if integer < 0 {
430 FieldElement::from(2u32).pow(&width.into()) + FieldElement::from(integer)
431 } else {
432 FieldElement::from(integer)
433 };
434 Ok(integer)
435}
436
437fn field_from_big_uint(bigint: BigUint) -> FieldElement {
438 FieldElement::from_be_bytes_reduce(&bigint.to_bytes_be())
439}
440
441fn field_from_big_int(bigint: BigInt) -> FieldElement {
442 match bigint.sign() {
443 num_bigint::Sign::Minus => {
444 unreachable!(
445 "Unsupported negative value; it should only be called with a positive value"
446 )
447 }
448 num_bigint::Sign::NoSign => FieldElement::zero(),
449 num_bigint::Sign::Plus => FieldElement::from_be_bytes_reduce(&bigint.to_bytes_be().1),
450 }
451}
452
453fn field_to_signed_hex(f: FieldElement, bit_size: u32) -> String {
454 let f_u128 = f.to_u128();
455 let max = if bit_size == 128 { i128::MAX as u128 } else { (1 << (bit_size - 1)) - 1 };
456 if f_u128 > max {
457 let f = FieldElement::from(2u32).pow(&bit_size.into()) - f;
458 format!("-0x{}", f.to_hex())
459 } else {
460 format!("0x{}", f.to_hex())
461 }
462}
463
464#[cfg(test)]
465mod test {
466 use acvm::{AcirField, FieldElement};
467 use num_bigint::BigUint;
468 use strum::IntoEnumIterator;
469
470 use super::{Format, parse_str_to_field, parse_str_to_signed};
471
472 fn big_uint_from_field(field: FieldElement) -> BigUint {
473 BigUint::from_bytes_be(&field.to_be_bytes())
474 }
475
476 #[test]
477 fn parse_empty_str_fails() {
478 assert!(parse_str_to_field("", "arg_name").is_err());
480 }
481
482 #[test]
483 fn parse_fields_from_strings() {
484 let fields = vec![
485 FieldElement::zero(),
486 FieldElement::one(),
487 FieldElement::from(u128::MAX) + FieldElement::one(),
488 -FieldElement::one(),
490 ];
491
492 for field in fields {
493 let hex_field = format!("0x{}", field.to_hex());
494 let field_from_hex = parse_str_to_field(&hex_field, "arg_name").unwrap();
495 assert_eq!(field_from_hex, field);
496
497 let dec_field = big_uint_from_field(field).to_string();
498 let field_from_dec = parse_str_to_field(&dec_field, "arg_name").unwrap();
499 assert_eq!(field_from_dec, field);
500 }
501 }
502
503 #[test]
504 fn rejects_noncanonical_fields() {
505 let noncanonical_field = FieldElement::modulus().to_string();
506 assert!(parse_str_to_field(&noncanonical_field, "arg_name").is_err());
507 }
508
509 #[test]
510 fn test_parse_str_to_signed() {
511 let value = parse_str_to_signed("1", 8, "arg_name").unwrap();
512 assert_eq!(value, FieldElement::from(1_u128));
513
514 let value = parse_str_to_signed("-1", 8, "arg_name").unwrap();
515 assert_eq!(value, FieldElement::from(255_u128));
516
517 let value = parse_str_to_signed("-1", 16, "arg_name").unwrap();
518 assert_eq!(value, FieldElement::from(65535_u128));
519
520 assert_eq!(
521 parse_str_to_signed("127", 8, "arg_name").unwrap(),
522 FieldElement::from(127_i128)
523 );
524 assert!(parse_str_to_signed("128", 8, "arg_name").is_err());
525 assert_eq!(
526 parse_str_to_signed("-128", 8, "arg_name").unwrap(),
527 FieldElement::from(128_i128)
528 );
529 assert_eq!(parse_str_to_signed("-1", 8, "arg_name").unwrap(), FieldElement::from(255_i128));
530 assert!(parse_str_to_signed("-129", 8, "arg_name").is_err());
531
532 assert_eq!(
533 parse_str_to_signed("32767", 16, "arg_name").unwrap(),
534 FieldElement::from(32767_i128)
535 );
536 assert!(parse_str_to_signed("32768", 16, "arg_name").is_err());
537 assert_eq!(
538 parse_str_to_signed("-32768", 16, "arg_name").unwrap(),
539 FieldElement::from(32768_i128)
540 );
541 assert_eq!(
542 parse_str_to_signed("-1", 16, "arg_name").unwrap(),
543 FieldElement::from(65535_i128)
544 );
545 assert!(parse_str_to_signed("-32769", 16, "arg_name").is_err());
546 }
547
548 #[test]
549 fn test_from_ext() {
550 for fmt in Format::iter() {
551 assert_eq!(Format::from_ext(fmt.ext()), Some(fmt));
552 }
553 assert_eq!(Format::from_ext("invalid extension"), None);
554 }
555}
556
557#[cfg(test)]
558mod arbitrary {
559 use proptest::prelude::*;
560
561 use crate::{AbiType, Sign};
562
563 pub(super) fn arb_signed_integer_type_and_value() -> BoxedStrategy<(AbiType, i64)> {
564 (2u32..=64)
565 .prop_flat_map(|width| {
566 let typ = Just(AbiType::Integer { width, sign: Sign::Signed });
567 let value = if width == 64 {
568 i64::MIN..i64::MAX
570 } else {
571 -(2i64.pow(width - 1))..(2i64.pow(width - 1) - 1)
572 };
573 (typ, value)
574 })
575 .boxed()
576 }
577}