1use std::{collections::BTreeSet, env, path::PathBuf};
7
8pub use casper_contract_schema;
9use casper_contract_schema::{
10 Access, Argument, CallMethod, ContractSchema, CustomType, Entrypoint, EnumVariant, Event,
11 NamedCLType, StructMember, UserError
12};
13
14use convert_case::{Boundary, Case, Casing};
15
16use odra_core::args::EntrypointArgument;
17
18const CCSV: u8 = 1;
19
20mod custom_type;
21mod ty;
22
23pub use ty::NamedCLTyped;
24
25pub trait SchemaEntrypoints {
27 fn schema_entrypoints() -> Vec<Entrypoint>;
29}
30
31pub trait SchemaEvents {
33 fn schema_events() -> Vec<Event> {
35 vec![]
36 }
37
38 fn custom_types() -> Vec<Option<CustomType>> {
43 vec![]
44 }
45}
46
47pub trait SchemaCustomTypes {
49 fn schema_types() -> Vec<Option<CustomType>> {
51 vec![]
52 }
53}
54
55pub trait SchemaErrors {
57 fn schema_errors() -> Vec<UserError> {
59 vec![]
60 }
61}
62
63pub trait SchemaCustomElement {}
65
66impl<T: SchemaCustomElement> SchemaErrors for T {}
67impl<T: SchemaCustomElement> SchemaEvents for T {}
68
69pub fn argument<T: NamedCLTyped + EntrypointArgument>(name: &str) -> Argument {
71 if T::is_required() {
72 Argument::new(name, "", <T as NamedCLTyped>::ty())
73 } else {
74 Argument::new_opt(name, "", <T as NamedCLTyped>::ty())
75 }
76}
77
78pub fn entry_point<T: NamedCLTyped>(
80 name: &str,
81 description: &str,
82 is_mutable: bool,
83 arguments: Vec<Argument>
84) -> Entrypoint {
85 Entrypoint {
86 name: name.into(),
87 description: Some(description.to_string()),
88 is_mutable,
89 arguments,
90 return_ty: T::ty().into(),
91 is_contract_context: true,
92 access: Access::Public
93 }
94}
95
96pub fn struct_member<T: NamedCLTyped>(name: &str) -> StructMember {
98 StructMember {
99 name: name.to_string(),
100 description: None,
101 ty: T::ty().into()
102 }
103}
104
105pub fn enum_typed_variant<T: NamedCLTyped>(name: &str, discriminant: u16) -> EnumVariant {
107 EnumVariant {
108 name: name.to_string(),
109 description: None,
110 discriminant,
111 ty: T::ty().into()
112 }
113}
114
115pub fn enum_variant(name: &str, discriminant: u16) -> EnumVariant {
117 enum_typed_variant::<()>(name, discriminant)
118}
119
120pub fn enum_custom_type_variant(name: &str, discriminant: u16, custom_type: &str) -> EnumVariant {
122 EnumVariant {
123 name: name.to_string(),
124 description: None,
125 discriminant,
126 ty: NamedCLType::Custom(custom_type.into()).into()
127 }
128}
129
130pub fn custom_struct(name: &str, members: Vec<StructMember>) -> CustomType {
132 CustomType::Struct {
133 name: name.into(),
134 description: None,
135 members
136 }
137}
138
139pub fn custom_enum(name: &str, variants: Vec<EnumVariant>) -> CustomType {
141 CustomType::Enum {
142 name: name.into(),
143 description: None,
144 variants
145 }
146}
147
148pub fn event(name: &str) -> Event {
150 Event {
151 name: name.into(),
152 ty: name.into()
153 }
154}
155
156pub fn error(name: &str, description: &str, discriminant: u16) -> UserError {
158 UserError {
159 name: name.into(),
160 description: Some(description.into()),
161 discriminant
162 }
163}
164
165pub fn schema<T: SchemaEntrypoints + SchemaEvents + SchemaCustomTypes + SchemaErrors>(
170 module_name: &str,
171 contract_name: &str,
172 contract_version: &str,
173 authors: Vec<String>,
174 repository: &str,
175 homepage: &str
176) -> ContractSchema {
177 let entry_points = T::schema_entrypoints();
178 let events = T::schema_events();
179 let errors = T::schema_errors();
180 let types = BTreeSet::from_iter(T::schema_types())
181 .into_iter()
182 .flatten()
183 .collect();
184
185 let init_ep = entry_points.iter().find(|e| e.name == "init");
186
187 let init_args = init_ep.map(|e| e.arguments.clone()).unwrap_or_default();
188
189 let init_description = init_ep.and_then(|e| e.description.clone());
190
191 let entry_points = entry_points
192 .into_iter()
193 .filter(|e| e.name != "init" && e.name != "upgrade")
194 .collect();
195
196 let wasm_file_name = format!("{}.wasm", module_name);
197
198 let repository = match repository {
199 "" => None,
200 _ => Some(repository.to_string())
201 };
202
203 let homepage = match homepage {
204 "" => None,
205 _ => Some(homepage.to_string())
206 };
207
208 ContractSchema {
209 casper_contract_schema_version: CCSV,
210 toolchain: env!("RUSTC_VERSION").to_string(),
211 contract_name: contract_name.to_string(),
212 contract_version: contract_version.to_string(),
213 types,
214 entry_points,
215 events,
216 call: Some(call_method(wasm_file_name, init_description, &init_args)),
217 authors,
218 repository,
219 homepage,
220 errors
221 }
222}
223
224pub fn find_schema_file_path(
226 contract_name: &str,
227 root_path: PathBuf
228) -> Result<PathBuf, &'static str> {
229 let mut path = root_path
230 .join(format!("{}_schema.json", camel_to_snake(contract_name)))
231 .with_extension("json");
232
233 let mut checked_paths = vec![];
234 for _ in 0..2 {
235 if path.exists() && path.is_file() {
236 return Ok(path);
237 } else {
238 checked_paths.push(path.clone());
239 path = path.parent().unwrap().to_path_buf();
240 }
241 }
242 Err("Schema not found")
243}
244
245pub fn find_schemas_file_paths(root_path: PathBuf) -> Result<Vec<PathBuf>, &'static str> {
247 let path = root_path;
248 if path.exists() && path.is_dir() {
249 let mut paths = vec![];
250 for entry in path.read_dir().map_err(|_| "Failed to read directory")? {
251 let entry = entry.map_err(|_| "Failed to read directory entry")?;
252 if entry.path().is_file() && entry.path().extension() == Some("json".as_ref()) {
253 paths.push(entry.path());
254 }
255 }
256 return Ok(paths);
257 }
258 Err("Schemas not found")
259}
260
261fn call_method(
262 file_name: String,
263 description: Option<String>,
264 constructor_args: &[Argument]
265) -> CallMethod {
266 CallMethod {
267 wasm_file_name: file_name.to_string(),
268 description,
269 arguments: vec![
270 Argument {
271 name: odra_core::consts::PACKAGE_HASH_KEY_NAME_ARG.to_string(),
272 description: Some("The arg name for the package hash key name.".to_string()),
273 ty: NamedCLType::String.into(),
274 optional: false
275 },
276 Argument {
277 name: odra_core::consts::ALLOW_KEY_OVERRIDE_ARG.to_string(),
278 description: Some("If true and the key specified in odra_cfg_package_hash_key_name already exists, it will be overwritten.".to_string()),
279 ty: NamedCLType::Bool.into(),
280 optional: false
281 },
282 Argument {
283 name: odra_core::consts::IS_UPGRADABLE_ARG.to_string(),
284 description: Some(
285 "The arg name for the contract upgradeability setting.".to_string()
286 ),
287 ty: NamedCLType::Bool.into(),
288 optional: false
289 },
290 Argument {
291 name: odra_core::consts::IS_UPGRADE_ARG.to_string(),
292 description: Some(
293 "The arg name for telling the installer that the contract is being upgraded.".to_string()
294 ),
295 ty: NamedCLType::Bool.into(),
296 optional: false
297 },
298 ]
299 .iter()
300 .chain(constructor_args.iter())
301 .cloned()
302 .collect()
303 }
304}
305
306pub fn camel_to_snake<T: ToString>(text: T) -> String {
308 text.to_string()
309 .from_case(Case::UpperCamel)
310 .without_boundaries(&[Boundary::UpperDigit, Boundary::LowerDigit])
311 .to_case(Case::Snake)
312}
313
314#[cfg(test)]
315mod test {
316 use odra_core::args::Maybe;
317 use odra_core::prelude::Address;
318
319 use super::*;
320
321 #[test]
322 fn test_argument() {
323 let arg = super::argument::<u32>("arg1");
324 assert_eq!(arg.name, "arg1");
325 assert_eq!(arg.ty, casper_contract_schema::NamedCLType::U32.into());
326 }
327
328 #[test]
329 fn test_opt_argument() {
330 let arg = super::argument::<Maybe<u32>>("arg1");
331 assert_eq!(arg.name, "arg1");
332 assert_eq!(arg.ty, casper_contract_schema::NamedCLType::U32.into());
333 }
334
335 #[test]
336 fn test_entry_point() {
337 let arg = super::argument::<u32>("arg1");
338 let entry_point = super::entry_point::<u32>("entry1", "description", true, vec![arg]);
339 assert_eq!(entry_point.name, "entry1");
340 assert_eq!(entry_point.description, Some("description".to_string()));
341 assert!(entry_point.is_mutable);
342 assert_eq!(entry_point.arguments.len(), 1);
343 assert_eq!(
344 entry_point.return_ty,
345 casper_contract_schema::NamedCLType::U32.into()
346 );
347 }
348
349 #[test]
350 fn test_struct_member() {
351 let member = super::struct_member::<u32>("member1");
352 assert_eq!(member.name, "member1");
353 assert_eq!(member.ty, casper_contract_schema::NamedCLType::U32.into());
354 }
355
356 #[test]
357 fn test_enum_typed_variant() {
358 let variant = super::enum_typed_variant::<Address>("variant1", 1);
359 assert_eq!(variant.name, "variant1");
360 assert_eq!(variant.discriminant, 1);
361 assert_eq!(variant.ty, casper_contract_schema::NamedCLType::Key.into());
362 }
363
364 #[test]
365 fn test_enum_variant() {
366 let variant = super::enum_variant("variant1", 1);
367 assert_eq!(variant.name, "variant1");
368 assert_eq!(variant.discriminant, 1);
369 assert_eq!(variant.ty, casper_contract_schema::NamedCLType::Unit.into());
370 }
371
372 #[test]
373 fn test_custom_struct() {
374 let member = super::struct_member::<u32>("member1");
375 let custom_struct = super::custom_struct("struct1", vec![member]);
376 match custom_struct {
377 casper_contract_schema::CustomType::Struct { name, members, .. } => {
378 assert_eq!(name, "struct1".into());
379 assert_eq!(members.len(), 1);
380 }
381 _ => panic!("Expected CustomType::Struct")
382 }
383 }
384
385 #[test]
386 fn test_custom_enum() {
387 let variant1 = super::enum_variant("variant1", 1);
388 let variant2 = super::enum_typed_variant::<String>("v2", 2);
389 let variant3 = super::enum_custom_type_variant("v3", 3, "Type1");
390 let custom_enum = super::custom_enum("enum1", vec![variant1, variant2, variant3]);
391 match custom_enum {
392 casper_contract_schema::CustomType::Enum { name, variants, .. } => {
393 assert_eq!(name, "enum1".into());
394 assert_eq!(variants.len(), 3);
395 assert_eq!(variants[0].ty, NamedCLType::Unit.into());
396 assert_eq!(variants[1].ty, NamedCLType::String.into());
397 assert_eq!(variants[2].ty, NamedCLType::Custom("Type1".into()).into());
398 }
399 _ => panic!("Expected CustomType::Enum")
400 }
401 }
402
403 #[test]
404 fn test_event() {
405 let event = super::event("event1");
406 assert_eq!(event.name, "event1");
407 }
408
409 #[test]
410 fn test_error() {
411 let error = super::error("error1", "description", 1);
412 assert_eq!(error.name, "error1");
413 assert_eq!(error.description, Some("description".to_string()));
414 assert_eq!(error.discriminant, 1);
415 }
416
417 #[test]
418 fn test_schema() {
419 struct TestSchema;
420
421 impl SchemaEntrypoints for TestSchema {
422 fn schema_entrypoints() -> Vec<Entrypoint> {
423 vec![entry_point::<u32>(
424 "entry1",
425 "description",
426 true,
427 vec![super::argument::<u32>("arg1")]
428 )]
429 }
430 }
431
432 impl SchemaEvents for TestSchema {
433 fn schema_events() -> Vec<Event> {
434 vec![event("event1")]
435 }
436 }
437
438 impl SchemaCustomTypes for TestSchema {
439 fn schema_types() -> Vec<Option<CustomType>> {
440 vec![
441 Some(custom_struct(
442 "struct1",
443 vec![struct_member::<u32>("member1")]
444 )),
445 Some(custom_enum("struct1", vec![enum_variant("variant1", 1)])),
446 ]
447 }
448 }
449
450 impl SchemaErrors for TestSchema {
451 fn schema_errors() -> Vec<UserError> {
452 vec![]
453 }
454 }
455
456 let schema = super::schema::<TestSchema>(
457 "module_name",
458 "contract_name",
459 "contract_version",
460 vec!["author".to_string()],
461 "repository",
462 "homepage"
463 );
464
465 assert_eq!(schema.contract_name, "contract_name");
466 assert_eq!(schema.contract_version, "contract_version");
467 assert_eq!(schema.authors, vec!["author".to_string()]);
468 assert_eq!(schema.repository, Some("repository".to_string()));
469 assert_eq!(schema.homepage, Some("homepage".to_string()));
470 assert_eq!(schema.entry_points.len(), 1);
471 assert_eq!(schema.types.len(), 2);
472 assert_eq!(schema.errors.len(), 0);
473 assert_eq!(schema.events.len(), 1);
474 }
475}