Skip to main content

endpoint_gen/
rust.rs

1use crate::definitions::EnumElement;
2use crate::docs::Data;
3use convert_case::{Case, Casing};
4use endpoint_libs::model::{EndpointErrorSchema, EnumVariant, Type};
5use eyre::bail;
6use itertools::Itertools;
7use std::collections::{BTreeSet, HashMap};
8use std::fs::File;
9use std::io::Write;
10use std::path::Path;
11use std::process::Command;
12
13pub trait ToRust {
14    fn to_rust_ref(&self, serde_with: bool) -> String;
15    fn to_rust_decl(&self, serde_with: bool, add_derives: bool) -> String;
16    fn add_derives(&self, input: String) -> String;
17}
18
19impl ToRust for Type {
20    fn to_rust_ref(&self, serde_with: bool) -> String {
21        match self {
22            Type::UInt32 => "u32".to_owned(),
23            Type::Int32 => "i32".to_owned(),
24            Type::Int64 => "i64".to_owned(),
25            Type::Float64 => "f64".to_owned(),
26            Type::TimeStampMs => "i64".to_owned(),
27            Type::Struct { name, .. } => name.clone(),
28            Type::StructRef(name) => name.clone(),
29            Type::Object => "serde_json::Value".to_owned(),
30            // Type::DataTable { name, .. } => format!("Vec<{name}>"),
31            Type::StructTable { struct_ref } => format!("Vec<{struct_ref}>"),
32            Type::Vec(ele) => {
33                format!("Vec<{}>", ele.to_rust_ref(serde_with))
34            }
35            Type::Unit => "()".to_owned(),
36            Type::Optional(t) => {
37                format!("Option<{}>", t.to_rust_ref(serde_with))
38            }
39            Type::Boolean => "bool".to_owned(),
40            Type::String => "String".to_owned(),
41            Type::Bytea => "Vec<u8>".to_owned(),
42            Type::UUID => "Uuid".to_owned(),
43            Type::NanoId { len } => format!("Nanoid<{len}, Base62Alphabet>"),
44            Type::IpAddr => "IpAddr".to_owned(),
45            Type::Enum { name, .. } => format!("Enum{}", name.to_case(Case::Pascal),),
46            Type::EnumRef { name, prefixed_name } => {
47                if *prefixed_name {
48                    format!("Enum{}", name.to_case(Case::Pascal),)
49                } else {
50                    name.to_case(Case::Pascal)
51                }
52            }
53            Type::BlockchainDecimal => "Decimal".to_owned(),
54            Type::BlockchainAddress if serde_with => "Address".to_owned(),
55            Type::BlockchainTransactionHash if serde_with => "H256".to_owned(),
56            Type::BlockchainAddress => "BlockchainAddress".to_owned(),
57            Type::BlockchainTransactionHash => "BlockchainTransactionHash".to_owned(),
58        }
59    }
60
61    fn to_rust_decl(&self, serde_with: bool, add_derives: bool) -> String {
62        let code_regex = regex::Regex::new(r"=\s*(\d+)").expect("Error building regex to extract endpoint code");
63
64        match self {
65            Type::Struct { name, fields } => {
66                let mut fields = fields.iter().map(|x| {
67                    let opt = matches!(&x.ty, Type::Optional(_));
68                    let serde_with_opt = match &x.ty {
69                        Type::BlockchainDecimal => "rust_decimal::serde::str",
70                        Type::BlockchainAddress if serde_with => "WithBlockchainAddress",
71                        Type::BlockchainTransactionHash if serde_with => "WithBlockchainTransactionHash",
72                        // TODO: handle optional decimals
73                        // Type::Optional(t) if matches!(**t, Type::BlockchainDecimal) => {
74                        //     "WithBlockchainDecimal"
75                        // }
76                        // Type::Optional(t) if matches!(**t, Type::BlockchainAddress) => {
77                        //     "WithBlockchainAddress"
78                        // }
79                        // Type::Optional(t) if matches!(**t, Type::BlockchainTransactionHash) => {
80                        //     "WithBlockchainTransactionHash"
81                        // }
82                        _ => "",
83                    };
84                    format!(
85                        "{} {} pub {}: {}",
86                        if opt { "#[serde(default)]" } else { "" },
87                        if serde_with_opt.is_empty() {
88                            "".to_string()
89                        } else {
90                            format!("#[serde(with = \"{serde_with_opt}\")]")
91                        },
92                        x.name,
93                        x.ty.to_rust_ref(serde_with)
94                    )
95                });
96                let input = format!("pub struct {} {{{}}}", name, fields.join(","));
97
98                if add_derives { self.add_derives(input) } else { input }
99            }
100            Type::Enum { name, variants: fields } => {
101                let mut fields = fields
102                    .iter()
103                    .map(|x| {
104                        format!(
105                            r#"
106    /// {}
107    {} = {}
108"#,
109                            x.description,
110                            if x.name.chars().last().unwrap().is_lowercase() {
111                                x.name.to_case(Case::Pascal)
112                            } else {
113                                x.name.clone()
114                            },
115                            x.value
116                        )
117                    })
118                    .sorted_by(|a, b| {
119                        // Sort by the endpoint code
120                        let code_a = {
121                            match code_regex.captures(a) {
122                                Some(code) => code[1].parse::<u64>().unwrap_or_else(|err| {
123                                    eprintln!("Sorting error: {err}: Rust output may not be sorted correctly");
124                                    0
125                                }),
126                                None => {
127                                    eprintln!("Sorting error: Rust output may not be sorted correctly");
128                                    0
129                                }
130                            }
131                        };
132
133                        let code_b = {
134                            match code_regex.captures(b) {
135                                Some(code) => code[1].parse::<u64>().unwrap_or_else(|err| {
136                                    eprintln!("Sorting error: {err}: Rust output may not be sorted correctly");
137                                    0
138                                }),
139                                None => {
140                                    eprintln!("Sorting error: Rust output may not be sorted correctly");
141                                    0
142                                }
143                            }
144                        };
145
146                        code_a.cmp(&code_b)
147                    });
148                let enum_content = format!(
149                    r#"pub enum Enum{} {{{}}}"#,
150                    name.to_case(Case::Pascal),
151                    fields.join(",")
152                );
153
154                if add_derives {
155                    self.add_derives(enum_content)
156                } else {
157                    enum_content
158                }
159            }
160            x => x.to_rust_ref(serde_with),
161        }
162    }
163
164    fn add_derives(&self, input: String) -> String {
165        match self {
166            Self::Enum { .. } => Self::add_default_enum_derives(input),
167            Self::Struct { .. } => Self::add_default_struct_derives(input),
168            _ => input,
169        }
170    }
171}
172
173pub fn collect_rust_recursive_types(t: Type) -> Vec<Type> {
174    match t {
175        Type::Struct { ref fields, .. } => {
176            let mut v = vec![t.clone()];
177            for x in fields {
178                v.extend(collect_rust_recursive_types(x.ty.clone()));
179            }
180            v
181        }
182        // Type::DataTable { name, fields } => {
183        //     collect_rust_recursive_types(Type::struct_(name, fields))
184        // }
185        // Type::StructTable { struct_ref } => {
186        //     collect_rust_recursive_types(Type::struct_ref(struct_ref))
187        // }
188        Type::Vec(x) => collect_rust_recursive_types(*x),
189        Type::Optional(x) => collect_rust_recursive_types(*x),
190        _ => vec![],
191    }
192}
193
194fn endpoint_error_enum_name(endpoint_name: &str) -> String {
195    format!("{}Error", endpoint_name.to_case(Case::Pascal))
196}
197
198fn endpoint_error_variant_name(error: &EndpointErrorSchema) -> String {
199    error.name.to_case(Case::Pascal)
200}
201
202pub(crate) fn error_code_variant_name(name: &str) -> String {
203    name.to_case(Case::Pascal)
204}
205
206fn endpoint_error_code_expr(error: &EndpointErrorSchema) -> String {
207    format!("EnumErrorCode::{}", error_code_variant_name(error.code.variant()))
208}
209
210fn rust_string_literal(value: &str) -> String {
211    serde_json::to_string(value).expect("string serialization should not fail")
212}
213
214fn gen_endpoint_error_enum(
215    endpoint_name: &str,
216    errors: &[EndpointErrorSchema],
217    mut writer: impl Write,
218) -> eyre::Result<()> {
219    if errors.is_empty() {
220        return Ok(());
221    }
222
223    let enum_name = endpoint_error_enum_name(endpoint_name);
224    writeln!(
225        writer,
226        "#[derive(Serialize, Deserialize, Debug, Clone)]\npub enum {enum_name} {{"
227    )?;
228
229    for error in errors {
230        let variant_name = endpoint_error_variant_name(error);
231        if !error.message.is_empty() {
232            writeln!(writer, "    /// {}", error.message)?;
233        }
234        if error.fields.is_empty() {
235            writeln!(writer, "    {variant_name},")?;
236        } else {
237            let fields = error
238                .fields
239                .iter()
240                .map(|field| format!("{}: {}", field.name, field.ty.to_rust_ref(true)))
241                .join(", ");
242            writeln!(writer, "    {variant_name} {{ {fields} }},")?;
243        }
244    }
245
246    writeln!(writer, "}}\n")?;
247    writeln!(writer, "impl From<{enum_name}> for CustomError {{")?;
248    writeln!(writer, "    fn from(err: {enum_name}) -> Self {{")?;
249    writeln!(writer, "        match err {{")?;
250
251    for error in errors {
252        let variant_name = endpoint_error_variant_name(error);
253        let code_expr = endpoint_error_code_expr(error);
254        let message = &error.message;
255        let kind = rust_string_literal(&variant_name);
256        if error.fields.is_empty() {
257            writeln!(
258                writer,
259                "            {enum_name}::{variant_name} => CustomError::new({code_expr}).with_message({}).with_kind({kind}),",
260                rust_string_literal(message),
261            )?;
262        } else {
263            let field_names = error.fields.iter().map(|field| field.name.as_str()).join(", ");
264            let json_fields = error
265                .fields
266                .iter()
267                .map(|field| format!(r#""{}": {}"#, field.name.to_case(Case::Camel), field.name))
268                .join(", ");
269            writeln!(
270                writer,
271                "            {enum_name}::{variant_name} {{ {field_names} }} => CustomError::new({code_expr}).with_message({}).with_kind({kind}).with_details(serde_json::json!({{ {json_fields} }})),",
272                rust_string_literal(message),
273            )?;
274        }
275    }
276
277    writeln!(writer, "        }}")?;
278    writeln!(writer, "    }}")?;
279    writeln!(writer, "}}\n")?;
280
281    Ok(())
282}
283
284pub fn gen_model_rs(data: &Data) -> eyre::Result<()> {
285    let db_filename = data.output_dir.join("model.rs");
286
287    // Ensure the parent directories exist
288    if let Some(parent) = db_filename.parent() {
289        std::fs::create_dir_all(parent)?;
290    }
291
292    let worktable_imports = if data.enums.iter().any(|e| e.config.worktable_support)
293        || data.structs.iter().any(|s| s.config.worktable_support)
294    {
295        r#"use worktable::prelude::*;
296           use rkyv::Archive;
297        "#
298    } else {
299        ""
300    };
301
302    let json_schema_imports = if data.enums.iter().any(|e| e.config.json_schema_gen)
303        || data.structs.iter().any(|s| s.config.json_schema_gen)
304    {
305        r#"use schemars::{schema_for, JsonSchema};"#
306    } else {
307        ""
308    };
309
310    let mut model_file = File::create(&db_filename)?;
311    write!(
312        &mut model_file,
313        "use endpoint_libs::libs::error_code::ErrorCode;
314        use endpoint_libs::libs::ws::*;
315        use endpoint_libs::libs::types::*;
316        use endpoint_libs::libs::ws::toolbox::CustomError;
317        use num_derive::FromPrimitive;
318        use serde::*;
319        use strum_macros::{{Display, EnumString}};
320        use uuid::Uuid;
321        use psc_nanoid::{{Nanoid, alphabet::Base62Alphabet}};
322        use std::net::IpAddr;
323        {worktable_imports}
324        {json_schema_imports}
325        ",
326    )?;
327
328    for e in &data.enums {
329        writeln!(&mut model_file, "{}", e.to_rust_decl(false, true))?;
330    }
331    for s in &data.structs {
332        writeln!(&mut model_file, "{}", s.to_rust_decl(false, true))?;
333    }
334    check_endpoint_codes(data, &mut model_file)?;
335    dump_endpoint_schema(data, &mut model_file)?;
336
337    let enum_ = Type::enum_(
338        "ErrorCode",
339        data.error_codes
340            .iter()
341            .map(|x| EnumVariant::new_with_description(error_code_variant_name(&x.name), x.description.clone(), x.code))
342            .collect(),
343    );
344    writeln!(&mut model_file, "{}", enum_.to_rust_decl(false, true))?;
345    writeln!(
346        &mut model_file,
347        r#"
348impl From<EnumErrorCode> for ErrorCode {{
349    fn from(e: EnumErrorCode) -> Self {{
350        ErrorCode::new(e as _)
351    }}
352}}
353    "#
354    )?;
355
356    let mut endpoint_reqres_types = BTreeSet::new();
357    for s in &data.services {
358        for e in &s.endpoints {
359            let req = Type::struct_(format!("{}Request", e.schema.name), e.schema.parameters.clone());
360            let resp = Type::struct_(format!("{}Response", e.schema.name), e.schema.returns.clone());
361            endpoint_reqres_types.extend(
362                [
363                    collect_rust_recursive_types(req),
364                    collect_rust_recursive_types(resp),
365                    e.schema
366                        .stream_response
367                        .clone()
368                        .into_iter()
369                        .flat_map(Type::try_unwrap)
370                        .collect::<Vec<_>>(),
371                    e.schema
372                        .errors
373                        .iter()
374                        .flat_map(|error| {
375                            error
376                                .fields
377                                .iter()
378                                .flat_map(|field| collect_rust_recursive_types(field.ty.clone()))
379                        })
380                        .collect::<Vec<_>>(),
381                ]
382                .concat(),
383            );
384        }
385    }
386    for s in endpoint_reqres_types {
387        write!(&mut model_file, r#"{}"#, s.to_rust_decl(true, true))?;
388    }
389
390    for s in &data.services {
391        for endpoint in &s.endpoints {
392            gen_endpoint_error_enum(&endpoint.schema.name, &endpoint.schema.errors, &mut model_file)?;
393        }
394    }
395
396    for s in &data.services {
397        for endpoint in &s.endpoints {
398            let roles_list = resolve_roles_ids(&endpoint.schema.roles, &data.enums)
399                .into_iter()
400                .map(|x| x.to_string())
401                .join(", ");
402
403            write!(
404                &mut model_file,
405                "
406impl WsRequest for {end_name2}Request {{
407    type Response = {end_name2}Response;
408    const METHOD_ID: u32 = {code};
409    const ROLES: &[u32] = &[{roles_list}];
410    const SCHEMA: &'static str = r#\"{schema}\"#;
411}}
412impl WsResponse for {end_name2}Response {{
413    type Request = {end_name2}Request;
414}}
415",
416                end_name2 = endpoint.schema.name.to_case(Case::Pascal),
417                code = endpoint.schema.code,
418                schema = serde_json::to_string_pretty(&endpoint.schema).unwrap()
419            )?;
420        }
421    }
422    model_file.flush()?;
423    drop(model_file);
424    rustfmt(&db_filename)?;
425
426    Ok(())
427}
428
429/// Resolves the IDs of roles from a list of role names and a list of enum types.
430/// endpoint_roles: vec!["Role1::Value1", "Role1::Value2"]
431fn resolve_roles_ids(endpoint_roles: &Vec<String>, all_enums: &Vec<EnumElement>) -> Vec<i64> {
432    let mut all_enums_typed: HashMap<String, Vec<EnumVariant>> = HashMap::new();
433    for e in all_enums {
434        if let Type::Enum { name: _, variants } = &e.inner {
435            all_enums_typed.insert(e.to_rust_ref(false), variants.clone());
436        }
437    }
438
439    let mut roles_ids = vec![];
440    for role in endpoint_roles {
441        let (role_enum_name, role_variant_name) = role.split_once("::").unwrap_or(("", role.as_str()));
442
443        if let Some(role_enum_variants) = all_enums_typed.get(role_enum_name) {
444            if let Some(role_variant_in_endpoint) = role_enum_variants.iter().find(|v| v.name == role_variant_name) {
445                roles_ids.push(role_variant_in_endpoint.value);
446            } else {
447                eprintln!("Warning: Role variant '{role_variant_name}' not found in enum '{role_enum_name}'");
448            }
449        } else {
450            eprintln!("Warning: Role enum '{role_enum_name}' not found");
451        }
452    }
453    // check there is not duplicate roles ids and print error if there are
454    let mut roles_ids_set: BTreeSet<i64> = BTreeSet::new();
455    for id in &roles_ids {
456        if !roles_ids_set.insert(*id) {
457            eprintln!("Warning: Duplicate role ID found: {id}");
458        }
459    }
460
461    roles_ids_set.into_iter().collect()
462}
463
464pub fn rustfmt(f: &Path) -> eyre::Result<()> {
465    let exit = Command::new("rustfmt")
466        .arg("--edition")
467        .arg("2021")
468        .arg(f)
469        .spawn()?
470        .wait()?;
471    if !exit.success() {
472        bail!("failed to rustfmt {:?}", exit);
473    }
474    Ok(())
475}
476
477pub fn check_endpoint_codes(data: &Data, mut writer: impl Write) -> eyre::Result<()> {
478    let mut variants = vec![];
479    for s in &data.services {
480        for e in &s.endpoints {
481            variants.push(EnumVariant::new(e.schema.name.clone(), e.schema.code as _));
482        }
483    }
484    let enum_ = Type::enum_("Endpoint", variants);
485    writeln!(writer, "{}", enum_.to_rust_decl(false, true))?;
486    // if it compiles, there're no duplicate codes or names
487    Ok(())
488}
489pub fn dump_endpoint_schema(data: &Data, mut writer: impl Write) -> eyre::Result<()> {
490    let mut cases = vec![];
491    for s in &data.services {
492        for e in &s.endpoints {
493            cases.push(format!(
494                "Self::{name} => {name}Request::SCHEMA,",
495                name = e.schema.name.to_case(Case::Pascal),
496            ));
497        }
498    }
499    let code = format!(
500        r#"
501    impl EnumEndpoint {{
502        pub fn schema(&self) -> endpoint_libs::model::EndpointSchema {{
503            let schema = match self {{
504                {cases}
505            }};
506            serde_json::from_str(schema).unwrap()
507        }}
508    }}
509    "#,
510        cases = cases.join("\n")
511    );
512    writeln!(writer, "{code}")?;
513    Ok(())
514}
515
516#[cfg(test)]
517mod tests {
518    use regex::Regex;
519
520    #[test]
521    fn test_extract_number_from_error_code() {
522        let re = Regex::new(r"=\s*(\d+)").unwrap();
523
524        // Test with newline between number and comma
525        let text1 = r#"    ///
526      LoginStep2 = 10003
527  ,"#;
528        let caps1 = re.captures(text1).expect("Should match");
529        let number1: u64 = caps1[1].parse().expect("Should parse as u64");
530        assert_eq!(number1, 10003);
531
532        // Test with spaces but no newline
533        let text2 = "Authorize = 10000,";
534        let caps2 = re.captures(text2).expect("Should match");
535        let number2: u64 = caps2[1].parse().expect("Should parse as u64");
536        assert_eq!(number2, 10000);
537
538        // Test with no spaces
539        let text3 = "SomeError=12345,";
540        let caps3 = re.captures(text3).expect("Should match");
541        let number3: u64 = caps3[1].parse().expect("Should parse as u64");
542        assert_eq!(number3, 12345);
543
544        // Test with multiple spaces
545        let text4 = r#"/// SQL R0019 UnauthorizedMessage
546    UnauthorizedMessage = 45349677
547, "#;
548        let caps4 = re.captures(text4).expect("Should match");
549        let number4: u64 = caps4[1].parse().expect("Should parse as u64");
550        assert_eq!(number4, 45349677);
551    }
552}