1use crate::rust::ToRust;
2use convert_case::{Case, Casing};
3use endpoint_gen_macros::DefinitionVariant;
4use endpoint_libs::model::{EndpointSchema, Type};
5use itertools::Itertools;
6use serde::{Deserialize, Serialize};
7use smart_default::SmartDefault;
8use smart_serde_default::smart_serde_default;
9
10#[derive(Serialize, Deserialize)]
15pub enum Definition {
16 EndpointSchema(EndpointSchemaDefinition),
17 EndpointSchemaList(EndpointSchemaListDefinition),
18 Enum(EnumElement),
19 EnumList(EnumListDefinition),
20 ErrorCodeList(ErrorCodeListDefinition),
21 Struct(StructElement),
22 StructList(StructListDefinition),
23}
24
25impl Definition {
26 pub fn validate_self(&self) -> eyre::Result<()> {
27 match self {
28 Definition::Enum(e) => e.validate_element(),
29 Definition::EnumList(list) => list.validate_element(),
30 Definition::ErrorCodeList(list) => list.validate_element(),
31 Definition::Struct(s) => s.validate_element(),
32 Definition::StructList(list) => list.validate_element(),
33 Definition::EndpointSchema(schema) => schema.validate_element(),
34 Definition::EndpointSchemaList(schemas) => schemas.validate_element(),
35 }
36 }
37}
38
39pub trait GenElement<T: ?Sized>
40where
41 T: GenElement<T>,
42{
43 fn validate_element(&self) -> eyre::Result<()>;
44}
45
46#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, PartialOrd, Eq, Ord)]
47pub struct ErrorCodeSchema {
48 pub name: String,
49 pub code: i64,
50 #[serde(default)]
51 pub description: String,
52}
53
54impl ErrorCodeSchema {
55 pub fn new(name: impl Into<String>, code: i64, description: impl Into<String>) -> Self {
56 Self {
57 name: name.into(),
58 code,
59 description: description.into(),
60 }
61 }
62}
63
64#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, PartialOrd, Eq, Ord, DefinitionVariant)]
65pub struct ErrorCodeListDefinition {
66 pub codes: Vec<ErrorCodeSchema>,
67}
68
69impl GenElement<ErrorCodeListDefinition> for ErrorCodeListDefinition {
70 fn validate_element(&self) -> eyre::Result<()> {
71 Ok(())
72 }
73}
74
75#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, PartialOrd, Eq, Ord, DefinitionVariant)]
77pub struct EnumElement {
78 #[serde(default)]
79 pub config: RustGenConfig,
80 pub inner: Type,
81}
82
83impl GenElement<EnumElement> for EnumElement {
84 fn validate_element(&self) -> eyre::Result<()> {
85 match &self.inner {
86 Type::Enum { .. } => Ok(()),
87 _ => eyre::bail!("Expected enum type"),
88 }
89 }
90}
91
92#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, PartialOrd, Eq, Ord, DefinitionVariant)]
93pub struct EnumListDefinition {
94 #[serde(default)]
95 pub config: RustGenConfig,
96 pub enum_elements: Vec<EnumElement>,
97}
98
99impl GenElement<EnumListDefinition> for EnumListDefinition {
100 fn validate_element(&self) -> eyre::Result<()> {
101 if self.enum_elements.iter().all(|e| matches!(e.inner, Type::Enum { .. })) {
102 Ok(())
103 } else {
104 eyre::bail!("Not all elements of the EnumListDefinition are Enum types")
105 }
106 }
107}
108
109impl ToRust for EnumElement {
110 fn to_rust_ref(&self, _serde_with: bool) -> String {
111 self.validate_element()
112 .unwrap_or_else(|_| panic!("EnumElement is invalid: {self:?}"));
113
114 let name = match &self.inner {
115 Type::Enum { name, .. } => name.to_case(Case::Pascal),
116 _ => unreachable!("The previous validation ensured that this type is a valid Enum"),
117 };
118
119 if self.config.prefix_enum {
120 format!("Enum{name}")
121 } else {
122 name
123 }
124 }
125
126 fn to_rust_decl(&self, serde_with: bool, add_derives: bool) -> String {
127 self.validate_element()
128 .unwrap_or_else(|_| panic!("EnumElement is invalid: {self:?}"));
129
130 let code_regex = regex::Regex::new(r"=\s*(\d+)").expect("Error building regex to extract endpoint code");
131
132 match &self.inner {
133 Type::Enum {
134 name: _,
135 variants: fields,
136 } => {
137 let mut fields = fields
138 .iter()
139 .map(|x| {
140 format!(
141 r#"
142 /// {}
143 {} = {}
144"#,
145 x.description,
146 if x.name.chars().last().unwrap().is_lowercase() {
147 x.name.to_case(Case::Pascal)
148 } else {
149 x.name.clone()
150 },
151 x.value
152 )
153 })
154 .sorted_by(|a, b| {
155 let code_a = {
157 match code_regex.captures(a) {
158 Some(code) => code[1].parse::<u64>().unwrap_or_else(|err| {
159 eprintln!("Sorting error: {err}: Rust output may not be sorted correctly");
160 0
161 }),
162 None => {
163 eprintln!("Sorting error: Rust output may not be sorted correctly");
164 0
165 }
166 }
167 };
168
169 let code_b = {
170 match code_regex.captures(b) {
171 Some(code) => code[1].parse::<u64>().unwrap_or_else(|err| {
172 eprintln!("Sorting error: {err}: Rust output may not be sorted correctly");
173 0
174 }),
175 None => {
176 eprintln!("Sorting error: Rust output may not be sorted correctly");
177 0
178 }
179 }
180 };
181
182 code_a.cmp(&code_b)
183 });
184 let enum_content = format!(r#"pub enum {} {{{}}}"#, self.to_rust_ref(serde_with), fields.join(","));
185
186 if add_derives {
187 self.add_derives(enum_content)
188 } else {
189 enum_content
190 }
191 }
192 _ => unreachable!(),
193 }
194 }
195
196 fn add_derives(&self, input: String) -> String {
197 if self.config.worktable_support {
198 format!(
199 r#"#[derive(
200 MemStat,
201 Archive,
202 Clone,
203 Copy,
204 Debug,
205 Display,
206 PartialEq,
207 PartialOrd,
208 Eq,
209 Hash,
210 Ord,
211 EnumString,
212 rkyv::Deserialize,
213 rkyv::Serialize,
214 serde::Serialize,
215 serde::Deserialize,
216 {}
217 )]
218 #[rkyv(compare(PartialEq), derive(Debug))]
219 #[repr(u8)]
220 {input}
221 "#,
222 if self.config.json_schema_gen {
223 "JsonSchema,"
224 } else {
225 Default::default()
226 }
227 )
228 } else {
229 Type::add_default_enum_derives(input)
230 }
231 }
232}
233#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, PartialOrd, Eq, Ord, DefinitionVariant)]
234pub struct StructListDefinition {
235 #[serde(default)]
236 pub config: RustGenConfig,
237 pub struct_elements: Vec<StructElement>,
238}
239
240impl GenElement<StructListDefinition> for StructListDefinition {
241 fn validate_element(&self) -> eyre::Result<()> {
242 if self
243 .struct_elements
244 .iter()
245 .all(|s| matches!(s.inner, Type::Struct { .. }))
246 {
247 Ok(())
248 } else {
249 eyre::bail!("Not all elements of the StructListDefinition are Struct types")
250 }
251 }
252}
253
254#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, PartialOrd, Eq, Ord, DefinitionVariant)]
256pub struct StructElement {
257 #[serde(default)]
258 pub config: RustGenConfig,
259 pub inner: Type,
260}
261
262impl GenElement<StructElement> for StructElement {
263 fn validate_element(&self) -> eyre::Result<()> {
264 match &self.inner {
265 Type::Struct { .. } => Ok(()),
266 _ => eyre::bail!("Expected struct type"),
267 }
268 }
269}
270
271impl ToRust for StructElement {
272 fn to_rust_ref(&self, _serde_with: bool) -> String {
273 self.validate_element()
274 .unwrap_or_else(|_| panic!("StructElement is invalid: {self:?}"));
275
276 match &self.inner {
277 Type::Struct { name, .. } => name.clone(),
278 _ => unreachable!("The previous validation ensured that this type is a valid Struct"),
279 }
280 }
281
282 fn to_rust_decl(&self, serde_with: bool, add_derives: bool) -> String {
283 self.validate_element()
284 .unwrap_or_else(|_| panic!("StructElement is invalid: {self:?}"));
285
286 let (name, fields) = match &self.inner {
287 Type::Struct { name, fields } => (name.to_case(Case::Pascal), fields),
288 _ => unreachable!("The previous validation ensured that this type is a valid Struct"),
289 };
290
291 let mut fields = fields.iter().map(|x| {
292 let opt = matches!(&x.ty, Type::Optional(_));
293 let serde_with_opt = match &x.ty {
294 Type::BlockchainDecimal => "rust_decimal::serde::str",
295 Type::BlockchainAddress if serde_with => "WithBlockchainAddress",
296 Type::BlockchainTransactionHash if serde_with => "WithBlockchainTransactionHash",
297 _ => "",
308 };
309 format!(
310 "{} {} pub {}: {}",
311 if opt { "#[serde(default)]" } else { "" },
312 if serde_with_opt.is_empty() {
313 "".to_string()
314 } else {
315 format!("#[serde(with = \"{serde_with_opt}\")]")
316 },
317 x.name,
318 x.ty.to_rust_ref(serde_with)
319 )
320 });
321 let input = format!("pub struct {} {{{}}}", name, fields.join(","));
322
323 if add_derives { self.add_derives(input) } else { input }
324 }
325
326 fn add_derives(&self, input: String) -> String {
327 if self.config.worktable_support {
328 format!(
358 r#" #[derive(Serialize, Deserialize, Debug, Clone, {})]
359 #[serde(rename_all = "camelCase")]
360 {input}
361 "#,
362 if self.config.json_schema_gen {
363 "JsonSchema,"
364 } else {
365 Default::default()
366 }
367 )
368 } else {
369 Type::add_default_struct_derives(input)
370 }
371 }
372}
373
374#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, PartialOrd, Eq, Ord, Default)]
375pub struct RustGenConfig {
376 #[serde(default)]
377 pub prefix_enum: bool,
378 #[serde(default)]
379 pub worktable_support: bool,
380 #[serde(default)]
381 pub json_schema_gen: bool,
382 #[serde(default)]
383 pub snake_case_fields: bool,
384 #[serde(default)]
385 pub override_parent: bool,
386}
387
388#[derive(Serialize, Deserialize, Clone)]
389pub struct GenService {
390 pub name: String,
391 pub id: u16,
392 pub endpoints: Vec<EndpointSchemaElement>,
393}
394
395impl GenService {
396 pub fn new(name: String, id: u16, endpoints: Vec<EndpointSchemaElement>) -> Self {
397 Self { name, id, endpoints }
398 }
399}
400
401#[derive(Serialize, Deserialize)]
402pub struct EndpointSchemaDefinition {
403 pub service_name: String,
404 pub service_id: u16,
405 pub schema: EndpointSchemaElement,
406}
407
408#[smart_serde_default]
409#[derive(Serialize, Deserialize, SmartDefault, Clone)]
410pub struct EndpointSchemaElement {
411 #[smart_default(true)]
412 pub frontend_facing: bool,
413 #[serde(default)]
414 pub config: RustGenConfig,
415 pub schema: EndpointSchema,
416}
417
418impl From<EndpointSchemaElement> for EndpointSchema {
419 fn from(val: EndpointSchemaElement) -> Self {
420 EndpointSchema {
421 name: val.schema.name,
422 code: val.schema.code,
423 parameters: val.schema.parameters,
424 returns: val.schema.returns,
425 stream_response: val.schema.stream_response,
426 description: val.schema.description,
427 json_schema: val.schema.json_schema,
428 roles: val.schema.roles,
429 errors: val.schema.errors,
430 }
431 }
432}
433
434impl FromIterator<EndpointSchemaElement> for Vec<EndpointSchema> {
435 fn from_iter<T: IntoIterator<Item = EndpointSchemaElement>>(iter: T) -> Self {
436 iter.into_iter().map(|element| element.schema).collect()
437 }
438}
439
440impl GenElement<EndpointSchemaDefinition> for EndpointSchemaDefinition {
441 fn validate_element(&self) -> eyre::Result<()> {
442 Ok(())
443 }
444}
445
446#[derive(Serialize, Deserialize, DefinitionVariant)]
447pub struct EndpointSchemaListDefinition {
448 pub service_name: String,
449 pub service_id: u16,
450 #[serde(default)]
451 pub config: RustGenConfig,
452 pub endpoints: Vec<EndpointSchemaElement>,
453}
454
455impl GenElement<EndpointSchemaListDefinition> for EndpointSchemaListDefinition {
456 fn validate_element(&self) -> eyre::Result<()> {
457 Ok(())
458 }
459}