Skip to main content

endpoint_gen/
rust.rs

1use crate::definitions::EnumElement;
2use crate::docs::{self, Data};
3use convert_case::{Case, Casing};
4use endpoint_libs::model::{EnumVariant, Field, Type};
5use eyre::bail;
6use itertools::Itertools;
7use std::collections::{BTreeSet, HashMap};
8use std::fs::File;
9use std::io::Write;
10use std::path::Path;
11use std::process::Command;
12
13pub trait ToRust {
14    fn to_rust_ref(&self, serde_with: bool) -> String;
15    fn to_rust_decl(&self, serde_with: bool, add_derives: bool) -> String;
16    fn add_derives(&self, input: String) -> String;
17}
18
19impl ToRust for Type {
20    fn to_rust_ref(&self, serde_with: bool) -> String {
21        match self {
22            Type::UInt32 => "u32".to_owned(),
23            Type::Int32 => "i32".to_owned(),
24            Type::Int64 => "i64".to_owned(),
25            Type::Float64 => "f64".to_owned(),
26            Type::TimeStampMs => "i64".to_owned(),
27            Type::Struct { name, .. } => name.clone(),
28            Type::StructRef(name) => name.clone(),
29            Type::Object => "serde_json::Value".to_owned(),
30            // Type::DataTable { name, .. } => format!("Vec<{name}>"),
31            Type::StructTable { struct_ref } => format!("Vec<{struct_ref}>"),
32            Type::Vec(ele) => {
33                format!("Vec<{}>", ele.to_rust_ref(serde_with))
34            }
35            Type::Unit => "()".to_owned(),
36            Type::Optional(t) => {
37                format!("Option<{}>", t.to_rust_ref(serde_with))
38            }
39            Type::Boolean => "bool".to_owned(),
40            Type::String => "String".to_owned(),
41            Type::Bytea => "Vec<u8>".to_owned(),
42            Type::UUID => "Uuid".to_owned(),
43            Type::NanoId { len } => format!("Nanoid<{len}, Base62Alphabet>"),
44            Type::IpAddr => "IpAddr".to_owned(),
45            Type::Enum { name, .. } => format!("Enum{}", name.to_case(Case::Pascal),),
46            Type::EnumRef { name, prefixed_name } => {
47                if *prefixed_name {
48                    format!("Enum{}", name.to_case(Case::Pascal),)
49                } else {
50                    name.to_case(Case::Pascal)
51                }
52            }
53            Type::BlockchainDecimal => "Decimal".to_owned(),
54            Type::BlockchainAddress if serde_with => "Address".to_owned(),
55            Type::BlockchainTransactionHash if serde_with => "H256".to_owned(),
56            Type::BlockchainAddress => "BlockchainAddress".to_owned(),
57            Type::BlockchainTransactionHash => "BlockchainTransactionHash".to_owned(),
58        }
59    }
60
61    fn to_rust_decl(&self, serde_with: bool, add_derives: bool) -> String {
62        let code_regex = regex::Regex::new(r"=\s*(\d+)").expect("Error building regex to extract endpoint code");
63
64        match self {
65            Type::Struct { name, fields } => {
66                let mut fields = fields.iter().map(|x| {
67                    let opt = matches!(&x.ty, Type::Optional(_));
68                    let serde_with_opt = match &x.ty {
69                        Type::BlockchainDecimal => "rust_decimal::serde::str",
70                        Type::BlockchainAddress if serde_with => "WithBlockchainAddress",
71                        Type::BlockchainTransactionHash if serde_with => "WithBlockchainTransactionHash",
72                        // TODO: handle optional decimals
73                        // Type::Optional(t) if matches!(**t, Type::BlockchainDecimal) => {
74                        //     "WithBlockchainDecimal"
75                        // }
76                        // Type::Optional(t) if matches!(**t, Type::BlockchainAddress) => {
77                        //     "WithBlockchainAddress"
78                        // }
79                        // Type::Optional(t) if matches!(**t, Type::BlockchainTransactionHash) => {
80                        //     "WithBlockchainTransactionHash"
81                        // }
82                        _ => "",
83                    };
84                    format!(
85                        "{} {} pub {}: {}",
86                        if opt { "#[serde(default)]" } else { "" },
87                        if serde_with_opt.is_empty() {
88                            "".to_string()
89                        } else {
90                            format!("#[serde(with = \"{serde_with_opt}\")]")
91                        },
92                        x.name,
93                        x.ty.to_rust_ref(serde_with)
94                    )
95                });
96                let input = format!("pub struct {} {{{}}}", name, fields.join(","));
97
98                if add_derives { self.add_derives(input) } else { input }
99            }
100            Type::Enum { name, variants: fields } => {
101                let mut fields = fields
102                    .iter()
103                    .map(|x| {
104                        format!(
105                            r#"
106    /// {}
107    {} = {}
108"#,
109                            x.description,
110                            if x.name.chars().last().unwrap().is_lowercase() {
111                                x.name.to_case(Case::Pascal)
112                            } else {
113                                x.name.clone()
114                            },
115                            x.value
116                        )
117                    })
118                    .sorted_by(|a, b| {
119                        // Sort by the endpoint code
120                        let code_a = {
121                            match code_regex.captures(a) {
122                                Some(code) => code[1].parse::<u64>().unwrap_or_else(|err| {
123                                    eprintln!("Sorting error: {err}: Rust output may not be sorted correctly");
124                                    0
125                                }),
126                                None => {
127                                    eprintln!("Sorting error: Rust output may not be sorted correctly");
128                                    0
129                                }
130                            }
131                        };
132
133                        let code_b = {
134                            match code_regex.captures(b) {
135                                Some(code) => code[1].parse::<u64>().unwrap_or_else(|err| {
136                                    eprintln!("Sorting error: {err}: Rust output may not be sorted correctly");
137                                    0
138                                }),
139                                None => {
140                                    eprintln!("Sorting error: Rust output may not be sorted correctly");
141                                    0
142                                }
143                            }
144                        };
145
146                        code_a.cmp(&code_b)
147                    });
148                let enum_content = format!(
149                    r#"pub enum Enum{} {{{}}}"#,
150                    name.to_case(Case::Pascal),
151                    fields.join(",")
152                );
153
154                if add_derives {
155                    self.add_derives(enum_content)
156                } else {
157                    enum_content
158                }
159            }
160            x => x.to_rust_ref(serde_with),
161        }
162    }
163
164    fn add_derives(&self, input: String) -> String {
165        match self {
166            Self::Enum { .. } => Self::add_default_enum_derives(input),
167            Self::Struct { .. } => Self::add_default_struct_derives(input),
168            _ => input,
169        }
170    }
171}
172
173pub fn collect_rust_recursive_types(t: Type) -> Vec<Type> {
174    match t {
175        Type::Struct { ref fields, .. } => {
176            let mut v = vec![t.clone()];
177            for x in fields {
178                v.extend(collect_rust_recursive_types(x.ty.clone()));
179            }
180            v
181        }
182        // Type::DataTable { name, fields } => {
183        //     collect_rust_recursive_types(Type::struct_(name, fields))
184        // }
185        // Type::StructTable { struct_ref } => {
186        //     collect_rust_recursive_types(Type::struct_ref(struct_ref))
187        // }
188        Type::Vec(x) => collect_rust_recursive_types(*x),
189        Type::Optional(x) => collect_rust_recursive_types(*x),
190        _ => vec![],
191    }
192}
193
194pub fn gen_model_rs(data: &Data) -> eyre::Result<()> {
195    let db_filename = data.output_dir.join("model.rs");
196
197    // Ensure the parent directories exist
198    if let Some(parent) = db_filename.parent() {
199        std::fs::create_dir_all(parent)?;
200    }
201
202    let worktable_imports = if data.enums.iter().any(|e| e.config.worktable_support)
203        || data.structs.iter().any(|s| s.config.worktable_support)
204    {
205        r#"use worktable::prelude::*;
206           use rkyv::Archive;
207        "#
208    } else {
209        ""
210    };
211
212    let json_schema_imports = if data.enums.iter().any(|e| e.config.json_schema_gen)
213        || data.structs.iter().any(|s| s.config.json_schema_gen)
214    {
215        r#"use schemars::{schema_for, JsonSchema};"#
216    } else {
217        ""
218    };
219
220    let mut model_file = File::create(&db_filename)?;
221    write!(
222        &mut model_file,
223        "use endpoint_libs::libs::error_code::ErrorCode;
224        use endpoint_libs::libs::ws::*;
225        use endpoint_libs::libs::types::*;
226        use num_derive::FromPrimitive;
227        use serde::*;
228        use strum_macros::{{Display, EnumString}};
229        use uuid::Uuid;
230        use psc_nanoid::{{Nanoid, alphabet::Base62Alphabet}};
231        use std::net::IpAddr;
232        {worktable_imports}
233        {json_schema_imports}
234        ",
235    )?;
236
237    for e in &data.enums {
238        writeln!(&mut model_file, "{}", e.to_rust_decl(false, true))?;
239    }
240    for s in &data.structs {
241        writeln!(&mut model_file, "{}", s.to_rust_decl(false, true))?;
242    }
243    check_endpoint_codes(data, &mut model_file)?;
244    dump_endpoint_schema(data, &mut model_file)?;
245
246    let errors = docs::get_error_messages(&data.project_root)?;
247    let rule = regex::Regex::new(r"\{[\w]+}")?;
248
249    for e in &errors.codes {
250        let name = format!("Error{}", e.symbol.to_case(Case::Pascal));
251        let s = Type::struct_(
252            name,
253            rule.find_iter(&e.message)
254                .map(|m| m.as_str())
255                .map(|s| s.trim_matches('{').trim_matches('}'))
256                .map(|s| Field::new(s.to_string(), Type::String))
257                .collect(),
258        );
259        writeln!(&mut model_file, "{}", s.to_rust_decl(true, true))?;
260    }
261    let enum_ = Type::enum_(
262        "ErrorCode",
263        errors
264            .codes
265            .into_iter()
266            .map(|x| {
267                EnumVariant::new_with_description(
268                    x.symbol.to_case(Case::Pascal),
269                    format!("{} {}", x.source, x.message),
270                    x.code,
271                )
272            })
273            .collect(),
274    );
275    writeln!(&mut model_file, "{}", enum_.to_rust_decl(false, true))?;
276    writeln!(
277        &mut model_file,
278        r#"
279impl From<EnumErrorCode> for ErrorCode {{
280    fn from(e: EnumErrorCode) -> Self {{
281        ErrorCode::new(e as _)
282    }}
283}}
284    "#
285    )?;
286
287    let mut endpoint_reqres_types = BTreeSet::new();
288    for s in &data.services {
289        for e in &s.endpoints {
290            let req = Type::struct_(format!("{}Request", e.schema.name), e.schema.parameters.clone());
291            let resp = Type::struct_(format!("{}Response", e.schema.name), e.schema.returns.clone());
292            endpoint_reqres_types.extend(
293                [
294                    collect_rust_recursive_types(req),
295                    collect_rust_recursive_types(resp),
296                    e.schema
297                        .stream_response
298                        .clone()
299                        .into_iter()
300                        .flat_map(Type::try_unwrap)
301                        .collect::<Vec<_>>(),
302                ]
303                .concat(),
304            );
305        }
306    }
307    for s in endpoint_reqres_types {
308        write!(&mut model_file, r#"{}"#, s.to_rust_decl(true, true))?;
309    }
310
311    for s in &data.services {
312        for endpoint in &s.endpoints {
313            let roles_list = resolve_roles_ids(&endpoint.schema.roles, &data.enums)
314                .into_iter()
315                .map(|x| x.to_string())
316                .join(", ");
317
318            write!(
319                &mut model_file,
320                "
321impl WsRequest for {end_name2}Request {{
322    type Response = {end_name2}Response;
323    const METHOD_ID: u32 = {code};
324    const ROLES: &[u32] = &[{roles_list}];
325    const SCHEMA: &'static str = r#\"{schema}\"#;
326}}
327impl WsResponse for {end_name2}Response {{
328    type Request = {end_name2}Request;
329}}
330",
331                end_name2 = endpoint.schema.name.to_case(Case::Pascal),
332                code = endpoint.schema.code,
333                schema = serde_json::to_string_pretty(&endpoint.schema).unwrap()
334            )?;
335        }
336    }
337    model_file.flush()?;
338    drop(model_file);
339    rustfmt(&db_filename)?;
340
341    Ok(())
342}
343
344/// Resolves the IDs of roles from a list of role names and a list of enum types.
345/// endpoint_roles: vec!["Role1::Value1", "Role1::Value2"]
346fn resolve_roles_ids(endpoint_roles: &Vec<String>, all_enums: &Vec<EnumElement>) -> Vec<i64> {
347    let mut all_enums_typed: HashMap<String, Vec<EnumVariant>> = HashMap::new();
348    for e in all_enums {
349        if let Type::Enum { name: _, variants } = &e.inner {
350            all_enums_typed.insert(e.to_rust_ref(false), variants.clone());
351        }
352    }
353
354    let mut roles_ids = vec![];
355    for role in endpoint_roles {
356        let (role_enum_name, role_variant_name) = role.split_once("::").unwrap_or(("", role.as_str()));
357
358        if let Some(role_enum_variants) = all_enums_typed.get(role_enum_name) {
359            if let Some(role_variant_in_endpoint) = role_enum_variants.iter().find(|v| v.name == role_variant_name) {
360                roles_ids.push(role_variant_in_endpoint.value);
361            } else {
362                eprintln!("Warning: Role variant '{role_variant_name}' not found in enum '{role_enum_name}'");
363            }
364        } else {
365            eprintln!("Warning: Role enum '{role_enum_name}' not found");
366        }
367    }
368    // check there is not duplicate roles ids and print error if there are
369    let mut roles_ids_set: BTreeSet<i64> = BTreeSet::new();
370    for id in &roles_ids {
371        if !roles_ids_set.insert(*id) {
372            eprintln!("Warning: Duplicate role ID found: {id}");
373        }
374    }
375
376    roles_ids_set.into_iter().collect()
377}
378
379pub fn rustfmt(f: &Path) -> eyre::Result<()> {
380    let exit = Command::new("rustfmt")
381        .arg("--edition")
382        .arg("2021")
383        .arg(f)
384        .spawn()?
385        .wait()?;
386    if !exit.success() {
387        bail!("failed to rustfmt {:?}", exit);
388    }
389    Ok(())
390}
391
392pub fn check_endpoint_codes(data: &Data, mut writer: impl Write) -> eyre::Result<()> {
393    let mut variants = vec![];
394    for s in &data.services {
395        for e in &s.endpoints {
396            variants.push(EnumVariant::new(e.schema.name.clone(), e.schema.code as _));
397        }
398    }
399    let enum_ = Type::enum_("Endpoint", variants);
400    writeln!(writer, "{}", enum_.to_rust_decl(false, true))?;
401    // if it compiles, there're no duplicate codes or names
402    Ok(())
403}
404pub fn dump_endpoint_schema(data: &Data, mut writer: impl Write) -> eyre::Result<()> {
405    let mut cases = vec![];
406    for s in &data.services {
407        for e in &s.endpoints {
408            cases.push(format!(
409                "Self::{name} => {name}Request::SCHEMA,",
410                name = e.schema.name.to_case(Case::Pascal),
411            ));
412        }
413    }
414    let code = format!(
415        r#"
416    impl EnumEndpoint {{
417        pub fn schema(&self) -> endpoint_libs::model::EndpointSchema {{
418            let schema = match self {{
419                {cases}
420            }};
421            serde_json::from_str(schema).unwrap()
422        }}
423    }}
424    "#,
425        cases = cases.join("\n")
426    );
427    writeln!(writer, "{code}")?;
428    Ok(())
429}
430
431#[cfg(test)]
432mod tests {
433    use regex::Regex;
434
435    #[test]
436    fn test_extract_number_from_error_code() {
437        let re = Regex::new(r"=\s*(\d+)").unwrap();
438
439        // Test with newline between number and comma
440        let text1 = r#"    ///
441      LoginStep2 = 10003
442  ,"#;
443        let caps1 = re.captures(text1).expect("Should match");
444        let number1: u64 = caps1[1].parse().expect("Should parse as u64");
445        assert_eq!(number1, 10003);
446
447        // Test with spaces but no newline
448        let text2 = "Authorize = 10000,";
449        let caps2 = re.captures(text2).expect("Should match");
450        let number2: u64 = caps2[1].parse().expect("Should parse as u64");
451        assert_eq!(number2, 10000);
452
453        // Test with no spaces
454        let text3 = "SomeError=12345,";
455        let caps3 = re.captures(text3).expect("Should match");
456        let number3: u64 = caps3[1].parse().expect("Should parse as u64");
457        assert_eq!(number3, 12345);
458
459        // Test with multiple spaces
460        let text4 = r#"/// SQL R0019 UnauthorizedMessage
461    UnauthorizedMessage = 45349677
462, "#;
463        let caps4 = re.captures(text4).expect("Should match");
464        let number4: u64 = caps4[1].parse().expect("Should parse as u64");
465        assert_eq!(number4, 45349677);
466    }
467}