use core::ffi::{c_char, c_void};
use core::ptr;
use std::ffi::CString;
use std::sync::mpsc;
use std::sync::{Arc, Mutex};
use crate::error::FMError;
use crate::ffi;
use crate::generation::GenerationOptions;
pub struct LanguageModelSession {
ptr: *mut c_void,
}
unsafe impl Send for LanguageModelSession {}
unsafe impl Sync for LanguageModelSession {}
impl LanguageModelSession {
#[must_use]
pub fn new() -> Self {
Self::try_new(None).expect("FoundationModels is not available on this OS")
}
#[must_use]
pub fn with_instructions(instructions: &str) -> Self {
Self::try_new(Some(instructions)).expect("FoundationModels is not available on this OS")
}
#[must_use]
pub fn try_new(instructions: Option<&str>) -> Option<Self> {
let cstring = match instructions {
Some(s) => Some(CString::new(s).ok()?),
None => None,
};
let ptr =
unsafe { ffi::fm_session_create(cstring.as_ref().map_or(ptr::null(), |s| s.as_ptr())) };
if ptr.is_null() {
return None;
}
Some(Self { ptr })
}
pub fn respond(&self, prompt: &str) -> Result<String, FMError> {
self.respond_with(prompt, GenerationOptions::new())
}
pub fn prewarm(&self) {
unsafe { ffi::fm_session_prewarm(self.ptr) };
}
#[must_use]
pub fn is_responding(&self) -> bool {
unsafe { ffi::fm_session_is_responding(self.ptr) }
}
pub fn respond_with(
&self,
prompt: &str,
options: GenerationOptions,
) -> Result<String, FMError> {
let prompt_c = CString::new(prompt)
.map_err(|e| FMError::InvalidArgument(format!("prompt contains NUL byte: {e}")))?;
let opts = options.to_ffi();
let (tx, rx) = mpsc::channel();
let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
let context = Box::into_raw(tx_box).cast::<c_void>();
unsafe {
ffi::fm_session_respond(
self.ptr,
prompt_c.as_ptr(),
opts.temperature,
opts.maximum_response_tokens,
opts.sampling_mode,
opts.top_k,
opts.top_p,
context,
respond_trampoline,
);
}
rx.recv().map_err(|_| FMError::Unknown {
code: ffi::status::UNKNOWN,
message: "Swift bridge dropped the callback channel".into(),
})?
}
pub fn stream<F>(&self, prompt: &str, mut on_chunk: F) -> Result<(), FMError>
where
F: FnMut(StreamEvent<'_>) + Send + 'static,
{
self.stream_with(prompt, GenerationOptions::new(), move |event| {
on_chunk(event);
})
}
pub fn stream_with<F>(
&self,
prompt: &str,
options: GenerationOptions,
on_chunk: F,
) -> Result<(), FMError>
where
F: FnMut(StreamEvent<'_>) + Send + 'static,
{
let prompt_c = CString::new(prompt)
.map_err(|e| FMError::InvalidArgument(format!("prompt contains NUL byte: {e}")))?;
let opts = options.to_ffi();
let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
let state = Arc::new(StreamState {
on_chunk: Mutex::new(Box::new(on_chunk)),
done_tx: Mutex::new(Some(done_tx)),
});
let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
unsafe {
ffi::fm_session_stream_response(
self.ptr,
prompt_c.as_ptr(),
opts.temperature,
opts.maximum_response_tokens,
opts.sampling_mode,
opts.top_k,
opts.top_p,
context,
stream_trampoline,
);
}
done_rx.recv().map_err(|_| FMError::Unknown {
code: ffi::status::UNKNOWN,
message: "Swift bridge dropped the stream channel".into(),
})?
}
}
impl Default for LanguageModelSession {
fn default() -> Self {
Self::new()
}
}
impl Drop for LanguageModelSession {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe { ffi::fm_object_release(self.ptr) };
}
}
}
impl core::fmt::Debug for LanguageModelSession {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("LanguageModelSession")
.field("ptr", &self.ptr)
.finish()
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum StreamEvent<'a> {
Chunk(&'a str),
Done,
Error(FMError),
}
unsafe extern "C" fn respond_trampoline(
context: *mut c_void,
response: *mut c_char,
error: *mut c_char,
status: i32,
) {
let tx = Box::from_raw(context.cast::<mpsc::Sender<Result<String, FMError>>>());
let result = if status == ffi::status::OK && !response.is_null() {
let s = core::ffi::CStr::from_ptr(response)
.to_string_lossy()
.into_owned();
ffi::fm_string_free(response);
Ok(s)
} else {
Err(crate::error::from_swift(status, error))
};
let _ = tx.send(result);
}
type StreamCallback = Box<dyn FnMut(StreamEvent<'_>) + Send>;
struct StreamState {
on_chunk: Mutex<StreamCallback>,
done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
}
unsafe extern "C" fn stream_trampoline(
context: *mut c_void,
chunk: *mut c_char,
done: bool,
status: i32,
) {
let state = Arc::from_raw(context.cast::<StreamState>());
let state_for_swift = state.clone();
core::mem::forget(state_for_swift);
let chunk_str: Option<String> = if chunk.is_null() {
None
} else {
let s = core::ffi::CStr::from_ptr(chunk)
.to_string_lossy()
.into_owned();
ffi::fm_string_free(chunk);
Some(s)
};
if status != ffi::status::OK {
let err = crate::error::from_swift(status, ptr::null_mut());
let err_for_callback = chunk_str
.map(|m| match err.clone() {
FMError::Unknown { code, .. } => FMError::Unknown { code, message: m },
other => other,
})
.unwrap_or(err);
let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
cb(StreamEvent::Error(err_for_callback.clone()));
drop(cb);
let pending_tx = state.done_tx.lock().expect("done_tx mutex poisoned").take();
if let Some(tx) = pending_tx {
let _ = tx.send(Err(err_for_callback));
}
drop(Arc::from_raw(Arc::as_ptr(&state)));
drop(state);
return;
}
if let Some(s) = chunk_str.as_deref() {
let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
cb(StreamEvent::Chunk(s));
}
if done {
let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
cb(StreamEvent::Done);
drop(cb);
let pending_tx = state.done_tx.lock().expect("done_tx mutex poisoned").take();
if let Some(tx) = pending_tx {
let _ = tx.send(Ok(()));
}
drop(Arc::from_raw(Arc::as_ptr(&state)));
}
drop(state);
}