use std::collections::BTreeSet;
use cratestack_core::{Field, Model, TypeArity};
use quote::quote;
use crate::shared::{ident, scalar_model_fields, to_snake_case};
pub(crate) fn generate_pg_from_row_impl(
model: &Model,
model_names: &BTreeSet<&str>,
enum_names: &BTreeSet<&str>,
) -> proc_macro2::TokenStream {
let model_ident = ident(&model.name);
let scalar_fields = scalar_model_fields(model, model_names);
let row_fields = scalar_fields
.iter()
.map(|field| row_field_tokens(field, enum_names));
let partial_row_fields = scalar_fields
.iter()
.map(|field| partial_row_field_tokens(field, enum_names));
quote! {
impl<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow> for #model_ident {
fn from_row(row: &'r sqlx::postgres::PgRow) -> Result<Self, sqlx::Error> {
use sqlx::Row;
Ok(Self {
#(#row_fields)*
})
}
}
impl ::cratestack::FromPartialPgRow for #model_ident {
fn decode_partial_pg_row(
row: &sqlx::postgres::PgRow,
selected: &[&str],
) -> ::std::result::Result<Self, sqlx::Error> {
use sqlx::Row;
Ok(Self {
#(#partial_row_fields)*
})
}
}
}
}
fn partial_row_field_tokens(
field: &Field,
enum_names: &BTreeSet<&str>,
) -> proc_macro2::TokenStream {
let field_ident = ident(&field.name);
let sql_name = to_snake_case(&field.name);
let decode_expr = row_field_decode_expr(field, enum_names);
quote! {
#field_ident: if selected.iter().any(|c| *c == #sql_name) {
#decode_expr
} else {
::std::default::Default::default()
},
}
}
fn row_field_decode_expr(field: &Field, enum_names: &BTreeSet<&str>) -> proc_macro2::TokenStream {
let field_name = &field.name;
if !enum_names.contains(field.ty.name.as_str()) {
return quote! { row.try_get(#field_name)? };
}
let enum_ident = ident(&field.ty.name);
let parse_error = |error: proc_macro2::TokenStream| {
quote! {
sqlx::Error::Decode(Box::new(::std::io::Error::new(
::std::io::ErrorKind::InvalidData,
#error,
)))
}
};
match field.ty.arity {
TypeArity::Required => {
let decode_error = parse_error(quote! { error });
quote! {
{
let raw: String = row.try_get(#field_name)?;
raw.parse::<super::types::#enum_ident>().map_err(|error| #decode_error)?
}
}
}
TypeArity::Optional => {
let decode_error = parse_error(quote! { error });
quote! {
{
let raw: Option<String> = row.try_get(#field_name)?;
raw.map(|value| value.parse::<super::types::#enum_ident>().map_err(|error| #decode_error)).transpose()?
}
}
}
TypeArity::List => {
let decode_error = parse_error(quote! { error });
quote! {
{
let raw: Vec<String> = row.try_get(#field_name)?;
raw.into_iter()
.map(|value| value.parse::<super::types::#enum_ident>().map_err(|error| #decode_error))
.collect::<Result<Vec<_>, sqlx::Error>>()?
}
}
}
}
}
fn row_field_tokens(field: &Field, enum_names: &BTreeSet<&str>) -> proc_macro2::TokenStream {
let field_ident = ident(&field.name);
let field_name = &field.name;
if !enum_names.contains(field.ty.name.as_str()) {
return quote! {
#field_ident: row.try_get(#field_name)?,
};
}
let enum_ident = ident(&field.ty.name);
let parse_error = |error: proc_macro2::TokenStream| {
quote! {
sqlx::Error::Decode(Box::new(::std::io::Error::new(
::std::io::ErrorKind::InvalidData,
#error,
)))
}
};
match field.ty.arity {
TypeArity::Required => {
let decode_error = parse_error(quote! { error });
quote! {
#field_ident: {
let raw: String = row.try_get(#field_name)?;
raw.parse::<super::types::#enum_ident>().map_err(|error| #decode_error)?
},
}
}
TypeArity::Optional => {
let decode_error = parse_error(quote! { error });
quote! {
#field_ident: {
let raw: Option<String> = row.try_get(#field_name)?;
raw.map(|value| value.parse::<super::types::#enum_ident>().map_err(|error| #decode_error)).transpose()?
},
}
}
TypeArity::List => {
let decode_error = parse_error(quote! { error });
quote! {
#field_ident: {
let raw: Vec<String> = row.try_get(#field_name)?;
raw.into_iter()
.map(|value| value.parse::<super::types::#enum_ident>().map_err(|error| #decode_error))
.collect::<Result<Vec<_>, sqlx::Error>>()?
},
}
}
}
}