odbc_api_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, DeriveInput};
4
5/// Use this to derive the trait `FetchRow` for structs defined in the application logic.
6///
7/// # Example
8///
9/// ```
10/// use odbc_api_derive::Fetch;
11/// use odbc_api::{Connection, Error, Cursor, parameter::VarCharArray, buffers::RowVec};
12///
13/// #[derive(Default, Clone, Copy, Fetch)]
14/// struct Person {
15///     first_name: VarCharArray<255>,
16///     last_name: VarCharArray<255>,
17/// }
18///
19/// fn send_greetings(conn: &mut Connection) -> Result<(), Error> {
20///     let max_rows_in_batch = 250;
21///     let buffer = RowVec::<Person>::new(max_rows_in_batch);
22///     let mut cursor = conn.execute("SELECT first_name, last_name FROM Persons", (), None)?
23///         .expect("SELECT must yield a result set");
24///     let mut block_cursor = cursor.bind_buffer(buffer)?;
25///
26///     while let Some(batch) = block_cursor.fetch()? {
27///         for person in batch.iter() {
28///             let first = person.first_name.as_str()
29///                 .expect("First name must be UTF-8")
30///                 .expect("First Name must not be NULL");
31///             let last = person.last_name.as_str()
32///                 .expect("Last name must be UTF-8")
33///                 .expect("Last Name must not be NULL");
34///             println!("Hello {first} {last}!")
35///         }
36///     }
37///     Ok(())
38/// }
39/// ```
40#[proc_macro_derive(Fetch)]
41pub fn derive_fetch_row(item: TokenStream) -> TokenStream {
42    let input = parse_macro_input!(item as DeriveInput);
43
44    let struct_name = input.ident;
45
46    let struct_data = match input.data {
47        syn::Data::Struct(struct_data) => struct_data,
48        _ => panic!("Fetch can only be derived for structs"),
49    };
50
51    let fields = struct_data.fields;
52
53    let bindings = fields.iter().enumerate().map(|(index, field)| {
54        let field_name = field
55            .ident
56            .as_ref()
57            .expect("All struct members must be named");
58        let col_index = (index + 1) as u16;
59        quote! {
60            odbc_api::buffers::FetchRowMember::bind_to_col(
61                &mut self.#field_name,
62                #col_index,
63                &mut cursor
64            )?;
65        }
66    });
67
68    let find_truncation = fields.iter().enumerate().map(|(index, field)| {
69        let field_name = field
70            .ident
71            .as_ref()
72            .expect("All struct members must be named");
73        quote! {
74            let maybe_truncation = odbc_api::buffers::FetchRowMember::find_truncation(
75                &self.#field_name,
76                #index,
77            );
78            if let Some(truncation_info) = maybe_truncation {
79                return Some(truncation_info);
80            }
81        }
82    });
83
84    let expanded = quote! {
85        unsafe impl odbc_api::buffers::FetchRow for #struct_name {
86
87            unsafe fn bind_columns_to_cursor(
88                &mut self,
89                mut cursor: odbc_api::handles::StatementRef<'_>
90            ) -> std::result::Result<(), odbc_api::Error> {
91                #(#bindings)*
92                Ok(())
93            }
94
95            fn find_truncation(&self) -> std::option::Option<odbc_api::TruncationInfo> {
96                #(#find_truncation)*
97                None
98            }
99        }
100    };
101
102    expanded.into()
103}