Skip to main content

simploxide_bindgen/
commands.rs

1//! Turns COMMANDS.md file into na Iterator of [`crate::commands::CommandResponse`].
2
3use convert_case::{Case, Casing as _};
4
5use crate::{
6    parse_utils,
7    types::{
8        DiscriminatedUnionType, RecordType, TopLevelDocs,
9        discriminated_union_type::DiscriminatedUnionVariant,
10    },
11};
12
13pub fn parse(commands_md: &str) -> impl Iterator<Item = Result<CommandResponse, String>> {
14    let mut parser = Parser::default();
15
16    commands_md
17        .split("---")
18        .skip(1)
19        .filter_map(|s| {
20            let trimmed = s.trim();
21            (!trimmed.is_empty()).then_some(trimmed)
22        })
23        .map(move |blk| parser.parse_block(blk))
24}
25
26pub struct CommandResponse {
27    pub command: RecordType,
28    pub response: DiscriminatedUnionType,
29}
30
31/// Generates the provided trait method for the ClientApi trait.
32///
33/// The ClientApi trait definition itself must be generated by the client and should look like this
34///
35/// ```ignore
36/// pub trait ClientApi {
37///     type Error;
38///
39///     fn send_raw(&self, command: String) -> impl Future<Output = Result<Arc<{}>, Self::Error>> + Send;
40///
41///     //..
42/// }
43/// ```
44///
45/// Then the provided methods could be inserted.
46///
47/// For methods that return multiple kinds a valid responses the additional wrapper type should be
48/// generated. You can get this type definition using the
49/// [`CommandResponseTraitMethod::response_wrapper`] method.
50pub struct CommandResponseTraitMethod<'a> {
51    pub command: &'a RecordType,
52    pub response: &'a DiscriminatedUnionType,
53    pub shapes: &'a [RecordType],
54}
55
56impl<'a> CommandResponseTraitMethod<'a> {
57    pub fn new(
58        command: &'a RecordType,
59        response: &'a DiscriminatedUnionType,
60        shapes: &'a [RecordType],
61    ) -> Self {
62        Self {
63            command,
64            response,
65            shapes,
66        }
67    }
68}
69
70impl<'a> CommandResponseTraitMethod<'a> {
71    /// Instead of accepting a command type directly we can accept its arguments and construct it
72    /// internally.
73    ///
74    /// ```ignore
75    ///     fn api_show_my_address(cmd: ApiShowMyAddressCommand) -> ...
76    /// ```
77    ///
78    /// turns into
79    ///
80    /// ```ignore
81    ///     fn api_show_my_address(user_id: i64) -> ... {
82    ///         let cmd = ApiShowMyAddressCommand { user_id };
83    ///     }
84    /// ```
85    ///
86    /// This condition determines when such transformation takes place
87    fn can_inline_args(&self) -> bool {
88        !self
89            .command
90            .fields
91            .iter()
92            .any(|f| f.is_optional() || f.is_bool())
93    }
94
95    /// If response consists only of a single valid variant this variant's inner struct can be
96    /// used directly as a return value of the API method.
97    fn can_inline_response(&self) -> Option<&DiscriminatedUnionVariant> {
98        if self.responses().count() == 1 {
99            self.responses().next()
100        } else {
101            None
102        }
103    }
104
105    fn responses(&self) -> impl Iterator<Item = &'_ DiscriminatedUnionVariant> {
106        self.response.variants.iter()
107    }
108}
109
110impl<'a> std::fmt::Display for CommandResponseTraitMethod<'a> {
111    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112        self.command.write_docs_fmt(f)?;
113        write!(
114            f,
115            "    fn {}(&self",
116            self.command.name.remove_empty().to_case(Case::Snake)
117        )?;
118
119        let (ret_type, is_response_inlined) =
120            if let Some(inlined_variant) = self.can_inline_response() {
121                let typename = inlined_variant.fields[0].typ.clone();
122                (format!("Arc<{typename}>"), true)
123            } else {
124                (self.response.name.clone(), false)
125            };
126
127        if self.can_inline_args() {
128            for field in self.command.fields.iter() {
129                write!(f, ", {}: {}", field.rust_name, field.typ)?;
130            }
131
132            writeln!(
133                f,
134                ") -> impl Future<Output = Result<{ret_type}, Self::Error>> + Send {{ async move {{",
135            )?;
136            write!(f, "        let command = {} {{", self.command.name)?;
137
138            for (ix, field) in self.command.fields.iter().enumerate() {
139                if ix > 0 {
140                    write!(f, ", ")?;
141                }
142
143                write!(f, "{}", field.rust_name)?;
144            }
145            writeln!(f, "}};")?;
146        } else {
147            writeln!(
148                f,
149                ", command: {}) -> impl Future<Output = Result<{ret_type}, Self::Error>> + Send {{ async move {{",
150                self.command.name,
151            )?;
152        }
153
154        writeln!(
155            f,
156            "        let response: {} = self.send(command).await?;",
157            self.response.name
158        )?;
159
160        let into_inner = if is_response_inlined {
161            ".into_inner()"
162        } else {
163            ""
164        };
165
166        writeln!(f, "        Ok(response{})", into_inner)?;
167
168        writeln!(f, "        }}")?;
169        writeln!(f, "    }}")
170    }
171}
172
173/// Use this formatter for command types instead of the standard std::fmt::Display impl of the
174/// [`RecordType`]. This impl strips down all serialization attributes and undocumented fields.
175pub struct CommandFmt<'a>(pub &'a RecordType);
176
177impl std::fmt::Display for CommandFmt<'_> {
178    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179        self.0.write_docs_fmt(f)?;
180
181        writeln!(f, "#[derive(Debug, Clone, PartialEq)]")?;
182        writeln!(f, "#[cfg_attr(feature = \"bon\", derive(::bon::Builder))]")?;
183
184        writeln!(f, "pub struct {} {{", self.0.name)?;
185
186        for field in self.0.fields.iter() {
187            writeln!(f, "    pub {}: {},", field.rust_name, field.typ)?;
188        }
189
190        writeln!(f, "}}")?;
191
192        if self.0.fields.iter().any(|f| f.is_optional() || f.is_bool()) {
193            writeln!(f)?;
194            writeln!(f, "impl {} {{", self.0.name)?;
195            writeln!(
196                f,
197                "    /// Creates a command with all `Option` parameters set to `None` and all `bool` parameters set to false"
198            )?;
199            write!(f, "    pub fn new(")?;
200
201            for (i, field) in self
202                .0
203                .fields
204                .iter()
205                .filter(|f| !f.is_optional() && !f.is_bool())
206                .enumerate()
207            {
208                if i > 0 {
209                    write!(f, ", ")?;
210                }
211                write!(f, "{}: {}", field.rust_name, field.typ)?;
212            }
213            writeln!(f, ") -> Self {{")?;
214
215            writeln!(f, "        Self {{")?;
216            for field in self.0.fields.iter() {
217                if field.is_optional() {
218                    writeln!(f, "            {}: None,", field.rust_name)?;
219                } else if field.is_bool() {
220                    writeln!(f, "            {}: false,", field.rust_name)?;
221                } else {
222                    writeln!(f, "            {},", field.rust_name)?;
223                }
224            }
225            writeln!(f, "        }}")?;
226            writeln!(f, "    }}")?;
227            writeln!(f, "}}")?;
228        }
229
230        Ok(())
231    }
232}
233
234pub struct ResponseFmt<'a>(pub &'a DiscriminatedUnionType);
235
236impl std::fmt::Display for ResponseFmt<'_> {
237    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238        writeln!(
239            f,
240            "#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]"
241        )?;
242        writeln!(f, "#[serde(tag = \"type\")]")?;
243        writeln!(f, "pub enum {} {{", self.0.name)?;
244
245        for variant in &self.0.variants {
246            for comment_line in &variant.doc_comments {
247                writeln!(
248                    f,
249                    "    /// {}",
250                    crate::types::convert_doc_links(comment_line)
251                )?;
252            }
253            writeln!(f, "    #[serde(rename = \"{}\")]", variant.api_name)?;
254            writeln!(
255                f,
256                "    {}(Arc<{}>),",
257                variant.rust_name, variant.fields[0].typ
258            )?;
259        }
260        writeln!(f, "}}\n")?;
261
262        writeln!(f, "impl {} {{", self.0.name)?;
263        // Gen getters
264        if self.0.variants.len() > 1 {
265            for var in self.0.variants.iter() {
266                assert_eq!(var.fields.len(), 1, "Discriminated union is not disjointed");
267                assert!(
268                    var.fields[0].rust_name.is_empty(),
269                    "Discriminated union is not disjointed"
270                );
271
272                writeln!(
273                    f,
274                    "    pub fn {}(&self) -> Option<&{}> {{",
275                    var.rust_name.remove_empty().to_case(Case::Snake),
276                    var.fields[0].typ
277                )?;
278
279                writeln!(f, "        if let Self::{}(ret) = self {{", var.rust_name)?;
280                writeln!(f, "            Some(ret)",)?;
281                writeln!(f, "        }} else {{ None }}",)?;
282                writeln!(f, "    }}\n")?;
283            }
284        } else {
285            let var = &self.0.variants[0];
286            assert_eq!(var.fields.len(), 1, "Discriminated union is not disjointed");
287            assert!(
288                var.fields[0].rust_name.is_empty(),
289                "Discriminated union is not disjointed"
290            );
291
292            writeln!(
293                f,
294                "    pub fn into_inner(self) -> Arc<{}> {{",
295                var.fields[0].typ
296            )?;
297
298            writeln!(f, "        match self {{")?;
299            writeln!(f, "            Self::{}(inner) => inner, ", var.rust_name)?;
300            writeln!(f, "        }}",)?;
301            writeln!(f, "    }}\n")?;
302        }
303        writeln!(f, "}}")
304    }
305}
306
307#[derive(Default)]
308struct Parser {
309    current_doc_section: Option<DocSection>,
310}
311
312impl Parser {
313    pub fn parse_block(&mut self, block: &str) -> Result<CommandResponse, String> {
314        self.parser(block.lines().map(str::trim))
315            .map_err(|e| format!("{e} in block\n```\n{block}\n```"))
316    }
317
318    fn parser<'a>(
319        &mut self,
320        mut lines: impl Iterator<Item = &'a str>,
321    ) -> Result<CommandResponse, String> {
322        const DOC_SECTION_PAT: &str = parse_utils::H2;
323        const TYPENAME_PAT: &str = parse_utils::H3;
324        const TYPEKINDS_PAT: &str = parse_utils::BOLD;
325
326        let mut next =
327            parse_utils::skip_empty(&mut lines).ok_or_else(|| "Got an empty block".to_owned())?;
328
329        let mut command_docs: Vec<String> = Vec::new();
330
331        let (typename, mut typekind) = loop {
332            if let Some(section_name) = next.strip_prefix(DOC_SECTION_PAT) {
333                let mut doc_section = DocSection::new(section_name.to_owned());
334
335                next = parse_utils::parse_doc_lines(&mut lines, &mut doc_section.contents, |s| {
336                    s.starts_with(TYPENAME_PAT)
337                })
338                .ok_or_else(|| format!("Failed to find a typename by pattern {TYPENAME_PAT:?} after the doc section"))?;
339
340                self.current_doc_section.replace(doc_section);
341            } else if let Some(name) = next.strip_prefix(TYPENAME_PAT) {
342                next = parse_utils::parse_doc_lines(&mut lines, &mut command_docs, |s| {
343                    s.starts_with(TYPEKINDS_PAT)
344                })
345                .map(|s| s.strip_prefix(TYPEKINDS_PAT).unwrap())
346                .ok_or_else(|| format!("Failed to find a typekind by pattern {TYPEKINDS_PAT:?} after the inner docs "))?;
347
348                break (name, next);
349            }
350        };
351
352        let command_name = typename.to_case(Case::Pascal);
353        let mut command = RecordType::new(command_name.clone(), vec![]);
354
355        loop {
356            if typekind.starts_with("Parameters") {
357                typekind = parse_utils::parse_record_fields(
358                    &mut lines,
359                    &mut command.fields,
360                    |s| s.starts_with(TYPEKINDS_PAT),
361                )?
362                .map(|s| s.strip_prefix(TYPEKINDS_PAT).unwrap())
363                .ok_or_else(|| format!(
364                    "Failed to find a command syntax after parameters by pattern {TYPENAME_PAT:?}"
365                ))?;
366            } else if typekind.starts_with("Syntax") {
367                parse_utils::parse_syntax(&mut lines, &mut command.syntax)?;
368                break;
369            }
370        }
371
372        let mut response_variants: Vec<DiscriminatedUnionVariant> = Vec::with_capacity(4);
373
374        parse_utils::skip_while(&mut lines, |s| !s.starts_with("**Response")).ok_or_else(|| {
375            "Failed to find responses section by pattern \"**Response\"".to_owned()
376        })?;
377
378        let mut variant_docline = Vec::new();
379
380        while let Some(docline) = parse_utils::skip_empty(&mut lines) {
381            if docline.starts_with(TYPEKINDS_PAT) {
382                break;
383            } else {
384                variant_docline.push(docline.to_owned());
385            }
386
387            let (mut variant, next) = parse_utils::parse_discriminated_union_variant(&mut lines)?;
388            assert!(next.map(|s| s.is_empty()).unwrap_or(true));
389            variant.doc_comments = std::mem::take(&mut variant_docline);
390            response_variants.push(variant);
391        }
392
393        let response =
394            DiscriminatedUnionType::new(format!("{command_name}Response"), response_variants);
395
396        if let Some(ref outer_docs) = self.current_doc_section {
397            command
398                .doc_comments
399                .push(format!("### {}", outer_docs.header.clone()));
400
401            command.doc_comments.push(String::new());
402
403            command
404                .doc_comments
405                .extend(outer_docs.contents.iter().cloned());
406
407            command.doc_comments.push(String::new());
408            command.doc_comments.push("----".to_owned());
409            command.doc_comments.push(String::new());
410        }
411
412        command.doc_comments.extend(command_docs);
413        Ok(CommandResponse { command, response })
414    }
415}
416
417#[derive(Default, Clone)]
418struct DocSection {
419    header: String,
420    contents: Vec<String>,
421}
422
423impl DocSection {
424    fn new(header: String) -> Self {
425        Self {
426            header,
427            contents: Vec::new(),
428        }
429    }
430}