baml_sys/
symbols.rs

1//! FFI symbol loading and storage.
2
3use libc::{c_char, c_int, c_void, size_t};
4use libloading::Symbol;
5use once_cell::sync::OnceCell;
6
7use crate::{
8    error::{BamlSysError, Result},
9    loader::{get_library, LoadedLibrary, VERSION},
10};
11
12/// Callback function type for results.
13pub type CallbackFn =
14    extern "C" fn(call_id: u32, is_done: c_int, content: *const i8, length: size_t);
15
16/// Callback function type for streaming ticks.
17pub type OnTickCallbackFn = extern "C" fn(call_id: u32);
18
19/// Buffer returned from object operations.
20#[repr(C)]
21#[allow(missing_docs)] // FFI struct fields are self-explanatory
22pub struct Buffer {
23    /// Pointer to the buffer data.
24    pub ptr: *const i8,
25    /// Length of the buffer.
26    pub len: size_t,
27}
28
29// Type aliases for FFI function signatures
30type VersionFn = unsafe extern "C" fn() -> Buffer;
31type RegisterCallbacksFn = unsafe extern "C" fn(CallbackFn, CallbackFn, OnTickCallbackFn);
32type CreateBamlRuntimeFn =
33    unsafe extern "C" fn(*const c_char, *const c_char, *const c_char) -> *const c_void;
34type DestroyBamlRuntimeFn = unsafe extern "C" fn(*const c_void);
35type InvokeRuntimeCliFn = unsafe extern "C" fn(*const *const c_char) -> c_int;
36type CallFunctionFromCFn =
37    unsafe extern "C" fn(*const c_void, *const c_char, *const c_char, size_t, u32) -> Buffer;
38type CancelFunctionCallFn = unsafe extern "C" fn(u32) -> Buffer;
39type CallObjectConstructorFn = unsafe extern "C" fn(*const c_char, size_t) -> Buffer;
40type CallObjectMethodFn = unsafe extern "C" fn(*const c_void, *const c_char, size_t) -> Buffer;
41type FreeBufferFn = unsafe extern "C" fn(Buffer);
42
43/// Loaded symbols from the dynamic library.
44#[allow(missing_docs)] // FFI symbol fields match their C function names
45pub struct Symbols {
46    pub(crate) version: Symbol<'static, VersionFn>,
47    pub(crate) register_callbacks: Symbol<'static, RegisterCallbacksFn>,
48    pub(crate) create_baml_runtime: Symbol<'static, CreateBamlRuntimeFn>,
49    pub(crate) destroy_baml_runtime: Symbol<'static, DestroyBamlRuntimeFn>,
50    pub(crate) invoke_runtime_cli: Symbol<'static, InvokeRuntimeCliFn>,
51    pub(crate) call_function_from_c: Symbol<'static, CallFunctionFromCFn>,
52    pub(crate) call_function_stream_from_c: Symbol<'static, CallFunctionFromCFn>,
53    pub(crate) call_function_parse_from_c: Symbol<'static, CallFunctionFromCFn>,
54    pub(crate) cancel_function_call: Symbol<'static, CancelFunctionCallFn>,
55    pub(crate) call_object_constructor: Symbol<'static, CallObjectConstructorFn>,
56    pub(crate) call_object_method: Symbol<'static, CallObjectMethodFn>,
57    pub(crate) free_buffer: Symbol<'static, FreeBufferFn>,
58}
59
60/// Global symbols instance.
61static SYMBOLS: OnceCell<Symbols> = OnceCell::new();
62
63/// Get the loaded symbols, initializing if necessary.
64pub fn get_symbols() -> Result<&'static Symbols> {
65    SYMBOLS.get_or_try_init(|| {
66        let lib = get_library()?;
67        load_symbols(lib)
68    })
69}
70
71/// Load all symbols from the library.
72fn load_symbols(lib: &'static LoadedLibrary) -> Result<Symbols> {
73    // Safety: We're loading symbols from a dynamic library that should
74    // have been built with the matching C ABI.
75    #[allow(unsafe_code)]
76    unsafe {
77        // Load free_buffer first so we can clean up the version buffer
78        let free_buffer: Symbol<FreeBufferFn> = load_symbol(&lib.library, "free_buffer")?;
79        let version: Symbol<VersionFn> = load_symbol(&lib.library, "version")?;
80
81        // Verify version matches - version() now returns Buffer
82        let version_buf = version();
83        let lib_version = if !version_buf.ptr.is_null() && version_buf.len > 0 {
84            let bytes = std::slice::from_raw_parts(version_buf.ptr as *const u8, version_buf.len);
85            String::from_utf8_lossy(bytes).into_owned()
86        } else {
87            "unknown".to_string()
88        };
89        free_buffer(version_buf); // Clean up immediately
90
91        if lib_version != VERSION {
92            return Err(BamlSysError::VersionMismatch {
93                expected: VERSION.to_string(),
94                actual: lib_version.to_string(),
95            });
96        }
97
98        Ok(Symbols {
99            version,
100            register_callbacks: load_symbol(&lib.library, "register_callbacks")?,
101            create_baml_runtime: load_symbol(&lib.library, "create_baml_runtime")?,
102            destroy_baml_runtime: load_symbol(&lib.library, "destroy_baml_runtime")?,
103            invoke_runtime_cli: load_symbol(&lib.library, "invoke_runtime_cli")?,
104            call_function_from_c: load_symbol(&lib.library, "call_function_from_c")?,
105            call_function_stream_from_c: load_symbol(&lib.library, "call_function_stream_from_c")?,
106            call_function_parse_from_c: load_symbol(&lib.library, "call_function_parse_from_c")?,
107            cancel_function_call: load_symbol(&lib.library, "cancel_function_call")?,
108            call_object_constructor: load_symbol(&lib.library, "call_object_constructor")?,
109            call_object_method: load_symbol(&lib.library, "call_object_method")?,
110            free_buffer,
111        })
112    }
113}
114
115/// Load a single symbol from the library.
116#[allow(unsafe_code)]
117unsafe fn load_symbol<T>(
118    library: &'static libloading::Library,
119    name: &'static str,
120) -> Result<Symbol<'static, T>> {
121    unsafe {
122        library
123            .get(name.as_bytes())
124            .map_err(|e| BamlSysError::SymbolNotFound {
125                symbol: name,
126                source: e,
127            })
128    }
129}