#![allow(unsafe_code)]
use std::{
collections::HashMap,
ffi::{c_char, c_void, CString},
};
use prost::Message;
use crate::{
args::FunctionArgs,
async_stream::AsyncStreamingCall,
codec::{traits::DecodeHandle, BamlDecode},
error::BamlError,
ffi::{self, callbacks},
proto::baml_cffi_v1::{
invocation_response::Response as InvResponse,
invocation_response_success::Result as InvSuccessResult, CffiValueHolder,
InvocationResponse,
},
raw_objects::{Audio, Collector, HTTPRequest, Image, Pdf, TypeBuilder, Video},
stream::StreamingCall,
};
pub struct BamlRuntime {
ptr: *const c_void,
}
#[allow(unsafe_code)]
unsafe impl Send for BamlRuntime {}
#[allow(unsafe_code)]
unsafe impl Sync for BamlRuntime {}
pub type StaticRuntimeType = once_cell::sync::Lazy<BamlRuntime>;
impl BamlRuntime {
pub fn new(
baml_src_dir: &str,
files: &HashMap<String, String>,
env: &HashMap<String, String>,
) -> Result<Self, BamlError> {
callbacks::initialize_callbacks()
.map_err(|e| BamlError::internal(format!("Failed to load BAML library: {e}")))?;
let files_json = json_encode_map(files)?;
let env_json = json_encode_map(env)?;
let dir_cstr = CString::new(baml_src_dir)
.map_err(|_| BamlError::internal("invalid baml_src_dir path (contains null byte)"))?;
let files_cstr = CString::new(files_json)
.map_err(|_| BamlError::internal("invalid files json (contains null byte)"))?;
let env_cstr = CString::new(env_json)
.map_err(|_| BamlError::internal("invalid env json (contains null byte)"))?;
#[allow(unsafe_code)]
let ptr = unsafe {
ffi::create_baml_runtime(dir_cstr.as_ptr(), files_cstr.as_ptr(), env_cstr.as_ptr())
.map_err(|e| BamlError::internal(format!("Failed to load BAML library: {e}")))?
};
if ptr.is_null() {
return Err(BamlError::internal("failed to create runtime"));
}
Ok(BamlRuntime { ptr })
}
pub fn call_function<T: BamlDecode>(
&self,
name: &str,
args: &FunctionArgs,
) -> Result<T, BamlError> {
let encoded = args.encode()?;
let name_cstr =
CString::new(name).map_err(|_| BamlError::internal("invalid function name"))?;
let (id, receiver) = callbacks::create_callback();
#[allow(unsafe_code)]
let buf = unsafe {
ffi::call_function_from_c(
self.ptr,
name_cstr.as_ptr(),
encoded.as_ptr().cast::<c_char>(),
encoded.len(),
id,
)
.map_err(|e| {
callbacks::remove_callback(id);
BamlError::internal(format!("Failed to load BAML library: {e}"))
})?
};
ffi::decode_async_response(buf).map_err(|e| {
callbacks::remove_callback(id);
BamlError::internal(e)
})?;
let _cancel_guard = args.cancellation_token.as_ref().map(|token| {
token.on_cancel(move || {
#[allow(unsafe_code)]
unsafe {
let _ = ffi::cancel_function_call(id);
}
})
});
let result = receiver.recv();
match result {
Ok(callbacks::CallbackResult::Final(data)) => {
let holder = CffiValueHolder::decode(&data[..])
.map_err(|e| BamlError::internal(format!("decode error: {e}")))?;
T::baml_decode(&holder)
}
Ok(callbacks::CallbackResult::Partial(_)) => Err(BamlError::internal(
"unexpected partial result in sync call",
)),
Ok(callbacks::CallbackResult::Error(e)) => Err(e),
Err(_) => Err(BamlError::internal("callback channel closed")),
}
}
pub fn call_function_stream<TPartial, TFinal>(
&self,
name: &str,
args: &FunctionArgs,
) -> Result<StreamingCall<TPartial, TFinal>, BamlError>
where
TPartial: BamlDecode + Send + 'static,
TFinal: Clone + BamlDecode + Send + 'static,
{
let on_tick_data = args.on_tick.as_ref().map(|cb| {
let collector = self.new_collector("on-tick-collector");
let data = callbacks::OnTickData {
callback: cb.clone(),
collector: collector.clone(),
};
(data, collector)
});
let extra_collector = on_tick_data.as_ref().map(|(_, c)| c);
let encoded = args.encode_with_extra_collector(extra_collector)?;
let name_cstr =
CString::new(name).map_err(|_| BamlError::internal("invalid function name"))?;
let (id, receiver) =
callbacks::create_callback_with_on_tick(on_tick_data.map(|(d, _)| d));
#[allow(unsafe_code)]
let buf = unsafe {
ffi::call_function_stream_from_c(
self.ptr,
name_cstr.as_ptr(),
encoded.as_ptr().cast::<c_char>(),
encoded.len(),
id,
)
.map_err(|e| {
callbacks::remove_callback(id);
BamlError::internal(format!("Failed to load BAML library: {e}"))
})?
};
ffi::decode_async_response(buf).map_err(|e| {
callbacks::remove_callback(id);
BamlError::internal(e)
})?;
let cancel_guard = args.cancellation_token.as_ref().map(|token| {
token.on_cancel(move || {
#[allow(unsafe_code)]
unsafe {
let _ = ffi::cancel_function_call(id);
}
})
});
Ok(StreamingCall::new(id, receiver, cancel_guard))
}
pub async fn call_function_async<T: BamlDecode>(
&self,
name: &str,
args: &FunctionArgs,
) -> Result<T, BamlError> {
let encoded = args.encode()?;
let name_cstr =
CString::new(name).map_err(|_| BamlError::internal("invalid function name"))?;
let (id, receiver) = callbacks::create_async_callback();
#[allow(unsafe_code)]
let buf = unsafe {
ffi::call_function_from_c(
self.ptr,
name_cstr.as_ptr(),
encoded.as_ptr().cast::<c_char>(),
encoded.len(),
id,
)
.map_err(|e| {
callbacks::remove_callback(id);
BamlError::internal(format!("Failed to load BAML library: {e}"))
})?
};
ffi::decode_async_response(buf).map_err(|e| {
callbacks::remove_callback(id);
BamlError::internal(e)
})?;
let _cancel_guard = args.cancellation_token.as_ref().map(|token| {
token.on_cancel(move || {
#[allow(unsafe_code)]
unsafe {
let _ = ffi::cancel_function_call(id);
}
})
});
match receiver.recv().await {
Ok(callbacks::CallbackResult::Final(data)) => {
let holder = CffiValueHolder::decode(&data[..])
.map_err(|e| BamlError::internal(format!("decode error: {e}")))?;
T::baml_decode(&holder)
}
Ok(callbacks::CallbackResult::Partial(_)) => Err(BamlError::internal(
"unexpected partial result in async call",
)),
Ok(callbacks::CallbackResult::Error(e)) => Err(e),
Err(_) => Err(BamlError::internal("callback channel closed")),
}
}
pub fn call_function_stream_async<TPartial, TFinal>(
&self,
name: &str,
args: &FunctionArgs,
) -> Result<AsyncStreamingCall<TPartial, TFinal>, BamlError>
where
TPartial: BamlDecode + Send + 'static,
TFinal: Clone + BamlDecode + Send + 'static,
{
let on_tick_data = args.on_tick.as_ref().map(|cb| {
let collector = self.new_collector("on-tick-collector");
let data = callbacks::OnTickData {
callback: cb.clone(),
collector: collector.clone(),
};
(data, collector)
});
let extra_collector = on_tick_data.as_ref().map(|(_, c)| c);
let encoded = args.encode_with_extra_collector(extra_collector)?;
let name_cstr =
CString::new(name).map_err(|_| BamlError::internal("invalid function name"))?;
let (id, receiver) =
callbacks::create_async_callback_with_on_tick(on_tick_data.map(|(d, _)| d));
#[allow(unsafe_code)]
let buf = unsafe {
ffi::call_function_stream_from_c(
self.ptr,
name_cstr.as_ptr(),
encoded.as_ptr().cast::<c_char>(),
encoded.len(),
id,
)
.map_err(|e| {
callbacks::remove_callback(id);
BamlError::internal(format!("Failed to load BAML library: {e}"))
})?
};
ffi::decode_async_response(buf).map_err(|e| {
callbacks::remove_callback(id);
BamlError::internal(e)
})?;
let cancel_guard = args.cancellation_token.as_ref().map(|token| {
token.on_cancel(move || {
#[allow(unsafe_code)]
unsafe {
let _ = ffi::cancel_function_call(id);
}
})
});
Ok(AsyncStreamingCall::new(id, receiver, cancel_guard))
}
pub fn parse<T: BamlDecode>(
&self,
function_name: &str,
llm_response: &str,
stream: bool,
) -> Result<T, BamlError> {
let args = FunctionArgs::new().arg("text", llm_response);
let args = if stream {
args.arg("stream", true)
} else {
args
};
let encoded = args.encode()?;
let name_cstr = CString::new(function_name)
.map_err(|_| BamlError::internal("invalid function name"))?;
let (id, receiver) = callbacks::create_callback();
#[allow(unsafe_code)]
let buf = unsafe {
ffi::call_function_parse_from_c(
self.ptr,
name_cstr.as_ptr(),
encoded.as_ptr().cast::<c_char>(),
encoded.len(),
id,
)
.map_err(|e| {
callbacks::remove_callback(id);
BamlError::internal(format!("Failed to load BAML library: {e}"))
})?
};
ffi::decode_async_response(buf).map_err(|e| {
callbacks::remove_callback(id);
BamlError::internal(format!("function parse error: {e}"))
})?;
match receiver.recv() {
Ok(callbacks::CallbackResult::Final(data)) => {
if stream {
Err(BamlError::internal("unexpected final result in parse call"))
} else {
let holder = CffiValueHolder::decode(&data[..])
.map_err(|e| BamlError::internal(format!("decode error: {e}")))?;
T::baml_decode(&holder)
}
}
Ok(callbacks::CallbackResult::Partial(data)) => {
if stream {
let holder = CffiValueHolder::decode(&data[..])
.map_err(|e| BamlError::internal(format!("decode error: {e}")))?;
T::baml_decode(&holder)
} else {
Err(BamlError::internal(
"unexpected partial result in parse call",
))
}
}
Ok(callbacks::CallbackResult::Error(e)) => Err(e),
Err(_) => Err(BamlError::internal("callback channel closed")),
}
}
pub fn build_request(&self, name: &str, args: &FunctionArgs) -> Result<HTTPRequest, BamlError> {
self.build_request_inner(name, args)
}
pub fn build_request_stream(
&self,
name: &str,
args: &FunctionArgs,
) -> Result<HTTPRequest, BamlError> {
self.build_request_inner(name, args)
}
pub async fn build_request_async(
&self,
name: &str,
args: &FunctionArgs,
) -> Result<HTTPRequest, BamlError> {
self.build_request_inner_async(name, args).await
}
pub async fn build_request_stream_async(
&self,
name: &str,
args: &FunctionArgs,
) -> Result<HTTPRequest, BamlError> {
self.build_request_inner_async(name, args).await
}
fn build_request_inner(
&self,
name: &str,
args: &FunctionArgs,
) -> Result<HTTPRequest, BamlError> {
let encoded = args.encode()?;
let name_cstr =
CString::new(name).map_err(|_| BamlError::internal("invalid function name"))?;
let (id, receiver) = callbacks::create_callback();
#[allow(unsafe_code)]
let buf = unsafe {
ffi::build_request_from_c(
self.ptr,
name_cstr.as_ptr(),
encoded.as_ptr().cast::<c_char>(),
encoded.len(),
id,
)
.map_err(|e| {
callbacks::remove_callback(id);
BamlError::internal(format!("Failed to load BAML library: {e}"))
})?
};
ffi::decode_async_response(buf).map_err(|e| {
callbacks::remove_callback(id);
BamlError::internal(e)
})?;
match receiver.recv() {
Ok(callbacks::CallbackResult::Final(data)) => {
Self::decode_http_request_from_invocation_response(&data, self.ptr)
}
Ok(callbacks::CallbackResult::Partial(_)) => Err(BamlError::internal(
"unexpected partial result in build_request call",
)),
Ok(callbacks::CallbackResult::Error(e)) => Err(e),
Err(_) => Err(BamlError::internal("callback channel closed")),
}
}
async fn build_request_inner_async(
&self,
name: &str,
args: &FunctionArgs,
) -> Result<HTTPRequest, BamlError> {
let encoded = args.encode()?;
let name_cstr =
CString::new(name).map_err(|_| BamlError::internal("invalid function name"))?;
let (id, receiver) = callbacks::create_async_callback();
#[allow(unsafe_code)]
let buf = unsafe {
ffi::build_request_from_c(
self.ptr,
name_cstr.as_ptr(),
encoded.as_ptr().cast::<c_char>(),
encoded.len(),
id,
)
.map_err(|e| {
callbacks::remove_callback(id);
BamlError::internal(format!("Failed to load BAML library: {e}"))
})?
};
ffi::decode_async_response(buf).map_err(|e| {
callbacks::remove_callback(id);
BamlError::internal(e)
})?;
match receiver.recv().await {
Ok(callbacks::CallbackResult::Final(data)) => {
Self::decode_http_request_from_invocation_response(&data, self.ptr)
}
Ok(callbacks::CallbackResult::Partial(_)) => Err(BamlError::internal(
"unexpected partial result in build_request async call",
)),
Ok(callbacks::CallbackResult::Error(e)) => Err(e),
Err(_) => Err(BamlError::internal("callback channel closed")),
}
}
fn decode_http_request_from_invocation_response(
data: &[u8],
runtime_ptr: *const c_void,
) -> Result<HTTPRequest, BamlError> {
let response = InvocationResponse::decode(data)
.map_err(|e| BamlError::internal(format!("decode InvocationResponse error: {e}")))?;
match response.response {
Some(InvResponse::Success(success)) => match success.result {
Some(InvSuccessResult::Object(handle)) => {
HTTPRequest::decode_handle(handle, runtime_ptr)
}
other => Err(BamlError::internal(format!(
"expected object handle in InvocationResponse, got: {other:?}"
))),
},
Some(InvResponse::Error(msg)) => Err(BamlError::internal(msg)),
None => Err(BamlError::internal(
"empty response in InvocationResponse for build_request",
)),
}
}
pub fn new_image_from_url(&self, url: &str, mime_type: Option<&str>) -> Image {
Image::from_url(self.ptr, url, mime_type)
}
pub fn new_image_from_base64(&self, base64: &str, mime_type: Option<&str>) -> Image {
Image::from_base64(self.ptr, base64, mime_type)
}
pub fn new_audio_from_url(&self, url: &str, mime_type: Option<&str>) -> Audio {
Audio::from_url(self.ptr, url, mime_type)
}
pub fn new_audio_from_base64(&self, base64: &str, mime_type: Option<&str>) -> Audio {
Audio::from_base64(self.ptr, base64, mime_type)
}
pub fn new_pdf_from_url(&self, url: &str, mime_type: Option<&str>) -> Pdf {
Pdf::from_url(self.ptr, url, mime_type)
}
pub fn new_pdf_from_base64(&self, base64: &str, mime_type: Option<&str>) -> Pdf {
Pdf::from_base64(self.ptr, base64, mime_type)
}
pub fn new_video_from_url(&self, url: &str, mime_type: Option<&str>) -> Video {
Video::from_url(self.ptr, url, mime_type)
}
pub fn new_video_from_base64(&self, base64: &str, mime_type: Option<&str>) -> Video {
Video::from_base64(self.ptr, base64, mime_type)
}
pub fn new_collector(&self, name: &str) -> Collector {
Collector::new(self.ptr, name)
}
pub fn new_type_builder(&self) -> TypeBuilder {
TypeBuilder::new(self.ptr)
}
}
impl Drop for BamlRuntime {
fn drop(&mut self) {
#[allow(unsafe_code)]
let _ = unsafe { ffi::destroy_baml_runtime(self.ptr) };
}
}
fn json_encode_map(map: &HashMap<String, String>) -> Result<String, BamlError> {
serde_json::to_string(map)
.map_err(|e| BamlError::internal(format!("failed to encode map: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_json_encode_empty_map() {
let map: HashMap<String, String> = HashMap::new();
let result = json_encode_map(&map).unwrap();
assert_eq!(result, "{}");
}
#[test]
fn test_json_encode_simple_map() {
let mut map = HashMap::new();
map.insert("key".to_string(), "value".to_string());
let result = json_encode_map(&map).unwrap();
assert_eq!(result, "{\"key\":\"value\"}");
}
}