odbc_api_derive/lib.rs
1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, DeriveInput};
4
5/// Use this to derive the trait `FetchRow` for structs defined in the application logic.
6///
7/// # Example
8///
9/// ```
10/// use odbc_api_derive::Fetch;
11/// use odbc_api::{Connection, Error, Cursor, parameter::VarCharArray, buffers::RowVec};
12///
13/// #[derive(Default, Clone, Copy, Fetch)]
14/// struct Person {
15/// first_name: VarCharArray<255>,
16/// last_name: VarCharArray<255>,
17/// }
18///
19/// fn send_greetings(conn: &mut Connection) -> Result<(), Error> {
20/// let max_rows_in_batch = 250;
21/// let buffer = RowVec::<Person>::new(max_rows_in_batch);
22/// let mut cursor = conn.execute("SELECT first_name, last_name FROM Persons", (), None)?
23/// .expect("SELECT must yield a result set");
24/// let mut block_cursor = cursor.bind_buffer(buffer)?;
25///
26/// while let Some(batch) = block_cursor.fetch()? {
27/// for person in batch.iter() {
28/// let first = person.first_name.as_str()
29/// .expect("First name must be UTF-8")
30/// .expect("First Name must not be NULL");
31/// let last = person.last_name.as_str()
32/// .expect("Last name must be UTF-8")
33/// .expect("Last Name must not be NULL");
34/// println!("Hello {first} {last}!")
35/// }
36/// }
37/// Ok(())
38/// }
39/// ```
40#[proc_macro_derive(Fetch)]
41pub fn derive_fetch_row(item: TokenStream) -> TokenStream {
42 let input = parse_macro_input!(item as DeriveInput);
43
44 let struct_name = input.ident;
45
46 let struct_data = match input.data {
47 syn::Data::Struct(struct_data) => struct_data,
48 _ => panic!("Fetch can only be derived for structs"),
49 };
50
51 let fields = struct_data.fields;
52
53 let bindings = fields.iter().enumerate().map(|(index, field)| {
54 let field_name = field
55 .ident
56 .as_ref()
57 .expect("All struct members must be named");
58 let col_index = (index + 1) as u16;
59 quote! {
60 odbc_api::buffers::FetchRowMember::bind_to_col(
61 &mut self.#field_name,
62 #col_index,
63 &mut cursor
64 )?;
65 }
66 });
67
68 let find_truncation = fields.iter().enumerate().map(|(index, field)| {
69 let field_name = field
70 .ident
71 .as_ref()
72 .expect("All struct members must be named");
73 quote! {
74 let maybe_truncation = odbc_api::buffers::FetchRowMember::find_truncation(
75 &self.#field_name,
76 #index,
77 );
78 if let Some(truncation_info) = maybe_truncation {
79 return Some(truncation_info);
80 }
81 }
82 });
83
84 let expanded = quote! {
85 unsafe impl odbc_api::buffers::FetchRow for #struct_name {
86
87 unsafe fn bind_columns_to_cursor(
88 &mut self,
89 mut cursor: odbc_api::handles::StatementRef<'_>
90 ) -> std::result::Result<(), odbc_api::Error> {
91 #(#bindings)*
92 Ok(())
93 }
94
95 fn find_truncation(&self) -> std::option::Option<odbc_api::TruncationInfo> {
96 #(#find_truncation)*
97 None
98 }
99 }
100 };
101
102 expanded.into()
103}