1use std::ffi::CStr;
4
5use libc::{c_char, c_int, c_void, size_t};
6use libloading::Symbol;
7use once_cell::sync::OnceCell;
8
9use crate::{
10 error::{BamlSysError, Result},
11 loader::{LoadedLibrary, VERSION, get_library},
12};
13
14pub type CallbackFn =
16 extern "C" fn(call_id: u32, is_done: c_int, content: *const i8, length: size_t);
17
18pub type OnTickCallbackFn = extern "C" fn(call_id: u32);
20
21#[repr(C)]
23#[allow(missing_docs)] pub struct Buffer {
25 pub ptr: *const i8,
27 pub len: size_t,
29}
30
31type VersionFn = unsafe extern "C" fn() -> *const c_char;
33type RegisterCallbacksFn = unsafe extern "C" fn(CallbackFn, CallbackFn, OnTickCallbackFn);
34type CreateBamlRuntimeFn =
35 unsafe extern "C" fn(*const c_char, *const c_char, *const c_char) -> *const c_void;
36type DestroyBamlRuntimeFn = unsafe extern "C" fn(*const c_void);
37type InvokeRuntimeCliFn = unsafe extern "C" fn(*const *const c_char) -> c_int;
38type CallFunctionFromCFn =
39 unsafe extern "C" fn(*const c_void, *const c_char, *const c_char, size_t, u32) -> *const c_void;
40type CancelFunctionCallFn = unsafe extern "C" fn(u32) -> *const c_void;
41type CallObjectConstructorFn = unsafe extern "C" fn(*const c_char, size_t) -> Buffer;
42type CallObjectMethodFn = unsafe extern "C" fn(*const c_void, *const c_char, size_t) -> Buffer;
43type FreeBufferFn = unsafe extern "C" fn(Buffer);
44
45#[allow(missing_docs)] pub struct Symbols {
48 pub(crate) version: Symbol<'static, VersionFn>,
49 pub(crate) register_callbacks: Symbol<'static, RegisterCallbacksFn>,
50 pub(crate) create_baml_runtime: Symbol<'static, CreateBamlRuntimeFn>,
51 pub(crate) destroy_baml_runtime: Symbol<'static, DestroyBamlRuntimeFn>,
52 pub(crate) invoke_runtime_cli: Symbol<'static, InvokeRuntimeCliFn>,
53 pub(crate) call_function_from_c: Symbol<'static, CallFunctionFromCFn>,
54 pub(crate) call_function_stream_from_c: Symbol<'static, CallFunctionFromCFn>,
55 pub(crate) call_function_parse_from_c: Symbol<'static, CallFunctionFromCFn>,
56 pub(crate) cancel_function_call: Symbol<'static, CancelFunctionCallFn>,
57 pub(crate) call_object_constructor: Symbol<'static, CallObjectConstructorFn>,
58 pub(crate) call_object_method: Symbol<'static, CallObjectMethodFn>,
59 pub(crate) free_buffer: Symbol<'static, FreeBufferFn>,
60}
61
62static SYMBOLS: OnceCell<Symbols> = OnceCell::new();
64
65pub fn get_symbols() -> Result<&'static Symbols> {
67 SYMBOLS.get_or_try_init(|| {
68 let lib = get_library()?;
69 load_symbols(lib)
70 })
71}
72
73fn load_symbols(lib: &'static LoadedLibrary) -> Result<Symbols> {
75 #[allow(unsafe_code)]
78 unsafe {
79 let version: Symbol<VersionFn> = load_symbol(&lib.library, "version")?;
80
81 let lib_version_ptr = version();
83 let lib_version = CStr::from_ptr(lib_version_ptr)
84 .to_str()
85 .unwrap_or("unknown");
86
87 if lib_version != VERSION {
88 return Err(BamlSysError::VersionMismatch {
89 expected: VERSION.to_string(),
90 actual: lib_version.to_string(),
91 });
92 }
93
94 Ok(Symbols {
95 version,
96 register_callbacks: load_symbol(&lib.library, "register_callbacks")?,
97 create_baml_runtime: load_symbol(&lib.library, "create_baml_runtime")?,
98 destroy_baml_runtime: load_symbol(&lib.library, "destroy_baml_runtime")?,
99 invoke_runtime_cli: load_symbol(&lib.library, "invoke_runtime_cli")?,
100 call_function_from_c: load_symbol(&lib.library, "call_function_from_c")?,
101 call_function_stream_from_c: load_symbol(&lib.library, "call_function_stream_from_c")?,
102 call_function_parse_from_c: load_symbol(&lib.library, "call_function_parse_from_c")?,
103 cancel_function_call: load_symbol(&lib.library, "cancel_function_call")?,
104 call_object_constructor: load_symbol(&lib.library, "call_object_constructor")?,
105 call_object_method: load_symbol(&lib.library, "call_object_method")?,
106 free_buffer: load_symbol(&lib.library, "free_buffer")?,
107 })
108 }
109}
110
111#[allow(unsafe_code)]
113unsafe fn load_symbol<T>(
114 library: &'static libloading::Library,
115 name: &'static str,
116) -> Result<Symbol<'static, T>> {
117 unsafe {
118 library
119 .get(name.as_bytes())
120 .map_err(|e| BamlSysError::SymbolNotFound {
121 symbol: name,
122 source: e,
123 })
124 }
125}