baml_sys/
symbols.rs

1//! FFI symbol loading and storage.
2
3use std::ffi::CStr;
4
5use libc::{c_char, c_int, c_void, size_t};
6use libloading::Symbol;
7use once_cell::sync::OnceCell;
8
9use crate::{
10    error::{BamlSysError, Result},
11    loader::{LoadedLibrary, VERSION, get_library},
12};
13
14/// Callback function type for results.
15pub type CallbackFn =
16    extern "C" fn(call_id: u32, is_done: c_int, content: *const i8, length: size_t);
17
18/// Callback function type for streaming ticks.
19pub type OnTickCallbackFn = extern "C" fn(call_id: u32);
20
21/// Buffer returned from object operations.
22#[repr(C)]
23#[allow(missing_docs)] // FFI struct fields are self-explanatory
24pub struct Buffer {
25    /// Pointer to the buffer data.
26    pub ptr: *const i8,
27    /// Length of the buffer.
28    pub len: size_t,
29}
30
31// Type aliases for FFI function signatures
32type VersionFn = unsafe extern "C" fn() -> *const c_char;
33type RegisterCallbacksFn = unsafe extern "C" fn(CallbackFn, CallbackFn, OnTickCallbackFn);
34type CreateBamlRuntimeFn =
35    unsafe extern "C" fn(*const c_char, *const c_char, *const c_char) -> *const c_void;
36type DestroyBamlRuntimeFn = unsafe extern "C" fn(*const c_void);
37type InvokeRuntimeCliFn = unsafe extern "C" fn(*const *const c_char) -> c_int;
38type CallFunctionFromCFn =
39    unsafe extern "C" fn(*const c_void, *const c_char, *const c_char, size_t, u32) -> *const c_void;
40type CancelFunctionCallFn = unsafe extern "C" fn(u32) -> *const c_void;
41type CallObjectConstructorFn = unsafe extern "C" fn(*const c_char, size_t) -> Buffer;
42type CallObjectMethodFn = unsafe extern "C" fn(*const c_void, *const c_char, size_t) -> Buffer;
43type FreeBufferFn = unsafe extern "C" fn(Buffer);
44
45/// Loaded symbols from the dynamic library.
46#[allow(missing_docs)] // FFI symbol fields match their C function names
47pub struct Symbols {
48    pub(crate) version: Symbol<'static, VersionFn>,
49    pub(crate) register_callbacks: Symbol<'static, RegisterCallbacksFn>,
50    pub(crate) create_baml_runtime: Symbol<'static, CreateBamlRuntimeFn>,
51    pub(crate) destroy_baml_runtime: Symbol<'static, DestroyBamlRuntimeFn>,
52    pub(crate) invoke_runtime_cli: Symbol<'static, InvokeRuntimeCliFn>,
53    pub(crate) call_function_from_c: Symbol<'static, CallFunctionFromCFn>,
54    pub(crate) call_function_stream_from_c: Symbol<'static, CallFunctionFromCFn>,
55    pub(crate) call_function_parse_from_c: Symbol<'static, CallFunctionFromCFn>,
56    pub(crate) cancel_function_call: Symbol<'static, CancelFunctionCallFn>,
57    pub(crate) call_object_constructor: Symbol<'static, CallObjectConstructorFn>,
58    pub(crate) call_object_method: Symbol<'static, CallObjectMethodFn>,
59    pub(crate) free_buffer: Symbol<'static, FreeBufferFn>,
60}
61
62/// Global symbols instance.
63static SYMBOLS: OnceCell<Symbols> = OnceCell::new();
64
65/// Get the loaded symbols, initializing if necessary.
66pub fn get_symbols() -> Result<&'static Symbols> {
67    SYMBOLS.get_or_try_init(|| {
68        let lib = get_library()?;
69        load_symbols(lib)
70    })
71}
72
73/// Load all symbols from the library.
74fn load_symbols(lib: &'static LoadedLibrary) -> Result<Symbols> {
75    // Safety: We're loading symbols from a dynamic library that should
76    // have been built with the matching C ABI.
77    #[allow(unsafe_code)]
78    unsafe {
79        let version: Symbol<VersionFn> = load_symbol(&lib.library, "version")?;
80
81        // Verify version matches
82        let lib_version_ptr = version();
83        let lib_version = CStr::from_ptr(lib_version_ptr)
84            .to_str()
85            .unwrap_or("unknown");
86
87        if lib_version != VERSION {
88            return Err(BamlSysError::VersionMismatch {
89                expected: VERSION.to_string(),
90                actual: lib_version.to_string(),
91            });
92        }
93
94        Ok(Symbols {
95            version,
96            register_callbacks: load_symbol(&lib.library, "register_callbacks")?,
97            create_baml_runtime: load_symbol(&lib.library, "create_baml_runtime")?,
98            destroy_baml_runtime: load_symbol(&lib.library, "destroy_baml_runtime")?,
99            invoke_runtime_cli: load_symbol(&lib.library, "invoke_runtime_cli")?,
100            call_function_from_c: load_symbol(&lib.library, "call_function_from_c")?,
101            call_function_stream_from_c: load_symbol(&lib.library, "call_function_stream_from_c")?,
102            call_function_parse_from_c: load_symbol(&lib.library, "call_function_parse_from_c")?,
103            cancel_function_call: load_symbol(&lib.library, "cancel_function_call")?,
104            call_object_constructor: load_symbol(&lib.library, "call_object_constructor")?,
105            call_object_method: load_symbol(&lib.library, "call_object_method")?,
106            free_buffer: load_symbol(&lib.library, "free_buffer")?,
107        })
108    }
109}
110
111/// Load a single symbol from the library.
112#[allow(unsafe_code)]
113unsafe fn load_symbol<T>(
114    library: &'static libloading::Library,
115    name: &'static str,
116) -> Result<Symbol<'static, T>> {
117    unsafe {
118        library
119            .get(name.as_bytes())
120            .map_err(|e| BamlSysError::SymbolNotFound {
121                symbol: name,
122                source: e,
123            })
124    }
125}