use std::collections::HashMap;
use syn::{Error, File, Item, ItemEnum, ItemStruct, ItemTrait, Result, TraitItem};
#[derive(Debug)]
pub struct ServiceDefinition {
pub service_trait: ItemTrait,
pub types: HashMap<String, ServiceType>,
pub imports: Vec<syn::ItemUse>,
}
#[derive(Debug)]
pub enum ServiceType {
Struct(ItemStruct),
Enum(ItemEnum),
}
impl ServiceDefinition {
pub fn parse(content: &str) -> Result<Self> {
let ast: File = syn::parse_str(content)?;
let mut service_trait = None;
let mut types = HashMap::new();
let mut imports = Vec::new();
for item in ast.items {
match item {
Item::Trait(trait_item) => {
if has_rpcnet_service_attribute(&trait_item) {
if service_trait.is_some() {
return Err(Error::new_spanned(
&trait_item,
"Multiple service traits found. Only one service per file is supported."
));
}
validate_trait_methods(&trait_item)?;
service_trait = Some(trait_item);
}
}
Item::Struct(struct_item) => {
types.insert(
struct_item.ident.to_string(),
ServiceType::Struct(struct_item),
);
}
Item::Enum(enum_item) => {
types.insert(enum_item.ident.to_string(), ServiceType::Enum(enum_item));
}
Item::Use(use_item) => {
imports.push(use_item);
}
_ => {} }
}
let service_trait = service_trait.ok_or_else(|| {
syn::Error::new(
proc_macro2::Span::call_site(),
"No service trait found. Add #[rpc_trait] attribute to your trait.",
)
})?;
Ok(ServiceDefinition {
service_trait,
types,
imports,
})
}
pub fn service_name(&self) -> &syn::Ident {
&self.service_trait.ident
}
pub fn methods(&self) -> Vec<&syn::TraitItemFn> {
self.service_trait
.items
.iter()
.filter_map(|item| {
if let TraitItem::Fn(method) = item {
Some(method)
} else {
None
}
})
.collect()
}
}
fn has_rpcnet_service_attribute(trait_item: &ItemTrait) -> bool {
trait_item.attrs.iter().any(|attr| {
if attr.path().is_ident("rpc_trait") {
return true;
}
if attr.path().is_ident("service") {
return true;
}
if attr.path().segments.len() == 2 {
let segments: Vec<_> = attr.path().segments.iter().collect();
segments[0].ident == "rpcnet" && segments[1].ident == "service"
} else {
false
}
})
}
fn validate_trait_methods(trait_item: &ItemTrait) -> Result<()> {
for item in &trait_item.items {
if let TraitItem::Fn(method) = item {
if method.sig.asyncness.is_none() {
return Err(Error::new_spanned(
&method.sig,
"Service methods must be async",
));
}
if method.sig.inputs.is_empty() {
return Err(Error::new_spanned(
&method.sig,
"Service methods must have &self as first parameter",
));
}
match &method.sig.output {
syn::ReturnType::Type(_, ty) => {
let type_str = quote::quote!(#ty).to_string();
if !type_str.contains("Result") {
return Err(Error::new_spanned(
ty,
"Service methods must return Result<Response, Error>",
));
}
}
syn::ReturnType::Default => {
return Err(Error::new_spanned(
&method.sig,
"Service methods must return Result<Response, Error>",
));
}
}
}
}
Ok(())
}