foundation_models/session/
mod.rs1use core::ffi::{c_char, c_void};
4use core::ptr;
5use std::ffi::CString;
6use std::sync::mpsc;
7use std::sync::{Arc, Mutex};
8
9use crate::error::FMError;
10use crate::ffi;
11use crate::generation::GenerationOptions;
12
13pub struct LanguageModelSession {
31 ptr: *mut c_void,
32}
33
34unsafe impl Send for LanguageModelSession {}
40unsafe impl Sync for LanguageModelSession {}
41
42impl LanguageModelSession {
43 #[must_use]
51 pub fn new() -> Self {
52 Self::try_new(None).expect("FoundationModels is not available on this OS")
53 }
54
55 #[must_use]
62 pub fn with_instructions(instructions: &str) -> Self {
63 Self::try_new(Some(instructions)).expect("FoundationModels is not available on this OS")
64 }
65
66 #[must_use]
70 pub fn try_new(instructions: Option<&str>) -> Option<Self> {
71 let cstring = match instructions {
72 Some(s) => Some(CString::new(s).ok()?),
73 None => None,
74 };
75 let ptr =
76 unsafe { ffi::fm_session_create(cstring.as_ref().map_or(ptr::null(), |s| s.as_ptr())) };
77 if ptr.is_null() {
78 return None;
79 }
80 Some(Self { ptr })
81 }
82
83 pub fn respond(&self, prompt: &str) -> Result<String, FMError> {
91 self.respond_with(prompt, GenerationOptions::new())
92 }
93
94 pub fn prewarm(&self) {
98 unsafe { ffi::fm_session_prewarm(self.ptr) };
99 }
100
101 #[must_use]
104 pub fn is_responding(&self) -> bool {
105 unsafe { ffi::fm_session_is_responding(self.ptr) }
106 }
107
108 pub fn respond_with_json_schema(
122 &self,
123 prompt: &str,
124 schema_description: &str,
125 ) -> Result<String, FMError> {
126 let wrapped = format!(
127 "{prompt}\n\n\
128 IMPORTANT: respond with VALID JSON ONLY (no prose, no markdown \
129 fences) that matches this schema:\n\n{schema_description}\n\n\
130 Your entire response must be parseable by JSON.parse()."
131 );
132 self.respond(&wrapped)
133 }
134
135 pub fn respond_with(
141 &self,
142 prompt: &str,
143 options: GenerationOptions,
144 ) -> Result<String, FMError> {
145 let prompt_c = CString::new(prompt)
146 .map_err(|e| FMError::InvalidArgument(format!("prompt contains NUL byte: {e}")))?;
147 let opts = options.to_ffi();
148 let (tx, rx) = mpsc::channel();
149 let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
150 let context = Box::into_raw(tx_box).cast::<c_void>();
151
152 unsafe {
153 ffi::fm_session_respond(
154 self.ptr,
155 prompt_c.as_ptr(),
156 opts.temperature,
157 opts.maximum_response_tokens,
158 opts.sampling_mode,
159 opts.top_k,
160 opts.top_p,
161 context,
162 respond_trampoline,
163 );
164 }
165
166 rx.recv().map_err(|_| FMError::Unknown {
169 code: ffi::status::UNKNOWN,
170 message: "Swift bridge dropped the callback channel".into(),
171 })?
172 }
173
174 pub fn respond_with_schema(
207 &self,
208 prompt: &str,
209 schema: &str,
210 include_schema_in_prompt: bool,
211 ) -> Result<String, FMError> {
212 self.respond_with_schema_options(prompt, schema, include_schema_in_prompt, GenerationOptions::new())
213 }
214
215 pub fn respond_with_schema_options(
222 &self,
223 prompt: &str,
224 schema: &str,
225 include_schema_in_prompt: bool,
226 options: GenerationOptions,
227 ) -> Result<String, FMError> {
228 let prompt_c = CString::new(prompt)
229 .map_err(|e| FMError::InvalidArgument(format!("prompt NUL byte: {e}")))?;
230 let schema_c = CString::new(schema)
231 .map_err(|e| FMError::InvalidArgument(format!("schema NUL byte: {e}")))?;
232 let opts = options.to_ffi();
233 let (tx, rx) = mpsc::channel();
234 let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
235 let context = Box::into_raw(tx_box).cast::<c_void>();
236
237 unsafe {
238 ffi::fm_session_respond_with_schema(
239 self.ptr,
240 prompt_c.as_ptr(),
241 schema_c.as_ptr(),
242 include_schema_in_prompt,
243 opts.temperature,
244 opts.maximum_response_tokens,
245 opts.sampling_mode,
246 opts.top_k,
247 opts.top_p,
248 context,
249 respond_trampoline,
250 );
251 }
252
253 rx.recv().map_err(|_| FMError::Unknown {
254 code: ffi::status::UNKNOWN,
255 message: "Swift bridge dropped the callback channel".into(),
256 })?
257 }
258
259 pub fn stream<F>(&self, prompt: &str, mut on_chunk: F) -> Result<(), FMError>
268 where
269 F: FnMut(StreamEvent<'_>) + Send + 'static,
270 {
271 self.stream_with(prompt, GenerationOptions::new(), move |event| {
272 on_chunk(event);
273 })
274 }
275
276 pub fn stream_with<F>(
282 &self,
283 prompt: &str,
284 options: GenerationOptions,
285 on_chunk: F,
286 ) -> Result<(), FMError>
287 where
288 F: FnMut(StreamEvent<'_>) + Send + 'static,
289 {
290 let prompt_c = CString::new(prompt)
291 .map_err(|e| FMError::InvalidArgument(format!("prompt contains NUL byte: {e}")))?;
292 let opts = options.to_ffi();
293
294 let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
298 let state = Arc::new(StreamState {
299 on_chunk: Mutex::new(Box::new(on_chunk)),
300 done_tx: Mutex::new(Some(done_tx)),
301 });
302 let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
303
304 unsafe {
305 ffi::fm_session_stream_response(
306 self.ptr,
307 prompt_c.as_ptr(),
308 opts.temperature,
309 opts.maximum_response_tokens,
310 opts.sampling_mode,
311 opts.top_k,
312 opts.top_p,
313 context,
314 stream_trampoline,
315 );
316 }
317
318 done_rx.recv().map_err(|_| FMError::Unknown {
319 code: ffi::status::UNKNOWN,
320 message: "Swift bridge dropped the stream channel".into(),
321 })?
322 }
323}
324
325impl Default for LanguageModelSession {
326 fn default() -> Self {
327 Self::new()
328 }
329}
330
331impl Drop for LanguageModelSession {
332 fn drop(&mut self) {
333 if !self.ptr.is_null() {
334 unsafe { ffi::fm_object_release(self.ptr) };
335 }
336 }
337}
338
339impl core::fmt::Debug for LanguageModelSession {
340 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
341 f.debug_struct("LanguageModelSession")
342 .field("ptr", &self.ptr)
343 .finish()
344 }
345}
346
347#[derive(Debug)]
349#[non_exhaustive]
350pub enum StreamEvent<'a> {
351 Chunk(&'a str),
353 Done,
355 Error(FMError),
357}
358
359unsafe extern "C" fn respond_trampoline(
362 context: *mut c_void,
363 response: *mut c_char,
364 error: *mut c_char,
365 status: i32,
366) {
367 let tx = Box::from_raw(context.cast::<mpsc::Sender<Result<String, FMError>>>());
368 let result = if status == ffi::status::OK && !response.is_null() {
369 let s = core::ffi::CStr::from_ptr(response)
370 .to_string_lossy()
371 .into_owned();
372 ffi::fm_string_free(response);
373 Ok(s)
374 } else {
375 Err(crate::error::from_swift(status, error))
376 };
377 let _ = tx.send(result);
378}
379
380type StreamCallback = Box<dyn FnMut(StreamEvent<'_>) + Send>;
381
382struct StreamState {
383 on_chunk: Mutex<StreamCallback>,
384 done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
385}
386
387unsafe extern "C" fn stream_trampoline(
388 context: *mut c_void,
389 chunk: *mut c_char,
390 done: bool,
391 status: i32,
392) {
393 let state = Arc::from_raw(context.cast::<StreamState>());
394 let state_for_swift = state.clone();
397 core::mem::forget(state_for_swift);
398
399 let chunk_str: Option<String> = if chunk.is_null() {
400 None
401 } else {
402 let s = core::ffi::CStr::from_ptr(chunk)
403 .to_string_lossy()
404 .into_owned();
405 ffi::fm_string_free(chunk);
406 Some(s)
407 };
408
409 if status != ffi::status::OK {
410 let err = crate::error::from_swift(status, ptr::null_mut());
411 let err_for_callback = chunk_str
412 .map(|m| match err.clone() {
413 FMError::Unknown { code, .. } => FMError::Unknown { code, message: m },
414 other => other,
415 })
416 .unwrap_or(err);
417 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
418 cb(StreamEvent::Error(err_for_callback.clone()));
419 drop(cb);
420 let pending_tx = state.done_tx.lock().expect("done_tx mutex poisoned").take();
421 if let Some(tx) = pending_tx {
422 let _ = tx.send(Err(err_for_callback));
423 }
424 drop(Arc::from_raw(Arc::as_ptr(&state)));
426 drop(state);
427 return;
428 }
429
430 if let Some(s) = chunk_str.as_deref() {
431 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
432 cb(StreamEvent::Chunk(s));
433 }
434
435 if done {
436 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
437 cb(StreamEvent::Done);
438 drop(cb);
439 let pending_tx = state.done_tx.lock().expect("done_tx mutex poisoned").take();
440 if let Some(tx) = pending_tx {
441 let _ = tx.send(Ok(()));
442 }
443 drop(Arc::from_raw(Arc::as_ptr(&state)));
445 }
446 drop(state);
447}