1use ark_ff::{BigInteger, PrimeField};
2use serde::{Deserialize, Deserializer, Serialize, Serializer};
3use std::collections::BTreeMap;
4use std::fmt;
5
6use crate::hash::poseidon2_hash_bytes;
7use crate::types::Fr;
8use crate::Error;
9
10fn strip_0x(s: &str) -> &str {
11 s.strip_prefix("0x").unwrap_or(s)
12}
13
14fn decode_selector_hex(s: &str) -> Result<[u8; 4], Error> {
15 let raw = strip_0x(s);
16 if raw.len() > 8 {
17 return Err(Error::InvalidData(
18 "function selector must fit in 4 bytes".to_owned(),
19 ));
20 }
21 let padded = format!("{raw:0>8}");
22 let bytes = hex::decode(padded).map_err(|e| Error::InvalidData(e.to_string()))?;
23 let mut out = [0u8; 4];
24 out.copy_from_slice(&bytes);
25 Ok(out)
26}
27
28fn field_to_selector_bytes(field: Fr) -> [u8; 4] {
29 let raw = field.0.into_bigint().to_bytes_be();
30 let mut padded = [0u8; 32];
31 padded[32 - raw.len()..].copy_from_slice(&raw);
32 let mut out = [0u8; 4];
33 out.copy_from_slice(&padded[28..]);
34 out
35}
36
37fn selector_bytes_to_field(bytes: [u8; 4]) -> Fr {
38 Fr::from(u64::from(u32::from_be_bytes(bytes)))
39}
40
41fn selector_from_signature(signature: &str) -> [u8; 4] {
42 let hash = poseidon2_hash_bytes(signature.as_bytes());
43 field_to_selector_bytes(hash)
44}
45
46#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
48pub struct FunctionSelector(pub [u8; 4]);
49
50impl FunctionSelector {
51 pub fn from_hex(value: &str) -> Result<Self, Error> {
53 Ok(Self(decode_selector_hex(value)?))
54 }
55
56 pub fn from_field(field: Fr) -> Self {
58 Self(field_to_selector_bytes(field))
59 }
60
61 pub fn from_signature(signature: &str) -> Self {
72 Self(selector_from_signature(signature))
73 }
74
75 pub fn from_name(_name: &str) -> Result<Self, Error> {
79 Err(Error::Abi(
80 "function selector derivation is not implemented yet".to_owned(),
81 ))
82 }
83
84 pub fn to_field(self) -> Fr {
86 selector_bytes_to_field(self.0)
87 }
88}
89
90impl fmt::Display for FunctionSelector {
91 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92 write!(f, "0x{}", hex::encode(self.0))
93 }
94}
95
96impl Serialize for FunctionSelector {
97 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
98 where
99 S: Serializer,
100 {
101 serializer.serialize_str(&self.to_string())
102 }
103}
104
105impl<'de> Deserialize<'de> for FunctionSelector {
106 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
107 where
108 D: Deserializer<'de>,
109 {
110 let s = String::deserialize(deserializer)?;
111 Self::from_hex(&s).map_err(serde::de::Error::custom)
112 }
113}
114
115#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
117pub struct AuthorizationSelector(pub [u8; 4]);
118
119impl AuthorizationSelector {
120 pub fn from_hex(value: &str) -> Result<Self, Error> {
122 Ok(Self(decode_selector_hex(value)?))
123 }
124
125 pub fn from_field(field: Fr) -> Self {
127 Self(field_to_selector_bytes(field))
128 }
129
130 pub fn from_signature(signature: &str) -> Self {
132 Self(selector_from_signature(signature))
133 }
134
135 pub fn to_field(self) -> Fr {
137 selector_bytes_to_field(self.0)
138 }
139}
140
141impl fmt::Display for AuthorizationSelector {
142 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143 write!(f, "0x{}", hex::encode(self.0))
144 }
145}
146
147impl Serialize for AuthorizationSelector {
148 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
149 where
150 S: Serializer,
151 {
152 serializer.serialize_str(&self.to_string())
153 }
154}
155
156impl<'de> Deserialize<'de> for AuthorizationSelector {
157 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
158 where
159 D: Deserializer<'de>,
160 {
161 let s = String::deserialize(deserializer)?;
162 Self::from_hex(&s).map_err(serde::de::Error::custom)
163 }
164}
165
166#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
168pub struct EventSelector(pub Fr);
169
170#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
172#[serde(rename_all = "lowercase")]
173pub enum FunctionType {
174 Private,
176 Public,
178 Utility,
180}
181
182#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
184#[serde(tag = "kind", rename_all = "snake_case")]
185pub enum AbiType {
186 Field,
188 Boolean,
190 Integer {
192 sign: String,
194 width: u16,
196 },
197 Array {
199 element: Box<Self>,
201 length: usize,
203 },
204 String {
206 length: usize,
208 },
209 Struct {
211 name: String,
213 fields: Vec<AbiParameter>,
215 },
216 Tuple {
218 elements: Vec<Self>,
220 },
221}
222
223#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
225#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
226pub enum AbiValue {
227 Field(Fr),
229 Boolean(bool),
231 Integer(i128),
233 Array(Vec<Self>),
235 String(String),
237 Struct(BTreeMap<String, Self>),
239 Tuple(Vec<Self>),
241}
242
243#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
245pub struct AbiParameter {
246 pub name: String,
248 #[serde(rename = "type")]
250 pub typ: AbiType,
251 #[serde(default)]
253 pub visibility: Option<String>,
254}
255
256#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
258pub struct FunctionArtifact {
259 pub name: String,
261 pub function_type: FunctionType,
263 #[serde(default)]
265 pub is_initializer: bool,
266 #[serde(default)]
268 pub is_static: bool,
269 #[serde(default)]
271 pub parameters: Vec<AbiParameter>,
272 #[serde(default)]
274 pub return_types: Vec<AbiType>,
275 #[serde(default)]
277 pub selector: Option<FunctionSelector>,
278}
279
280#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
282pub struct ContractArtifact {
283 pub name: String,
285 #[serde(default)]
287 pub functions: Vec<FunctionArtifact>,
288}
289
290impl ContractArtifact {
291 pub fn from_json(json: &str) -> Result<Self, Error> {
293 serde_json::from_str(json).map_err(Error::from)
294 }
295
296 pub fn find_function(&self, name: &str) -> Result<&FunctionArtifact, Error> {
298 self.functions
299 .iter()
300 .find(|f| f.name == name)
301 .ok_or_else(|| {
302 Error::Abi(format!(
303 "function '{}' not found in artifact '{}'",
304 name, self.name
305 ))
306 })
307 }
308
309 pub fn find_function_by_type(
311 &self,
312 name: &str,
313 function_type: &FunctionType,
314 ) -> Result<&FunctionArtifact, Error> {
315 self.functions
316 .iter()
317 .find(|f| f.name == name && &f.function_type == function_type)
318 .ok_or_else(|| {
319 Error::Abi(format!(
320 "{:?} function '{}' not found in artifact '{}'",
321 function_type, name, self.name
322 ))
323 })
324 }
325}
326
327#[cfg(test)]
328#[allow(clippy::expect_used, clippy::panic)]
329mod tests {
330 use super::*;
331
332 const MINIMAL_ARTIFACT: &str = r#"
333 {
334 "name": "TestContract",
335 "functions": [
336 {
337 "name": "increment",
338 "function_type": "public",
339 "is_initializer": false,
340 "is_static": false,
341 "parameters": [
342 { "name": "value", "type": { "kind": "field" } }
343 ],
344 "return_types": []
345 }
346 ]
347 }
348 "#;
349
350 const MULTI_FUNCTION_ARTIFACT: &str = r#"
351 {
352 "name": "TokenContract",
353 "functions": [
354 {
355 "name": "constructor",
356 "function_type": "private",
357 "is_initializer": true,
358 "is_static": false,
359 "parameters": [
360 { "name": "admin", "type": { "kind": "field" } },
361 { "name": "name", "type": { "kind": "string", "length": 31 } }
362 ],
363 "return_types": []
364 },
365 {
366 "name": "transfer",
367 "function_type": "private",
368 "is_initializer": false,
369 "is_static": false,
370 "parameters": [
371 { "name": "from", "type": { "kind": "field" } },
372 { "name": "to", "type": { "kind": "field" } },
373 { "name": "amount", "type": { "kind": "integer", "sign": "unsigned", "width": 64 } }
374 ],
375 "return_types": []
376 },
377 {
378 "name": "balance_of",
379 "function_type": "utility",
380 "is_initializer": false,
381 "is_static": true,
382 "parameters": [
383 { "name": "owner", "type": { "kind": "field" } }
384 ],
385 "return_types": [
386 { "kind": "integer", "sign": "unsigned", "width": 64 }
387 ]
388 },
389 {
390 "name": "total_supply",
391 "function_type": "public",
392 "is_initializer": false,
393 "is_static": true,
394 "parameters": [],
395 "return_types": [
396 { "kind": "integer", "sign": "unsigned", "width": 64 }
397 ]
398 }
399 ]
400 }
401 "#;
402
403 #[test]
404 fn function_type_roundtrip() {
405 for (ft, expected) in [
406 (FunctionType::Private, "\"private\""),
407 (FunctionType::Public, "\"public\""),
408 (FunctionType::Utility, "\"utility\""),
409 ] {
410 let json = serde_json::to_string(&ft).expect("serialize FunctionType");
411 assert_eq!(json, expected);
412 let decoded: FunctionType =
413 serde_json::from_str(&json).expect("deserialize FunctionType");
414 assert_eq!(decoded, ft);
415 }
416 }
417
418 #[test]
419 fn function_selector_hex_roundtrip() {
420 let selector = FunctionSelector::from_hex("0xaabbccdd").expect("valid hex");
421 assert_eq!(selector.0, [0xaa, 0xbb, 0xcc, 0xdd]);
422 assert_eq!(selector.to_string(), "0xaabbccdd");
423
424 let json = serde_json::to_string(&selector).expect("serialize selector");
425 let decoded: FunctionSelector = serde_json::from_str(&json).expect("deserialize selector");
426 assert_eq!(decoded, selector);
427 }
428
429 #[test]
430 fn authorization_selector_hex_roundtrip() {
431 let selector = AuthorizationSelector::from_hex("0x01020304").expect("valid hex");
432 assert_eq!(selector.0, [0x01, 0x02, 0x03, 0x04]);
433 assert_eq!(selector.to_string(), "0x01020304");
434
435 let json = serde_json::to_string(&selector).expect("serialize selector");
436 let decoded: AuthorizationSelector =
437 serde_json::from_str(&json).expect("deserialize selector");
438 assert_eq!(decoded, selector);
439 }
440
441 #[test]
442 fn function_selector_rejects_too_long() {
443 let result = FunctionSelector::from_hex("0xaabbccddee");
444 assert!(result.is_err());
445 }
446
447 #[test]
448 fn event_selector_roundtrip() {
449 let selector = EventSelector(Fr::from(42u64));
450 let json = serde_json::to_string(&selector).expect("serialize EventSelector");
451 let decoded: EventSelector =
452 serde_json::from_str(&json).expect("deserialize EventSelector");
453 assert_eq!(decoded, selector);
454 }
455
456 #[test]
457 fn load_minimal_artifact() {
458 let artifact = ContractArtifact::from_json(MINIMAL_ARTIFACT).expect("parse artifact");
459 assert_eq!(artifact.name, "TestContract");
460 assert_eq!(artifact.functions.len(), 1);
461 assert_eq!(artifact.functions[0].name, "increment");
462 assert_eq!(artifact.functions[0].function_type, FunctionType::Public);
463 assert!(!artifact.functions[0].is_initializer);
464 assert_eq!(artifact.functions[0].parameters.len(), 1);
465 assert_eq!(artifact.functions[0].parameters[0].name, "value");
466 }
467
468 #[test]
469 fn load_multi_function_artifact() {
470 let artifact =
471 ContractArtifact::from_json(MULTI_FUNCTION_ARTIFACT).expect("parse artifact");
472 assert_eq!(artifact.name, "TokenContract");
473 assert_eq!(artifact.functions.len(), 4);
474
475 let constructor = &artifact.functions[0];
476 assert_eq!(constructor.name, "constructor");
477 assert_eq!(constructor.function_type, FunctionType::Private);
478 assert!(constructor.is_initializer);
479 assert_eq!(constructor.parameters.len(), 2);
480
481 let transfer = &artifact.functions[1];
482 assert_eq!(transfer.name, "transfer");
483 assert_eq!(transfer.function_type, FunctionType::Private);
484 assert!(!transfer.is_static);
485
486 let balance = &artifact.functions[2];
487 assert_eq!(balance.name, "balance_of");
488 assert_eq!(balance.function_type, FunctionType::Utility);
489 assert!(balance.is_static);
490 assert_eq!(balance.return_types.len(), 1);
491
492 let supply = &artifact.functions[3];
493 assert_eq!(supply.name, "total_supply");
494 assert_eq!(supply.function_type, FunctionType::Public);
495 assert!(supply.is_static);
496 }
497
498 #[test]
499 fn find_function_by_name() {
500 let artifact =
501 ContractArtifact::from_json(MULTI_FUNCTION_ARTIFACT).expect("parse artifact");
502
503 let transfer = artifact.find_function("transfer").expect("find transfer");
504 assert_eq!(transfer.name, "transfer");
505 assert_eq!(transfer.function_type, FunctionType::Private);
506 }
507
508 #[test]
509 fn find_function_not_found() {
510 let artifact =
511 ContractArtifact::from_json(MULTI_FUNCTION_ARTIFACT).expect("parse artifact");
512
513 let result = artifact.find_function("nonexistent");
514 assert!(result.is_err());
515 }
516
517 #[test]
518 fn find_function_by_type() {
519 let artifact =
520 ContractArtifact::from_json(MULTI_FUNCTION_ARTIFACT).expect("parse artifact");
521
522 let balance = artifact
523 .find_function_by_type("balance_of", &FunctionType::Utility)
524 .expect("find balance_of as utility");
525 assert_eq!(balance.name, "balance_of");
526
527 let wrong_type = artifact.find_function_by_type("balance_of", &FunctionType::Public);
528 assert!(wrong_type.is_err());
529 }
530
531 #[test]
532 fn abi_value_field_roundtrip() {
533 let value = AbiValue::Field(Fr::from(1u64));
534 let json = serde_json::to_string(&value).expect("serialize AbiValue::Field");
535 assert!(json.contains("field"));
536 let decoded: AbiValue = serde_json::from_str(&json).expect("deserialize AbiValue");
537 assert_eq!(decoded, value);
538 }
539
540 #[test]
541 fn abi_value_boolean_roundtrip() {
542 let value = AbiValue::Boolean(true);
543 let json = serde_json::to_string(&value).expect("serialize");
544 let decoded: AbiValue = serde_json::from_str(&json).expect("deserialize");
545 assert_eq!(decoded, value);
546 }
547
548 #[test]
549 fn abi_value_integer_roundtrip() {
550 let value = AbiValue::Integer(42);
551 let json = serde_json::to_string(&value).expect("serialize");
552 let decoded: AbiValue = serde_json::from_str(&json).expect("deserialize");
553 assert_eq!(decoded, value);
554 }
555
556 #[test]
557 fn abi_value_array_roundtrip() {
558 let value = AbiValue::Array(vec![
559 AbiValue::Field(Fr::from(1u64)),
560 AbiValue::Field(Fr::from(2u64)),
561 ]);
562 let json = serde_json::to_string(&value).expect("serialize");
563 let decoded: AbiValue = serde_json::from_str(&json).expect("deserialize");
564 assert_eq!(decoded, value);
565 }
566
567 #[test]
568 fn abi_value_struct_roundtrip() {
569 let mut fields = BTreeMap::new();
570 fields.insert("x".to_owned(), AbiValue::Field(Fr::from(1u64)));
571 fields.insert("y".to_owned(), AbiValue::Integer(2));
572 let value = AbiValue::Struct(fields);
573 let json = serde_json::to_string(&value).expect("serialize");
574 let decoded: AbiValue = serde_json::from_str(&json).expect("deserialize");
575 assert_eq!(decoded, value);
576 }
577
578 #[test]
579 fn abi_type_struct_roundtrip() {
580 let typ = AbiType::Struct {
581 name: "Point".to_owned(),
582 fields: vec![
583 AbiParameter {
584 name: "x".to_owned(),
585 typ: AbiType::Field,
586 visibility: None,
587 },
588 AbiParameter {
589 name: "y".to_owned(),
590 typ: AbiType::Field,
591 visibility: None,
592 },
593 ],
594 };
595 let json = serde_json::to_string(&typ).expect("serialize AbiType::Struct");
596 let decoded: AbiType = serde_json::from_str(&json).expect("deserialize AbiType::Struct");
597 assert_eq!(decoded, typ);
598 }
599
600 #[test]
601 fn abi_type_array_roundtrip() {
602 let typ = AbiType::Array {
603 element: Box::new(AbiType::Field),
604 length: 10,
605 };
606 let json = serde_json::to_string(&typ).expect("serialize");
607 let decoded: AbiType = serde_json::from_str(&json).expect("deserialize");
608 assert_eq!(decoded, typ);
609 }
610
611 #[test]
612 fn artifact_from_invalid_json_fails() {
613 let result = ContractArtifact::from_json("not json");
614 assert!(result.is_err());
615 }
616
617 #[test]
618 fn from_signature_is_deterministic() {
619 let a = FunctionSelector::from_signature("sponsor_unconditionally()");
620 let b = FunctionSelector::from_signature("sponsor_unconditionally()");
621 assert_eq!(a, b);
622 }
623
624 #[test]
625 fn from_signature_different_inputs_differ() {
626 let a = FunctionSelector::from_signature("sponsor_unconditionally()");
627 let b = FunctionSelector::from_signature("claim_and_end_setup((Field),u128,Field,Field)");
628 assert_ne!(a, b);
629 }
630
631 #[test]
632 fn from_signature_empty_string() {
633 let a = FunctionSelector::from_signature("");
635 let b = FunctionSelector::from_signature("");
636 assert_eq!(a, b);
637 }
638
639 #[test]
640 fn from_signature_produces_4_bytes() {
641 let selector = FunctionSelector::from_signature("transfer(Field,Field,u64)");
642 assert_eq!(selector.0.len(), 4);
643 }
644
645 #[test]
646 fn function_selector_roundtrips_through_field() {
647 let selector = FunctionSelector::from_signature("set_authorized(Field,bool)");
648 assert_eq!(FunctionSelector::from_field(selector.to_field()), selector);
649 }
650
651 #[test]
652 fn authorization_selector_roundtrips_through_field() {
653 let selector =
654 AuthorizationSelector::from_signature("CallAuthorization((Field),(u32),Field)");
655 assert_eq!(
656 AuthorizationSelector::from_field(selector.to_field()),
657 selector
658 );
659 }
660}