1use libc::{c_char, c_int, c_void, size_t};
4use libloading::Symbol;
5use once_cell::sync::OnceCell;
6
7use crate::{
8 error::{BamlSysError, Result},
9 loader::{get_library, LoadedLibrary, VERSION},
10};
11
12pub type CallbackFn =
14 extern "C" fn(call_id: u32, is_done: c_int, content: *const i8, length: size_t);
15
16pub type OnTickCallbackFn = extern "C" fn(call_id: u32);
18
19#[repr(C)]
21#[allow(missing_docs)] pub struct Buffer {
23 pub ptr: *const i8,
25 pub len: size_t,
27}
28
29type VersionFn = unsafe extern "C" fn() -> Buffer;
31type RegisterCallbacksFn = unsafe extern "C" fn(CallbackFn, CallbackFn, OnTickCallbackFn);
32type CreateBamlRuntimeFn =
33 unsafe extern "C" fn(*const c_char, *const c_char, *const c_char) -> *const c_void;
34type DestroyBamlRuntimeFn = unsafe extern "C" fn(*const c_void);
35type InvokeRuntimeCliFn = unsafe extern "C" fn(*const *const c_char) -> c_int;
36type CallFunctionFromCFn =
37 unsafe extern "C" fn(*const c_void, *const c_char, *const c_char, size_t, u32) -> Buffer;
38type CancelFunctionCallFn = unsafe extern "C" fn(u32) -> Buffer;
39type CallObjectConstructorFn = unsafe extern "C" fn(*const c_char, size_t) -> Buffer;
40type CallObjectMethodFn = unsafe extern "C" fn(*const c_void, *const c_char, size_t) -> Buffer;
41type FreeBufferFn = unsafe extern "C" fn(Buffer);
42
43#[allow(missing_docs)] pub struct Symbols {
46 pub(crate) version: Symbol<'static, VersionFn>,
47 pub(crate) register_callbacks: Symbol<'static, RegisterCallbacksFn>,
48 pub(crate) create_baml_runtime: Symbol<'static, CreateBamlRuntimeFn>,
49 pub(crate) destroy_baml_runtime: Symbol<'static, DestroyBamlRuntimeFn>,
50 pub(crate) invoke_runtime_cli: Symbol<'static, InvokeRuntimeCliFn>,
51 pub(crate) call_function_from_c: Symbol<'static, CallFunctionFromCFn>,
52 pub(crate) call_function_stream_from_c: Symbol<'static, CallFunctionFromCFn>,
53 pub(crate) call_function_parse_from_c: Symbol<'static, CallFunctionFromCFn>,
54 pub(crate) cancel_function_call: Symbol<'static, CancelFunctionCallFn>,
55 pub(crate) call_object_constructor: Symbol<'static, CallObjectConstructorFn>,
56 pub(crate) call_object_method: Symbol<'static, CallObjectMethodFn>,
57 pub(crate) free_buffer: Symbol<'static, FreeBufferFn>,
58}
59
60static SYMBOLS: OnceCell<Symbols> = OnceCell::new();
62
63pub fn get_symbols() -> Result<&'static Symbols> {
65 SYMBOLS.get_or_try_init(|| {
66 let lib = get_library()?;
67 load_symbols(lib)
68 })
69}
70
71fn load_symbols(lib: &'static LoadedLibrary) -> Result<Symbols> {
73 #[allow(unsafe_code)]
76 unsafe {
77 let free_buffer: Symbol<FreeBufferFn> = load_symbol(&lib.library, "free_buffer")?;
79 let version: Symbol<VersionFn> = load_symbol(&lib.library, "version")?;
80
81 let version_buf = version();
83 let lib_version = if !version_buf.ptr.is_null() && version_buf.len > 0 {
84 let bytes = std::slice::from_raw_parts(version_buf.ptr as *const u8, version_buf.len);
85 String::from_utf8_lossy(bytes).into_owned()
86 } else {
87 "unknown".to_string()
88 };
89 free_buffer(version_buf); if lib_version != VERSION {
92 return Err(BamlSysError::VersionMismatch {
93 expected: VERSION.to_string(),
94 actual: lib_version.to_string(),
95 });
96 }
97
98 Ok(Symbols {
99 version,
100 register_callbacks: load_symbol(&lib.library, "register_callbacks")?,
101 create_baml_runtime: load_symbol(&lib.library, "create_baml_runtime")?,
102 destroy_baml_runtime: load_symbol(&lib.library, "destroy_baml_runtime")?,
103 invoke_runtime_cli: load_symbol(&lib.library, "invoke_runtime_cli")?,
104 call_function_from_c: load_symbol(&lib.library, "call_function_from_c")?,
105 call_function_stream_from_c: load_symbol(&lib.library, "call_function_stream_from_c")?,
106 call_function_parse_from_c: load_symbol(&lib.library, "call_function_parse_from_c")?,
107 cancel_function_call: load_symbol(&lib.library, "cancel_function_call")?,
108 call_object_constructor: load_symbol(&lib.library, "call_object_constructor")?,
109 call_object_method: load_symbol(&lib.library, "call_object_method")?,
110 free_buffer,
111 })
112 }
113}
114
115#[allow(unsafe_code)]
117unsafe fn load_symbol<T>(
118 library: &'static libloading::Library,
119 name: &'static str,
120) -> Result<Symbol<'static, T>> {
121 unsafe {
122 library
123 .get(name.as_bytes())
124 .map_err(|e| BamlSysError::SymbolNotFound {
125 symbol: name,
126 source: e,
127 })
128 }
129}