duckdb_loadable_macros/
lib.rs

1#![allow(clippy::redundant_clone)]
2use proc_macro2::{Ident, Span};
3
4use syn::{parse_macro_input, spanned::Spanned, Item};
5
6use proc_macro::TokenStream;
7use quote::quote_spanned;
8
9use darling::{ast::NestedMeta, Error, FromMeta};
10
11use std::env;
12
13/// For parsing the arguments to the duckdb_entrypoint_c_api macro
14#[derive(Debug, FromMeta)]
15struct CEntryPointMacroArgs {
16    #[darling(default)]
17    /// The name to be given to this extension. This name is used in the entrypoint function called by duckdb
18    ext_name: Option<String>,
19    /// The minimum C API version this extension requires. It is recommended to set this to the lowest possible version
20    /// at which your extension still compiles
21    min_duckdb_version: Option<String>,
22}
23
24/// Wraps an entrypoint function to expose an unsafe extern "C" function of the same name.
25/// Warning: experimental!
26#[proc_macro_attribute]
27pub fn duckdb_entrypoint_c_api(attr: TokenStream, item: TokenStream) -> TokenStream {
28    let attr_args = match NestedMeta::parse_meta_list(attr.into()) {
29        Ok(v) => v,
30        Err(e) => {
31            return TokenStream::from(Error::from(e).write_errors());
32        }
33    };
34
35    let args = match CEntryPointMacroArgs::from_list(&attr_args) {
36        Ok(v) => v,
37        Err(e) => {
38            return TokenStream::from(e.write_errors());
39        }
40    };
41
42    // Set the minimum duckdb version (dev by default)
43    let minimum_duckdb_version = match args.min_duckdb_version {
44        Some(i) => i,
45        None => env::var("DUCKDB_EXTENSION_MIN_DUCKDB_VERSION").expect("Please either set env var DUCKDB_EXTENSION_MIN_DUCKDB_VERSION or pass it as an argument to the proc macro").to_string()
46    };
47
48    let extension_name = match args.ext_name {
49        Some(i) => i,
50        None => env::var("DUCKDB_EXTENSION_NAME").expect("Please either set env var DUCKDB_EXTENSION_MIN_DUCKDB_VERSION or pass it as an argument to the proc macro").to_string()
51    };
52
53    let ast = parse_macro_input!(item as syn::Item);
54
55    match ast {
56        Item::Fn(func) => {
57            let c_entrypoint = Ident::new(format!("{}_init_c_api", extension_name).as_str(), Span::call_site());
58            let prefixed_original_function = func.sig.ident.clone();
59            let c_entrypoint_internal = Ident::new(
60                format!("{}_init_c_api_internal", extension_name).as_str(),
61                Span::call_site(),
62            );
63
64            quote_spanned! {func.span()=>
65                /// # Safety
66                ///
67                /// Internal Entrypoint for error handling
68                pub unsafe fn #c_entrypoint_internal(info: ffi::duckdb_extension_info, access: *const ffi::duckdb_extension_access) -> Result<bool, Box<dyn std::error::Error>> {
69                    let have_api_struct = ffi::duckdb_rs_extension_api_init(info, access, #minimum_duckdb_version).unwrap();
70
71                    if !have_api_struct {
72                        // initialization failed to return an api struct, likely due to an API version mismatch, we can simply return here
73                        return Ok(false);
74                    }
75
76                    // TODO: handle error here?
77                    let db : ffi::duckdb_database = *(*access).get_database.unwrap()(info);
78                    let connection = Connection::open_from_raw(db.cast())?;
79
80                    #prefixed_original_function(connection)?;
81
82                    Ok(true)
83                }
84
85                /// # Safety
86                ///
87                /// Entrypoint that will be called by DuckDB
88                #[no_mangle]
89                pub unsafe extern "C" fn #c_entrypoint(info: ffi::duckdb_extension_info, access: *const ffi::duckdb_extension_access) -> bool {
90                    let init_result = #c_entrypoint_internal(info, access);
91
92                    if let Err(x) = init_result {
93                        let error_c_string = std::ffi::CString::new(x.to_string());
94
95                        match error_c_string {
96                            Ok(e) => {
97                                (*access).set_error.unwrap()(info, e.as_ptr());
98                            },
99                            Err(_e) => {
100                                let error_alloc_failure = c"An error occured but the extension failed to allocate memory for an error string";
101                                (*access).set_error.unwrap()(info, error_alloc_failure.as_ptr());
102                            }
103                        }
104                        return false;
105                    }
106
107                    init_result.unwrap()
108                }
109
110                #func
111            }
112            .into()
113        }
114        _ => panic!("Only function items are allowed on duckdb_entrypoint"),
115    }
116}
117
118/// Wraps an entrypoint function to expose an unsafe extern "C" function of the same name.
119#[proc_macro_attribute]
120pub fn duckdb_entrypoint(_attr: TokenStream, item: TokenStream) -> TokenStream {
121    let ast = parse_macro_input!(item as syn::Item);
122    match ast {
123        Item::Fn(mut func) => {
124            let c_entrypoint = func.sig.ident.clone();
125            let c_entrypoint_version = Ident::new(
126                c_entrypoint.to_string().replace("_init", "_version").as_str(),
127                Span::call_site(),
128            );
129
130            let original_funcname = func.sig.ident.to_string();
131            func.sig.ident = Ident::new(format!("_{}", original_funcname).as_str(), func.sig.ident.span());
132
133            let prefixed_original_function = func.sig.ident.clone();
134
135            quote_spanned! {func.span()=>
136                #func
137
138                /// # Safety
139                ///
140                /// Will be called by duckdb
141                #[no_mangle]
142                pub unsafe extern "C" fn #c_entrypoint(db: *mut c_void) {
143                    unsafe {
144                        let connection = Connection::open_from_raw(db.cast()).expect("can't open db connection");
145                        #prefixed_original_function(connection).expect("init failed");
146                    }
147                }
148
149                /// # Safety
150                ///
151                /// Predefined function, don't need to change unless you are sure
152                #[no_mangle]
153                pub unsafe extern "C" fn #c_entrypoint_version() -> *const c_char {
154                    unsafe {
155                        ffi::duckdb_library_version()
156                    }
157                }
158
159
160            }
161            .into()
162        }
163        _ => panic!("Only function items are allowed on duckdb_entrypoint"),
164    }
165}