1use crate::rust::ToRust;
2use convert_case::{Case, Casing};
3use endpoint_gen_macros::DefinitionVariant;
4use endpoint_libs::model::{EndpointSchema, Type};
5use itertools::Itertools;
6use serde::{Deserialize, Serialize};
7use smart_default::SmartDefault;
8use smart_serde_default::smart_serde_default;
9
10#[derive(Serialize, Deserialize)]
15pub enum Definition {
16 EndpointSchema(EndpointSchemaDefinition),
17 EndpointSchemaList(EndpointSchemaListDefinition),
18 Enum(EnumElement),
19 EnumList(EnumListDefinition),
20 Struct(StructElement),
21 StructList(StructListDefinition),
22}
23
24impl Definition {
25 pub fn validate_self(&self) -> eyre::Result<()> {
26 match self {
27 Definition::Enum(e) => e.validate_element(),
28 Definition::EnumList(list) => list.validate_element(),
29 Definition::Struct(s) => s.validate_element(),
30 Definition::StructList(list) => list.validate_element(),
31 Definition::EndpointSchema(schema) => schema.validate_element(),
32 Definition::EndpointSchemaList(schemas) => schemas.validate_element(),
33 }
34 }
35}
36
37pub trait GenElement<T: ?Sized>
38where
39 T: GenElement<T>,
40{
41 fn validate_element(&self) -> eyre::Result<()>;
42}
43
44#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, PartialOrd, Eq, Ord, DefinitionVariant)]
46pub struct EnumElement {
47 #[serde(default)]
48 pub config: RustGenConfig,
49 pub inner: Type,
50}
51
52impl GenElement<EnumElement> for EnumElement {
53 fn validate_element(&self) -> eyre::Result<()> {
54 match &self.inner {
55 Type::Enum { .. } => Ok(()),
56 _ => eyre::bail!("Expected enum type"),
57 }
58 }
59}
60
61#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, PartialOrd, Eq, Ord, DefinitionVariant)]
62pub struct EnumListDefinition {
63 #[serde(default)]
64 pub config: RustGenConfig,
65 pub enum_elements: Vec<EnumElement>,
66}
67
68impl GenElement<EnumListDefinition> for EnumListDefinition {
69 fn validate_element(&self) -> eyre::Result<()> {
70 if self.enum_elements.iter().all(|e| matches!(e.inner, Type::Enum { .. })) {
71 Ok(())
72 } else {
73 eyre::bail!("Not all elements of the EnumListDefinition are Enum types")
74 }
75 }
76}
77
78impl ToRust for EnumElement {
79 fn to_rust_ref(&self, _serde_with: bool) -> String {
80 self.validate_element()
81 .unwrap_or_else(|_| panic!("EnumElement is invalid: {self:?}"));
82
83 let name = match &self.inner {
84 Type::Enum { name, .. } => name.to_case(Case::Pascal),
85 _ => unreachable!("The previous validation ensured that this type is a valid Enum"),
86 };
87
88 if self.config.prefix_enum {
89 format!("Enum{name}")
90 } else {
91 name
92 }
93 }
94
95 fn to_rust_decl(&self, serde_with: bool, add_derives: bool) -> String {
96 self.validate_element()
97 .unwrap_or_else(|_| panic!("EnumElement is invalid: {self:?}"));
98
99 let code_regex = regex::Regex::new(r"=\s*(\d+)").expect("Error building regex to extract endpoint code");
100
101 match &self.inner {
102 Type::Enum {
103 name: _,
104 variants: fields,
105 } => {
106 let mut fields = fields
107 .iter()
108 .map(|x| {
109 format!(
110 r#"
111 /// {}
112 {} = {}
113"#,
114 x.description,
115 if x.name.chars().last().unwrap().is_lowercase() {
116 x.name.to_case(Case::Pascal)
117 } else {
118 x.name.clone()
119 },
120 x.value
121 )
122 })
123 .sorted_by(|a, b| {
124 let code_a = {
126 match code_regex.captures(a) {
127 Some(code) => code[1].parse::<u64>().unwrap_or_else(|err| {
128 eprintln!("Sorting error: {err}: Rust output may not be sorted correctly");
129 0
130 }),
131 None => {
132 eprintln!("Sorting error: Rust output may not be sorted correctly");
133 0
134 }
135 }
136 };
137
138 let code_b = {
139 match code_regex.captures(b) {
140 Some(code) => code[1].parse::<u64>().unwrap_or_else(|err| {
141 eprintln!("Sorting error: {err}: Rust output may not be sorted correctly");
142 0
143 }),
144 None => {
145 eprintln!("Sorting error: Rust output may not be sorted correctly");
146 0
147 }
148 }
149 };
150
151 code_a.cmp(&code_b)
152 });
153 let enum_content = format!(r#"pub enum {} {{{}}}"#, self.to_rust_ref(serde_with), fields.join(","));
154
155 if add_derives {
156 self.add_derives(enum_content)
157 } else {
158 enum_content
159 }
160 }
161 _ => unreachable!(),
162 }
163 }
164
165 fn add_derives(&self, input: String) -> String {
166 if self.config.worktable_support {
167 format!(
168 r#"#[derive(
169 MemStat,
170 Archive,
171 Clone,
172 Copy,
173 Debug,
174 Display,
175 PartialEq,
176 PartialOrd,
177 Eq,
178 Hash,
179 Ord,
180 EnumString,
181 rkyv::Deserialize,
182 rkyv::Serialize,
183 serde::Serialize,
184 serde::Deserialize,
185 {}
186 )]
187 #[rkyv(compare(PartialEq), derive(Debug))]
188 #[repr(u8)]
189 {input}
190 "#,
191 if self.config.json_schema_gen {
192 "JsonSchema,"
193 } else {
194 Default::default()
195 }
196 )
197 } else {
198 Type::add_default_enum_derives(input)
199 }
200 }
201}
202#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, PartialOrd, Eq, Ord, DefinitionVariant)]
203pub struct StructListDefinition {
204 #[serde(default)]
205 pub config: RustGenConfig,
206 pub struct_elements: Vec<StructElement>,
207}
208
209impl GenElement<StructListDefinition> for StructListDefinition {
210 fn validate_element(&self) -> eyre::Result<()> {
211 if self
212 .struct_elements
213 .iter()
214 .all(|s| matches!(s.inner, Type::Struct { .. }))
215 {
216 Ok(())
217 } else {
218 eyre::bail!("Not all elements of the StructListDefinition are Struct types")
219 }
220 }
221}
222
223#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, PartialOrd, Eq, Ord, DefinitionVariant)]
225pub struct StructElement {
226 #[serde(default)]
227 pub config: RustGenConfig,
228 pub inner: Type,
229}
230
231impl GenElement<StructElement> for StructElement {
232 fn validate_element(&self) -> eyre::Result<()> {
233 match &self.inner {
234 Type::Struct { .. } => Ok(()),
235 _ => eyre::bail!("Expected struct type"),
236 }
237 }
238}
239
240impl ToRust for StructElement {
241 fn to_rust_ref(&self, _serde_with: bool) -> String {
242 self.validate_element()
243 .unwrap_or_else(|_| panic!("StructElement is invalid: {self:?}"));
244
245 match &self.inner {
246 Type::Struct { name, .. } => name.clone(),
247 _ => unreachable!("The previous validation ensured that this type is a valid Struct"),
248 }
249 }
250
251 fn to_rust_decl(&self, serde_with: bool, add_derives: bool) -> String {
252 self.validate_element()
253 .unwrap_or_else(|_| panic!("StructElement is invalid: {self:?}"));
254
255 let (name, fields) = match &self.inner {
256 Type::Struct { name, fields } => (name.to_case(Case::Pascal), fields),
257 _ => unreachable!("The previous validation ensured that this type is a valid Struct"),
258 };
259
260 let mut fields = fields.iter().map(|x| {
261 let opt = matches!(&x.ty, Type::Optional(_));
262 let serde_with_opt = match &x.ty {
263 Type::BlockchainDecimal => "rust_decimal::serde::str",
264 Type::BlockchainAddress if serde_with => "WithBlockchainAddress",
265 Type::BlockchainTransactionHash if serde_with => "WithBlockchainTransactionHash",
266 _ => "",
277 };
278 format!(
279 "{} {} pub {}: {}",
280 if opt { "#[serde(default)]" } else { "" },
281 if serde_with_opt.is_empty() {
282 "".to_string()
283 } else {
284 format!("#[serde(with = \"{serde_with_opt}\")]")
285 },
286 x.name,
287 x.ty.to_rust_ref(serde_with)
288 )
289 });
290 let input = format!("pub struct {} {{{}}}", name, fields.join(","));
291
292 if add_derives { self.add_derives(input) } else { input }
293 }
294
295 fn add_derives(&self, input: String) -> String {
296 if self.config.worktable_support {
297 format!(
327 r#" #[derive(Serialize, Deserialize, Debug, Clone, {})]
328 #[serde(rename_all = "camelCase")]
329 {input}
330 "#,
331 if self.config.json_schema_gen {
332 "JsonSchema,"
333 } else {
334 Default::default()
335 }
336 )
337 } else {
338 Type::add_default_struct_derives(input)
339 }
340 }
341}
342
343#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, PartialOrd, Eq, Ord, Default)]
344pub struct RustGenConfig {
345 #[serde(default)]
346 pub prefix_enum: bool,
347 #[serde(default)]
348 pub worktable_support: bool,
349 #[serde(default)]
350 pub json_schema_gen: bool,
351 #[serde(default)]
352 pub snake_case_fields: bool,
353 #[serde(default)]
354 pub override_parent: bool,
355}
356
357#[derive(Serialize, Deserialize, Clone)]
358pub struct GenService {
359 pub name: String,
360 pub id: u16,
361 pub endpoints: Vec<EndpointSchemaElement>,
362}
363
364impl GenService {
365 pub fn new(name: String, id: u16, endpoints: Vec<EndpointSchemaElement>) -> Self {
366 Self { name, id, endpoints }
367 }
368}
369
370#[derive(Serialize, Deserialize)]
371pub struct EndpointSchemaDefinition {
372 pub service_name: String,
373 pub service_id: u16,
374 pub schema: EndpointSchemaElement,
375}
376
377#[smart_serde_default]
378#[derive(Serialize, Deserialize, SmartDefault, Clone)]
379pub struct EndpointSchemaElement {
380 #[smart_default(true)]
381 pub frontend_facing: bool,
382 #[serde(default)]
383 pub config: RustGenConfig,
384 pub schema: EndpointSchema,
385}
386
387impl From<EndpointSchemaElement> for EndpointSchema {
388 fn from(val: EndpointSchemaElement) -> Self {
389 EndpointSchema {
390 name: val.schema.name,
391 code: val.schema.code,
392 parameters: val.schema.parameters,
393 returns: val.schema.returns,
394 stream_response: val.schema.stream_response,
395 description: val.schema.description,
396 json_schema: val.schema.json_schema,
397 roles: val.schema.roles,
398 }
399 }
400}
401
402impl FromIterator<EndpointSchemaElement> for Vec<EndpointSchema> {
403 fn from_iter<T: IntoIterator<Item = EndpointSchemaElement>>(iter: T) -> Self {
404 iter.into_iter().map(|element| element.schema).collect()
405 }
406}
407
408impl GenElement<EndpointSchemaDefinition> for EndpointSchemaDefinition {
409 fn validate_element(&self) -> eyre::Result<()> {
410 Ok(())
411 }
412}
413
414#[derive(Serialize, Deserialize, DefinitionVariant)]
415pub struct EndpointSchemaListDefinition {
416 pub service_name: String,
417 pub service_id: u16,
418 #[serde(default)]
419 pub config: RustGenConfig,
420 pub endpoints: Vec<EndpointSchemaElement>,
421}
422
423impl GenElement<EndpointSchemaListDefinition> for EndpointSchemaListDefinition {
424 fn validate_element(&self) -> eyre::Result<()> {
425 Ok(())
426 }
427}