casper_contract_sdk_codegen/
lib.rs1pub mod support;
2
3use casper_contract_sdk::{
4 abi::{Declaration, Definition, Primitive},
5 casper_executor_wasm_common::flags::EntryPointFlags,
6 schema::{Schema, SchemaType},
7};
8use codegen::{Field, Scope, Type};
9use indexmap::IndexMap;
10use serde::{Deserialize, Serialize};
11use std::{
12 collections::{BTreeMap, VecDeque},
13 iter,
14 str::FromStr,
15};
16
17const DEFAULT_DERIVED_TRAITS: &[&str] = &[
18 "Clone",
19 "Debug",
20 "PartialEq",
21 "Eq",
22 "PartialOrd",
23 "Ord",
24 "Hash",
25 "BorshSerialize",
26 "BorshDeserialize",
27];
28
29fn slugify_type(input: &str) -> String {
31 let mut output = String::with_capacity(input.len());
32
33 for c in input.chars() {
34 if c.is_ascii_alphanumeric() {
35 output.push(c);
36 } else {
37 output.push('_');
38 }
39 }
40
41 output
42}
43
44#[derive(Debug, Deserialize, Serialize)]
45enum Specialized {
46 Result { ok: Declaration, err: Declaration },
47 Option { some: Declaration },
48}
49
50#[derive(Deserialize, Serialize)]
51pub struct Codegen {
52 schema: Schema,
53 type_mapping: BTreeMap<Declaration, String>,
54 specialized_types: BTreeMap<Declaration, Specialized>,
55}
56
57impl FromStr for Codegen {
58 type Err = serde_json::Error;
59
60 fn from_str(s: &str) -> Result<Self, Self::Err> {
61 let schema: Schema = serde_json::from_str(s)?;
62 Ok(Self::new(schema))
63 }
64}
65
66impl Codegen {
67 pub fn new(schema: Schema) -> Self {
68 Self {
69 schema,
70 type_mapping: Default::default(),
71 specialized_types: Default::default(),
72 }
73 }
74
75 pub fn from_file(path: &str) -> Result<Self, std::io::Error> {
76 let file = std::fs::File::open(path)?;
77 let schema: Schema = serde_json::from_reader(file)?;
78 Ok(Self::new(schema))
79 }
80
81 pub fn gen(&mut self) -> String {
82 let mut scope = Scope::new();
83
84 scope.import("borsh", "self");
85 scope.import("borsh", "BorshSerialize");
86 scope.import("borsh", "BorshDeserialize");
87 scope.import("casper_contract_sdk_codegen::support", "IntoResult");
88 scope.import("casper_contract_sdk_codegen::support", "IntoOption");
89 scope.import("casper_contract_sdk", "Selector");
90 scope.import("casper_contract_sdk", "ToCallData");
91
92 let _head = self
93 .schema
94 .definitions
95 .first()
96 .expect("No definitions found.");
97
98 match &self.schema.type_ {
99 SchemaType::Contract { state } => {
100 if !self.schema.definitions.has_definition(state) {
101 panic!(
102 "Missing state definition. Expected to find a definition for {}.",
103 &state
104 )
105 };
106 }
107 SchemaType::Interface => {}
108 }
109
110 let mut queue = VecDeque::new();
112
113 let mut processed = std::collections::HashSet::new();
115
116 let mut graph: IndexMap<_, VecDeque<_>> = IndexMap::new();
117
118 for (def_index, (next_decl, next_def)) in self.schema.definitions.iter().enumerate() {
119 println!(
120 "{def_index}. decl={decl}",
121 def_index = def_index,
122 decl = next_decl
123 );
124
125 queue.push_back(next_decl);
126
127 while let Some(decl) = queue.pop_front() {
128 if processed.contains(decl) {
129 continue;
130 }
131
132 processed.insert(decl);
133 graph.entry(next_decl).or_default().push_back(decl);
134 match Primitive::from_str(decl) {
137 Ok(primitive) => {
138 println!("Processing primitive type {primitive:?}");
139 continue;
140 }
141 Err(_) => {
142 }
144 };
145
146 let def = self
147 .schema
148 .definitions
149 .get(decl)
150 .unwrap_or_else(|| panic!("Missing definition for {}", decl));
151
152 match def {
157 Definition::Primitive(_primitive) => {
158 continue;
159 }
160 Definition::Mapping { key, value } => {
161 if !processed.contains(key) {
162 queue.push_front(key);
163 continue;
164 }
165
166 if !processed.contains(value) {
167 queue.push_front(value);
168 continue;
169 }
170 }
171 Definition::Sequence { decl } => {
172 queue.push_front(decl);
173 }
174 Definition::FixedSequence { length: _, decl } => {
175 if !processed.contains(decl) {
176 queue.push_front(decl);
177 continue;
178 }
179 }
180 Definition::Tuple { items } => {
181 for item in items {
182 if !processed.contains(item) {
183 queue.push_front(item);
184 continue;
185 }
186 }
187
188 }
190 Definition::Enum { items } => {
191 for item in items {
192 if !processed.contains(&item.decl) {
193 queue.push_front(&item.decl);
194 continue;
195 }
196 }
197 }
198 Definition::Struct { items } => {
199 for item in items {
200 if !processed.contains(&item.decl) {
201 queue.push_front(&item.decl);
202 continue;
203 }
204 }
205 }
206 }
207 }
208
209 match next_def {
210 Definition::Primitive(_) => {}
211 Definition::Mapping { key, value } => {
212 assert!(processed.contains(key));
213 assert!(processed.contains(value));
214 }
215 Definition::Sequence { decl } => {
216 assert!(processed.contains(decl));
217 }
218 Definition::FixedSequence { length: _, decl } => {
219 assert!(processed.contains(decl));
220 }
221 Definition::Tuple { items } => {
222 for item in items {
223 assert!(processed.contains(&item));
224 }
225 }
226 Definition::Enum { items } => {
227 for item in items {
228 assert!(processed.contains(&item.decl));
229 }
230 }
231 Definition::Struct { items } => {
232 for item in items {
233 assert!(processed.contains(&item.decl));
234 }
235 }
236 }
237 }
238 dbg!(&graph);
239
240 let mut counter = iter::successors(Some(0usize), |prev| prev.checked_add(1));
241
242 for (_decl, deps) in graph {
243 for decl in deps.into_iter().rev() {
244 let def = self
247 .schema
248 .definitions
249 .get(decl)
250 .cloned()
251 .or_else(|| Primitive::from_str(decl).ok().map(Definition::Primitive))
252 .unwrap_or_else(|| panic!("Missing definition for {}", decl));
253
254 match def {
255 Definition::Primitive(primitive) => {
256 let (from, to) = match primitive {
257 Primitive::Char => ("Char", "char"),
258 Primitive::U8 => ("U8", "u8"),
259 Primitive::I8 => ("I8", "i8"),
260 Primitive::U16 => ("U16", "u16"),
261 Primitive::I16 => ("I16", "i16"),
262 Primitive::U32 => ("U32", "u32"),
263 Primitive::I32 => ("I32", "i32"),
264 Primitive::U64 => ("U64", "u64"),
265 Primitive::I64 => ("I64", "i64"),
266 Primitive::U128 => ("U128", "u128"),
267 Primitive::I128 => ("I128", "i128"),
268 Primitive::Bool => ("Bool", "bool"),
269 Primitive::F32 => ("F32", "f32"),
270 Primitive::F64 => ("F64", "f64"),
271 };
272
273 scope.new_type_alias(from, to).vis("pub");
274 self.type_mapping.insert(decl.to_string(), from.to_string());
275 }
276 Definition::Mapping { key: _, value: _ } => {
277 todo!()
279 }
280 Definition::Sequence { decl: seq_decl } => {
281 println!("Processing sequence type {decl:?}");
282 if decl.as_str() == "String"
283 && Primitive::from_str(&seq_decl) == Ok(Primitive::Char)
284 {
285 self.type_mapping
286 .insert("String".to_owned(), "String".to_owned());
287 } else {
288 let mapped_type = self
289 .type_mapping
290 .get(&seq_decl)
291 .unwrap_or_else(|| panic!("Missing type mapping for {}", seq_decl));
292 let type_name =
293 format!("Sequence{}_{seq_decl}", counter.next().unwrap());
294 scope.new_type_alias(&type_name, format!("Vec<{}>", mapped_type));
295 self.type_mapping.insert(decl.to_string(), type_name);
296 }
297 }
298 Definition::FixedSequence {
299 length,
300 decl: fixed_seq_decl,
301 } => {
302 let mapped_type =
303 self.type_mapping.get(&fixed_seq_decl).unwrap_or_else(|| {
304 panic!("Missing type mapping for {}", fixed_seq_decl)
305 });
306
307 let type_name = format!(
308 "FixedSequence{}_{length}_{fixed_seq_decl}",
309 counter.next().unwrap()
310 );
311 scope.new_type_alias(&type_name, format!("[{}; {}]", mapped_type, length));
312 self.type_mapping.insert(decl.to_string(), type_name);
313 }
314 Definition::Tuple { items } => {
315 if decl.as_str() == "()" && items.is_empty() {
316 self.type_mapping.insert("()".to_owned(), "()".to_owned());
317 continue;
318 }
319
320 println!("Processing tuple type {items:?}");
321 let struct_name = slugify_type(decl);
322
323 let r#struct = scope
324 .new_struct(&struct_name)
325 .doc(&format!("Declared as {decl}"));
326
327 for trait_name in DEFAULT_DERIVED_TRAITS {
328 r#struct.derive(trait_name);
329 }
330
331 if items.is_empty() {
332 r#struct.tuple_field(Type::new("()"));
333 } else {
334 for item in items {
335 let mapped_type = self
336 .type_mapping
337 .get(&item)
338 .unwrap_or_else(|| panic!("Missing type mapping for {}", item));
339 r#struct.tuple_field(mapped_type);
340 }
341 }
342
343 self.type_mapping.insert(decl.to_string(), struct_name);
344 }
345 Definition::Enum { items } => {
346 println!("Processing enum type {decl} {items:?}");
347
348 let mut items: Vec<&casper_contract_sdk::abi::EnumVariant> =
349 items.iter().collect();
350
351 let mut specialized = None;
352
353 if decl.starts_with("Result")
354 && items.len() == 2
355 && items[0].name == "Ok"
356 && items[1].name == "Err"
357 {
358 specialized = Some(Specialized::Result {
359 ok: items[0].decl.clone(),
360 err: items[1].decl.clone(),
361 });
362
363 items.reverse();
371 }
372
373 if decl.starts_with("Option")
374 && items.len() == 2
375 && items[0].name == "None"
376 && items[1].name == "Some"
377 {
378 specialized = Some(Specialized::Option {
379 some: items[1].decl.clone(),
380 });
381
382 items.reverse();
383 }
384
385 let enum_name = slugify_type(decl);
386
387 let r#enum = scope
388 .new_enum(&enum_name)
389 .vis("pub")
390 .doc(&format!("Declared as {decl}"));
391
392 for trait_name in DEFAULT_DERIVED_TRAITS {
393 r#enum.derive(trait_name);
394 }
395
396 for item in &items {
397 let variant = r#enum.new_variant(&item.name);
398
399 let def = self.type_mapping.get(&item.decl).unwrap_or_else(|| {
400 panic!("Missing type mapping for {}", item.decl)
401 });
402
403 variant.tuple(def);
404 }
405
406 self.type_mapping
407 .insert(decl.to_string(), enum_name.to_owned());
408
409 match specialized {
410 Some(Specialized::Result { ok, err }) => {
411 let ok_type = self
412 .type_mapping
413 .get(&ok)
414 .unwrap_or_else(|| panic!("Missing type mapping for {}", ok));
415 let err_type = self
416 .type_mapping
417 .get(&err)
418 .unwrap_or_else(|| panic!("Missing type mapping for {}", err));
419
420 let impl_block = scope
421 .new_impl(&enum_name)
422 .impl_trait(format!("IntoResult<{ok_type}, {err_type}>"));
423
424 let func = impl_block.new_fn("into_result").arg_self().ret(
425 Type::new(format!(
426 "Result<{ok_type}, {err_type}>",
427 ok_type = ok_type,
428 err_type = err_type
429 )),
430 );
431 func.line("match self {")
432 .line(format!("{enum_name}::Ok(ok) => Ok(ok),"))
433 .line(format!("{enum_name}::Err(err) => Err(err),"))
434 .line("}");
435 }
436 Some(Specialized::Option { some }) => {
437 let some_type = self.type_mapping.get(&some).unwrap_or_else(|| {
438 panic!("Missing type mapping for {}", &some)
439 });
440
441 let impl_block = scope
442 .new_impl(&enum_name)
443 .impl_trait(format!("IntoOption<{some_type}>"));
444
445 let func = impl_block
446 .new_fn("into_option")
447 .arg_self()
448 .ret(Type::new(format!("Option<{some_type}>",)));
449 func.line("match self {")
450 .line(format!("{enum_name}::None => None,"))
451 .line(format!("{enum_name}::Some(some) => Some(some),"))
452 .line("}");
453 }
454 None => {}
455 }
456 }
457 Definition::Struct { items } => {
458 println!("Processing struct type {items:?}");
459
460 let type_name = slugify_type(decl);
461
462 let r#struct = scope.new_struct(&type_name);
463
464 for trait_name in DEFAULT_DERIVED_TRAITS {
465 r#struct.derive(trait_name);
466 }
467
468 for item in items {
469 let mapped_type =
470 self.type_mapping.get(&item.decl).unwrap_or_else(|| {
471 panic!("Missing type mapping for {}", item.decl)
472 });
473 let field = Field::new(&item.name, Type::new(mapped_type))
474 .doc(format!("Declared as {}", item.decl))
475 .to_owned();
476
477 r#struct.push_field(field);
478 }
479 self.type_mapping.insert(decl.to_string(), type_name);
480 }
481 }
482 }
483 }
484
485 let struct_name = format!("{}Client", self.schema.name);
486 let client = scope.new_struct(&struct_name).vis("pub");
487
488 for trait_name in DEFAULT_DERIVED_TRAITS {
489 client.derive(trait_name);
490 }
491
492 let mut field = Field::new("address", Type::new("[u8; 32]"));
493 field.vis("pub");
494
495 client.push_field(field);
496
497 let client_impl = scope.new_impl(&struct_name);
498
499 for entry_point in &self.schema.entry_points {
500 let func = client_impl.new_fn(&entry_point.name);
501 func.vis("pub");
502
503 let result_type = self
504 .type_mapping
505 .get(&entry_point.result)
506 .unwrap_or_else(|| panic!("Missing type mapping for {}", entry_point.result));
507
508 if entry_point.flags.contains(EntryPointFlags::CONSTRUCTOR) {
509 func.ret(Type::new(format!(
510 "Result<{}, casper_contract_sdk::types::CallError>",
511 &struct_name
512 )))
513 .generic("C")
514 .bound("C", "casper_contract_sdk::Contract");
515 } else {
516 func.ret(Type::new(format!(
517 "Result<casper_contract_sdk::host::CallResult<{result_type}>, casper_contract_sdk::types::CallError>"
518 )));
519 func.arg_ref_self();
520 }
521
522 for arg in &entry_point.arguments {
523 let mapped_type = self
524 .type_mapping
525 .get(&arg.decl)
526 .unwrap_or_else(|| panic!("Missing type mapping for {}", arg.decl));
527 let arg_ty = Type::new(mapped_type);
528 func.arg(&arg.name, arg_ty);
529 }
530
531 func.line("let value = 0; // TODO: Transferring values");
532
533 let input_struct_name =
534 format!("{}_{}", slugify_type(&self.schema.name), &entry_point.name);
535
536 if entry_point.arguments.is_empty() {
537 func.line(format!(r#"let call_data = {input_struct_name};"#));
538 } else {
539 func.line(format!(r#"let call_data = {input_struct_name} {{ "#));
540 for arg in &entry_point.arguments {
541 func.line(format!("{},", arg.name));
542 }
543 func.line("};");
544 }
545
546 if entry_point.flags.contains(EntryPointFlags::CONSTRUCTOR) {
547 func.line(r#"let create_result = C::create(call_data)?;"#);
551 func.line(format!(
554 r#"let result = {struct_name} {{ address: create_result.contract_address }};"#,
555 struct_name = &struct_name
556 ));
557 func.line("Ok(result)");
558 continue;
559 } else {
560 func.line(r#"casper_contract_sdk::host::call(&self.address, value, call_data)"#);
561 }
562 }
563
564 for entry_point in &self.schema.entry_points {
565 let struct_name = format!("{}_{}", &self.schema.name, &entry_point.name);
567 let input_struct = scope.new_struct(&struct_name);
568
569 for trait_name in DEFAULT_DERIVED_TRAITS {
570 input_struct.derive(trait_name);
571 }
572
573 for argument in &entry_point.arguments {
574 let mapped_type = self.type_mapping.get(&argument.decl).unwrap_or_else(|| {
575 panic!(
576 "Missing type mapping for {} when generating input arg {}",
577 argument.decl, &struct_name
578 )
579 });
580 input_struct.push_field(Field::new(&argument.name, Type::new(mapped_type)));
581 }
582
583 let impl_block = scope.new_impl(&struct_name).impl_trait("ToCallData");
584
585 let input_data_func = impl_block
586 .new_fn("input_data")
587 .arg_ref_self()
588 .ret(Type::new("Option<Vec<u8>>"));
589
590 if entry_point.arguments.is_empty() {
591 input_data_func.line(r#"None"#);
592 } else {
593 input_data_func
594 .line(r#"let input_data = borsh::to_vec(&self).expect("Serialization to succeed");"#)
595 .line(r#"Some(input_data)"#);
596 }
597 }
598
599 scope.to_string()
600 }
601}
602
603#[cfg(test)]
604mod tests {
605 use super::*;
606
607 #[test]
608 fn should_slugify_complex_type() {
609 let input = "Option<Result<(), vm2_cep18::error::Cep18Error>>";
610 let expected = "Option_Result_____vm2_cep18__error__Cep18Error__";
611
612 assert_eq!(slugify_type(input), expected);
613 }
614}