use std::ffi::CString;
use std::os::raw::c_char;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use libloading::{Library, Symbol};
use serde_json::Value;
use crate::configuration::Configuration;
use crate::error::{FoundryLocalError, Result};
fn checked_i32_length(name: &str, len: usize) -> Result<i32> {
i32::try_from(len).map_err(|_| FoundryLocalError::CommandExecution {
reason: format!("{name} length {len} exceeds i32::MAX"),
})
}
#[repr(C)]
struct RequestBuffer {
command: *const c_char,
command_length: i32,
data: *const c_char,
data_length: i32,
}
#[repr(C)]
struct ResponseBuffer {
data: *mut u8,
data_length: u32,
error: *mut u8,
error_length: u32,
}
impl ResponseBuffer {
fn new() -> Self {
Self {
data: std::ptr::null_mut(),
data_length: 0,
error: std::ptr::null_mut(),
error_length: 0,
}
}
}
#[repr(C)]
struct StreamingRequestBuffer {
command: *const c_char,
command_length: i32,
data: *const c_char,
data_length: i32,
binary_data: *const u8,
binary_data_length: i32,
}
type ExecuteCommandFn = unsafe extern "C" fn(*const RequestBuffer, *mut ResponseBuffer);
type CallbackFn = unsafe extern "C" fn(*const u8, i32, *mut std::ffi::c_void) -> i32;
type ExecuteCommandWithCallbackFn = unsafe extern "C" fn(
*const RequestBuffer,
*mut ResponseBuffer,
CallbackFn,
*mut std::ffi::c_void,
);
type ExecuteCommandWithBinaryFn =
unsafe extern "C" fn(*const StreamingRequestBuffer, *mut ResponseBuffer);
#[cfg(target_os = "windows")]
const LIB_EXTENSION: &str = "dll";
#[cfg(target_os = "macos")]
const LIB_EXTENSION: &str = "dylib";
#[cfg(target_os = "linux")]
const LIB_EXTENSION: &str = "so";
unsafe fn free_native_buffer(ptr: *mut u8) {
if ptr.is_null() {
return;
}
#[cfg(unix)]
{
extern "C" {
fn free(ptr: *mut std::ffi::c_void);
}
free(ptr as *mut std::ffi::c_void);
}
#[cfg(windows)]
{
extern "system" {
fn LocalFree(hMem: *mut std::ffi::c_void) -> *mut std::ffi::c_void;
}
LocalFree(ptr as *mut std::ffi::c_void);
}
}
struct StreamingCallbackState<'a> {
callback: &'a mut dyn FnMut(&str),
buf: Vec<u8>,
cancel_flag: Option<Arc<AtomicBool>>,
cancelled_observed: bool,
}
impl<'a> StreamingCallbackState<'a> {
fn new(callback: &'a mut dyn FnMut(&str)) -> Self {
Self {
callback,
buf: Vec::new(),
cancel_flag: None,
cancelled_observed: false,
}
}
fn new_cancellable(callback: &'a mut dyn FnMut(&str), cancel_flag: Arc<AtomicBool>) -> Self {
Self {
callback,
buf: Vec::new(),
cancel_flag: Some(cancel_flag),
cancelled_observed: false,
}
}
fn mark_cancelled_if_requested(&mut self) -> bool {
let cancelled = self
.cancel_flag
.as_ref()
.is_some_and(|f| f.load(Ordering::Relaxed));
if cancelled {
self.cancelled_observed = true;
}
cancelled
}
fn cancellation_observed(&self) -> bool {
self.cancelled_observed
}
fn push(&mut self, bytes: &[u8]) {
self.buf.extend_from_slice(bytes);
loop {
match std::str::from_utf8(&self.buf) {
Ok(s) => {
if !s.is_empty() {
(self.callback)(s);
}
self.buf.clear();
break;
}
Err(e) => {
let n = e.valid_up_to();
if n > 0 {
let valid = unsafe { std::str::from_utf8_unchecked(&self.buf[..n]) };
(self.callback)(valid);
}
match e.error_len() {
Some(err_len) => {
self.buf.drain(..n + err_len);
}
None => {
self.buf.drain(..n);
break;
}
}
}
}
}
}
fn flush(&mut self) {
if self.cancelled_observed {
self.buf.clear();
return;
}
if !self.buf.is_empty() {
let text = String::from_utf8_lossy(&self.buf).into_owned();
(self.callback)(&text);
self.buf.clear();
}
}
}
unsafe extern "C" fn streaming_trampoline(
data: *const u8,
length: i32,
user_data: *mut std::ffi::c_void,
) -> i32 {
if data.is_null() || length <= 0 {
return 0;
}
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let state = &mut *(user_data as *mut StreamingCallbackState<'_>);
if state.mark_cancelled_if_requested() {
return 1; }
let slice = std::slice::from_raw_parts(data, length as usize);
state.push(slice);
0 }));
result.unwrap_or(1)
}
pub(crate) struct CoreInterop {
_library: Library,
#[cfg(target_os = "windows")]
_dependency_libs: Vec<Library>,
execute_command: unsafe extern "C" fn(*const RequestBuffer, *mut ResponseBuffer),
execute_command_with_callback: unsafe extern "C" fn(
*const RequestBuffer,
*mut ResponseBuffer,
CallbackFn,
*mut std::ffi::c_void,
),
execute_command_with_binary:
Option<unsafe extern "C" fn(*const StreamingRequestBuffer, *mut ResponseBuffer)>,
}
impl std::fmt::Debug for CoreInterop {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CoreInterop").finish_non_exhaustive()
}
}
impl CoreInterop {
pub fn new(config: &mut Configuration) -> Result<Self> {
let lib_path = Self::resolve_library_path(config)?;
#[cfg(target_os = "windows")]
let _dependency_libs = Self::load_windows_dependencies(&lib_path)?;
let library = unsafe {
Library::new(&lib_path).map_err(|e| FoundryLocalError::LibraryLoad {
reason: format!(
"Failed to load native library at {}: {e}",
lib_path.display()
),
})?
};
let execute_command: ExecuteCommandFn = unsafe {
let sym: Symbol<ExecuteCommandFn> =
library
.get(b"execute_command\0")
.map_err(|e| FoundryLocalError::LibraryLoad {
reason: format!("Symbol 'execute_command' not found: {e}"),
})?;
*sym
};
let execute_command_with_callback: ExecuteCommandWithCallbackFn = unsafe {
let sym: Symbol<ExecuteCommandWithCallbackFn> = library
.get(b"execute_command_with_callback\0")
.map_err(|e| FoundryLocalError::LibraryLoad {
reason: format!("Symbol 'execute_command_with_callback' not found: {e}"),
})?;
*sym
};
let execute_command_with_binary: Option<ExecuteCommandWithBinaryFn> = unsafe {
library
.get::<ExecuteCommandWithBinaryFn>(b"execute_command_with_binary\0")
.ok()
.map(|sym| *sym)
};
Ok(Self {
_library: library,
#[cfg(target_os = "windows")]
_dependency_libs,
execute_command,
execute_command_with_callback,
execute_command_with_binary,
})
}
pub fn execute_command(&self, command: &str, params: Option<&Value>) -> Result<String> {
let cmd = CString::new(command).map_err(|e| FoundryLocalError::CommandExecution {
reason: format!("Invalid command string: {e}"),
})?;
let data_json = match params {
Some(v) => serde_json::to_string(v)?,
None => String::new(),
};
let data_cstr =
CString::new(data_json.as_str()).map_err(|e| FoundryLocalError::CommandExecution {
reason: format!("Invalid data string: {e}"),
})?;
let request = RequestBuffer {
command: cmd.as_ptr(),
command_length: checked_i32_length("command", cmd.as_bytes().len())?,
data: data_cstr.as_ptr(),
data_length: checked_i32_length("data", data_cstr.as_bytes().len())?,
};
let mut response = ResponseBuffer::new();
unsafe {
(self.execute_command)(&request, &mut response);
}
Self::process_response(response)
}
pub fn execute_command_with_binary(
&self,
command: &str,
params: Option<&Value>,
binary_data: &[u8],
) -> Result<String> {
let native_fn = self.execute_command_with_binary.ok_or_else(|| {
FoundryLocalError::CommandExecution {
reason: "execute_command_with_binary is not supported by this native core \
(symbol not found)"
.into(),
}
})?;
let cmd = CString::new(command).map_err(|e| FoundryLocalError::CommandExecution {
reason: format!("Invalid command string: {e}"),
})?;
let data_json = match params {
Some(v) => serde_json::to_string(v)?,
None => String::new(),
};
let data_cstr =
CString::new(data_json.as_str()).map_err(|e| FoundryLocalError::CommandExecution {
reason: format!("Invalid data string: {e}"),
})?;
let request = StreamingRequestBuffer {
command: cmd.as_ptr(),
command_length: checked_i32_length("command", cmd.as_bytes().len())?,
data: data_cstr.as_ptr(),
data_length: checked_i32_length("data", data_cstr.as_bytes().len())?,
binary_data: if binary_data.is_empty() {
std::ptr::null()
} else {
binary_data.as_ptr()
},
binary_data_length: checked_i32_length("binary data", binary_data.len())?,
};
let mut response = ResponseBuffer::new();
unsafe {
(native_fn)(&request, &mut response);
}
Self::process_response(response)
}
pub fn execute_command_streaming<F>(
&self,
command: &str,
params: Option<&Value>,
mut callback: F,
) -> Result<String>
where
F: FnMut(&str),
{
self.execute_command_streaming_impl(command, params, &mut callback, None)
}
pub fn execute_command_streaming_cancellable<F>(
&self,
command: &str,
params: Option<&Value>,
mut callback: F,
cancel_flag: Arc<AtomicBool>,
) -> Result<String>
where
F: FnMut(&str),
{
self.execute_command_streaming_impl(command, params, &mut callback, Some(cancel_flag))
}
fn execute_command_streaming_impl(
&self,
command: &str,
params: Option<&Value>,
callback: &mut dyn FnMut(&str),
cancel_flag: Option<Arc<AtomicBool>>,
) -> Result<String> {
let cmd = CString::new(command).map_err(|e| FoundryLocalError::CommandExecution {
reason: format!("Invalid command string: {e}"),
})?;
let data_json = match params {
Some(v) => serde_json::to_string(v)?,
None => String::new(),
};
let data_cstr =
CString::new(data_json.as_str()).map_err(|e| FoundryLocalError::CommandExecution {
reason: format!("Invalid data string: {e}"),
})?;
let request = RequestBuffer {
command: cmd.as_ptr(),
command_length: checked_i32_length("command", cmd.as_bytes().len())?,
data: data_cstr.as_ptr(),
data_length: checked_i32_length("data", data_cstr.as_bytes().len())?,
};
let mut response = ResponseBuffer::new();
let mut state = match cancel_flag {
Some(flag) => StreamingCallbackState::new_cancellable(callback, flag),
None => StreamingCallbackState::new(callback),
};
let user_data = &mut state as *mut StreamingCallbackState<'_> as *mut std::ffi::c_void;
unsafe {
(self.execute_command_with_callback)(
&request,
&mut response,
streaming_trampoline,
user_data,
);
}
let cancelled = state.cancellation_observed();
state.flush();
if cancelled {
Self::process_response(response).ok();
return Err(FoundryLocalError::CommandExecution {
reason: "Operation cancelled".to_string(),
});
}
Self::process_response(response)
}
pub async fn execute_command_async(
self: &Arc<Self>,
command: String,
params: Option<Value>,
) -> Result<String> {
let this = Arc::clone(self);
tokio::task::spawn_blocking(move || this.execute_command(&command, params.as_ref()))
.await
.map_err(|e| FoundryLocalError::CommandExecution {
reason: format!("task join error: {e}"),
})?
}
pub async fn execute_command_streaming_async<F>(
self: &Arc<Self>,
command: String,
params: Option<Value>,
callback: F,
) -> Result<String>
where
F: FnMut(&str) + Send + 'static,
{
let this = Arc::clone(self);
tokio::task::spawn_blocking(move || {
this.execute_command_streaming(&command, params.as_ref(), callback)
})
.await
.map_err(|e| FoundryLocalError::CommandExecution {
reason: format!("task join error: {e}"),
})?
}
pub async fn execute_command_streaming_cancellable_async<F>(
self: &Arc<Self>,
command: String,
params: Option<Value>,
callback: F,
cancel_flag: Arc<AtomicBool>,
) -> Result<String>
where
F: FnMut(&str) + Send + 'static,
{
let this = Arc::clone(self);
tokio::task::spawn_blocking(move || {
this.execute_command_streaming_cancellable(
&command,
params.as_ref(),
callback,
cancel_flag,
)
})
.await
.map_err(|e| FoundryLocalError::CommandExecution {
reason: format!("task join error: {e}"),
})?
}
pub async fn execute_command_streaming_channel(
self: &Arc<Self>,
command: String,
params: Option<Value>,
) -> Result<tokio::sync::mpsc::UnboundedReceiver<Result<String>>> {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<Result<String>>();
let this = Arc::clone(self);
tokio::task::spawn_blocking(move || {
let tx_chunk = tx.clone();
let result =
this.execute_command_streaming(&command, params.as_ref(), move |chunk: &str| {
let _ = tx_chunk.send(Ok(chunk.to_owned()));
});
match result {
Ok(_final_payload) => {
}
Err(e) => {
let _ = tx.send(Err(e));
}
}
});
Ok(rx)
}
unsafe fn read_native_buffer(ptr: *mut u8, len: u32) -> Option<String> {
if ptr.is_null() || len == 0 {
return None;
}
let slice = std::slice::from_raw_parts(ptr, len as usize);
Some(String::from_utf8_lossy(slice).into_owned())
}
fn process_response(response: ResponseBuffer) -> Result<String> {
let error_str = unsafe { Self::read_native_buffer(response.error, response.error_length) };
let data_str = unsafe { Self::read_native_buffer(response.data, response.data_length) };
unsafe {
free_native_buffer(response.data);
free_native_buffer(response.error);
}
if let Some(err) = error_str {
Err(FoundryLocalError::CommandExecution { reason: err })
} else {
Ok(data_str.unwrap_or_default())
}
}
fn resolve_library_path(config: &Configuration) -> Result<PathBuf> {
let lib_name = format!("Microsoft.AI.Foundry.Local.Core.{LIB_EXTENSION}");
if let Some(dir) = config.params.get("FoundryLocalCorePath") {
let p = Path::new(dir).join(&lib_name);
if p.exists() {
return Ok(p);
}
let p = Path::new(dir);
if p.exists() && p.is_file() {
return Ok(p.to_path_buf());
}
}
if let Some(dir) = option_env!("FOUNDRY_NATIVE_DIR") {
let p = Path::new(dir).join(&lib_name);
if p.exists() {
return Ok(p);
}
}
if let Ok(exe) = std::env::current_exe() {
if let Some(dir) = exe.parent() {
let p = dir.join(&lib_name);
if p.exists() {
return Ok(p);
}
}
}
Err(FoundryLocalError::LibraryLoad {
reason: format!(
"Could not locate native library '{lib_name}'. \
Set the FoundryLocalCorePath config option."
),
})
}
#[cfg(target_os = "windows")]
fn load_windows_dependencies(core_lib_path: &Path) -> Result<Vec<Library>> {
let dir = core_lib_path.parent().unwrap_or_else(|| Path::new("."));
let mut libs = Vec::new();
for dep in &["onnxruntime.dll", "onnxruntime-genai.dll"] {
let dep_path = dir.join(dep);
if dep_path.exists() {
let lib = unsafe {
Library::new(&dep_path).map_err(|e| FoundryLocalError::LibraryLoad {
reason: format!("Failed to load dependency {dep}: {e}"),
})?
};
libs.push(lib);
}
}
#[cfg(feature = "winml")]
{
let winml_dep = "Microsoft.Windows.AI.MachineLearning.dll";
let winml_path = dir.join(winml_dep);
if winml_path.exists() {
let lib = unsafe {
Library::new(&winml_path).map_err(|e| FoundryLocalError::LibraryLoad {
reason: format!("Failed to load dependency {winml_dep}: {e}"),
})?
};
libs.push(lib);
}
}
Ok(libs)
}
}
#[cfg(test)]
mod tests {
use super::{checked_i32_length, StreamingCallbackState};
use crate::error::FoundryLocalError;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
#[test]
fn cancellation_request_after_callback_is_not_observed_until_next_callback() {
let cancel_flag = Arc::new(AtomicBool::new(false));
let mut callback = |_chunk: &str| {};
let mut state =
StreamingCallbackState::new_cancellable(&mut callback, Arc::clone(&cancel_flag));
state.push(b"100");
cancel_flag.store(true, Ordering::Relaxed);
assert!(!state.cancellation_observed());
}
#[test]
fn cancellation_is_recorded_when_callback_observes_cancel_flag() {
let cancel_flag = Arc::new(AtomicBool::new(true));
let mut callback = |_chunk: &str| {};
let mut state = StreamingCallbackState::new_cancellable(&mut callback, cancel_flag);
assert!(state.mark_cancelled_if_requested());
assert!(state.cancellation_observed());
}
#[test]
fn flush_drops_buffer_after_cancellation_without_callback() {
let cancel_flag = Arc::new(AtomicBool::new(true));
let mut chunks = Vec::new();
{
let mut callback = |chunk: &str| chunks.push(chunk.to_owned());
let mut state = StreamingCallbackState::new_cancellable(&mut callback, cancel_flag);
state.push(&[0xE2]);
assert!(state.mark_cancelled_if_requested());
state.flush();
}
assert!(chunks.is_empty());
}
#[test]
fn checked_i32_length_rejects_too_large_values() {
assert_eq!(
checked_i32_length("data", i32::MAX as usize).unwrap(),
i32::MAX
);
match checked_i32_length("data", i32::MAX as usize + 1).unwrap_err() {
FoundryLocalError::CommandExecution { reason } => {
assert!(reason.contains("exceeds i32::MAX"));
}
err => panic!("unexpected error: {err:?}"),
}
}
}