odbc-api-derive 12.1.1

Derive macros for odbc-api crate
Documentation
use proc_macro2::TokenStream;
use quote::quote;
use syn::{Data, DeriveInput, Fields};

pub fn expand(input: DeriveInput) -> TokenStream {
    let struct_name = input.ident;

    let Data::Struct(struct_data) = input.data else {
        return quote! { compile_error!("Fetch can only be derived for structs"); };
    };

    let Fields::Named(named_fields) = struct_data.fields else {
        return quote! { compile_error!("Fetch can only be derived for structs with named fields"); };
    };
    let fields = named_fields.named;

    let field_names = || {
        fields.iter().map(|f| {
            f.ident
                .as_ref()
                .expect("All fields in a struct with named fields must be named.")
        })
    };

    let bindings = field_names().enumerate().map(|(index, field_name)| {
        let col_index = (index + 1) as u16;
        quote! {
            odbc_api::buffers::FetchRowMember::bind_to_col(
                &mut self.#field_name,
                #col_index,
                &mut cursor
            )?;
        }
    });

    let find_truncation = field_names().enumerate().map(|(index, field_name)| {
        quote! {
            let maybe_truncation = odbc_api::buffers::FetchRowMember::find_truncation(
                &self.#field_name,
                #index,
            );
            if let Some(truncation_info) = maybe_truncation {
                return Some(truncation_info);
            }
        }
    });

    let expanded = quote! {
        unsafe impl odbc_api::buffers::FetchRow for #struct_name {

            unsafe fn bind_columns_to_cursor(
                &mut self,
                mut cursor: odbc_api::handles::StatementRef<'_>
            ) -> std::result::Result<(), odbc_api::Error> {
                #(#bindings)*
                Ok(())
            }

            fn find_truncation(&self) -> std::option::Option<odbc_api::TruncationInfo> {
                #(#find_truncation)*
                None
            }
        }
    };

    expanded
}

#[cfg(test)]
mod tests {
    use proc_macro2::TokenStream;
    use quote::quote;
    use syn::{DeriveInput, parse2};

    use super::expand;

    #[test]
    fn derive_fetch_for_row() {
        let input = given(quote! {
            struct MyRow {
                a: i64,
                b: VarCharArray<50>,
            }
        });

        let output = expand(input);

        let expected = quote! {
            unsafe impl odbc_api::buffers::FetchRow for MyRow {

                unsafe fn bind_columns_to_cursor(
                    &mut self,
                    mut cursor: odbc_api::handles::StatementRef<'_>
                ) -> std::result::Result<(), odbc_api::Error> {
                    odbc_api::buffers::FetchRowMember::bind_to_col(
                        &mut self.a,
                        1u16,
                        &mut cursor
                    )?;
                    odbc_api::buffers::FetchRowMember::bind_to_col(
                        &mut self.b,
                        2u16,
                        &mut cursor
                    )?;
                    Ok(())
                }

                fn find_truncation(&self) -> std::option::Option<odbc_api::TruncationInfo> {
                    let maybe_truncation = odbc_api::buffers::FetchRowMember::find_truncation(
                        &self.a,
                        0usize,
                    );
                    if let Some(truncation_info) = maybe_truncation {
                        return Some(truncation_info);
                    }
                    let maybe_truncation = odbc_api::buffers::FetchRowMember::find_truncation(
                        &self.b,
                        1usize,
                    );
                    if let Some(truncation_info) = maybe_truncation {
                        return Some(truncation_info);
                    }
                    None
                }
            }
        };
        assert_eq!(expected.to_string(), output.to_string());
    }

    #[test]
    fn compiler_error_when_deriving_for_enum() {
        let input = given(quote! {
            enum NotAStruct {}
        });

        let output = expand(input);

        let expected = quote! {
            compile_error!("Fetch can only be derived for structs");
        };
        assert_eq!(expected.to_string(), output.to_string());
    }

    #[test]
    fn compiler_error_when_deriving_for_tuple_struct() {
        let input = given(quote! {
            struct TupleStruct(i64, i64);
        });

        let output = expand(input);

        let expected = quote! {
            compile_error!("Fetch can only be derived for structs with named fields");
        };
        assert_eq!(expected.to_string(), output.to_string());
    }

    fn given(input: TokenStream) -> DeriveInput {
        parse2(input).unwrap()
    }
}