foundation_models/session/
mod.rs1use core::ffi::{c_char, c_void};
4use core::ptr;
5use std::ffi::CString;
6use std::sync::mpsc;
7use std::sync::{Arc, Mutex};
8
9use crate::error::FMError;
10use crate::ffi;
11use crate::generation::GenerationOptions;
12
13pub struct LanguageModelSession {
31 ptr: *mut c_void,
32}
33
34unsafe impl Send for LanguageModelSession {}
40unsafe impl Sync for LanguageModelSession {}
41
42impl LanguageModelSession {
43 #[must_use]
51 pub fn new() -> Self {
52 Self::try_new(None).expect("FoundationModels is not available on this OS")
53 }
54
55 #[must_use]
62 pub fn with_instructions(instructions: &str) -> Self {
63 Self::try_new(Some(instructions)).expect("FoundationModels is not available on this OS")
64 }
65
66 #[must_use]
70 pub fn try_new(instructions: Option<&str>) -> Option<Self> {
71 let cstring = match instructions {
72 Some(s) => Some(CString::new(s).ok()?),
73 None => None,
74 };
75 let ptr =
76 unsafe { ffi::fm_session_create(cstring.as_ref().map_or(ptr::null(), |s| s.as_ptr())) };
77 if ptr.is_null() {
78 return None;
79 }
80 Some(Self { ptr })
81 }
82
83 pub fn respond(&self, prompt: &str) -> Result<String, FMError> {
91 self.respond_with(prompt, GenerationOptions::new())
92 }
93
94 pub fn prewarm(&self) {
98 unsafe { ffi::fm_session_prewarm(self.ptr) };
99 }
100
101 #[must_use]
104 pub fn is_responding(&self) -> bool {
105 unsafe { ffi::fm_session_is_responding(self.ptr) }
106 }
107
108 #[must_use]
113 pub fn transcript_json(&self) -> String {
114 let p = unsafe { ffi::fm_session_transcript_json(self.ptr) };
115 if p.is_null() {
116 return String::from("{}");
117 }
118 let s = unsafe { core::ffi::CStr::from_ptr(p) }
119 .to_string_lossy()
120 .into_owned();
121 unsafe { ffi::fm_string_free(p) };
122 s
123 }
124
125 pub fn log_feedback(&self, sentiment: i32, description: Option<&str>) {
129 let cstr = description.and_then(|s| CString::new(s).ok());
130 let p = cstr.as_ref().map_or(core::ptr::null(), |c| c.as_ptr());
131 unsafe { ffi::fm_session_log_feedback(self.ptr, sentiment, p) };
132 }
133
134 pub fn respond_with_json_schema(
148 &self,
149 prompt: &str,
150 schema_description: &str,
151 ) -> Result<String, FMError> {
152 let wrapped = format!(
153 "{prompt}\n\n\
154 IMPORTANT: respond with VALID JSON ONLY (no prose, no markdown \
155 fences) that matches this schema:\n\n{schema_description}\n\n\
156 Your entire response must be parseable by JSON.parse()."
157 );
158 self.respond(&wrapped)
159 }
160
161 pub fn respond_with(
167 &self,
168 prompt: &str,
169 options: GenerationOptions,
170 ) -> Result<String, FMError> {
171 let prompt_c = CString::new(prompt)
172 .map_err(|e| FMError::InvalidArgument(format!("prompt contains NUL byte: {e}")))?;
173 let opts = options.to_ffi();
174 let (tx, rx) = mpsc::channel();
175 let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
176 let context = Box::into_raw(tx_box).cast::<c_void>();
177
178 unsafe {
179 ffi::fm_session_respond(
180 self.ptr,
181 prompt_c.as_ptr(),
182 opts.temperature,
183 opts.maximum_response_tokens,
184 opts.sampling_mode,
185 opts.top_k,
186 opts.top_p,
187 context,
188 respond_trampoline,
189 );
190 }
191
192 rx.recv().map_err(|_| FMError::Unknown {
195 code: ffi::status::UNKNOWN,
196 message: "Swift bridge dropped the callback channel".into(),
197 })?
198 }
199
200 pub fn respond_with_schema(
233 &self,
234 prompt: &str,
235 schema: &str,
236 include_schema_in_prompt: bool,
237 ) -> Result<String, FMError> {
238 self.respond_with_schema_options(prompt, schema, include_schema_in_prompt, GenerationOptions::new())
239 }
240
241 pub fn respond_with_schema_options(
248 &self,
249 prompt: &str,
250 schema: &str,
251 include_schema_in_prompt: bool,
252 options: GenerationOptions,
253 ) -> Result<String, FMError> {
254 let prompt_c = CString::new(prompt)
255 .map_err(|e| FMError::InvalidArgument(format!("prompt NUL byte: {e}")))?;
256 let schema_c = CString::new(schema)
257 .map_err(|e| FMError::InvalidArgument(format!("schema NUL byte: {e}")))?;
258 let opts = options.to_ffi();
259 let (tx, rx) = mpsc::channel();
260 let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
261 let context = Box::into_raw(tx_box).cast::<c_void>();
262
263 unsafe {
264 ffi::fm_session_respond_with_schema(
265 self.ptr,
266 prompt_c.as_ptr(),
267 schema_c.as_ptr(),
268 include_schema_in_prompt,
269 opts.temperature,
270 opts.maximum_response_tokens,
271 opts.sampling_mode,
272 opts.top_k,
273 opts.top_p,
274 context,
275 respond_trampoline,
276 );
277 }
278
279 rx.recv().map_err(|_| FMError::Unknown {
280 code: ffi::status::UNKNOWN,
281 message: "Swift bridge dropped the callback channel".into(),
282 })?
283 }
284
285 pub fn stream<F>(&self, prompt: &str, mut on_chunk: F) -> Result<(), FMError>
294 where
295 F: FnMut(StreamEvent<'_>) + Send + 'static,
296 {
297 self.stream_with(prompt, GenerationOptions::new(), move |event| {
298 on_chunk(event);
299 })
300 }
301
302 pub fn stream_with<F>(
308 &self,
309 prompt: &str,
310 options: GenerationOptions,
311 on_chunk: F,
312 ) -> Result<(), FMError>
313 where
314 F: FnMut(StreamEvent<'_>) + Send + 'static,
315 {
316 let prompt_c = CString::new(prompt)
317 .map_err(|e| FMError::InvalidArgument(format!("prompt contains NUL byte: {e}")))?;
318 let opts = options.to_ffi();
319
320 let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
324 let state = Arc::new(StreamState {
325 on_chunk: Mutex::new(Box::new(on_chunk)),
326 done_tx: Mutex::new(Some(done_tx)),
327 });
328 let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
329
330 unsafe {
331 ffi::fm_session_stream_response(
332 self.ptr,
333 prompt_c.as_ptr(),
334 opts.temperature,
335 opts.maximum_response_tokens,
336 opts.sampling_mode,
337 opts.top_k,
338 opts.top_p,
339 context,
340 stream_trampoline,
341 );
342 }
343
344 done_rx.recv().map_err(|_| FMError::Unknown {
345 code: ffi::status::UNKNOWN,
346 message: "Swift bridge dropped the stream channel".into(),
347 })?
348 }
349}
350
351impl Default for LanguageModelSession {
352 fn default() -> Self {
353 Self::new()
354 }
355}
356
357impl Drop for LanguageModelSession {
358 fn drop(&mut self) {
359 if !self.ptr.is_null() {
360 unsafe { ffi::fm_object_release(self.ptr) };
361 }
362 }
363}
364
365impl core::fmt::Debug for LanguageModelSession {
366 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
367 f.debug_struct("LanguageModelSession")
368 .field("ptr", &self.ptr)
369 .finish()
370 }
371}
372
373#[derive(Debug)]
375#[non_exhaustive]
376pub enum StreamEvent<'a> {
377 Chunk(&'a str),
379 Done,
381 Error(FMError),
383}
384
385unsafe extern "C" fn respond_trampoline(
388 context: *mut c_void,
389 response: *mut c_char,
390 error: *mut c_char,
391 status: i32,
392) {
393 let tx = Box::from_raw(context.cast::<mpsc::Sender<Result<String, FMError>>>());
394 let result = if status == ffi::status::OK && !response.is_null() {
395 let s = core::ffi::CStr::from_ptr(response)
396 .to_string_lossy()
397 .into_owned();
398 ffi::fm_string_free(response);
399 Ok(s)
400 } else {
401 Err(crate::error::from_swift(status, error))
402 };
403 let _ = tx.send(result);
404}
405
406type StreamCallback = Box<dyn FnMut(StreamEvent<'_>) + Send>;
407
408struct StreamState {
409 on_chunk: Mutex<StreamCallback>,
410 done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
411}
412
413unsafe extern "C" fn stream_trampoline(
414 context: *mut c_void,
415 chunk: *mut c_char,
416 done: bool,
417 status: i32,
418) {
419 let state = Arc::from_raw(context.cast::<StreamState>());
420 let state_for_swift = state.clone();
423 core::mem::forget(state_for_swift);
424
425 let chunk_str: Option<String> = if chunk.is_null() {
426 None
427 } else {
428 let s = core::ffi::CStr::from_ptr(chunk)
429 .to_string_lossy()
430 .into_owned();
431 ffi::fm_string_free(chunk);
432 Some(s)
433 };
434
435 if status != ffi::status::OK {
436 let err = crate::error::from_swift(status, ptr::null_mut());
437 let err_for_callback = chunk_str
438 .map(|m| match err.clone() {
439 FMError::Unknown { code, .. } => FMError::Unknown { code, message: m },
440 other => other,
441 })
442 .unwrap_or(err);
443 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
444 cb(StreamEvent::Error(err_for_callback.clone()));
445 drop(cb);
446 let pending_tx = state.done_tx.lock().expect("done_tx mutex poisoned").take();
447 if let Some(tx) = pending_tx {
448 let _ = tx.send(Err(err_for_callback));
449 }
450 drop(Arc::from_raw(Arc::as_ptr(&state)));
452 drop(state);
453 return;
454 }
455
456 if let Some(s) = chunk_str.as_deref() {
457 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
458 cb(StreamEvent::Chunk(s));
459 }
460
461 if done {
462 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
463 cb(StreamEvent::Done);
464 drop(cb);
465 let pending_tx = state.done_tx.lock().expect("done_tx mutex poisoned").take();
466 if let Some(tx) = pending_tx {
467 let _ = tx.send(Ok(()));
468 }
469 drop(Arc::from_raw(Arc::as_ptr(&state)));
471 }
472 drop(state);
473}