1use serde::{Deserialize, Deserializer, Serialize, Serializer};
2use std::collections::BTreeMap;
3use std::fmt;
4
5use crate::types::Fr;
6use crate::Error;
7
8fn strip_0x(s: &str) -> &str {
9 s.strip_prefix("0x").unwrap_or(s)
10}
11
12fn decode_selector_hex(s: &str) -> Result<[u8; 4], Error> {
13 let raw = strip_0x(s);
14 if raw.len() > 8 {
15 return Err(Error::InvalidData(
16 "function selector must fit in 4 bytes".to_owned(),
17 ));
18 }
19 let padded = format!("{raw:0>8}");
20 let bytes = hex::decode(padded).map_err(|e| Error::InvalidData(e.to_string()))?;
21 let mut out = [0u8; 4];
22 out.copy_from_slice(&bytes);
23 Ok(out)
24}
25
26#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
28pub struct FunctionSelector(pub [u8; 4]);
29
30impl FunctionSelector {
31 pub fn from_hex(value: &str) -> Result<Self, Error> {
33 Ok(Self(decode_selector_hex(value)?))
34 }
35
36 pub fn from_name(_name: &str) -> Result<Self, Error> {
40 Err(Error::Abi(
41 "function selector derivation is not implemented yet".to_owned(),
42 ))
43 }
44}
45
46impl fmt::Display for FunctionSelector {
47 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48 write!(f, "0x{}", hex::encode(self.0))
49 }
50}
51
52impl Serialize for FunctionSelector {
53 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
54 where
55 S: Serializer,
56 {
57 serializer.serialize_str(&self.to_string())
58 }
59}
60
61impl<'de> Deserialize<'de> for FunctionSelector {
62 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
63 where
64 D: Deserializer<'de>,
65 {
66 let s = String::deserialize(deserializer)?;
67 Self::from_hex(&s).map_err(serde::de::Error::custom)
68 }
69}
70
71#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
73pub struct EventSelector(pub Fr);
74
75#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
77#[serde(rename_all = "lowercase")]
78pub enum FunctionType {
79 Private,
81 Public,
83 Utility,
85}
86
87#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
89#[serde(tag = "kind", rename_all = "snake_case")]
90pub enum AbiType {
91 Field,
93 Boolean,
95 Integer {
97 sign: String,
99 width: u16,
101 },
102 Array {
104 element: Box<Self>,
106 length: usize,
108 },
109 String {
111 length: usize,
113 },
114 Struct {
116 name: String,
118 fields: Vec<AbiParameter>,
120 },
121 Tuple {
123 elements: Vec<Self>,
125 },
126}
127
128#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
130#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
131pub enum AbiValue {
132 Field(Fr),
134 Boolean(bool),
136 Integer(i128),
138 Array(Vec<Self>),
140 String(String),
142 Struct(BTreeMap<String, Self>),
144 Tuple(Vec<Self>),
146}
147
148#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
150pub struct AbiParameter {
151 pub name: String,
153 #[serde(rename = "type")]
155 pub typ: AbiType,
156 #[serde(default)]
158 pub visibility: Option<String>,
159}
160
161#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
163pub struct FunctionArtifact {
164 pub name: String,
166 pub function_type: FunctionType,
168 #[serde(default)]
170 pub is_initializer: bool,
171 #[serde(default)]
173 pub is_static: bool,
174 #[serde(default)]
176 pub parameters: Vec<AbiParameter>,
177 #[serde(default)]
179 pub return_types: Vec<AbiType>,
180 #[serde(default)]
182 pub selector: Option<FunctionSelector>,
183}
184
185#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
187pub struct ContractArtifact {
188 pub name: String,
190 #[serde(default)]
192 pub functions: Vec<FunctionArtifact>,
193}
194
195impl ContractArtifact {
196 pub fn from_json(json: &str) -> Result<Self, Error> {
198 serde_json::from_str(json).map_err(Error::from)
199 }
200
201 pub fn find_function(&self, name: &str) -> Result<&FunctionArtifact, Error> {
203 self.functions
204 .iter()
205 .find(|f| f.name == name)
206 .ok_or_else(|| {
207 Error::Abi(format!(
208 "function '{}' not found in artifact '{}'",
209 name, self.name
210 ))
211 })
212 }
213
214 pub fn find_function_by_type(
216 &self,
217 name: &str,
218 function_type: &FunctionType,
219 ) -> Result<&FunctionArtifact, Error> {
220 self.functions
221 .iter()
222 .find(|f| f.name == name && &f.function_type == function_type)
223 .ok_or_else(|| {
224 Error::Abi(format!(
225 "{:?} function '{}' not found in artifact '{}'",
226 function_type, name, self.name
227 ))
228 })
229 }
230}
231
232#[cfg(test)]
233#[allow(clippy::expect_used, clippy::panic)]
234mod tests {
235 use super::*;
236
237 const MINIMAL_ARTIFACT: &str = r#"
238 {
239 "name": "TestContract",
240 "functions": [
241 {
242 "name": "increment",
243 "function_type": "public",
244 "is_initializer": false,
245 "is_static": false,
246 "parameters": [
247 { "name": "value", "type": { "kind": "field" } }
248 ],
249 "return_types": []
250 }
251 ]
252 }
253 "#;
254
255 const MULTI_FUNCTION_ARTIFACT: &str = r#"
256 {
257 "name": "TokenContract",
258 "functions": [
259 {
260 "name": "constructor",
261 "function_type": "private",
262 "is_initializer": true,
263 "is_static": false,
264 "parameters": [
265 { "name": "admin", "type": { "kind": "field" } },
266 { "name": "name", "type": { "kind": "string", "length": 31 } }
267 ],
268 "return_types": []
269 },
270 {
271 "name": "transfer",
272 "function_type": "private",
273 "is_initializer": false,
274 "is_static": false,
275 "parameters": [
276 { "name": "from", "type": { "kind": "field" } },
277 { "name": "to", "type": { "kind": "field" } },
278 { "name": "amount", "type": { "kind": "integer", "sign": "unsigned", "width": 64 } }
279 ],
280 "return_types": []
281 },
282 {
283 "name": "balance_of",
284 "function_type": "utility",
285 "is_initializer": false,
286 "is_static": true,
287 "parameters": [
288 { "name": "owner", "type": { "kind": "field" } }
289 ],
290 "return_types": [
291 { "kind": "integer", "sign": "unsigned", "width": 64 }
292 ]
293 },
294 {
295 "name": "total_supply",
296 "function_type": "public",
297 "is_initializer": false,
298 "is_static": true,
299 "parameters": [],
300 "return_types": [
301 { "kind": "integer", "sign": "unsigned", "width": 64 }
302 ]
303 }
304 ]
305 }
306 "#;
307
308 #[test]
309 fn function_type_roundtrip() {
310 for (ft, expected) in [
311 (FunctionType::Private, "\"private\""),
312 (FunctionType::Public, "\"public\""),
313 (FunctionType::Utility, "\"utility\""),
314 ] {
315 let json = serde_json::to_string(&ft).expect("serialize FunctionType");
316 assert_eq!(json, expected);
317 let decoded: FunctionType =
318 serde_json::from_str(&json).expect("deserialize FunctionType");
319 assert_eq!(decoded, ft);
320 }
321 }
322
323 #[test]
324 fn function_selector_hex_roundtrip() {
325 let selector = FunctionSelector::from_hex("0xaabbccdd").expect("valid hex");
326 assert_eq!(selector.0, [0xaa, 0xbb, 0xcc, 0xdd]);
327 assert_eq!(selector.to_string(), "0xaabbccdd");
328
329 let json = serde_json::to_string(&selector).expect("serialize selector");
330 let decoded: FunctionSelector = serde_json::from_str(&json).expect("deserialize selector");
331 assert_eq!(decoded, selector);
332 }
333
334 #[test]
335 fn function_selector_rejects_too_long() {
336 let result = FunctionSelector::from_hex("0xaabbccddee");
337 assert!(result.is_err());
338 }
339
340 #[test]
341 fn event_selector_roundtrip() {
342 let selector = EventSelector(Fr::from(42u64));
343 let json = serde_json::to_string(&selector).expect("serialize EventSelector");
344 let decoded: EventSelector =
345 serde_json::from_str(&json).expect("deserialize EventSelector");
346 assert_eq!(decoded, selector);
347 }
348
349 #[test]
350 fn load_minimal_artifact() {
351 let artifact = ContractArtifact::from_json(MINIMAL_ARTIFACT).expect("parse artifact");
352 assert_eq!(artifact.name, "TestContract");
353 assert_eq!(artifact.functions.len(), 1);
354 assert_eq!(artifact.functions[0].name, "increment");
355 assert_eq!(artifact.functions[0].function_type, FunctionType::Public);
356 assert!(!artifact.functions[0].is_initializer);
357 assert_eq!(artifact.functions[0].parameters.len(), 1);
358 assert_eq!(artifact.functions[0].parameters[0].name, "value");
359 }
360
361 #[test]
362 fn load_multi_function_artifact() {
363 let artifact =
364 ContractArtifact::from_json(MULTI_FUNCTION_ARTIFACT).expect("parse artifact");
365 assert_eq!(artifact.name, "TokenContract");
366 assert_eq!(artifact.functions.len(), 4);
367
368 let constructor = &artifact.functions[0];
369 assert_eq!(constructor.name, "constructor");
370 assert_eq!(constructor.function_type, FunctionType::Private);
371 assert!(constructor.is_initializer);
372 assert_eq!(constructor.parameters.len(), 2);
373
374 let transfer = &artifact.functions[1];
375 assert_eq!(transfer.name, "transfer");
376 assert_eq!(transfer.function_type, FunctionType::Private);
377 assert!(!transfer.is_static);
378
379 let balance = &artifact.functions[2];
380 assert_eq!(balance.name, "balance_of");
381 assert_eq!(balance.function_type, FunctionType::Utility);
382 assert!(balance.is_static);
383 assert_eq!(balance.return_types.len(), 1);
384
385 let supply = &artifact.functions[3];
386 assert_eq!(supply.name, "total_supply");
387 assert_eq!(supply.function_type, FunctionType::Public);
388 assert!(supply.is_static);
389 }
390
391 #[test]
392 fn find_function_by_name() {
393 let artifact =
394 ContractArtifact::from_json(MULTI_FUNCTION_ARTIFACT).expect("parse artifact");
395
396 let transfer = artifact.find_function("transfer").expect("find transfer");
397 assert_eq!(transfer.name, "transfer");
398 assert_eq!(transfer.function_type, FunctionType::Private);
399 }
400
401 #[test]
402 fn find_function_not_found() {
403 let artifact =
404 ContractArtifact::from_json(MULTI_FUNCTION_ARTIFACT).expect("parse artifact");
405
406 let result = artifact.find_function("nonexistent");
407 assert!(result.is_err());
408 }
409
410 #[test]
411 fn find_function_by_type() {
412 let artifact =
413 ContractArtifact::from_json(MULTI_FUNCTION_ARTIFACT).expect("parse artifact");
414
415 let balance = artifact
416 .find_function_by_type("balance_of", &FunctionType::Utility)
417 .expect("find balance_of as utility");
418 assert_eq!(balance.name, "balance_of");
419
420 let wrong_type = artifact.find_function_by_type("balance_of", &FunctionType::Public);
421 assert!(wrong_type.is_err());
422 }
423
424 #[test]
425 fn abi_value_field_roundtrip() {
426 let value = AbiValue::Field(Fr::from(1u64));
427 let json = serde_json::to_string(&value).expect("serialize AbiValue::Field");
428 assert!(json.contains("field"));
429 let decoded: AbiValue = serde_json::from_str(&json).expect("deserialize AbiValue");
430 assert_eq!(decoded, value);
431 }
432
433 #[test]
434 fn abi_value_boolean_roundtrip() {
435 let value = AbiValue::Boolean(true);
436 let json = serde_json::to_string(&value).expect("serialize");
437 let decoded: AbiValue = serde_json::from_str(&json).expect("deserialize");
438 assert_eq!(decoded, value);
439 }
440
441 #[test]
442 fn abi_value_integer_roundtrip() {
443 let value = AbiValue::Integer(42);
444 let json = serde_json::to_string(&value).expect("serialize");
445 let decoded: AbiValue = serde_json::from_str(&json).expect("deserialize");
446 assert_eq!(decoded, value);
447 }
448
449 #[test]
450 fn abi_value_array_roundtrip() {
451 let value = AbiValue::Array(vec![
452 AbiValue::Field(Fr::from(1u64)),
453 AbiValue::Field(Fr::from(2u64)),
454 ]);
455 let json = serde_json::to_string(&value).expect("serialize");
456 let decoded: AbiValue = serde_json::from_str(&json).expect("deserialize");
457 assert_eq!(decoded, value);
458 }
459
460 #[test]
461 fn abi_value_struct_roundtrip() {
462 let mut fields = BTreeMap::new();
463 fields.insert("x".to_owned(), AbiValue::Field(Fr::from(1u64)));
464 fields.insert("y".to_owned(), AbiValue::Integer(2));
465 let value = AbiValue::Struct(fields);
466 let json = serde_json::to_string(&value).expect("serialize");
467 let decoded: AbiValue = serde_json::from_str(&json).expect("deserialize");
468 assert_eq!(decoded, value);
469 }
470
471 #[test]
472 fn abi_type_struct_roundtrip() {
473 let typ = AbiType::Struct {
474 name: "Point".to_owned(),
475 fields: vec![
476 AbiParameter {
477 name: "x".to_owned(),
478 typ: AbiType::Field,
479 visibility: None,
480 },
481 AbiParameter {
482 name: "y".to_owned(),
483 typ: AbiType::Field,
484 visibility: None,
485 },
486 ],
487 };
488 let json = serde_json::to_string(&typ).expect("serialize AbiType::Struct");
489 let decoded: AbiType = serde_json::from_str(&json).expect("deserialize AbiType::Struct");
490 assert_eq!(decoded, typ);
491 }
492
493 #[test]
494 fn abi_type_array_roundtrip() {
495 let typ = AbiType::Array {
496 element: Box::new(AbiType::Field),
497 length: 10,
498 };
499 let json = serde_json::to_string(&typ).expect("serialize");
500 let decoded: AbiType = serde_json::from_str(&json).expect("deserialize");
501 assert_eq!(decoded, typ);
502 }
503
504 #[test]
505 fn artifact_from_invalid_json_fails() {
506 let result = ContractArtifact::from_json("not json");
507 assert!(result.is_err());
508 }
509}