1use crate::definitions::EnumElement;
2use crate::docs::{self, Data};
3use convert_case::{Case, Casing};
4use endpoint_libs::model::{EnumVariant, Field, Type};
5use eyre::bail;
6use itertools::Itertools;
7use std::collections::{BTreeSet, HashMap};
8use std::fs::File;
9use std::io::Write;
10use std::path::Path;
11use std::process::Command;
12
13pub trait ToRust {
14 fn to_rust_ref(&self, serde_with: bool) -> String;
15 fn to_rust_decl(&self, serde_with: bool, add_derives: bool) -> String;
16 fn add_derives(&self, input: String) -> String;
17}
18
19impl ToRust for Type {
20 fn to_rust_ref(&self, serde_with: bool) -> String {
21 match self {
22 Type::UInt32 => "u32".to_owned(),
23 Type::Int32 => "i32".to_owned(),
24 Type::Int64 => "i64".to_owned(),
25 Type::Float64 => "f64".to_owned(),
26 Type::TimeStampMs => "i64".to_owned(),
27 Type::Struct { name, .. } => name.clone(),
28 Type::StructRef(name) => name.clone(),
29 Type::Object => "serde_json::Value".to_owned(),
30 Type::StructTable { struct_ref } => format!("Vec<{struct_ref}>"),
32 Type::Vec(ele) => {
33 format!("Vec<{}>", ele.to_rust_ref(serde_with))
34 }
35 Type::Unit => "()".to_owned(),
36 Type::Optional(t) => {
37 format!("Option<{}>", t.to_rust_ref(serde_with))
38 }
39 Type::Boolean => "bool".to_owned(),
40 Type::String => "String".to_owned(),
41 Type::Bytea => "Vec<u8>".to_owned(),
42 Type::UUID => "Uuid".to_owned(),
43 Type::NanoId { len } => format!("Nanoid<{len}, Base62Alphabet>"),
44 Type::IpAddr => "IpAddr".to_owned(),
45 Type::Enum { name, .. } => format!("Enum{}", name.to_case(Case::Pascal),),
46 Type::EnumRef { name, prefixed_name } => {
47 if *prefixed_name {
48 format!("Enum{}", name.to_case(Case::Pascal),)
49 } else {
50 name.to_case(Case::Pascal)
51 }
52 }
53 Type::BlockchainDecimal => "Decimal".to_owned(),
54 Type::BlockchainAddress if serde_with => "Address".to_owned(),
55 Type::BlockchainTransactionHash if serde_with => "H256".to_owned(),
56 Type::BlockchainAddress => "BlockchainAddress".to_owned(),
57 Type::BlockchainTransactionHash => "BlockchainTransactionHash".to_owned(),
58 }
59 }
60
61 fn to_rust_decl(&self, serde_with: bool, add_derives: bool) -> String {
62 let code_regex = regex::Regex::new(r"=\s*(\d+)").expect("Error building regex to extract endpoint code");
63
64 match self {
65 Type::Struct { name, fields } => {
66 let mut fields = fields.iter().map(|x| {
67 let opt = matches!(&x.ty, Type::Optional(_));
68 let serde_with_opt = match &x.ty {
69 Type::BlockchainDecimal => "rust_decimal::serde::str",
70 Type::BlockchainAddress if serde_with => "WithBlockchainAddress",
71 Type::BlockchainTransactionHash if serde_with => "WithBlockchainTransactionHash",
72 _ => "",
83 };
84 format!(
85 "{} {} pub {}: {}",
86 if opt { "#[serde(default)]" } else { "" },
87 if serde_with_opt.is_empty() {
88 "".to_string()
89 } else {
90 format!("#[serde(with = \"{serde_with_opt}\")]")
91 },
92 x.name,
93 x.ty.to_rust_ref(serde_with)
94 )
95 });
96 let input = format!("pub struct {} {{{}}}", name, fields.join(","));
97
98 if add_derives { self.add_derives(input) } else { input }
99 }
100 Type::Enum { name, variants: fields } => {
101 let mut fields = fields
102 .iter()
103 .map(|x| {
104 format!(
105 r#"
106 /// {}
107 {} = {}
108"#,
109 x.description,
110 if x.name.chars().last().unwrap().is_lowercase() {
111 x.name.to_case(Case::Pascal)
112 } else {
113 x.name.clone()
114 },
115 x.value
116 )
117 })
118 .sorted_by(|a, b| {
119 let code_a = {
121 match code_regex.captures(a) {
122 Some(code) => code[1].parse::<u64>().unwrap_or_else(|err| {
123 eprintln!("Sorting error: {err}: Rust output may not be sorted correctly");
124 0
125 }),
126 None => {
127 eprintln!("Sorting error: Rust output may not be sorted correctly");
128 0
129 }
130 }
131 };
132
133 let code_b = {
134 match code_regex.captures(b) {
135 Some(code) => code[1].parse::<u64>().unwrap_or_else(|err| {
136 eprintln!("Sorting error: {err}: Rust output may not be sorted correctly");
137 0
138 }),
139 None => {
140 eprintln!("Sorting error: Rust output may not be sorted correctly");
141 0
142 }
143 }
144 };
145
146 code_a.cmp(&code_b)
147 });
148 let enum_content = format!(
149 r#"pub enum Enum{} {{{}}}"#,
150 name.to_case(Case::Pascal),
151 fields.join(",")
152 );
153
154 if add_derives {
155 self.add_derives(enum_content)
156 } else {
157 enum_content
158 }
159 }
160 x => x.to_rust_ref(serde_with),
161 }
162 }
163
164 fn add_derives(&self, input: String) -> String {
165 match self {
166 Self::Enum { .. } => Self::add_default_enum_derives(input),
167 Self::Struct { .. } => Self::add_default_struct_derives(input),
168 _ => input,
169 }
170 }
171}
172
173pub fn collect_rust_recursive_types(t: Type) -> Vec<Type> {
174 match t {
175 Type::Struct { ref fields, .. } => {
176 let mut v = vec![t.clone()];
177 for x in fields {
178 v.extend(collect_rust_recursive_types(x.ty.clone()));
179 }
180 v
181 }
182 Type::Vec(x) => collect_rust_recursive_types(*x),
189 Type::Optional(x) => collect_rust_recursive_types(*x),
190 _ => vec![],
191 }
192}
193
194pub fn gen_model_rs(data: &Data) -> eyre::Result<()> {
195 let db_filename = data.output_dir.join("model.rs");
196
197 if let Some(parent) = db_filename.parent() {
199 std::fs::create_dir_all(parent)?;
200 }
201
202 let worktable_imports = if data.enums.iter().any(|e| e.config.worktable_support)
203 || data.structs.iter().any(|s| s.config.worktable_support)
204 {
205 r#"use worktable::prelude::*;
206 use rkyv::Archive;
207 "#
208 } else {
209 ""
210 };
211
212 let json_schema_imports = if data.enums.iter().any(|e| e.config.json_schema_gen)
213 || data.structs.iter().any(|s| s.config.json_schema_gen)
214 {
215 r#"use schemars::{schema_for, JsonSchema};"#
216 } else {
217 ""
218 };
219
220 let mut model_file = File::create(&db_filename)?;
221 write!(
222 &mut model_file,
223 "use endpoint_libs::libs::error_code::ErrorCode;
224 use endpoint_libs::libs::ws::*;
225 use endpoint_libs::libs::types::*;
226 use num_derive::FromPrimitive;
227 use serde::*;
228 use strum_macros::{{Display, EnumString}};
229 use uuid::Uuid;
230 use psc_nanoid::{{Nanoid, alphabet::Base62Alphabet}};
231 use std::net::IpAddr;
232 {worktable_imports}
233 {json_schema_imports}
234 ",
235 )?;
236
237 for e in &data.enums {
238 writeln!(&mut model_file, "{}", e.to_rust_decl(false, true))?;
239 }
240 for s in &data.structs {
241 writeln!(&mut model_file, "{}", s.to_rust_decl(false, true))?;
242 }
243 check_endpoint_codes(data, &mut model_file)?;
244 dump_endpoint_schema(data, &mut model_file)?;
245
246 let errors = docs::get_error_messages(&data.project_root)?;
247 let rule = regex::Regex::new(r"\{[\w]+}")?;
248
249 for e in &errors.codes {
250 let name = format!("Error{}", e.symbol.to_case(Case::Pascal));
251 let s = Type::struct_(
252 name,
253 rule.find_iter(&e.message)
254 .map(|m| m.as_str())
255 .map(|s| s.trim_matches('{').trim_matches('}'))
256 .map(|s| Field::new(s.to_string(), Type::String))
257 .collect(),
258 );
259 writeln!(&mut model_file, "{}", s.to_rust_decl(true, true))?;
260 }
261 let enum_ = Type::enum_(
262 "ErrorCode",
263 errors
264 .codes
265 .into_iter()
266 .map(|x| {
267 EnumVariant::new_with_description(
268 x.symbol.to_case(Case::Pascal),
269 format!("{} {}", x.source, x.message),
270 x.code,
271 )
272 })
273 .collect(),
274 );
275 writeln!(&mut model_file, "{}", enum_.to_rust_decl(false, true))?;
276 writeln!(
277 &mut model_file,
278 r#"
279impl From<EnumErrorCode> for ErrorCode {{
280 fn from(e: EnumErrorCode) -> Self {{
281 ErrorCode::new(e as _)
282 }}
283}}
284 "#
285 )?;
286
287 let mut endpoint_reqres_types = BTreeSet::new();
288 for s in &data.services {
289 for e in &s.endpoints {
290 let req = Type::struct_(format!("{}Request", e.schema.name), e.schema.parameters.clone());
291 let resp = Type::struct_(format!("{}Response", e.schema.name), e.schema.returns.clone());
292 endpoint_reqres_types.extend(
293 [
294 collect_rust_recursive_types(req),
295 collect_rust_recursive_types(resp),
296 e.schema
297 .stream_response
298 .clone()
299 .into_iter()
300 .flat_map(Type::try_unwrap)
301 .collect::<Vec<_>>(),
302 ]
303 .concat(),
304 );
305 }
306 }
307 for s in endpoint_reqres_types {
308 write!(&mut model_file, r#"{}"#, s.to_rust_decl(true, true))?;
309 }
310
311 for s in &data.services {
312 for endpoint in &s.endpoints {
313 let roles_list = resolve_roles_ids(&endpoint.schema.roles, &data.enums)
314 .into_iter()
315 .map(|x| x.to_string())
316 .join(", ");
317
318 write!(
319 &mut model_file,
320 "
321impl WsRequest for {end_name2}Request {{
322 type Response = {end_name2}Response;
323 const METHOD_ID: u32 = {code};
324 const ROLES: &[u32] = &[{roles_list}];
325 const SCHEMA: &'static str = r#\"{schema}\"#;
326}}
327impl WsResponse for {end_name2}Response {{
328 type Request = {end_name2}Request;
329}}
330",
331 end_name2 = endpoint.schema.name.to_case(Case::Pascal),
332 code = endpoint.schema.code,
333 schema = serde_json::to_string_pretty(&endpoint.schema).unwrap()
334 )?;
335 }
336 }
337 model_file.flush()?;
338 drop(model_file);
339 rustfmt(&db_filename)?;
340
341 Ok(())
342}
343
344fn resolve_roles_ids(endpoint_roles: &Vec<String>, all_enums: &Vec<EnumElement>) -> Vec<i64> {
347 let mut all_enums_typed: HashMap<String, Vec<EnumVariant>> = HashMap::new();
348 for e in all_enums {
349 if let Type::Enum { name: _, variants } = &e.inner {
350 all_enums_typed.insert(e.to_rust_ref(false), variants.clone());
351 }
352 }
353
354 let mut roles_ids = vec![];
355 for role in endpoint_roles {
356 let (role_enum_name, role_variant_name) = role.split_once("::").unwrap_or(("", role.as_str()));
357
358 if let Some(role_enum_variants) = all_enums_typed.get(role_enum_name) {
359 if let Some(role_variant_in_endpoint) = role_enum_variants.iter().find(|v| v.name == role_variant_name) {
360 roles_ids.push(role_variant_in_endpoint.value);
361 } else {
362 eprintln!("Warning: Role variant '{role_variant_name}' not found in enum '{role_enum_name}'");
363 }
364 } else {
365 eprintln!("Warning: Role enum '{role_enum_name}' not found");
366 }
367 }
368 let mut roles_ids_set: BTreeSet<i64> = BTreeSet::new();
370 for id in &roles_ids {
371 if !roles_ids_set.insert(*id) {
372 eprintln!("Warning: Duplicate role ID found: {id}");
373 }
374 }
375
376 roles_ids_set.into_iter().collect()
377}
378
379pub fn rustfmt(f: &Path) -> eyre::Result<()> {
380 let exit = Command::new("rustfmt")
381 .arg("--edition")
382 .arg("2021")
383 .arg(f)
384 .spawn()?
385 .wait()?;
386 if !exit.success() {
387 bail!("failed to rustfmt {:?}", exit);
388 }
389 Ok(())
390}
391
392pub fn check_endpoint_codes(data: &Data, mut writer: impl Write) -> eyre::Result<()> {
393 let mut variants = vec![];
394 for s in &data.services {
395 for e in &s.endpoints {
396 variants.push(EnumVariant::new(e.schema.name.clone(), e.schema.code as _));
397 }
398 }
399 let enum_ = Type::enum_("Endpoint", variants);
400 writeln!(writer, "{}", enum_.to_rust_decl(false, true))?;
401 Ok(())
403}
404pub fn dump_endpoint_schema(data: &Data, mut writer: impl Write) -> eyre::Result<()> {
405 let mut cases = vec![];
406 for s in &data.services {
407 for e in &s.endpoints {
408 cases.push(format!(
409 "Self::{name} => {name}Request::SCHEMA,",
410 name = e.schema.name.to_case(Case::Pascal),
411 ));
412 }
413 }
414 let code = format!(
415 r#"
416 impl EnumEndpoint {{
417 pub fn schema(&self) -> endpoint_libs::model::EndpointSchema {{
418 let schema = match self {{
419 {cases}
420 }};
421 serde_json::from_str(schema).unwrap()
422 }}
423 }}
424 "#,
425 cases = cases.join("\n")
426 );
427 writeln!(writer, "{code}")?;
428 Ok(())
429}
430
431#[cfg(test)]
432mod tests {
433 use regex::Regex;
434
435 #[test]
436 fn test_extract_number_from_error_code() {
437 let re = Regex::new(r"=\s*(\d+)").unwrap();
438
439 let text1 = r#" ///
441 LoginStep2 = 10003
442 ,"#;
443 let caps1 = re.captures(text1).expect("Should match");
444 let number1: u64 = caps1[1].parse().expect("Should parse as u64");
445 assert_eq!(number1, 10003);
446
447 let text2 = "Authorize = 10000,";
449 let caps2 = re.captures(text2).expect("Should match");
450 let number2: u64 = caps2[1].parse().expect("Should parse as u64");
451 assert_eq!(number2, 10000);
452
453 let text3 = "SomeError=12345,";
455 let caps3 = re.captures(text3).expect("Should match");
456 let number3: u64 = caps3[1].parse().expect("Should parse as u64");
457 assert_eq!(number3, 12345);
458
459 let text4 = r#"/// SQL R0019 UnauthorizedMessage
461 UnauthorizedMessage = 45349677
462, "#;
463 let caps4 = re.captures(text4).expect("Should match");
464 let number4: u64 = caps4[1].parse().expect("Should parse as u64");
465 assert_eq!(number4, 45349677);
466 }
467}