1use crate::definitions::EnumElement;
2use crate::docs::Data;
3use convert_case::{Case, Casing};
4use endpoint_libs::model::{EndpointErrorSchema, EnumVariant, Type};
5use eyre::bail;
6use itertools::Itertools;
7use std::collections::{BTreeSet, HashMap};
8use std::fs::File;
9use std::io::Write;
10use std::path::Path;
11use std::process::Command;
12
13pub trait ToRust {
14 fn to_rust_ref(&self, serde_with: bool) -> String;
15 fn to_rust_decl(&self, serde_with: bool, add_derives: bool) -> String;
16 fn add_derives(&self, input: String) -> String;
17}
18
19impl ToRust for Type {
20 fn to_rust_ref(&self, serde_with: bool) -> String {
21 match self {
22 Type::UInt32 => "u32".to_owned(),
23 Type::Int32 => "i32".to_owned(),
24 Type::Int64 => "i64".to_owned(),
25 Type::Float64 => "f64".to_owned(),
26 Type::TimeStampMs => "i64".to_owned(),
27 Type::Struct { name, .. } => name.clone(),
28 Type::StructRef(name) => name.clone(),
29 Type::Object => "serde_json::Value".to_owned(),
30 Type::StructTable { struct_ref } => format!("Vec<{struct_ref}>"),
32 Type::Vec(ele) => {
33 format!("Vec<{}>", ele.to_rust_ref(serde_with))
34 }
35 Type::Unit => "()".to_owned(),
36 Type::Optional(t) => {
37 format!("Option<{}>", t.to_rust_ref(serde_with))
38 }
39 Type::Boolean => "bool".to_owned(),
40 Type::String => "String".to_owned(),
41 Type::Bytea => "Vec<u8>".to_owned(),
42 Type::UUID => "Uuid".to_owned(),
43 Type::NanoId { len } => format!("Nanoid<{len}, Base62Alphabet>"),
44 Type::IpAddr => "IpAddr".to_owned(),
45 Type::Enum { name, .. } => format!("Enum{}", name.to_case(Case::Pascal),),
46 Type::EnumRef { name, prefixed_name } => {
47 if *prefixed_name {
48 format!("Enum{}", name.to_case(Case::Pascal),)
49 } else {
50 name.to_case(Case::Pascal)
51 }
52 }
53 Type::BlockchainDecimal => "Decimal".to_owned(),
54 Type::BlockchainAddress if serde_with => "Address".to_owned(),
55 Type::BlockchainTransactionHash if serde_with => "H256".to_owned(),
56 Type::BlockchainAddress => "BlockchainAddress".to_owned(),
57 Type::BlockchainTransactionHash => "BlockchainTransactionHash".to_owned(),
58 }
59 }
60
61 fn to_rust_decl(&self, serde_with: bool, add_derives: bool) -> String {
62 let code_regex = regex::Regex::new(r"=\s*(\d+)").expect("Error building regex to extract endpoint code");
63
64 match self {
65 Type::Struct { name, fields } => {
66 let mut fields = fields.iter().map(|x| {
67 let opt = matches!(&x.ty, Type::Optional(_));
68 let serde_with_opt = match &x.ty {
69 Type::BlockchainDecimal => "rust_decimal::serde::str",
70 Type::BlockchainAddress if serde_with => "WithBlockchainAddress",
71 Type::BlockchainTransactionHash if serde_with => "WithBlockchainTransactionHash",
72 _ => "",
83 };
84 format!(
85 "{} {} pub {}: {}",
86 if opt { "#[serde(default)]" } else { "" },
87 if serde_with_opt.is_empty() {
88 "".to_string()
89 } else {
90 format!("#[serde(with = \"{serde_with_opt}\")]")
91 },
92 x.name,
93 x.ty.to_rust_ref(serde_with)
94 )
95 });
96 let input = format!("pub struct {} {{{}}}", name, fields.join(","));
97
98 if add_derives { self.add_derives(input) } else { input }
99 }
100 Type::Enum { name, variants: fields } => {
101 let mut fields = fields
102 .iter()
103 .map(|x| {
104 format!(
105 r#"
106 /// {}
107 {} = {}
108"#,
109 x.description,
110 if x.name.chars().last().unwrap().is_lowercase() {
111 x.name.to_case(Case::Pascal)
112 } else {
113 x.name.clone()
114 },
115 x.value
116 )
117 })
118 .sorted_by(|a, b| {
119 let code_a = {
121 match code_regex.captures(a) {
122 Some(code) => code[1].parse::<u64>().unwrap_or_else(|err| {
123 eprintln!("Sorting error: {err}: Rust output may not be sorted correctly");
124 0
125 }),
126 None => {
127 eprintln!("Sorting error: Rust output may not be sorted correctly");
128 0
129 }
130 }
131 };
132
133 let code_b = {
134 match code_regex.captures(b) {
135 Some(code) => code[1].parse::<u64>().unwrap_or_else(|err| {
136 eprintln!("Sorting error: {err}: Rust output may not be sorted correctly");
137 0
138 }),
139 None => {
140 eprintln!("Sorting error: Rust output may not be sorted correctly");
141 0
142 }
143 }
144 };
145
146 code_a.cmp(&code_b)
147 });
148 let enum_content = format!(
149 r#"pub enum Enum{} {{{}}}"#,
150 name.to_case(Case::Pascal),
151 fields.join(",")
152 );
153
154 if add_derives {
155 self.add_derives(enum_content)
156 } else {
157 enum_content
158 }
159 }
160 x => x.to_rust_ref(serde_with),
161 }
162 }
163
164 fn add_derives(&self, input: String) -> String {
165 match self {
166 Self::Enum { .. } => Self::add_default_enum_derives(input),
167 Self::Struct { .. } => Self::add_default_struct_derives(input),
168 _ => input,
169 }
170 }
171}
172
173pub fn collect_rust_recursive_types(t: Type) -> Vec<Type> {
174 match t {
175 Type::Struct { ref fields, .. } => {
176 let mut v = vec![t.clone()];
177 for x in fields {
178 v.extend(collect_rust_recursive_types(x.ty.clone()));
179 }
180 v
181 }
182 Type::Vec(x) => collect_rust_recursive_types(*x),
189 Type::Optional(x) => collect_rust_recursive_types(*x),
190 _ => vec![],
191 }
192}
193
194fn endpoint_error_enum_name(endpoint_name: &str) -> String {
195 format!("{}Error", endpoint_name.to_case(Case::Pascal))
196}
197
198fn endpoint_error_variant_name(error: &EndpointErrorSchema) -> String {
199 error.name.to_case(Case::Pascal)
200}
201
202pub(crate) fn error_code_variant_name(name: &str) -> String {
203 name.to_case(Case::Pascal)
204}
205
206fn endpoint_error_code_expr(error: &EndpointErrorSchema) -> String {
207 format!("EnumErrorCode::{}", error_code_variant_name(error.code.variant()))
208}
209
210fn rust_string_literal(value: &str) -> String {
211 serde_json::to_string(value).expect("string serialization should not fail")
212}
213
214fn gen_endpoint_error_enum(
215 endpoint_name: &str,
216 errors: &[EndpointErrorSchema],
217 mut writer: impl Write,
218) -> eyre::Result<()> {
219 if errors.is_empty() {
220 return Ok(());
221 }
222
223 let enum_name = endpoint_error_enum_name(endpoint_name);
224 writeln!(
225 writer,
226 "#[derive(Serialize, Deserialize, Debug, Clone)]\npub enum {enum_name} {{"
227 )?;
228
229 for error in errors {
230 let variant_name = endpoint_error_variant_name(error);
231 if !error.message.is_empty() {
232 writeln!(writer, " /// {}", error.message)?;
233 }
234 if error.fields.is_empty() {
235 writeln!(writer, " {variant_name},")?;
236 } else {
237 let fields = error
238 .fields
239 .iter()
240 .map(|field| format!("{}: {}", field.name, field.ty.to_rust_ref(true)))
241 .join(", ");
242 writeln!(writer, " {variant_name} {{ {fields} }},")?;
243 }
244 }
245
246 writeln!(writer, "}}\n")?;
247 writeln!(writer, "impl From<{enum_name}> for CustomError {{")?;
248 writeln!(writer, " fn from(err: {enum_name}) -> Self {{")?;
249 writeln!(writer, " match err {{")?;
250
251 for error in errors {
252 let variant_name = endpoint_error_variant_name(error);
253 let code_expr = endpoint_error_code_expr(error);
254 let message = &error.message;
255 let kind = rust_string_literal(&variant_name);
256 if error.fields.is_empty() {
257 writeln!(
258 writer,
259 " {enum_name}::{variant_name} => CustomError::new({code_expr}).with_message({}).with_kind({kind}),",
260 rust_string_literal(message),
261 )?;
262 } else {
263 let field_names = error.fields.iter().map(|field| field.name.as_str()).join(", ");
264 let json_fields = error
265 .fields
266 .iter()
267 .map(|field| format!(r#""{}": {}"#, field.name.to_case(Case::Camel), field.name))
268 .join(", ");
269 writeln!(
270 writer,
271 " {enum_name}::{variant_name} {{ {field_names} }} => CustomError::new({code_expr}).with_message({}).with_kind({kind}).with_details(serde_json::json!({{ {json_fields} }})),",
272 rust_string_literal(message),
273 )?;
274 }
275 }
276
277 writeln!(writer, " }}")?;
278 writeln!(writer, " }}")?;
279 writeln!(writer, "}}\n")?;
280
281 Ok(())
282}
283
284pub fn gen_model_rs(data: &Data) -> eyre::Result<()> {
285 let db_filename = data.output_dir.join("model.rs");
286
287 if let Some(parent) = db_filename.parent() {
289 std::fs::create_dir_all(parent)?;
290 }
291
292 let worktable_imports = if data.enums.iter().any(|e| e.config.worktable_support)
293 || data.structs.iter().any(|s| s.config.worktable_support)
294 {
295 r#"use worktable::prelude::*;
296 use rkyv::Archive;
297 "#
298 } else {
299 ""
300 };
301
302 let json_schema_imports = if data.enums.iter().any(|e| e.config.json_schema_gen)
303 || data.structs.iter().any(|s| s.config.json_schema_gen)
304 {
305 r#"use schemars::{schema_for, JsonSchema};"#
306 } else {
307 ""
308 };
309
310 let mut model_file = File::create(&db_filename)?;
311 write!(
312 &mut model_file,
313 "use endpoint_libs::libs::error_code::ErrorCode;
314 use endpoint_libs::libs::ws::*;
315 use endpoint_libs::libs::types::*;
316 use endpoint_libs::libs::ws::toolbox::CustomError;
317 use num_derive::FromPrimitive;
318 use serde::*;
319 use strum_macros::{{Display, EnumString}};
320 use uuid::Uuid;
321 use psc_nanoid::{{Nanoid, alphabet::Base62Alphabet}};
322 use std::net::IpAddr;
323 {worktable_imports}
324 {json_schema_imports}
325 ",
326 )?;
327
328 for e in &data.enums {
329 writeln!(&mut model_file, "{}", e.to_rust_decl(false, true))?;
330 }
331 for s in &data.structs {
332 writeln!(&mut model_file, "{}", s.to_rust_decl(false, true))?;
333 }
334 check_endpoint_codes(data, &mut model_file)?;
335 dump_endpoint_schema(data, &mut model_file)?;
336
337 let enum_ = Type::enum_(
338 "ErrorCode",
339 data.error_codes
340 .iter()
341 .map(|x| EnumVariant::new_with_description(error_code_variant_name(&x.name), x.description.clone(), x.code))
342 .collect(),
343 );
344 writeln!(&mut model_file, "{}", enum_.to_rust_decl(false, true))?;
345 writeln!(
346 &mut model_file,
347 r#"
348impl From<EnumErrorCode> for ErrorCode {{
349 fn from(e: EnumErrorCode) -> Self {{
350 ErrorCode::new(e as _)
351 }}
352}}
353 "#
354 )?;
355
356 let mut endpoint_reqres_types = BTreeSet::new();
357 for s in &data.services {
358 for e in &s.endpoints {
359 let req = Type::struct_(format!("{}Request", e.schema.name), e.schema.parameters.clone());
360 let resp = Type::struct_(format!("{}Response", e.schema.name), e.schema.returns.clone());
361 endpoint_reqres_types.extend(
362 [
363 collect_rust_recursive_types(req),
364 collect_rust_recursive_types(resp),
365 e.schema
366 .stream_response
367 .clone()
368 .into_iter()
369 .flat_map(Type::try_unwrap)
370 .collect::<Vec<_>>(),
371 e.schema
372 .errors
373 .iter()
374 .flat_map(|error| {
375 error
376 .fields
377 .iter()
378 .flat_map(|field| collect_rust_recursive_types(field.ty.clone()))
379 })
380 .collect::<Vec<_>>(),
381 ]
382 .concat(),
383 );
384 }
385 }
386 for s in endpoint_reqres_types {
387 write!(&mut model_file, r#"{}"#, s.to_rust_decl(true, true))?;
388 }
389
390 for s in &data.services {
391 for endpoint in &s.endpoints {
392 gen_endpoint_error_enum(&endpoint.schema.name, &endpoint.schema.errors, &mut model_file)?;
393 }
394 }
395
396 for s in &data.services {
397 for endpoint in &s.endpoints {
398 let roles_list = resolve_roles_ids(&endpoint.schema.roles, &data.enums)
399 .into_iter()
400 .map(|x| x.to_string())
401 .join(", ");
402
403 write!(
404 &mut model_file,
405 "
406impl WsRequest for {end_name2}Request {{
407 type Response = {end_name2}Response;
408 const METHOD_ID: u32 = {code};
409 const ROLES: &[u32] = &[{roles_list}];
410 const SCHEMA: &'static str = r#\"{schema}\"#;
411}}
412impl WsResponse for {end_name2}Response {{
413 type Request = {end_name2}Request;
414}}
415",
416 end_name2 = endpoint.schema.name.to_case(Case::Pascal),
417 code = endpoint.schema.code,
418 schema = serde_json::to_string_pretty(&endpoint.schema).unwrap()
419 )?;
420 }
421 }
422 model_file.flush()?;
423 drop(model_file);
424 rustfmt(&db_filename)?;
425
426 Ok(())
427}
428
429fn resolve_roles_ids(endpoint_roles: &Vec<String>, all_enums: &Vec<EnumElement>) -> Vec<i64> {
432 let mut all_enums_typed: HashMap<String, Vec<EnumVariant>> = HashMap::new();
433 for e in all_enums {
434 if let Type::Enum { name: _, variants } = &e.inner {
435 all_enums_typed.insert(e.to_rust_ref(false), variants.clone());
436 }
437 }
438
439 let mut roles_ids = vec![];
440 for role in endpoint_roles {
441 let (role_enum_name, role_variant_name) = role.split_once("::").unwrap_or(("", role.as_str()));
442
443 if let Some(role_enum_variants) = all_enums_typed.get(role_enum_name) {
444 if let Some(role_variant_in_endpoint) = role_enum_variants.iter().find(|v| v.name == role_variant_name) {
445 roles_ids.push(role_variant_in_endpoint.value);
446 } else {
447 eprintln!("Warning: Role variant '{role_variant_name}' not found in enum '{role_enum_name}'");
448 }
449 } else {
450 eprintln!("Warning: Role enum '{role_enum_name}' not found");
451 }
452 }
453 let mut roles_ids_set: BTreeSet<i64> = BTreeSet::new();
455 for id in &roles_ids {
456 if !roles_ids_set.insert(*id) {
457 eprintln!("Warning: Duplicate role ID found: {id}");
458 }
459 }
460
461 roles_ids_set.into_iter().collect()
462}
463
464pub fn rustfmt(f: &Path) -> eyre::Result<()> {
465 let exit = Command::new("rustfmt")
466 .arg("--edition")
467 .arg("2021")
468 .arg(f)
469 .spawn()?
470 .wait()?;
471 if !exit.success() {
472 bail!("failed to rustfmt {:?}", exit);
473 }
474 Ok(())
475}
476
477pub fn check_endpoint_codes(data: &Data, mut writer: impl Write) -> eyre::Result<()> {
478 let mut variants = vec![];
479 for s in &data.services {
480 for e in &s.endpoints {
481 variants.push(EnumVariant::new(e.schema.name.clone(), e.schema.code as _));
482 }
483 }
484 let enum_ = Type::enum_("Endpoint", variants);
485 writeln!(writer, "{}", enum_.to_rust_decl(false, true))?;
486 Ok(())
488}
489pub fn dump_endpoint_schema(data: &Data, mut writer: impl Write) -> eyre::Result<()> {
490 let mut cases = vec![];
491 for s in &data.services {
492 for e in &s.endpoints {
493 cases.push(format!(
494 "Self::{name} => {name}Request::SCHEMA,",
495 name = e.schema.name.to_case(Case::Pascal),
496 ));
497 }
498 }
499 let code = format!(
500 r#"
501 impl EnumEndpoint {{
502 pub fn schema(&self) -> endpoint_libs::model::EndpointSchema {{
503 let schema = match self {{
504 {cases}
505 }};
506 serde_json::from_str(schema).unwrap()
507 }}
508 }}
509 "#,
510 cases = cases.join("\n")
511 );
512 writeln!(writer, "{code}")?;
513 Ok(())
514}
515
516#[cfg(test)]
517mod tests {
518 use regex::Regex;
519
520 #[test]
521 fn test_extract_number_from_error_code() {
522 let re = Regex::new(r"=\s*(\d+)").unwrap();
523
524 let text1 = r#" ///
526 LoginStep2 = 10003
527 ,"#;
528 let caps1 = re.captures(text1).expect("Should match");
529 let number1: u64 = caps1[1].parse().expect("Should parse as u64");
530 assert_eq!(number1, 10003);
531
532 let text2 = "Authorize = 10000,";
534 let caps2 = re.captures(text2).expect("Should match");
535 let number2: u64 = caps2[1].parse().expect("Should parse as u64");
536 assert_eq!(number2, 10000);
537
538 let text3 = "SomeError=12345,";
540 let caps3 = re.captures(text3).expect("Should match");
541 let number3: u64 = caps3[1].parse().expect("Should parse as u64");
542 assert_eq!(number3, 12345);
543
544 let text4 = r#"/// SQL R0019 UnauthorizedMessage
546 UnauthorizedMessage = 45349677
547, "#;
548 let caps4 = re.captures(text4).expect("Should match");
549 let number4: u64 = caps4[1].parse().expect("Should parse as u64");
550 assert_eq!(number4, 45349677);
551 }
552}