Skip to main content

daft_ext_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Literal;
3use quote::quote;
4use syn::{DeriveInput, parse_macro_input};
5
6/// Generates the `daft_module_magic` entry point for a Daft extension module.
7///
8/// Place this attribute on a struct that implements `DaftExtension`. The macro
9/// generates the `#[no_mangle] pub extern "C" fn daft_module_magic()` symbol
10/// that Daft's module loader resolves via `dlopen`.
11///
12/// # Example
13///
14/// ```ignore
15/// use daft_ext::prelude::*;
16///
17/// #[daft_extension]
18/// struct MyExtension;
19///
20/// impl DaftExtension for MyExtension {
21///     fn install(session: &mut dyn DaftSession) {
22///         // register scalar functions here
23///     }
24/// }
25/// ```
26#[proc_macro_attribute]
27pub fn daft_extension(_attr: TokenStream, item: TokenStream) -> TokenStream {
28    let input = parse_macro_input!(item as DeriveInput);
29    let ident = &input.ident;
30    let name = pascal_to_snake(&ident.to_string());
31    let mut name_bytes = name.into_bytes();
32    name_bytes.push(0); // null terminator
33    let name_lit = Literal::byte_string(&name_bytes);
34
35    let output = quote! {
36        #input
37
38        #[unsafe(no_mangle)]
39        pub extern "C" fn daft_module_magic() -> ::daft_ext::abi::FFI_Module {
40            unsafe extern "C" fn __daft_init(
41                session: *mut ::daft_ext::abi::FFI_SessionContext,
42            ) -> ::std::ffi::c_int {
43                let session = unsafe { &mut *session };
44                let mut ctx = ::daft_ext::prelude::SessionContext::new(session);
45                <#ident as ::daft_ext::prelude::DaftExtension>::install(&mut ctx);
46                0
47            }
48
49            ::daft_ext::abi::FFI_Module {
50                daft_abi_version: ::daft_ext::abi::DAFT_ABI_VERSION,
51                // SAFETY: literal is null-terminated and valid UTF-8.
52                name: unsafe {
53                    ::std::ffi::CStr::from_bytes_with_nul_unchecked(#name_lit)
54                }.as_ptr(),
55                init: __daft_init,
56                free_string: ::daft_ext::prelude::free_string,
57            }
58        }
59    };
60
61    output.into()
62}
63
64/// Convert a PascalCase identifier to snake_case.
65fn pascal_to_snake(s: &str) -> String {
66    let mut out = String::with_capacity(s.len() + 4);
67    for (i, ch) in s.chars().enumerate() {
68        if ch.is_uppercase() {
69            if i > 0 {
70                out.push('_');
71            }
72            out.push(ch.to_lowercase().next().unwrap());
73        } else {
74            out.push(ch);
75        }
76    }
77    out
78}