use anyhow::{Context, Result};
use std::path::Path;
use syn::{Attribute, Fields, Meta, Type};
#[derive(Debug, Clone)]
pub struct MessageType {
pub name: String,
pub docs: Vec<String>,
pub fields: Vec<Field>,
}
#[derive(Debug, Clone)]
#[allow(dead_code)] pub struct Field {
pub name: String,
pub ty: FieldType,
pub docs: Vec<String>,
pub optional: bool,
pub serde_rename: Option<String>,
}
#[derive(Debug, Clone, PartialEq)]
#[allow(dead_code)] pub enum FieldType {
String,
Bool,
I8,
I16,
I32,
I64,
U8,
U16,
U32,
U64,
F32,
F64,
Vec(Box<FieldType>),
Option(Box<FieldType>),
Custom(String),
}
impl MessageType {
pub fn parse_file(path: &Path) -> Result<Vec<MessageType>> {
let content =
std::fs::read_to_string(path).with_context(|| format!("Failed to read {path:?}"))?;
Self::parse_source(&content)
}
pub fn parse_source(source: &str) -> Result<Vec<MessageType>> {
let file = syn::parse_file(source).context("Failed to parse Rust source")?;
let mut messages = Vec::new();
for item in &file.items {
if let syn::Item::Struct(s) = item
&& is_message_struct(&s.attrs)
{
messages.push(Self::from_struct(s)?);
}
}
Ok(messages)
}
fn from_struct(s: &syn::ItemStruct) -> Result<MessageType> {
let name = s.ident.to_string();
let docs = extract_docs(&s.attrs);
let fields = match &s.fields {
Fields::Named(fields) => fields
.named
.iter()
.map(Field::from_syn_field)
.collect::<Result<Vec<_>>>()?,
Fields::Unnamed(_) => {
anyhow::bail!("Tuple structs are not supported for message generation")
}
Fields::Unit => Vec::new(),
};
Ok(MessageType { name, docs, fields })
}
}
impl Field {
fn from_syn_field(f: &syn::Field) -> Result<Field> {
let name = f
.ident
.as_ref()
.context("Field must have a name")?
.to_string();
let docs = extract_docs(&f.attrs);
let serde_rename = extract_serde_rename(&f.attrs);
let (ty, optional) = parse_field_type(&f.ty)?;
Ok(Field {
name,
ty,
docs,
optional,
serde_rename,
})
}
}
fn is_message_struct(attrs: &[Attribute]) -> bool {
for attr in attrs {
if let Meta::List(meta_list) = &attr.meta
&& meta_list.path.is_ident("derive")
{
let tokens = meta_list.tokens.to_string();
if tokens.contains("Message")
|| (tokens.contains("Serialize") && tokens.contains("Deserialize"))
{
return true;
}
}
}
false
}
fn extract_docs(attrs: &[Attribute]) -> Vec<String> {
let mut docs = Vec::new();
for attr in attrs {
if attr.path().is_ident("doc")
&& let Meta::NameValue(meta) = &attr.meta
&& let syn::Expr::Lit(expr_lit) = &meta.value
&& let syn::Lit::Str(lit_str) = &expr_lit.lit
{
let doc = lit_str.value();
let doc = doc.trim();
if !doc.is_empty() {
docs.push(doc.to_string());
}
}
}
docs
}
fn extract_serde_rename(attrs: &[Attribute]) -> Option<String> {
for attr in attrs {
if attr.path().is_ident("serde")
&& let Meta::List(meta_list) = &attr.meta
{
let nested = meta_list.tokens.to_string();
if let Some(rename) = nested.strip_prefix("rename = \"")
&& let Some(end) = rename.find('"')
{
return Some(rename[..end].to_string());
}
}
}
None
}
fn parse_field_type(ty: &Type) -> Result<(FieldType, bool)> {
match ty {
Type::Path(type_path) => {
let path = &type_path.path;
if path.is_ident("String") {
return Ok((FieldType::String, false));
}
if path.is_ident("bool") {
return Ok((FieldType::Bool, false));
}
if path.is_ident("i8") {
return Ok((FieldType::I8, false));
}
if path.is_ident("i16") {
return Ok((FieldType::I16, false));
}
if path.is_ident("i32") {
return Ok((FieldType::I32, false));
}
if path.is_ident("i64") {
return Ok((FieldType::I64, false));
}
if path.is_ident("u8") {
return Ok((FieldType::U8, false));
}
if path.is_ident("u16") {
return Ok((FieldType::U16, false));
}
if path.is_ident("u32") {
return Ok((FieldType::U32, false));
}
if path.is_ident("u64") {
return Ok((FieldType::U64, false));
}
if path.is_ident("f32") {
return Ok((FieldType::F32, false));
}
if path.is_ident("f64") {
return Ok((FieldType::F64, false));
}
if let Some(segment) = path.segments.last() {
let ident = &segment.ident;
if ident == "Vec"
&& let syn::PathArguments::AngleBracketed(args) = &segment.arguments
&& let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first()
{
let (inner_field_type, _) = parse_field_type(inner_ty)?;
return Ok((FieldType::Vec(Box::new(inner_field_type)), false));
}
if ident == "Option"
&& let syn::PathArguments::AngleBracketed(args) = &segment.arguments
&& let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first()
{
let (inner_field_type, _) = parse_field_type(inner_ty)?;
return Ok((inner_field_type, true)); }
return Ok((FieldType::Custom(ident.to_string()), false));
}
Ok((
FieldType::Custom(
path.segments
.last()
.map(|s| s.ident.to_string())
.unwrap_or_else(|| "Unknown".to_string()),
),
false,
))
}
_ => anyhow::bail!("Unsupported field type: {}", quote::quote!(#ty)),
}
}
#[cfg(test)]
mod tests {
#![allow(non_snake_case)]
use super::*;
#[test]
fn parse_source___extracts_message_struct() {
let source = r#"
use serde::{Serialize, Deserialize};
/// A test message.
#[derive(Serialize, Deserialize)]
pub struct TestMessage {
/// The name field.
pub name: String,
pub count: i32,
}
"#;
let messages = MessageType::parse_source(source).unwrap();
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].name, "TestMessage");
assert_eq!(messages[0].docs, vec!["A test message."]);
assert_eq!(messages[0].fields.len(), 2);
assert_eq!(messages[0].fields[0].name, "name");
assert_eq!(messages[0].fields[0].ty, FieldType::String);
assert_eq!(messages[0].fields[0].docs, vec!["The name field."]);
}
#[test]
fn parse_source___handles_optional_fields() {
let source = r#"
use serde::{Serialize, Deserialize};
#[derive(Serialize, Deserialize)]
pub struct TestMessage {
pub required: String,
pub optional: Option<String>,
}
"#;
let messages = MessageType::parse_source(source).unwrap();
assert!(!messages[0].fields[0].optional);
assert!(messages[0].fields[1].optional);
assert_eq!(messages[0].fields[1].ty, FieldType::String); }
#[test]
fn parse_source___handles_serde_rename() {
let source = r#"
use serde::{Serialize, Deserialize};
#[derive(Serialize, Deserialize)]
pub struct TestMessage {
#[serde(rename = "old_name")]
pub new_name: String,
}
"#;
let messages = MessageType::parse_source(source).unwrap();
assert_eq!(messages[0].fields[0].name, "new_name");
assert_eq!(
messages[0].fields[0].serde_rename,
Some("old_name".to_string())
);
}
#[test]
fn parse_source___handles_nested_vec() {
let source = r#"
use serde::{Serialize, Deserialize};
#[derive(Serialize, Deserialize)]
pub struct TestMessage {
pub items: Vec<Vec<String>>,
}
"#;
let messages = MessageType::parse_source(source).unwrap();
assert_eq!(
messages[0].fields[0].ty,
FieldType::Vec(Box::new(FieldType::Vec(Box::new(FieldType::String))))
);
}
#[test]
fn parse_source___handles_custom_types() {
let source = r#"
use serde::{Serialize, Deserialize};
#[derive(Serialize, Deserialize)]
pub struct Address {
pub street: String,
pub city: String,
}
#[derive(Serialize, Deserialize)]
pub struct Person {
pub name: String,
pub address: Address,
}
"#;
let messages = MessageType::parse_source(source).unwrap();
assert_eq!(messages.len(), 2);
assert_eq!(messages[1].name, "Person");
assert_eq!(
messages[1].fields[1].ty,
FieldType::Custom("Address".to_string())
);
}
#[test]
fn parse_source___handles_vec_of_custom_types() {
let source = r#"
use serde::{Serialize, Deserialize};
#[derive(Serialize, Deserialize)]
pub struct Item {
pub id: u64,
}
#[derive(Serialize, Deserialize)]
pub struct Container {
pub items: Vec<Item>,
}
"#;
let messages = MessageType::parse_source(source).unwrap();
assert_eq!(
messages[1].fields[0].ty,
FieldType::Vec(Box::new(FieldType::Custom("Item".to_string())))
);
}
#[test]
fn parse_source___handles_option_of_vec() {
let source = r#"
use serde::{Serialize, Deserialize};
#[derive(Serialize, Deserialize)]
pub struct TestMessage {
pub tags: Option<Vec<String>>,
}
"#;
let messages = MessageType::parse_source(source).unwrap();
assert!(messages[0].fields[0].optional);
assert_eq!(
messages[0].fields[0].ty,
FieldType::Vec(Box::new(FieldType::String))
);
}
#[test]
fn parse_source___ignores_non_message_structs() {
let source = r#"
use serde::{Serialize, Deserialize};
// Should be ignored - no derives
pub struct Ignored {
pub data: String,
}
// Should be included
#[derive(Serialize, Deserialize)]
pub struct Included {
pub data: String,
}
"#;
let messages = MessageType::parse_source(source).unwrap();
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].name, "Included");
}
#[test]
fn parse_source___handles_all_integer_types() {
let source = r#"
use serde::{Serialize, Deserialize};
#[derive(Serialize, Deserialize)]
pub struct AllInts {
pub i8_val: i8,
pub i16_val: i16,
pub i32_val: i32,
pub i64_val: i64,
pub u8_val: u8,
pub u16_val: u16,
pub u32_val: u32,
pub u64_val: u64,
}
"#;
let messages = MessageType::parse_source(source).unwrap();
assert_eq!(messages[0].fields[0].ty, FieldType::I8);
assert_eq!(messages[0].fields[1].ty, FieldType::I16);
assert_eq!(messages[0].fields[2].ty, FieldType::I32);
assert_eq!(messages[0].fields[3].ty, FieldType::I64);
assert_eq!(messages[0].fields[4].ty, FieldType::U8);
assert_eq!(messages[0].fields[5].ty, FieldType::U16);
assert_eq!(messages[0].fields[6].ty, FieldType::U32);
assert_eq!(messages[0].fields[7].ty, FieldType::U64);
}
#[test]
fn parse_source___handles_multiline_docs() {
let source = r#"
use serde::{Serialize, Deserialize};
/// First line.
/// Second line.
/// Third line.
#[derive(Serialize, Deserialize)]
pub struct TestMessage {
/// Field doc line 1.
/// Field doc line 2.
pub field: String,
}
"#;
let messages = MessageType::parse_source(source).unwrap();
assert_eq!(messages[0].docs.len(), 3);
assert_eq!(messages[0].docs[0], "First line.");
assert_eq!(messages[0].docs[1], "Second line.");
assert_eq!(messages[0].docs[2], "Third line.");
assert_eq!(messages[0].fields[0].docs.len(), 2);
assert_eq!(messages[0].fields[0].docs[0], "Field doc line 1.");
assert_eq!(messages[0].fields[0].docs[1], "Field doc line 2.");
}
#[test]
fn parse_source___handles_message_derive() {
let source = r#"
#[derive(Message)]
pub struct TestMessage {
pub data: String,
}
"#;
let messages = MessageType::parse_source(source).unwrap();
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].name, "TestMessage");
}
}