duckdb_loadable_macros/
lib.rs

1#![allow(clippy::redundant_clone)]
2use proc_macro2::{Ident, Span};
3
4use syn::{parse_macro_input, spanned::Spanned, Item};
5
6use proc_macro::TokenStream;
7use quote::quote_spanned;
8
9use darling::{ast::NestedMeta, Error, FromMeta};
10
11use std::env;
12
13const DEFAULT_DUCKDB_VERSION: &str = "v1.2.0";
14
15/// For parsing the arguments to the duckdb_entrypoint_c_api macro
16#[derive(Debug, FromMeta)]
17struct CEntryPointMacroArgs {
18    #[darling(default)]
19    /// The name to be given to this extension. This name is used in the entrypoint function called by duckdb
20    ext_name: Option<String>,
21    /// The minimum C API version this extension requires. It is recommended to set this to the lowest possible version
22    /// at which your extension still compiles
23    min_duckdb_version: Option<String>,
24}
25
26/// Wraps an entrypoint function to expose an unsafe extern "C" function of the same name.
27/// Warning: experimental!
28#[proc_macro_attribute]
29pub fn duckdb_entrypoint_c_api(attr: TokenStream, item: TokenStream) -> TokenStream {
30    let attr_args = match NestedMeta::parse_meta_list(attr.into()) {
31        Ok(v) => v,
32        Err(e) => {
33            return TokenStream::from(Error::from(e).write_errors());
34        }
35    };
36
37    let args = match CEntryPointMacroArgs::from_list(&attr_args) {
38        Ok(v) => v,
39        Err(e) => {
40            return TokenStream::from(e.write_errors());
41        }
42    };
43
44    // Set the minimum duckdb version (dev by default)
45    let minimum_duckdb_version = match (args.min_duckdb_version, env::var("DUCKDB_EXTENSION_MIN_DUCKDB_VERSION")) {
46        (Some(i), _) => i,
47        (None, Ok(i)) => i.to_string(),
48        _ => DEFAULT_DUCKDB_VERSION.to_string(),
49    };
50
51    let extension_name = match (args.ext_name, env::var("DUCKDB_EXTENSION_NAME")) {
52        (Some(i), _) => i,
53        (None, Ok(i)) => i.to_string(),
54        _ => env::var("CARGO_PKG_NAME").unwrap().to_string(),
55    };
56
57    let ast = parse_macro_input!(item as syn::Item);
58
59    match ast {
60        Item::Fn(func) => {
61            let c_entrypoint = Ident::new(format!("{extension_name}_init_c_api").as_str(), Span::call_site());
62            let prefixed_original_function = func.sig.ident.clone();
63            let c_entrypoint_internal = Ident::new(
64                format!("{extension_name}_init_c_api_internal").as_str(),
65                Span::call_site(),
66            );
67
68            quote_spanned! {func.span()=>
69                /// # Safety
70                ///
71                /// Internal Entrypoint for error handling
72                pub unsafe fn #c_entrypoint_internal(info: ffi::duckdb_extension_info, access: *const ffi::duckdb_extension_access) -> Result<bool, Box<dyn std::error::Error>> {
73                    let have_api_struct = ffi::duckdb_rs_extension_api_init(info, access, #minimum_duckdb_version).unwrap();
74
75                    if !have_api_struct {
76                        // initialization failed to return an api struct, likely due to an API version mismatch, we can simply return here
77                        return Ok(false);
78                    }
79
80                    // TODO: handle error here?
81                    let db : ffi::duckdb_database = *(*access).get_database.unwrap()(info);
82                    let connection = Connection::open_from_raw(db.cast())?;
83
84                    #prefixed_original_function(connection)?;
85
86                    Ok(true)
87                }
88
89                /// # Safety
90                ///
91                /// Entrypoint that will be called by DuckDB
92                #[no_mangle]
93                pub unsafe extern "C" fn #c_entrypoint(info: ffi::duckdb_extension_info, access: *const ffi::duckdb_extension_access) -> bool {
94                    let init_result = #c_entrypoint_internal(info, access);
95
96                    if let Err(x) = init_result {
97                        let error_c_string = std::ffi::CString::new(x.to_string());
98
99                        match error_c_string {
100                            Ok(e) => {
101                                (*access).set_error.unwrap()(info, e.as_ptr());
102                            },
103                            Err(_e) => {
104                                let error_alloc_failure = c"An error occured but the extension failed to allocate memory for an error string";
105                                (*access).set_error.unwrap()(info, error_alloc_failure.as_ptr());
106                            }
107                        }
108                        return false;
109                    }
110
111                    init_result.unwrap()
112                }
113
114                #func
115            }
116            .into()
117        }
118        _ => panic!("Only function items are allowed on duckdb_entrypoint"),
119    }
120}
121
122/// Wraps an entrypoint function to expose an unsafe extern "C" function of the same name.
123#[proc_macro_attribute]
124pub fn duckdb_entrypoint(_attr: TokenStream, item: TokenStream) -> TokenStream {
125    let ast = parse_macro_input!(item as syn::Item);
126    match ast {
127        Item::Fn(mut func) => {
128            let c_entrypoint = func.sig.ident.clone();
129            let c_entrypoint_version = Ident::new(
130                c_entrypoint.to_string().replace("_init", "_version").as_str(),
131                Span::call_site(),
132            );
133
134            let original_funcname = func.sig.ident.to_string();
135            func.sig.ident = Ident::new(format!("_{original_funcname}").as_str(), func.sig.ident.span());
136
137            let prefixed_original_function = func.sig.ident.clone();
138
139            quote_spanned! {func.span()=>
140                #func
141
142                /// # Safety
143                ///
144                /// Will be called by duckdb
145                #[unsafe(no_mangle)]
146                pub unsafe extern "C" fn #c_entrypoint(db: *mut std::ffi::c_void) {
147                    unsafe {
148                        let connection = Connection::open_from_raw(db.cast()).expect("can't open db connection");
149                        #prefixed_original_function(connection).expect("init failed");
150                    }
151                }
152
153                /// # Safety
154                ///
155                /// Predefined function, don't need to change unless you are sure
156                #[unsafe(no_mangle)]
157                pub unsafe extern "C" fn #c_entrypoint_version() -> *const std::ffi::c_char {
158                    unsafe {
159                        ffi::duckdb_library_version()
160                    }
161                }
162
163
164            }
165            .into()
166        }
167        _ => panic!("Only function items are allowed on duckdb_entrypoint"),
168    }
169}