use convert_case::{Case, Casing as _};
use crate::{
parse_utils,
types::{
DiscriminatedUnionType, Field, RecordType, TopLevelDocs,
discriminated_union_type::DiscriminatedUnionVariant,
},
};
pub fn parse(commands_md: &str) -> impl Iterator<Item = Result<CommandResponse, String>> {
let mut parser = Parser::default();
commands_md
.split("---")
.skip(1)
.filter_map(|s| {
let trimmed = s.trim();
(!trimmed.is_empty()).then_some(trimmed)
})
.map(move |blk| parser.parse_block(blk))
}
pub struct CommandResponse {
pub command: RecordType,
pub response: DiscriminatedUnionType,
}
pub struct CommandResponseTraitMethod<'a> {
pub command: &'a RecordType,
pub response: &'a DiscriminatedUnionType,
pub shapes: &'a [RecordType],
}
impl<'a> CommandResponseTraitMethod<'a> {
pub fn new(
command: &'a RecordType,
response: &'a DiscriminatedUnionType,
shapes: &'a [RecordType],
) -> Self {
Self {
command,
response,
shapes,
}
}
}
impl<'a> CommandResponseTraitMethod<'a> {
pub fn response_wrapper(&self) -> Option<ResponseWrapperFmt> {
if self.can_inline_response().is_some() {
return None;
}
Some(ResponseWrapperFmt(DiscriminatedUnionType::new(
self.response_wrapper_name(),
self.valid_responses()
.cloned()
.zip(self.valid_response_shapes())
.map(|(mut resp, shape)| {
if shape.fields.len() == 1 {
resp.fields[0] = Field {
api_name: String::new(),
rust_name: String::new(),
typ: shape.fields[0].typ.clone(),
}
}
resp
})
.collect(),
)))
}
fn can_inline_args(&self) -> bool {
!self
.command
.fields
.iter()
.any(|f| f.is_optional() || f.is_bool())
}
fn can_inline_response(&self) -> Option<&DiscriminatedUnionVariant> {
if self.valid_responses().count() == 1 {
self.valid_responses().next()
} else {
None
}
}
fn can_inline_response_shape(&self) -> Option<&Field> {
if self.valid_response_shapes().count() != 1 {
return None;
}
let shape = self.valid_response_shapes().next().unwrap();
if shape.fields.len() == 1 {
Some(&shape.fields[0])
} else {
None
}
}
fn valid_responses(&self) -> impl Iterator<Item = &'_ DiscriminatedUnionVariant> {
self.response
.variants
.iter()
.filter(|x| x.rust_name != "ChatCmdError")
}
fn valid_response_shapes(&self) -> impl Iterator<Item = &'_ RecordType> {
self.shapes.iter().filter(|x| x.name != "ChatCmdError")
}
fn response_wrapper_name(&self) -> String {
format!("{}s", self.response.name)
}
}
impl<'a> std::fmt::Display for CommandResponseTraitMethod<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.command.write_docs_fmt(f)?;
write!(
f,
" fn {}(&self",
self.command.name.remove_empty().to_case(Case::Snake)
)?;
let (ret_type, unwrapped_response_typename) =
if let Some(inlined_variant) = self.can_inline_response() {
let typename = if let Some(field) = self.can_inline_response_shape() {
field.typ.clone()
} else {
inlined_variant.fields[0].typ.clone()
};
(format!("Arc<{typename}>"), typename)
} else {
let typename = self.response_wrapper_name();
(typename.clone(), typename)
};
if self.can_inline_args() {
for field in self.command.fields.iter() {
write!(f, ", {}: {}", field.rust_name, field.typ)?;
}
writeln!(
f,
") -> impl Future<Output = Result<{ret_type}, Self::Error>> + Send {{ async move {{",
)?;
write!(f, " let command = {} {{", self.command.name)?;
for (ix, field) in self.command.fields.iter().enumerate() {
if ix > 0 {
write!(f, ", ")?;
}
write!(f, "{}", field.rust_name)?;
}
writeln!(f, "}};")?;
} else {
writeln!(
f,
", command: {}) -> impl Future<Output = Result<{ret_type}, Self::Error>> + Send {{ async move {{",
self.command.name,
)?;
}
writeln!(
f,
" let json = self.send_raw(command.to_command_string()).await?;"
)?;
writeln!(
f,
" // Safe to unwrap because unrecognized JSON goes to undocumented variant"
)?;
writeln!(
f,
" let response = serde_json::from_value(json).unwrap();"
)?;
writeln!(f, " match response {{")?;
if let Some(variant) = self.can_inline_response() {
if let Some(field) = self.can_inline_response_shape() {
writeln!(
f,
" {}::{}(resp) => Ok(Arc::new(resp.{})),",
self.response.name, variant.rust_name, field.rust_name,
)?;
} else {
writeln!(
f,
" {}::{}(resp) => Ok(Arc::new(resp)),",
self.response.name, variant.rust_name
)?;
}
} else {
for (variant, shape) in self.valid_responses().zip(self.valid_response_shapes()) {
if shape.fields.len() == 1 {
writeln!(
f,
" {resp_name}::{var_name}(resp) => Ok({typename}::{var_name}(Arc::new(resp.{field}))),",
resp_name = self.response.name,
typename = unwrapped_response_typename,
var_name = variant.rust_name,
field = shape.fields[0].rust_name,
)?;
} else {
writeln!(
f,
" {}::{var_name}(resp) => Ok({}::{var_name}(Arc::new(resp))),",
self.response.name,
unwrapped_response_typename,
var_name = variant.rust_name,
)?;
}
}
}
writeln!(
f,
" {}::ChatCmdError(resp) => Err(BadResponseError::ChatCmdError(Arc::new(resp.chat_error)).into()),",
self.response.name,
)?;
writeln!(
f,
" {}::Undocumented(resp) => Err(BadResponseError::Undocumented(resp).into()),",
self.response.name,
)?;
writeln!(f, " }}")?;
writeln!(f, " }}")?;
writeln!(f, " }}")
}
}
pub struct CommandFmt<'a>(pub &'a RecordType);
impl std::fmt::Display for CommandFmt<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.write_docs_fmt(f)?;
writeln!(f, "#[derive(Debug, Clone, PartialEq)]")?;
writeln!(f, "#[cfg_attr(feature = \"bon\", derive(::bon::Builder))]")?;
writeln!(f, "pub struct {} {{", self.0.name)?;
for field in self.0.fields.iter() {
writeln!(f, " pub {}: {},", field.rust_name, field.typ)?;
}
writeln!(f, "}}")
}
}
pub struct ResponseWrapperFmt(pub DiscriminatedUnionType);
impl std::fmt::Display for ResponseWrapperFmt {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(
f,
"#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]"
)?;
writeln!(f, "#[serde(tag = \"type\")]")?;
writeln!(f, "pub enum {} {{", self.0.name)?;
for variant in &self.0.variants {
for comment_line in &variant.doc_comments {
writeln!(f, " /// {}", comment_line)?;
}
writeln!(f, " #[serde(rename = \"{}\")]", variant.api_name)?;
writeln!(
f,
" {}(Arc<{}>),",
variant.rust_name, variant.fields[0].typ
)?;
}
writeln!(f, "}}\n")?;
writeln!(f, "impl {} {{", self.0.name)?;
for var in self.0.variants.iter() {
assert_eq!(var.fields.len(), 1, "Discriminated union is not disjointed");
assert!(
var.fields[0].rust_name.is_empty(),
"Discriminated union is not disjointed"
);
writeln!(
f,
" pub fn {}(&self) -> Option<&{}> {{",
var.rust_name.remove_empty().to_case(Case::Snake),
var.fields[0].typ
)?;
writeln!(f, " if let Self::{}(ret) = self {{", var.rust_name)?;
writeln!(f, " Some(ret)",)?;
writeln!(f, " }} else {{ None }}",)?;
writeln!(f, " }}\n")?;
}
writeln!(f, "}}")
}
}
#[derive(Default)]
struct Parser {
current_doc_section: Option<DocSection>,
}
impl Parser {
pub fn parse_block(&mut self, block: &str) -> Result<CommandResponse, String> {
self.parser(block.lines().map(str::trim))
.map_err(|e| format!("{e} in block\n```\n{block}\n```"))
}
fn parser<'a>(
&mut self,
mut lines: impl Iterator<Item = &'a str>,
) -> Result<CommandResponse, String> {
const DOC_SECTION_PAT: &str = parse_utils::H2;
const TYPENAME_PAT: &str = parse_utils::H3;
const TYPEKINDS_PAT: &str = parse_utils::BOLD;
let mut next =
parse_utils::skip_empty(&mut lines).ok_or_else(|| "Got an empty block".to_owned())?;
let mut command_docs: Vec<String> = Vec::new();
let (typename, mut typekind) = loop {
if let Some(section_name) = next.strip_prefix(DOC_SECTION_PAT) {
let mut doc_section = DocSection::new(section_name.to_owned());
next = parse_utils::parse_doc_lines(&mut lines, &mut doc_section.contents, |s| {
s.starts_with(TYPENAME_PAT)
})
.ok_or_else(|| format!("Failed to find a typename by pattern {TYPENAME_PAT:?} after the doc section"))?;
self.current_doc_section.replace(doc_section);
} else if let Some(name) = next.strip_prefix(TYPENAME_PAT) {
next = parse_utils::parse_doc_lines(&mut lines, &mut command_docs, |s| {
s.starts_with(TYPEKINDS_PAT)
})
.map(|s| s.strip_prefix(TYPEKINDS_PAT).unwrap())
.ok_or_else(|| format!("Failed to find a typekind by pattern {TYPEKINDS_PAT:?} after the inner docs "))?;
break (name, next);
}
};
let command_name = typename.to_case(Case::Pascal);
let mut command = RecordType::new(command_name.clone(), vec![]);
loop {
if typekind.starts_with("Parameters") {
typekind = parse_utils::parse_record_fields(
&mut lines,
&mut command.fields,
|s| s.starts_with(TYPEKINDS_PAT),
)?
.map(|s| s.strip_prefix(TYPEKINDS_PAT).unwrap())
.ok_or_else(|| format!(
"Failed to find a command syntax after parameters by pattern {TYPENAME_PAT:?}"
))?;
} else if typekind.starts_with("Syntax") {
parse_utils::parse_syntax(&mut lines, &mut command.syntax)?;
break;
}
}
let mut response_variants: Vec<DiscriminatedUnionVariant> = Vec::with_capacity(4);
parse_utils::skip_while(&mut lines, |s| !s.starts_with("**Response")).ok_or_else(|| {
"Failed to find responses section by pattern \"**Response\"".to_owned()
})?;
let mut variant_docline = Vec::new();
while let Some(docline) = parse_utils::skip_empty(&mut lines) {
if docline.starts_with(TYPEKINDS_PAT) {
break;
} else {
variant_docline.push(docline.to_owned());
}
let (mut variant, next) = parse_utils::parse_discriminated_union_variant(&mut lines)?;
assert!(next.map(|s| s.is_empty()).unwrap_or(true));
variant.doc_comments = std::mem::take(&mut variant_docline);
response_variants.push(variant);
}
let response =
DiscriminatedUnionType::new(format!("{command_name}Response"), response_variants);
if let Some(ref outer_docs) = self.current_doc_section {
command
.doc_comments
.push(format!("### {}", outer_docs.header.clone()));
command.doc_comments.push(String::new());
command
.doc_comments
.extend(outer_docs.contents.iter().cloned());
command.doc_comments.push(String::new());
command.doc_comments.push("----".to_owned());
command.doc_comments.push(String::new());
}
command.doc_comments.extend(command_docs);
Ok(CommandResponse { command, response })
}
}
#[derive(Default, Clone)]
struct DocSection {
header: String,
contents: Vec<String>,
}
impl DocSection {
fn new(header: String) -> Self {
Self {
header,
contents: Vec::new(),
}
}
}