use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::Type;
use super::parse::StructInfo;
pub fn generate_from_row(info: &StructInfo) -> TokenStream {
let mut field_assignments = Vec::new();
for field in &info.fields {
let rust_ident = format_ident!("{}", field.rust_name);
let sql_name = &field.sql_name;
if field.is_ignored {
field_assignments.push(quote! {
#rust_ident: Default::default()
});
continue;
}
let getter = get_row_getter(&field.ty, sql_name);
field_assignments.push(quote! {
#rust_ident: #getter
});
}
quote! {
fn from_row(row: &statiq::row::OdbcRow) -> Result<Self, statiq::error::SqlError> {
Ok(Self {
#(#field_assignments),*
})
}
}
}
fn get_row_getter(ty: &Type, col: &str) -> TokenStream {
let ty_str = quote!(#ty).to_string().replace(' ', "");
if ty_str.starts_with("Option<") {
let inner = ty_str
.trim_start_matches("Option<")
.trim_end_matches('>');
return optional_getter(inner, col);
}
required_getter(&ty_str, col)
}
fn required_getter(ty_str: &str, col: &str) -> TokenStream {
match ty_str {
"bool" => quote! { row.get_bool(#col)? },
"u8" => quote! { row.get_u8(#col)? },
"i8" => quote! { row.get_i16(#col)? as i8 },
"i16" => quote! { row.get_i16(#col)? },
"i32" => quote! { row.get_i32(#col)? },
"i64" => quote! { row.get_i64(#col)? },
"f32" => quote! { row.get_f32(#col)? },
"f64" => quote! { row.get_f64(#col)? },
"Decimal"
| "rust_decimal::Decimal"
| "rust_decimal::decimal::Decimal" => quote! { row.get_decimal(#col)? },
"String" => quote! { row.get_string(#col)? },
"Vec<u8>" => quote! { row.get_bytes(#col)? },
"NaiveDate"
| "chrono::NaiveDate" => quote! { row.get_naive_date(#col)? },
"NaiveTime"
| "chrono::NaiveTime" => quote! { row.get_naive_time(#col)? },
"DateTime<Utc>"
| "chrono::DateTime<Utc>"
| "chrono::DateTime<chrono::Utc>" => quote! { row.get_datetime(#col)? },
"DateTime<FixedOffset>"
| "chrono::DateTime<FixedOffset>"
| "chrono::DateTime<chrono::FixedOffset>" => quote! { row.get_datetime_offset(#col)? },
"Uuid"
| "uuid::Uuid" => quote! { row.get_uuid(#col)? },
_ => quote! {
row.get_string(#col)?.parse().map_err(|e: Box<dyn std::error::Error + Send + Sync>| {
statiq::error::SqlError::row_mapping(#col, e.to_string())
})?
},
}
}
fn optional_getter(inner: &str, col: &str) -> TokenStream {
match inner {
"bool" => quote! { row.get_bool_opt(#col)? },
"u8" => quote! { row.get_u8_opt(#col)? },
"i8" => quote! { row.get_i16_opt(#col)?.map(|v| v as i8) },
"i16" => quote! { row.get_i16_opt(#col)? },
"i32" => quote! { row.get_i32_opt(#col)? },
"i64" => quote! { row.get_i64_opt(#col)? },
"f32" => quote! { row.get_f32_opt(#col)? },
"f64" => quote! { row.get_f64_opt(#col)? },
"Decimal"
| "rust_decimal::Decimal"
| "rust_decimal::decimal::Decimal" => quote! { row.get_decimal_opt(#col)? },
"String" => quote! { row.get_string_opt(#col)? },
"Vec<u8>" => quote! { row.get_bytes_opt(#col)? },
"NaiveDate"
| "chrono::NaiveDate" => quote! { row.get_naive_date_opt(#col)? },
"NaiveTime"
| "chrono::NaiveTime" => quote! { row.get_naive_time_opt(#col)? },
"DateTime<Utc>"
| "chrono::DateTime<Utc>"
| "chrono::DateTime<chrono::Utc>" => quote! { row.get_datetime_opt(#col)? },
"DateTime<FixedOffset>"
| "chrono::DateTime<FixedOffset>"
| "chrono::DateTime<chrono::FixedOffset>" => quote! { row.get_datetime_offset_opt(#col)? },
"Uuid"
| "uuid::Uuid" => quote! { row.get_uuid_opt(#col)? },
_ => quote! {
row.get_string_opt(#col)?.map(|s| s.parse().unwrap())
},
}
}