foundation_models/session/
mod.rs1use core::ffi::{c_char, c_void};
4use core::ptr;
5use std::ffi::CString;
6use std::sync::mpsc;
7use std::sync::{Arc, Mutex};
8
9use crate::error::FMError;
10use crate::ffi;
11use crate::generation::GenerationOptions;
12
13pub struct LanguageModelSession {
31 ptr: *mut c_void,
32}
33
34unsafe impl Send for LanguageModelSession {}
40unsafe impl Sync for LanguageModelSession {}
41
42impl LanguageModelSession {
43 #[must_use]
51 pub fn new() -> Self {
52 Self::try_new(None).expect("FoundationModels is not available on this OS")
53 }
54
55 #[must_use]
62 pub fn with_instructions(instructions: &str) -> Self {
63 Self::try_new(Some(instructions)).expect("FoundationModels is not available on this OS")
64 }
65
66 #[must_use]
70 pub fn try_new(instructions: Option<&str>) -> Option<Self> {
71 let cstring = match instructions {
72 Some(s) => Some(CString::new(s).ok()?),
73 None => None,
74 };
75 let ptr =
76 unsafe { ffi::fm_session_create(cstring.as_ref().map_or(ptr::null(), |s| s.as_ptr())) };
77 if ptr.is_null() {
78 return None;
79 }
80 Some(Self { ptr })
81 }
82
83 pub fn respond(&self, prompt: &str) -> Result<String, FMError> {
91 self.respond_with(prompt, GenerationOptions::new())
92 }
93
94 pub fn respond_with(
100 &self,
101 prompt: &str,
102 options: GenerationOptions,
103 ) -> Result<String, FMError> {
104 let prompt_c = CString::new(prompt)
105 .map_err(|e| FMError::InvalidArgument(format!("prompt contains NUL byte: {e}")))?;
106 let opts = options.to_ffi();
107 let (tx, rx) = mpsc::channel();
108 let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
109 let context = Box::into_raw(tx_box).cast::<c_void>();
110
111 unsafe {
112 ffi::fm_session_respond(
113 self.ptr,
114 prompt_c.as_ptr(),
115 opts.temperature,
116 opts.maximum_response_tokens,
117 opts.sampling_mode,
118 opts.top_k,
119 opts.top_p,
120 context,
121 respond_trampoline,
122 );
123 }
124
125 rx.recv().map_err(|_| FMError::Unknown {
128 code: ffi::status::UNKNOWN,
129 message: "Swift bridge dropped the callback channel".into(),
130 })?
131 }
132
133 pub fn stream<F>(&self, prompt: &str, mut on_chunk: F) -> Result<(), FMError>
142 where
143 F: FnMut(StreamEvent<'_>) + Send + 'static,
144 {
145 self.stream_with(prompt, GenerationOptions::new(), move |event| {
146 on_chunk(event);
147 })
148 }
149
150 pub fn stream_with<F>(
156 &self,
157 prompt: &str,
158 options: GenerationOptions,
159 on_chunk: F,
160 ) -> Result<(), FMError>
161 where
162 F: FnMut(StreamEvent<'_>) + Send + 'static,
163 {
164 let prompt_c = CString::new(prompt)
165 .map_err(|e| FMError::InvalidArgument(format!("prompt contains NUL byte: {e}")))?;
166 let opts = options.to_ffi();
167
168 let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
172 let state = Arc::new(StreamState {
173 on_chunk: Mutex::new(Box::new(on_chunk)),
174 done_tx: Mutex::new(Some(done_tx)),
175 });
176 let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
177
178 unsafe {
179 ffi::fm_session_stream_response(
180 self.ptr,
181 prompt_c.as_ptr(),
182 opts.temperature,
183 opts.maximum_response_tokens,
184 opts.sampling_mode,
185 opts.top_k,
186 opts.top_p,
187 context,
188 stream_trampoline,
189 );
190 }
191
192 done_rx.recv().map_err(|_| FMError::Unknown {
193 code: ffi::status::UNKNOWN,
194 message: "Swift bridge dropped the stream channel".into(),
195 })?
196 }
197}
198
199impl Default for LanguageModelSession {
200 fn default() -> Self {
201 Self::new()
202 }
203}
204
205impl Drop for LanguageModelSession {
206 fn drop(&mut self) {
207 if !self.ptr.is_null() {
208 unsafe { ffi::fm_object_release(self.ptr) };
209 }
210 }
211}
212
213impl core::fmt::Debug for LanguageModelSession {
214 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
215 f.debug_struct("LanguageModelSession")
216 .field("ptr", &self.ptr)
217 .finish()
218 }
219}
220
221#[derive(Debug)]
223#[non_exhaustive]
224pub enum StreamEvent<'a> {
225 Chunk(&'a str),
227 Done,
229 Error(FMError),
231}
232
233unsafe extern "C" fn respond_trampoline(
236 context: *mut c_void,
237 response: *mut c_char,
238 error: *mut c_char,
239 status: i32,
240) {
241 let tx = Box::from_raw(context.cast::<mpsc::Sender<Result<String, FMError>>>());
242 let result = if status == ffi::status::OK && !response.is_null() {
243 let s = core::ffi::CStr::from_ptr(response)
244 .to_string_lossy()
245 .into_owned();
246 ffi::fm_string_free(response);
247 Ok(s)
248 } else {
249 Err(crate::error::from_swift(status, error))
250 };
251 let _ = tx.send(result);
252}
253
254type StreamCallback = Box<dyn FnMut(StreamEvent<'_>) + Send>;
255
256struct StreamState {
257 on_chunk: Mutex<StreamCallback>,
258 done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
259}
260
261unsafe extern "C" fn stream_trampoline(
262 context: *mut c_void,
263 chunk: *mut c_char,
264 done: bool,
265 status: i32,
266) {
267 let state = Arc::from_raw(context.cast::<StreamState>());
268 let state_for_swift = state.clone();
271 core::mem::forget(state_for_swift);
272
273 let chunk_str: Option<String> = if chunk.is_null() {
274 None
275 } else {
276 let s = core::ffi::CStr::from_ptr(chunk)
277 .to_string_lossy()
278 .into_owned();
279 ffi::fm_string_free(chunk);
280 Some(s)
281 };
282
283 if status != ffi::status::OK {
284 let err = crate::error::from_swift(status, ptr::null_mut());
285 let err_for_callback = chunk_str
286 .map(|m| match err.clone() {
287 FMError::Unknown { code, .. } => FMError::Unknown { code, message: m },
288 other => other,
289 })
290 .unwrap_or(err);
291 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
292 cb(StreamEvent::Error(err_for_callback.clone()));
293 drop(cb);
294 let pending_tx = state.done_tx.lock().expect("done_tx mutex poisoned").take();
295 if let Some(tx) = pending_tx {
296 let _ = tx.send(Err(err_for_callback));
297 }
298 drop(Arc::from_raw(Arc::as_ptr(&state)));
300 drop(state);
301 return;
302 }
303
304 if let Some(s) = chunk_str.as_deref() {
305 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
306 cb(StreamEvent::Chunk(s));
307 }
308
309 if done {
310 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
311 cb(StreamEvent::Done);
312 drop(cb);
313 let pending_tx = state.done_tx.lock().expect("done_tx mutex poisoned").take();
314 if let Some(tx) = pending_tx {
315 let _ = tx.send(Ok(()));
316 }
317 drop(Arc::from_raw(Arc::as_ptr(&state)));
319 }
320 drop(state);
321}