Skip to main content

endpoint_gen/
definitions.rs

1use crate::rust::ToRust;
2use convert_case::{Case, Casing};
3use endpoint_gen_macros::DefinitionVariant;
4use endpoint_libs::model::{EndpointSchema, Type};
5use itertools::Itertools;
6use serde::{Deserialize, Serialize};
7use smart_default::SmartDefault;
8use smart_serde_default::smart_serde_default;
9
10/// Marker trait for types that can be used as Definition variants
11/// All types used in Definition must implement this to ensure they are properly validatable
12// pub trait DefinitionVariant: GenElement<Self> {}
13
14#[derive(Serialize, Deserialize)]
15pub enum Definition {
16    EndpointSchema(EndpointSchemaDefinition),
17    EndpointSchemaList(EndpointSchemaListDefinition),
18    Enum(EnumElement),
19    EnumList(EnumListDefinition),
20    Struct(StructElement),
21    StructList(StructListDefinition),
22}
23
24impl Definition {
25    pub fn validate_self(&self) -> eyre::Result<()> {
26        match self {
27            Definition::Enum(e) => e.validate_element(),
28            Definition::EnumList(list) => list.validate_element(),
29            Definition::Struct(s) => s.validate_element(),
30            Definition::StructList(list) => list.validate_element(),
31            Definition::EndpointSchema(schema) => schema.validate_element(),
32            Definition::EndpointSchemaList(schemas) => schemas.validate_element(),
33        }
34    }
35}
36
37pub trait GenElement<T: ?Sized>
38where
39    T: GenElement<T>,
40{
41    fn validate_element(&self) -> eyre::Result<()>;
42}
43
44/// Wraps the [Type::Enum] variant with extra config
45#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, PartialOrd, Eq, Ord, DefinitionVariant)]
46pub struct EnumElement {
47    #[serde(default)]
48    pub config: RustGenConfig,
49    pub inner: Type,
50}
51
52impl GenElement<EnumElement> for EnumElement {
53    fn validate_element(&self) -> eyre::Result<()> {
54        match &self.inner {
55            Type::Enum { .. } => Ok(()),
56            _ => eyre::bail!("Expected enum type"),
57        }
58    }
59}
60
61#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, PartialOrd, Eq, Ord, DefinitionVariant)]
62pub struct EnumListDefinition {
63    #[serde(default)]
64    pub config: RustGenConfig,
65    pub enum_elements: Vec<EnumElement>,
66}
67
68impl GenElement<EnumListDefinition> for EnumListDefinition {
69    fn validate_element(&self) -> eyre::Result<()> {
70        if self.enum_elements.iter().all(|e| matches!(e.inner, Type::Enum { .. })) {
71            Ok(())
72        } else {
73            eyre::bail!("Not all elements of the EnumListDefinition are Enum types")
74        }
75    }
76}
77
78impl ToRust for EnumElement {
79    fn to_rust_ref(&self, _serde_with: bool) -> String {
80        self.validate_element()
81            .unwrap_or_else(|_| panic!("EnumElement is invalid: {self:?}"));
82
83        let name = match &self.inner {
84            Type::Enum { name, .. } => name.to_case(Case::Pascal),
85            _ => unreachable!("The previous validation ensured that this type is a valid Enum"),
86        };
87
88        if self.config.prefix_enum {
89            format!("Enum{name}")
90        } else {
91            name
92        }
93    }
94
95    fn to_rust_decl(&self, serde_with: bool, add_derives: bool) -> String {
96        self.validate_element()
97            .unwrap_or_else(|_| panic!("EnumElement is invalid: {self:?}"));
98
99        let code_regex = regex::Regex::new(r"=\s*(\d+)").expect("Error building regex to extract endpoint code");
100
101        match &self.inner {
102            Type::Enum {
103                name: _,
104                variants: fields,
105            } => {
106                let mut fields = fields
107                    .iter()
108                    .map(|x| {
109                        format!(
110                            r#"
111    /// {}
112    {} = {}
113"#,
114                            x.description,
115                            if x.name.chars().last().unwrap().is_lowercase() {
116                                x.name.to_case(Case::Pascal)
117                            } else {
118                                x.name.clone()
119                            },
120                            x.value
121                        )
122                    })
123                    .sorted_by(|a, b| {
124                        // Sort by the endpoint code
125                        let code_a = {
126                            match code_regex.captures(a) {
127                                Some(code) => code[1].parse::<u64>().unwrap_or_else(|err| {
128                                    eprintln!("Sorting error: {err}: Rust output may not be sorted correctly");
129                                    0
130                                }),
131                                None => {
132                                    eprintln!("Sorting error: Rust output may not be sorted correctly");
133                                    0
134                                }
135                            }
136                        };
137
138                        let code_b = {
139                            match code_regex.captures(b) {
140                                Some(code) => code[1].parse::<u64>().unwrap_or_else(|err| {
141                                    eprintln!("Sorting error: {err}: Rust output may not be sorted correctly");
142                                    0
143                                }),
144                                None => {
145                                    eprintln!("Sorting error: Rust output may not be sorted correctly");
146                                    0
147                                }
148                            }
149                        };
150
151                        code_a.cmp(&code_b)
152                    });
153                let enum_content = format!(r#"pub enum {} {{{}}}"#, self.to_rust_ref(serde_with), fields.join(","));
154
155                if add_derives {
156                    self.add_derives(enum_content)
157                } else {
158                    enum_content
159                }
160            }
161            _ => unreachable!(),
162        }
163    }
164
165    fn add_derives(&self, input: String) -> String {
166        if self.config.worktable_support {
167            format!(
168                r#"#[derive(
169                    MemStat,
170                    Archive,
171                    Clone,
172                    Copy,
173                    Debug,
174                    Display,
175                    PartialEq,
176                    PartialOrd,
177                    Eq,
178                    Hash,
179                    Ord,
180                    EnumString,
181                    rkyv::Deserialize,
182                    rkyv::Serialize,
183                    serde::Serialize,
184                    serde::Deserialize,
185                    {}
186                )]
187                #[rkyv(compare(PartialEq), derive(Debug))]
188                #[repr(u8)]
189                {input}
190            "#,
191                if self.config.json_schema_gen {
192                    "JsonSchema,"
193                } else {
194                    Default::default()
195                }
196            )
197        } else {
198            Type::add_default_enum_derives(input)
199        }
200    }
201}
202#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, PartialOrd, Eq, Ord, DefinitionVariant)]
203pub struct StructListDefinition {
204    #[serde(default)]
205    pub config: RustGenConfig,
206    pub struct_elements: Vec<StructElement>,
207}
208
209impl GenElement<StructListDefinition> for StructListDefinition {
210    fn validate_element(&self) -> eyre::Result<()> {
211        if self
212            .struct_elements
213            .iter()
214            .all(|s| matches!(s.inner, Type::Struct { .. }))
215        {
216            Ok(())
217        } else {
218            eyre::bail!("Not all elements of the StructListDefinition are Struct types")
219        }
220    }
221}
222
223/// Wraps the [Type::Struct] variant with extra config
224#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, PartialOrd, Eq, Ord, DefinitionVariant)]
225pub struct StructElement {
226    #[serde(default)]
227    pub config: RustGenConfig,
228    pub inner: Type,
229}
230
231impl GenElement<StructElement> for StructElement {
232    fn validate_element(&self) -> eyre::Result<()> {
233        match &self.inner {
234            Type::Struct { .. } => Ok(()),
235            _ => eyre::bail!("Expected struct type"),
236        }
237    }
238}
239
240impl ToRust for StructElement {
241    fn to_rust_ref(&self, _serde_with: bool) -> String {
242        self.validate_element()
243            .unwrap_or_else(|_| panic!("StructElement is invalid: {self:?}"));
244
245        match &self.inner {
246            Type::Struct { name, .. } => name.clone(),
247            _ => unreachable!("The previous validation ensured that this type is a valid Struct"),
248        }
249    }
250
251    fn to_rust_decl(&self, serde_with: bool, add_derives: bool) -> String {
252        self.validate_element()
253            .unwrap_or_else(|_| panic!("StructElement is invalid: {self:?}"));
254
255        let (name, fields) = match &self.inner {
256            Type::Struct { name, fields } => (name.to_case(Case::Pascal), fields),
257            _ => unreachable!("The previous validation ensured that this type is a valid Struct"),
258        };
259
260        let mut fields = fields.iter().map(|x| {
261            let opt = matches!(&x.ty, Type::Optional(_));
262            let serde_with_opt = match &x.ty {
263                Type::BlockchainDecimal => "rust_decimal::serde::str",
264                Type::BlockchainAddress if serde_with => "WithBlockchainAddress",
265                Type::BlockchainTransactionHash if serde_with => "WithBlockchainTransactionHash",
266                // TODO: handle optional decimals
267                // Type::Optional(t) if matches!(**t, Type::BlockchainDecimal) => {
268                //     "WithBlockchainDecimal"
269                // }
270                // Type::Optional(t) if matches!(**t, Type::BlockchainAddress) => {
271                //     "WithBlockchainAddress"
272                // }
273                // Type::Optional(t) if matches!(**t, Type::BlockchainTransactionHash) => {
274                //     "WithBlockchainTransactionHash"
275                // }
276                _ => "",
277            };
278            format!(
279                "{} {} pub {}: {}",
280                if opt { "#[serde(default)]" } else { "" },
281                if serde_with_opt.is_empty() {
282                    "".to_string()
283                } else {
284                    format!("#[serde(with = \"{serde_with_opt}\")]")
285                },
286                x.name,
287                x.ty.to_rust_ref(serde_with)
288            )
289        });
290        let input = format!("pub struct {} {{{}}}", name, fields.join(","));
291
292        if add_derives { self.add_derives(input) } else { input }
293    }
294
295    fn add_derives(&self, input: String) -> String {
296        if self.config.worktable_support {
297            // format!(
298            //     r#"#[derive(
299            //             Clone,
300            //             Copy,
301            //             Debug,
302            //             Default,
303            //             Eq,
304            //             Hash,
305            //             Ord,
306            //             PartialEq,
307            //             PartialOrd,
308            //             derive_more::Display,
309            //             derive_more::From,
310            //             derive_more::FromStr,
311            //             derive_more::Into,
312            //             MemStat,
313            //             rkyv::Archive,
314            //             SizeMeasure,
315            //             rkyv::Deserialize,
316            //             rkyv::Serialize,
317            //             serde::Serialize,
318            //             serde::Deserialize,
319            //             derive_more::Display,
320            //         )]
321            //         #[rkyv(compare(PartialEq), derive(Debug, PartialOrd, PartialEq, Eq, Ord))]
322            //         {input}
323            //     "#
324
325            // TODO: Fix worktable support for structs
326            format!(
327                r#" #[derive(Serialize, Deserialize, Debug, Clone, {})]
328                #[serde(rename_all = "camelCase")]
329                {input}
330            "#,
331                if self.config.json_schema_gen {
332                    "JsonSchema,"
333                } else {
334                    Default::default()
335                }
336            )
337        } else {
338            Type::add_default_struct_derives(input)
339        }
340    }
341}
342
343#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, PartialOrd, Eq, Ord, Default)]
344pub struct RustGenConfig {
345    #[serde(default)]
346    pub prefix_enum: bool,
347    #[serde(default)]
348    pub worktable_support: bool,
349    #[serde(default)]
350    pub json_schema_gen: bool,
351    #[serde(default)]
352    pub snake_case_fields: bool,
353    #[serde(default)]
354    pub override_parent: bool,
355}
356
357#[derive(Serialize, Deserialize, Clone)]
358pub struct GenService {
359    pub name: String,
360    pub id: u16,
361    pub endpoints: Vec<EndpointSchemaElement>,
362}
363
364impl GenService {
365    pub fn new(name: String, id: u16, endpoints: Vec<EndpointSchemaElement>) -> Self {
366        Self { name, id, endpoints }
367    }
368}
369
370#[derive(Serialize, Deserialize)]
371pub struct EndpointSchemaDefinition {
372    pub service_name: String,
373    pub service_id: u16,
374    pub schema: EndpointSchemaElement,
375}
376
377#[smart_serde_default]
378#[derive(Serialize, Deserialize, SmartDefault, Clone)]
379pub struct EndpointSchemaElement {
380    #[smart_default(true)]
381    pub frontend_facing: bool,
382    #[serde(default)]
383    pub config: RustGenConfig,
384    pub schema: EndpointSchema,
385}
386
387impl From<EndpointSchemaElement> for EndpointSchema {
388    fn from(val: EndpointSchemaElement) -> Self {
389        EndpointSchema {
390            name: val.schema.name,
391            code: val.schema.code,
392            parameters: val.schema.parameters,
393            returns: val.schema.returns,
394            stream_response: val.schema.stream_response,
395            description: val.schema.description,
396            json_schema: val.schema.json_schema,
397            roles: val.schema.roles,
398        }
399    }
400}
401
402impl FromIterator<EndpointSchemaElement> for Vec<EndpointSchema> {
403    fn from_iter<T: IntoIterator<Item = EndpointSchemaElement>>(iter: T) -> Self {
404        iter.into_iter().map(|element| element.schema).collect()
405    }
406}
407
408impl GenElement<EndpointSchemaDefinition> for EndpointSchemaDefinition {
409    fn validate_element(&self) -> eyre::Result<()> {
410        Ok(())
411    }
412}
413
414#[derive(Serialize, Deserialize, DefinitionVariant)]
415pub struct EndpointSchemaListDefinition {
416    pub service_name: String,
417    pub service_id: u16,
418    #[serde(default)]
419    pub config: RustGenConfig,
420    pub endpoints: Vec<EndpointSchemaElement>,
421}
422
423impl GenElement<EndpointSchemaListDefinition> for EndpointSchemaListDefinition {
424    fn validate_element(&self) -> eyre::Result<()> {
425        Ok(())
426    }
427}