Skip to main content

endpoint_gen/
definitions.rs

1use crate::rust::ToRust;
2use convert_case::{Case, Casing};
3use endpoint_gen_macros::DefinitionVariant;
4use endpoint_libs::model::{EndpointSchema, Type};
5use itertools::Itertools;
6use serde::{Deserialize, Serialize};
7use smart_default::SmartDefault;
8use smart_serde_default::smart_serde_default;
9
10/// Marker trait for types that can be used as Definition variants
11/// All types used in Definition must implement this to ensure they are properly validatable
12// pub trait DefinitionVariant: GenElement<Self> {}
13
14#[derive(Serialize, Deserialize)]
15pub enum Definition {
16    EndpointSchema(EndpointSchemaDefinition),
17    EndpointSchemaList(EndpointSchemaListDefinition),
18    Enum(EnumElement),
19    EnumList(EnumListDefinition),
20    ErrorCodeList(ErrorCodeListDefinition),
21    Struct(StructElement),
22    StructList(StructListDefinition),
23}
24
25impl Definition {
26    pub fn validate_self(&self) -> eyre::Result<()> {
27        match self {
28            Definition::Enum(e) => e.validate_element(),
29            Definition::EnumList(list) => list.validate_element(),
30            Definition::ErrorCodeList(list) => list.validate_element(),
31            Definition::Struct(s) => s.validate_element(),
32            Definition::StructList(list) => list.validate_element(),
33            Definition::EndpointSchema(schema) => schema.validate_element(),
34            Definition::EndpointSchemaList(schemas) => schemas.validate_element(),
35        }
36    }
37}
38
39pub trait GenElement<T: ?Sized>
40where
41    T: GenElement<T>,
42{
43    fn validate_element(&self) -> eyre::Result<()>;
44}
45
46#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, PartialOrd, Eq, Ord)]
47pub struct ErrorCodeSchema {
48    pub name: String,
49    pub code: i64,
50    #[serde(default)]
51    pub description: String,
52}
53
54impl ErrorCodeSchema {
55    pub fn new(name: impl Into<String>, code: i64, description: impl Into<String>) -> Self {
56        Self {
57            name: name.into(),
58            code,
59            description: description.into(),
60        }
61    }
62}
63
64#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, PartialOrd, Eq, Ord, DefinitionVariant)]
65pub struct ErrorCodeListDefinition {
66    pub codes: Vec<ErrorCodeSchema>,
67}
68
69impl GenElement<ErrorCodeListDefinition> for ErrorCodeListDefinition {
70    fn validate_element(&self) -> eyre::Result<()> {
71        Ok(())
72    }
73}
74
75/// Wraps the [Type::Enum] variant with extra config
76#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, PartialOrd, Eq, Ord, DefinitionVariant)]
77pub struct EnumElement {
78    #[serde(default)]
79    pub config: RustGenConfig,
80    pub inner: Type,
81}
82
83impl GenElement<EnumElement> for EnumElement {
84    fn validate_element(&self) -> eyre::Result<()> {
85        match &self.inner {
86            Type::Enum { .. } => Ok(()),
87            _ => eyre::bail!("Expected enum type"),
88        }
89    }
90}
91
92#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, PartialOrd, Eq, Ord, DefinitionVariant)]
93pub struct EnumListDefinition {
94    #[serde(default)]
95    pub config: RustGenConfig,
96    pub enum_elements: Vec<EnumElement>,
97}
98
99impl GenElement<EnumListDefinition> for EnumListDefinition {
100    fn validate_element(&self) -> eyre::Result<()> {
101        if self.enum_elements.iter().all(|e| matches!(e.inner, Type::Enum { .. })) {
102            Ok(())
103        } else {
104            eyre::bail!("Not all elements of the EnumListDefinition are Enum types")
105        }
106    }
107}
108
109impl ToRust for EnumElement {
110    fn to_rust_ref(&self, _serde_with: bool) -> String {
111        self.validate_element()
112            .unwrap_or_else(|_| panic!("EnumElement is invalid: {self:?}"));
113
114        let name = match &self.inner {
115            Type::Enum { name, .. } => name.to_case(Case::Pascal),
116            _ => unreachable!("The previous validation ensured that this type is a valid Enum"),
117        };
118
119        if self.config.prefix_enum {
120            format!("Enum{name}")
121        } else {
122            name
123        }
124    }
125
126    fn to_rust_decl(&self, serde_with: bool, add_derives: bool) -> String {
127        self.validate_element()
128            .unwrap_or_else(|_| panic!("EnumElement is invalid: {self:?}"));
129
130        let code_regex = regex::Regex::new(r"=\s*(\d+)").expect("Error building regex to extract endpoint code");
131
132        match &self.inner {
133            Type::Enum {
134                name: _,
135                variants: fields,
136            } => {
137                let mut fields = fields
138                    .iter()
139                    .map(|x| {
140                        format!(
141                            r#"
142    /// {}
143    {} = {}
144"#,
145                            x.description,
146                            if x.name.chars().last().unwrap().is_lowercase() {
147                                x.name.to_case(Case::Pascal)
148                            } else {
149                                x.name.clone()
150                            },
151                            x.value
152                        )
153                    })
154                    .sorted_by(|a, b| {
155                        // Sort by the endpoint code
156                        let code_a = {
157                            match code_regex.captures(a) {
158                                Some(code) => code[1].parse::<u64>().unwrap_or_else(|err| {
159                                    eprintln!("Sorting error: {err}: Rust output may not be sorted correctly");
160                                    0
161                                }),
162                                None => {
163                                    eprintln!("Sorting error: Rust output may not be sorted correctly");
164                                    0
165                                }
166                            }
167                        };
168
169                        let code_b = {
170                            match code_regex.captures(b) {
171                                Some(code) => code[1].parse::<u64>().unwrap_or_else(|err| {
172                                    eprintln!("Sorting error: {err}: Rust output may not be sorted correctly");
173                                    0
174                                }),
175                                None => {
176                                    eprintln!("Sorting error: Rust output may not be sorted correctly");
177                                    0
178                                }
179                            }
180                        };
181
182                        code_a.cmp(&code_b)
183                    });
184                let enum_content = format!(r#"pub enum {} {{{}}}"#, self.to_rust_ref(serde_with), fields.join(","));
185
186                if add_derives {
187                    self.add_derives(enum_content)
188                } else {
189                    enum_content
190                }
191            }
192            _ => unreachable!(),
193        }
194    }
195
196    fn add_derives(&self, input: String) -> String {
197        if self.config.worktable_support {
198            format!(
199                r#"#[derive(
200                    MemStat,
201                    Archive,
202                    Clone,
203                    Copy,
204                    Debug,
205                    Display,
206                    PartialEq,
207                    PartialOrd,
208                    Eq,
209                    Hash,
210                    Ord,
211                    EnumString,
212                    rkyv::Deserialize,
213                    rkyv::Serialize,
214                    serde::Serialize,
215                    serde::Deserialize,
216                    {}
217                )]
218                #[rkyv(compare(PartialEq), derive(Debug))]
219                #[repr(u8)]
220                {input}
221            "#,
222                if self.config.json_schema_gen {
223                    "JsonSchema,"
224                } else {
225                    Default::default()
226                }
227            )
228        } else {
229            Type::add_default_enum_derives(input)
230        }
231    }
232}
233#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, PartialOrd, Eq, Ord, DefinitionVariant)]
234pub struct StructListDefinition {
235    #[serde(default)]
236    pub config: RustGenConfig,
237    pub struct_elements: Vec<StructElement>,
238}
239
240impl GenElement<StructListDefinition> for StructListDefinition {
241    fn validate_element(&self) -> eyre::Result<()> {
242        if self
243            .struct_elements
244            .iter()
245            .all(|s| matches!(s.inner, Type::Struct { .. }))
246        {
247            Ok(())
248        } else {
249            eyre::bail!("Not all elements of the StructListDefinition are Struct types")
250        }
251    }
252}
253
254/// Wraps the [Type::Struct] variant with extra config
255#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, PartialOrd, Eq, Ord, DefinitionVariant)]
256pub struct StructElement {
257    #[serde(default)]
258    pub config: RustGenConfig,
259    pub inner: Type,
260}
261
262impl GenElement<StructElement> for StructElement {
263    fn validate_element(&self) -> eyre::Result<()> {
264        match &self.inner {
265            Type::Struct { .. } => Ok(()),
266            _ => eyre::bail!("Expected struct type"),
267        }
268    }
269}
270
271impl ToRust for StructElement {
272    fn to_rust_ref(&self, _serde_with: bool) -> String {
273        self.validate_element()
274            .unwrap_or_else(|_| panic!("StructElement is invalid: {self:?}"));
275
276        match &self.inner {
277            Type::Struct { name, .. } => name.clone(),
278            _ => unreachable!("The previous validation ensured that this type is a valid Struct"),
279        }
280    }
281
282    fn to_rust_decl(&self, serde_with: bool, add_derives: bool) -> String {
283        self.validate_element()
284            .unwrap_or_else(|_| panic!("StructElement is invalid: {self:?}"));
285
286        let (name, fields) = match &self.inner {
287            Type::Struct { name, fields } => (name.to_case(Case::Pascal), fields),
288            _ => unreachable!("The previous validation ensured that this type is a valid Struct"),
289        };
290
291        let mut fields = fields.iter().map(|x| {
292            let opt = matches!(&x.ty, Type::Optional(_));
293            let serde_with_opt = match &x.ty {
294                Type::BlockchainDecimal => "rust_decimal::serde::str",
295                Type::BlockchainAddress if serde_with => "WithBlockchainAddress",
296                Type::BlockchainTransactionHash if serde_with => "WithBlockchainTransactionHash",
297                // TODO: handle optional decimals
298                // Type::Optional(t) if matches!(**t, Type::BlockchainDecimal) => {
299                //     "WithBlockchainDecimal"
300                // }
301                // Type::Optional(t) if matches!(**t, Type::BlockchainAddress) => {
302                //     "WithBlockchainAddress"
303                // }
304                // Type::Optional(t) if matches!(**t, Type::BlockchainTransactionHash) => {
305                //     "WithBlockchainTransactionHash"
306                // }
307                _ => "",
308            };
309            format!(
310                "{} {} pub {}: {}",
311                if opt { "#[serde(default)]" } else { "" },
312                if serde_with_opt.is_empty() {
313                    "".to_string()
314                } else {
315                    format!("#[serde(with = \"{serde_with_opt}\")]")
316                },
317                x.name,
318                x.ty.to_rust_ref(serde_with)
319            )
320        });
321        let input = format!("pub struct {} {{{}}}", name, fields.join(","));
322
323        if add_derives { self.add_derives(input) } else { input }
324    }
325
326    fn add_derives(&self, input: String) -> String {
327        if self.config.worktable_support {
328            // format!(
329            //     r#"#[derive(
330            //             Clone,
331            //             Copy,
332            //             Debug,
333            //             Default,
334            //             Eq,
335            //             Hash,
336            //             Ord,
337            //             PartialEq,
338            //             PartialOrd,
339            //             derive_more::Display,
340            //             derive_more::From,
341            //             derive_more::FromStr,
342            //             derive_more::Into,
343            //             MemStat,
344            //             rkyv::Archive,
345            //             SizeMeasure,
346            //             rkyv::Deserialize,
347            //             rkyv::Serialize,
348            //             serde::Serialize,
349            //             serde::Deserialize,
350            //             derive_more::Display,
351            //         )]
352            //         #[rkyv(compare(PartialEq), derive(Debug, PartialOrd, PartialEq, Eq, Ord))]
353            //         {input}
354            //     "#
355
356            // TODO: Fix worktable support for structs
357            format!(
358                r#" #[derive(Serialize, Deserialize, Debug, Clone, {})]
359                #[serde(rename_all = "camelCase")]
360                {input}
361            "#,
362                if self.config.json_schema_gen {
363                    "JsonSchema,"
364                } else {
365                    Default::default()
366                }
367            )
368        } else {
369            Type::add_default_struct_derives(input)
370        }
371    }
372}
373
374#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, PartialOrd, Eq, Ord, Default)]
375pub struct RustGenConfig {
376    #[serde(default)]
377    pub prefix_enum: bool,
378    #[serde(default)]
379    pub worktable_support: bool,
380    #[serde(default)]
381    pub json_schema_gen: bool,
382    #[serde(default)]
383    pub snake_case_fields: bool,
384    #[serde(default)]
385    pub override_parent: bool,
386}
387
388#[derive(Serialize, Deserialize, Clone)]
389pub struct GenService {
390    pub name: String,
391    pub id: u16,
392    pub endpoints: Vec<EndpointSchemaElement>,
393}
394
395impl GenService {
396    pub fn new(name: String, id: u16, endpoints: Vec<EndpointSchemaElement>) -> Self {
397        Self { name, id, endpoints }
398    }
399}
400
401#[derive(Serialize, Deserialize)]
402pub struct EndpointSchemaDefinition {
403    pub service_name: String,
404    pub service_id: u16,
405    pub schema: EndpointSchemaElement,
406}
407
408#[smart_serde_default]
409#[derive(Serialize, Deserialize, SmartDefault, Clone)]
410pub struct EndpointSchemaElement {
411    #[smart_default(true)]
412    pub frontend_facing: bool,
413    #[serde(default)]
414    pub config: RustGenConfig,
415    pub schema: EndpointSchema,
416}
417
418impl From<EndpointSchemaElement> for EndpointSchema {
419    fn from(val: EndpointSchemaElement) -> Self {
420        EndpointSchema {
421            name: val.schema.name,
422            code: val.schema.code,
423            parameters: val.schema.parameters,
424            returns: val.schema.returns,
425            stream_response: val.schema.stream_response,
426            description: val.schema.description,
427            json_schema: val.schema.json_schema,
428            roles: val.schema.roles,
429            errors: val.schema.errors,
430        }
431    }
432}
433
434impl FromIterator<EndpointSchemaElement> for Vec<EndpointSchema> {
435    fn from_iter<T: IntoIterator<Item = EndpointSchemaElement>>(iter: T) -> Self {
436        iter.into_iter().map(|element| element.schema).collect()
437    }
438}
439
440impl GenElement<EndpointSchemaDefinition> for EndpointSchemaDefinition {
441    fn validate_element(&self) -> eyre::Result<()> {
442        Ok(())
443    }
444}
445
446#[derive(Serialize, Deserialize, DefinitionVariant)]
447pub struct EndpointSchemaListDefinition {
448    pub service_name: String,
449    pub service_id: u16,
450    #[serde(default)]
451    pub config: RustGenConfig,
452    pub endpoints: Vec<EndpointSchemaElement>,
453}
454
455impl GenElement<EndpointSchemaListDefinition> for EndpointSchemaListDefinition {
456    fn validate_element(&self) -> eyre::Result<()> {
457        Ok(())
458    }
459}