casper_contract_sdk_codegen/
lib.rs

1pub mod support;
2
3use casper_contract_sdk::{
4    abi::{Declaration, Definition, Primitive},
5    casper_executor_wasm_common::flags::EntryPointFlags,
6    schema::{Schema, SchemaType},
7};
8use codegen::{Field, Scope, Type};
9use indexmap::IndexMap;
10use serde::{Deserialize, Serialize};
11use std::{
12    collections::{BTreeMap, VecDeque},
13    iter,
14    str::FromStr,
15};
16
17const DEFAULT_DERIVED_TRAITS: &[&str] = &[
18    "Clone",
19    "Debug",
20    "PartialEq",
21    "Eq",
22    "PartialOrd",
23    "Ord",
24    "Hash",
25    "BorshSerialize",
26    "BorshDeserialize",
27];
28
29/// Replaces characters that are not valid in Rust identifiers with underscores.
30fn slugify_type(input: &str) -> String {
31    let mut output = String::with_capacity(input.len());
32
33    for c in input.chars() {
34        if c.is_ascii_alphanumeric() {
35            output.push(c);
36        } else {
37            output.push('_');
38        }
39    }
40
41    output
42}
43
44#[derive(Debug, Deserialize, Serialize)]
45enum Specialized {
46    Result { ok: Declaration, err: Declaration },
47    Option { some: Declaration },
48}
49
50#[derive(Deserialize, Serialize)]
51pub struct Codegen {
52    schema: Schema,
53    type_mapping: BTreeMap<Declaration, String>,
54    specialized_types: BTreeMap<Declaration, Specialized>,
55}
56
57impl FromStr for Codegen {
58    type Err = serde_json::Error;
59
60    fn from_str(s: &str) -> Result<Self, Self::Err> {
61        let schema: Schema = serde_json::from_str(s)?;
62        Ok(Self::new(schema))
63    }
64}
65
66impl Codegen {
67    pub fn new(schema: Schema) -> Self {
68        Self {
69            schema,
70            type_mapping: Default::default(),
71            specialized_types: Default::default(),
72        }
73    }
74
75    pub fn from_file(path: &str) -> Result<Self, std::io::Error> {
76        let file = std::fs::File::open(path)?;
77        let schema: Schema = serde_json::from_reader(file)?;
78        Ok(Self::new(schema))
79    }
80
81    pub fn gen(&mut self) -> String {
82        let mut scope = Scope::new();
83
84        scope.import("borsh", "self");
85        scope.import("borsh", "BorshSerialize");
86        scope.import("borsh", "BorshDeserialize");
87        scope.import("casper_contract_sdk_codegen::support", "IntoResult");
88        scope.import("casper_contract_sdk_codegen::support", "IntoOption");
89        scope.import("casper_contract_sdk", "Selector");
90        scope.import("casper_contract_sdk", "ToCallData");
91
92        let _head = self
93            .schema
94            .definitions
95            .first()
96            .expect("No definitions found.");
97
98        match &self.schema.type_ {
99            SchemaType::Contract { state } => {
100                if !self.schema.definitions.has_definition(state) {
101                    panic!(
102                        "Missing state definition. Expected to find a definition for {}.",
103                        &state
104                    )
105                };
106            }
107            SchemaType::Interface => {}
108        }
109
110        // Initialize a queue with the first definition
111        let mut queue = VecDeque::new();
112
113        // Create a set to keep track of processed definitions
114        let mut processed = std::collections::HashSet::new();
115
116        let mut graph: IndexMap<_, VecDeque<_>> = IndexMap::new();
117
118        for (def_index, (next_decl, next_def)) in self.schema.definitions.iter().enumerate() {
119            println!(
120                "{def_index}. decl={decl}",
121                def_index = def_index,
122                decl = next_decl
123            );
124
125            queue.push_back(next_decl);
126
127            while let Some(decl) = queue.pop_front() {
128                if processed.contains(decl) {
129                    continue;
130                }
131
132                processed.insert(decl);
133                graph.entry(next_decl).or_default().push_back(decl);
134                // graph.find
135
136                match Primitive::from_str(decl) {
137                    Ok(primitive) => {
138                        println!("Processing primitive type {primitive:?}");
139                        continue;
140                    }
141                    Err(_) => {
142                        // Not a primitive type
143                    }
144                };
145
146                let def = self
147                    .schema
148                    .definitions
149                    .get(decl)
150                    .unwrap_or_else(|| panic!("Missing definition for {}", decl));
151
152                // graph.entry(next_decl).or_default().push(decl);
153                // println!("Processing type {decl}");
154
155                // Enqueue all unprocessed definitions that depend on the current definition
156                match def {
157                    Definition::Primitive(_primitive) => {
158                        continue;
159                    }
160                    Definition::Mapping { key, value } => {
161                        if !processed.contains(key) {
162                            queue.push_front(key);
163                            continue;
164                        }
165
166                        if !processed.contains(value) {
167                            queue.push_front(value);
168                            continue;
169                        }
170                    }
171                    Definition::Sequence { decl } => {
172                        queue.push_front(decl);
173                    }
174                    Definition::FixedSequence { length: _, decl } => {
175                        if !processed.contains(decl) {
176                            queue.push_front(decl);
177                            continue;
178                        }
179                    }
180                    Definition::Tuple { items } => {
181                        for item in items {
182                            if !processed.contains(item) {
183                                queue.push_front(item);
184                                continue;
185                            }
186                        }
187
188                        // queue.push_front(decl);
189                    }
190                    Definition::Enum { items } => {
191                        for item in items {
192                            if !processed.contains(&item.decl) {
193                                queue.push_front(&item.decl);
194                                continue;
195                            }
196                        }
197                    }
198                    Definition::Struct { items } => {
199                        for item in items {
200                            if !processed.contains(&item.decl) {
201                                queue.push_front(&item.decl);
202                                continue;
203                            }
204                        }
205                    }
206                }
207            }
208
209            match next_def {
210                Definition::Primitive(_) => {}
211                Definition::Mapping { key, value } => {
212                    assert!(processed.contains(key));
213                    assert!(processed.contains(value));
214                }
215                Definition::Sequence { decl } => {
216                    assert!(processed.contains(decl));
217                }
218                Definition::FixedSequence { length: _, decl } => {
219                    assert!(processed.contains(decl));
220                }
221                Definition::Tuple { items } => {
222                    for item in items {
223                        assert!(processed.contains(&item));
224                    }
225                }
226                Definition::Enum { items } => {
227                    for item in items {
228                        assert!(processed.contains(&item.decl));
229                    }
230                }
231                Definition::Struct { items } => {
232                    for item in items {
233                        assert!(processed.contains(&item.decl));
234                    }
235                }
236            }
237        }
238        dbg!(&graph);
239
240        let mut counter = iter::successors(Some(0usize), |prev| prev.checked_add(1));
241
242        for (_decl, deps) in graph {
243            for decl in deps.into_iter().rev() {
244                // println!("generate {decl}");
245
246                let def = self
247                    .schema
248                    .definitions
249                    .get(decl)
250                    .cloned()
251                    .or_else(|| Primitive::from_str(decl).ok().map(Definition::Primitive))
252                    .unwrap_or_else(|| panic!("Missing definition for {}", decl));
253
254                match def {
255                    Definition::Primitive(primitive) => {
256                        let (from, to) = match primitive {
257                            Primitive::Char => ("Char", "char"),
258                            Primitive::U8 => ("U8", "u8"),
259                            Primitive::I8 => ("I8", "i8"),
260                            Primitive::U16 => ("U16", "u16"),
261                            Primitive::I16 => ("I16", "i16"),
262                            Primitive::U32 => ("U32", "u32"),
263                            Primitive::I32 => ("I32", "i32"),
264                            Primitive::U64 => ("U64", "u64"),
265                            Primitive::I64 => ("I64", "i64"),
266                            Primitive::U128 => ("U128", "u128"),
267                            Primitive::I128 => ("I128", "i128"),
268                            Primitive::Bool => ("Bool", "bool"),
269                            Primitive::F32 => ("F32", "f32"),
270                            Primitive::F64 => ("F64", "f64"),
271                        };
272
273                        scope.new_type_alias(from, to).vis("pub");
274                        self.type_mapping.insert(decl.to_string(), from.to_string());
275                    }
276                    Definition::Mapping { key: _, value: _ } => {
277                        // println!("Processing mapping type {key:?} -> {value:?}");
278                        todo!()
279                    }
280                    Definition::Sequence { decl: seq_decl } => {
281                        println!("Processing sequence type {decl:?}");
282                        if decl.as_str() == "String"
283                            && Primitive::from_str(&seq_decl) == Ok(Primitive::Char)
284                        {
285                            self.type_mapping
286                                .insert("String".to_owned(), "String".to_owned());
287                        } else {
288                            let mapped_type = self
289                                .type_mapping
290                                .get(&seq_decl)
291                                .unwrap_or_else(|| panic!("Missing type mapping for {}", seq_decl));
292                            let type_name =
293                                format!("Sequence{}_{seq_decl}", counter.next().unwrap());
294                            scope.new_type_alias(&type_name, format!("Vec<{}>", mapped_type));
295                            self.type_mapping.insert(decl.to_string(), type_name);
296                        }
297                    }
298                    Definition::FixedSequence {
299                        length,
300                        decl: fixed_seq_decl,
301                    } => {
302                        let mapped_type =
303                            self.type_mapping.get(&fixed_seq_decl).unwrap_or_else(|| {
304                                panic!("Missing type mapping for {}", fixed_seq_decl)
305                            });
306
307                        let type_name = format!(
308                            "FixedSequence{}_{length}_{fixed_seq_decl}",
309                            counter.next().unwrap()
310                        );
311                        scope.new_type_alias(&type_name, format!("[{}; {}]", mapped_type, length));
312                        self.type_mapping.insert(decl.to_string(), type_name);
313                    }
314                    Definition::Tuple { items } => {
315                        if decl.as_str() == "()" && items.is_empty() {
316                            self.type_mapping.insert("()".to_owned(), "()".to_owned());
317                            continue;
318                        }
319
320                        println!("Processing tuple type {items:?}");
321                        let struct_name = slugify_type(decl);
322
323                        let r#struct = scope
324                            .new_struct(&struct_name)
325                            .doc(&format!("Declared as {decl}"));
326
327                        for trait_name in DEFAULT_DERIVED_TRAITS {
328                            r#struct.derive(trait_name);
329                        }
330
331                        if items.is_empty() {
332                            r#struct.tuple_field(Type::new("()"));
333                        } else {
334                            for item in items {
335                                let mapped_type = self
336                                    .type_mapping
337                                    .get(&item)
338                                    .unwrap_or_else(|| panic!("Missing type mapping for {}", item));
339                                r#struct.tuple_field(mapped_type);
340                            }
341                        }
342
343                        self.type_mapping.insert(decl.to_string(), struct_name);
344                    }
345                    Definition::Enum { items } => {
346                        println!("Processing enum type {decl} {items:?}");
347
348                        let mut items: Vec<&casper_contract_sdk::abi::EnumVariant> =
349                            items.iter().collect();
350
351                        let mut specialized = None;
352
353                        if decl.starts_with("Result")
354                            && items.len() == 2
355                            && items[0].name == "Ok"
356                            && items[1].name == "Err"
357                        {
358                            specialized = Some(Specialized::Result {
359                                ok: items[0].decl.clone(),
360                                err: items[1].decl.clone(),
361                            });
362
363                            // NOTE: Because we're not doing the standard library Result, and also
364                            // to simplify things we're using default impl of
365                            // BorshSerialize/BorshDeserialize, we have to flip the order of enums.
366                            // The standard library defines Result as Ok, Err, but the borsh impl
367                            // serializes Err as 0, and Ok as 1. So, by flipping the order we can
368                            // enforce byte for byte compatibility between our "custom" Result and a
369                            // real Result.
370                            items.reverse();
371                        }
372
373                        if decl.starts_with("Option")
374                            && items.len() == 2
375                            && items[0].name == "None"
376                            && items[1].name == "Some"
377                        {
378                            specialized = Some(Specialized::Option {
379                                some: items[1].decl.clone(),
380                            });
381
382                            items.reverse();
383                        }
384
385                        let enum_name = slugify_type(decl);
386
387                        let r#enum = scope
388                            .new_enum(&enum_name)
389                            .vis("pub")
390                            .doc(&format!("Declared as {decl}"));
391
392                        for trait_name in DEFAULT_DERIVED_TRAITS {
393                            r#enum.derive(trait_name);
394                        }
395
396                        for item in &items {
397                            let variant = r#enum.new_variant(&item.name);
398
399                            let def = self.type_mapping.get(&item.decl).unwrap_or_else(|| {
400                                panic!("Missing type mapping for {}", item.decl)
401                            });
402
403                            variant.tuple(def);
404                        }
405
406                        self.type_mapping
407                            .insert(decl.to_string(), enum_name.to_owned());
408
409                        match specialized {
410                            Some(Specialized::Result { ok, err }) => {
411                                let ok_type = self
412                                    .type_mapping
413                                    .get(&ok)
414                                    .unwrap_or_else(|| panic!("Missing type mapping for {}", ok));
415                                let err_type = self
416                                    .type_mapping
417                                    .get(&err)
418                                    .unwrap_or_else(|| panic!("Missing type mapping for {}", err));
419
420                                let impl_block = scope
421                                    .new_impl(&enum_name)
422                                    .impl_trait(format!("IntoResult<{ok_type}, {err_type}>"));
423
424                                let func = impl_block.new_fn("into_result").arg_self().ret(
425                                    Type::new(format!(
426                                        "Result<{ok_type}, {err_type}>",
427                                        ok_type = ok_type,
428                                        err_type = err_type
429                                    )),
430                                );
431                                func.line("match self {")
432                                    .line(format!("{enum_name}::Ok(ok) => Ok(ok),"))
433                                    .line(format!("{enum_name}::Err(err) => Err(err),"))
434                                    .line("}");
435                            }
436                            Some(Specialized::Option { some }) => {
437                                let some_type = self.type_mapping.get(&some).unwrap_or_else(|| {
438                                    panic!("Missing type mapping for {}", &some)
439                                });
440
441                                let impl_block = scope
442                                    .new_impl(&enum_name)
443                                    .impl_trait(format!("IntoOption<{some_type}>"));
444
445                                let func = impl_block
446                                    .new_fn("into_option")
447                                    .arg_self()
448                                    .ret(Type::new(format!("Option<{some_type}>",)));
449                                func.line("match self {")
450                                    .line(format!("{enum_name}::None => None,"))
451                                    .line(format!("{enum_name}::Some(some) => Some(some),"))
452                                    .line("}");
453                            }
454                            None => {}
455                        }
456                    }
457                    Definition::Struct { items } => {
458                        println!("Processing struct type {items:?}");
459
460                        let type_name = slugify_type(decl);
461
462                        let r#struct = scope.new_struct(&type_name);
463
464                        for trait_name in DEFAULT_DERIVED_TRAITS {
465                            r#struct.derive(trait_name);
466                        }
467
468                        for item in items {
469                            let mapped_type =
470                                self.type_mapping.get(&item.decl).unwrap_or_else(|| {
471                                    panic!("Missing type mapping for {}", item.decl)
472                                });
473                            let field = Field::new(&item.name, Type::new(mapped_type))
474                                .doc(format!("Declared as {}", item.decl))
475                                .to_owned();
476
477                            r#struct.push_field(field);
478                        }
479                        self.type_mapping.insert(decl.to_string(), type_name);
480                    }
481                }
482            }
483        }
484
485        let struct_name = format!("{}Client", self.schema.name);
486        let client = scope.new_struct(&struct_name).vis("pub");
487
488        for trait_name in DEFAULT_DERIVED_TRAITS {
489            client.derive(trait_name);
490        }
491
492        let mut field = Field::new("address", Type::new("[u8; 32]"));
493        field.vis("pub");
494
495        client.push_field(field);
496
497        let client_impl = scope.new_impl(&struct_name);
498
499        for entry_point in &self.schema.entry_points {
500            let func = client_impl.new_fn(&entry_point.name);
501            func.vis("pub");
502
503            let result_type = self
504                .type_mapping
505                .get(&entry_point.result)
506                .unwrap_or_else(|| panic!("Missing type mapping for {}", entry_point.result));
507
508            if entry_point.flags.contains(EntryPointFlags::CONSTRUCTOR) {
509                func.ret(Type::new(format!(
510                    "Result<{}, casper_contract_sdk::types::CallError>",
511                    &struct_name
512                )))
513                .generic("C")
514                .bound("C", "casper_contract_sdk::Contract");
515            } else {
516                func.ret(Type::new(format!(
517                    "Result<casper_contract_sdk::host::CallResult<{result_type}>, casper_contract_sdk::types::CallError>"
518                )));
519                func.arg_ref_self();
520            }
521
522            for arg in &entry_point.arguments {
523                let mapped_type = self
524                    .type_mapping
525                    .get(&arg.decl)
526                    .unwrap_or_else(|| panic!("Missing type mapping for {}", arg.decl));
527                let arg_ty = Type::new(mapped_type);
528                func.arg(&arg.name, arg_ty);
529            }
530
531            func.line("let value = 0; // TODO: Transferring values");
532
533            let input_struct_name =
534                format!("{}_{}", slugify_type(&self.schema.name), &entry_point.name);
535
536            if entry_point.arguments.is_empty() {
537                func.line(format!(r#"let call_data = {input_struct_name};"#));
538            } else {
539                func.line(format!(r#"let call_data = {input_struct_name} {{ "#));
540                for arg in &entry_point.arguments {
541                    func.line(format!("{},", arg.name));
542                }
543                func.line("};");
544            }
545
546            if entry_point.flags.contains(EntryPointFlags::CONSTRUCTOR) {
547                // if !entry_point.arguments.is_empty() {
548                //     func.line(r#"let create_result = C::create(SELECTOR, Some(&input_data))?;"#);
549                // } else {
550                func.line(r#"let create_result = C::create(call_data)?;"#);
551                // }
552
553                func.line(format!(
554                    r#"let result = {struct_name} {{ address: create_result.contract_address }};"#,
555                    struct_name = &struct_name
556                ));
557                func.line("Ok(result)");
558                continue;
559            } else {
560                func.line(r#"casper_contract_sdk::host::call(&self.address, value, call_data)"#);
561            }
562        }
563
564        for entry_point in &self.schema.entry_points {
565            // Generate arg structure similar to what casper-contract-macros is doing
566            let struct_name = format!("{}_{}", &self.schema.name, &entry_point.name);
567            let input_struct = scope.new_struct(&struct_name);
568
569            for trait_name in DEFAULT_DERIVED_TRAITS {
570                input_struct.derive(trait_name);
571            }
572
573            for argument in &entry_point.arguments {
574                let mapped_type = self.type_mapping.get(&argument.decl).unwrap_or_else(|| {
575                    panic!(
576                        "Missing type mapping for {} when generating input arg {}",
577                        argument.decl, &struct_name
578                    )
579                });
580                input_struct.push_field(Field::new(&argument.name, Type::new(mapped_type)));
581            }
582
583            let impl_block = scope.new_impl(&struct_name).impl_trait("ToCallData");
584
585            let input_data_func = impl_block
586                .new_fn("input_data")
587                .arg_ref_self()
588                .ret(Type::new("Option<Vec<u8>>"));
589
590            if entry_point.arguments.is_empty() {
591                input_data_func.line(r#"None"#);
592            } else {
593                input_data_func
594                        .line(r#"let input_data = borsh::to_vec(&self).expect("Serialization to succeed");"#)
595                        .line(r#"Some(input_data)"#);
596            }
597        }
598
599        scope.to_string()
600    }
601}
602
603#[cfg(test)]
604mod tests {
605    use super::*;
606
607    #[test]
608    fn should_slugify_complex_type() {
609        let input = "Option<Result<(), vm2_cep18::error::Cep18Error>>";
610        let expected = "Option_Result_____vm2_cep18__error__Cep18Error__";
611
612        assert_eq!(slugify_type(input), expected);
613    }
614}