use alef_core::ir::{PrimitiveType, TypeRef};
pub fn resolve_type(ty: &syn::Type) -> TypeRef {
match ty {
syn::Type::Path(type_path) => resolve_path_type(type_path),
syn::Type::Reference(type_ref) => resolve_reference_type(type_ref),
syn::Type::Tuple(tuple) => {
if tuple.elems.is_empty() {
TypeRef::Unit
} else {
let parts: Vec<String> = tuple.elems.iter().map(type_to_string).collect();
TypeRef::Named(format!("({})", parts.join(", ")))
}
}
syn::Type::Slice(slice) => resolve_slice_type(&slice.elem),
syn::Type::TraitObject(trait_obj) => {
if let Some(syn::TypeParamBound::Trait(trait_bound)) = trait_obj.bounds.first() {
if let Some(seg) = trait_bound.path.segments.last() {
return TypeRef::Named(seg.ident.to_string());
}
}
TypeRef::Named("DynObject".to_string())
}
syn::Type::ImplTrait(impl_trait) => {
if let Some(syn::TypeParamBound::Trait(trait_bound)) = impl_trait.bounds.first() {
if let Some(seg) = trait_bound.path.segments.last() {
let trait_name = seg.ident.to_string();
if trait_name == "Into" || trait_name == "AsRef" {
if let Some(inner_ty) = extract_single_generic_arg(seg) {
return inner_ty;
}
}
return TypeRef::Named(trait_name);
}
}
TypeRef::Named("ImplTrait".to_string())
}
_ => TypeRef::Named(type_to_string(ty)),
}
}
pub fn type_to_string(ty: &syn::Type) -> String {
use quote::ToTokens;
ty.to_token_stream().to_string().replace(' ', "")
}
fn resolve_path_type(type_path: &syn::TypePath) -> TypeRef {
let segment = match type_path.path.segments.last() {
Some(seg) => seg,
None => return TypeRef::Named(String::new()),
};
let ident = segment.ident.to_string();
if type_path.path.segments.len() >= 2 {
let full_path: String = type_path
.path
.segments
.iter()
.map(|s| s.ident.to_string())
.collect::<Vec<_>>()
.join("::");
if full_path == "serde_json::Value" {
return TypeRef::Json;
}
}
match ident.as_str() {
"bool" => TypeRef::Primitive(PrimitiveType::Bool),
"u8" => TypeRef::Primitive(PrimitiveType::U8),
"u16" => TypeRef::Primitive(PrimitiveType::U16),
"u32" => TypeRef::Primitive(PrimitiveType::U32),
"u64" => TypeRef::Primitive(PrimitiveType::U64),
"i8" => TypeRef::Primitive(PrimitiveType::I8),
"i16" => TypeRef::Primitive(PrimitiveType::I16),
"i32" => TypeRef::Primitive(PrimitiveType::I32),
"i64" => TypeRef::Primitive(PrimitiveType::I64),
"f32" => TypeRef::Primitive(PrimitiveType::F32),
"f64" => TypeRef::Primitive(PrimitiveType::F64),
"usize" => TypeRef::Primitive(PrimitiveType::Usize),
"isize" => TypeRef::Primitive(PrimitiveType::Isize),
"String" => TypeRef::String,
"char" => TypeRef::Char,
"PathBuf" => TypeRef::Path,
"Bytes" => TypeRef::Bytes,
"JsonValue" => TypeRef::Json,
"Vec" => {
let inner = extract_single_generic_arg(segment);
match inner {
Some(inner_ty) => {
if matches!(inner_ty, TypeRef::Primitive(PrimitiveType::U8)) {
TypeRef::Bytes
} else {
TypeRef::Vec(Box::new(inner_ty))
}
}
None => TypeRef::Vec(Box::new(TypeRef::Named("unknown".into()))),
}
}
"Option" => {
let inner = extract_single_generic_arg(segment).unwrap_or(TypeRef::Named("unknown".into()));
TypeRef::Optional(Box::new(inner))
}
"HashMap" | "BTreeMap" => {
let (k, v) = extract_two_generic_args(segment);
TypeRef::Map(Box::new(k), Box::new(v))
}
"Result" => extract_single_generic_arg(segment).unwrap_or(TypeRef::Named("unknown".into())),
"Box" | "Arc" | "Rc" => extract_single_generic_arg(segment).unwrap_or(TypeRef::Named("unknown".into())),
"Duration" => TypeRef::Duration,
"SecretString" => TypeRef::String,
"Cow" => {
extract_single_generic_arg(segment).unwrap_or(TypeRef::String)
}
other => TypeRef::Named(other.to_string()),
}
}
fn resolve_reference_type(type_ref: &syn::TypeReference) -> TypeRef {
let inner = &*type_ref.elem;
match inner {
syn::Type::Path(p) => {
if let Some(seg) = p.path.segments.last() {
match seg.ident.to_string().as_str() {
"str" => TypeRef::String,
"Path" => TypeRef::Path,
_ => resolve_type(inner),
}
} else {
resolve_type(inner)
}
}
syn::Type::Slice(slice) => resolve_slice_type(&slice.elem),
_ => resolve_type(inner),
}
}
fn resolve_slice_type(elem: &syn::Type) -> TypeRef {
let inner = resolve_type(elem);
if matches!(inner, TypeRef::Primitive(PrimitiveType::U8)) {
TypeRef::Bytes
} else {
TypeRef::Vec(Box::new(inner))
}
}
pub fn extract_single_generic_arg_syn(segment: &syn::PathSegment) -> Option<Box<syn::Type>> {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
for arg in &args.args {
if let syn::GenericArgument::Type(ty) = arg {
return Some(Box::new(ty.clone()));
}
}
}
None
}
fn extract_single_generic_arg(segment: &syn::PathSegment) -> Option<TypeRef> {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
for arg in &args.args {
if let syn::GenericArgument::Type(ty) = arg {
return Some(resolve_type(ty));
}
}
}
None
}
fn extract_two_generic_args(segment: &syn::PathSegment) -> (TypeRef, TypeRef) {
let mut types = Vec::new();
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
for arg in &args.args {
if let syn::GenericArgument::Type(ty) = arg {
types.push(resolve_type(ty));
}
}
}
let k = types.first().cloned().unwrap_or(TypeRef::Named("unknown".into()));
let v = types.get(1).cloned().unwrap_or(TypeRef::Named("unknown".into()));
(k, v)
}
pub fn is_option_type(ty: &syn::Type) -> Option<TypeRef> {
if let syn::Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
if segment.ident == "Option" {
return extract_single_generic_arg(segment);
}
}
}
None
}
pub fn extract_result_error_type(ty: &syn::Type) -> Option<String> {
if let syn::Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
if segment.ident == "Result" {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
let type_args: Vec<_> = args
.args
.iter()
.filter_map(|a| {
if let syn::GenericArgument::Type(ty) = a {
Some(ty)
} else {
None
}
})
.collect();
if type_args.len() >= 2 {
return Some(type_to_string(type_args[1]));
}
if !type_args.is_empty() {
return Some("anyhow::Error".to_string());
}
}
}
}
}
None
}
pub fn unwrap_result_type(ty: &syn::Type) -> Option<&syn::Type> {
if let syn::Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
if segment.ident == "Result" {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
for arg in &args.args {
if let syn::GenericArgument::Type(inner_ty) = arg {
return Some(inner_ty);
}
}
}
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
fn parse_type(s: &str) -> syn::Type {
syn::parse_str(s).unwrap()
}
#[test]
fn test_primitives() {
assert_eq!(
resolve_type(&parse_type("bool")),
TypeRef::Primitive(PrimitiveType::Bool)
);
assert_eq!(resolve_type(&parse_type("u32")), TypeRef::Primitive(PrimitiveType::U32));
assert_eq!(resolve_type(&parse_type("f64")), TypeRef::Primitive(PrimitiveType::F64));
assert_eq!(
resolve_type(&parse_type("usize")),
TypeRef::Primitive(PrimitiveType::Usize)
);
}
#[test]
fn test_string_types() {
assert_eq!(resolve_type(&parse_type("String")), TypeRef::String);
assert_eq!(resolve_type(&parse_type("&str")), TypeRef::String);
}
#[test]
fn test_bytes_types() {
assert_eq!(resolve_type(&parse_type("Vec<u8>")), TypeRef::Bytes);
assert_eq!(resolve_type(&parse_type("&[u8]")), TypeRef::Bytes);
assert_eq!(resolve_type(&parse_type("Bytes")), TypeRef::Bytes);
}
#[test]
fn test_vec() {
assert_eq!(
resolve_type(&parse_type("Vec<String>")),
TypeRef::Vec(Box::new(TypeRef::String))
);
}
#[test]
fn test_option() {
assert_eq!(
resolve_type(&parse_type("Option<u64>")),
TypeRef::Optional(Box::new(TypeRef::Primitive(PrimitiveType::U64)))
);
}
#[test]
fn test_map() {
assert_eq!(
resolve_type(&parse_type("HashMap<String, u32>")),
TypeRef::Map(
Box::new(TypeRef::String),
Box::new(TypeRef::Primitive(PrimitiveType::U32))
)
);
}
#[test]
fn test_path_types() {
assert_eq!(resolve_type(&parse_type("PathBuf")), TypeRef::Path);
assert_eq!(resolve_type(&parse_type("&Path")), TypeRef::Path);
}
#[test]
fn test_unit() {
assert_eq!(resolve_type(&parse_type("()")), TypeRef::Unit);
}
#[test]
fn test_json() {
assert_eq!(resolve_type(&parse_type("serde_json::Value")), TypeRef::Json);
assert_eq!(resolve_type(&parse_type("JsonValue")), TypeRef::Json);
}
#[test]
fn test_box_arc_unwrap() {
assert_eq!(resolve_type(&parse_type("Box<String>")), TypeRef::String);
assert_eq!(
resolve_type(&parse_type("Arc<u32>")),
TypeRef::Primitive(PrimitiveType::U32)
);
}
#[test]
fn test_result_unwrap() {
assert_eq!(resolve_type(&parse_type("Result<String, Error>")), TypeRef::String);
}
#[test]
fn test_named() {
assert_eq!(
resolve_type(&parse_type("MyCustomType")),
TypeRef::Named("MyCustomType".into())
);
}
#[test]
fn test_trait_object() {
assert_eq!(
resolve_type(&parse_type("dyn MyTrait")),
TypeRef::Named("MyTrait".into())
);
}
#[test]
fn test_box_dyn_trait() {
assert_eq!(
resolve_type(&parse_type("Box<dyn MyTrait>")),
TypeRef::Named("MyTrait".into())
);
}
#[test]
fn test_duration() {
assert_eq!(resolve_type(&parse_type("Duration")), TypeRef::Duration);
}
#[test]
fn test_secret_string() {
assert_eq!(resolve_type(&parse_type("SecretString")), TypeRef::String);
}
#[test]
fn test_impl_trait() {
assert_eq!(resolve_type(&parse_type("impl Into<String>")), TypeRef::String);
}
#[test]
fn test_extract_result_error() {
let ty = parse_type("Result<String, MyError>");
assert_eq!(extract_result_error_type(&ty), Some("MyError".into()));
}
}