foundation_models/session/
mod.rs1use core::ffi::{c_char, c_void};
4use core::ptr;
5use std::ffi::CString;
6use std::sync::mpsc;
7use std::sync::{Arc, Mutex};
8
9use crate::error::FMError;
10use crate::ffi;
11use crate::generation::GenerationOptions;
12
13pub struct LanguageModelSession {
31 ptr: *mut c_void,
32}
33
34unsafe impl Send for LanguageModelSession {}
40unsafe impl Sync for LanguageModelSession {}
41
42impl LanguageModelSession {
43 #[must_use]
51 pub fn new() -> Self {
52 Self::try_new(None).expect("FoundationModels is not available on this OS")
53 }
54
55 #[must_use]
62 pub fn with_instructions(instructions: &str) -> Self {
63 Self::try_new(Some(instructions)).expect("FoundationModels is not available on this OS")
64 }
65
66 #[must_use]
70 pub fn try_new(instructions: Option<&str>) -> Option<Self> {
71 let cstring = match instructions {
72 Some(s) => Some(CString::new(s).ok()?),
73 None => None,
74 };
75 let ptr =
76 unsafe { ffi::fm_session_create(cstring.as_ref().map_or(ptr::null(), |s| s.as_ptr())) };
77 if ptr.is_null() {
78 return None;
79 }
80 Some(Self { ptr })
81 }
82
83 pub fn respond(&self, prompt: &str) -> Result<String, FMError> {
91 self.respond_with(prompt, GenerationOptions::new())
92 }
93
94 pub fn prewarm(&self) {
98 unsafe { ffi::fm_session_prewarm(self.ptr) };
99 }
100
101 #[must_use]
104 pub fn is_responding(&self) -> bool {
105 unsafe { ffi::fm_session_is_responding(self.ptr) }
106 }
107
108 pub fn respond_with_json_schema(
122 &self,
123 prompt: &str,
124 schema_description: &str,
125 ) -> Result<String, FMError> {
126 let wrapped = format!(
127 "{prompt}\n\n\
128 IMPORTANT: respond with VALID JSON ONLY (no prose, no markdown \
129 fences) that matches this schema:\n\n{schema_description}\n\n\
130 Your entire response must be parseable by JSON.parse()."
131 );
132 self.respond(&wrapped)
133 }
134
135 pub fn respond_with(
141 &self,
142 prompt: &str,
143 options: GenerationOptions,
144 ) -> Result<String, FMError> {
145 let prompt_c = CString::new(prompt)
146 .map_err(|e| FMError::InvalidArgument(format!("prompt contains NUL byte: {e}")))?;
147 let opts = options.to_ffi();
148 let (tx, rx) = mpsc::channel();
149 let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
150 let context = Box::into_raw(tx_box).cast::<c_void>();
151
152 unsafe {
153 ffi::fm_session_respond(
154 self.ptr,
155 prompt_c.as_ptr(),
156 opts.temperature,
157 opts.maximum_response_tokens,
158 opts.sampling_mode,
159 opts.top_k,
160 opts.top_p,
161 context,
162 respond_trampoline,
163 );
164 }
165
166 rx.recv().map_err(|_| FMError::Unknown {
169 code: ffi::status::UNKNOWN,
170 message: "Swift bridge dropped the callback channel".into(),
171 })?
172 }
173
174 pub fn stream<F>(&self, prompt: &str, mut on_chunk: F) -> Result<(), FMError>
183 where
184 F: FnMut(StreamEvent<'_>) + Send + 'static,
185 {
186 self.stream_with(prompt, GenerationOptions::new(), move |event| {
187 on_chunk(event);
188 })
189 }
190
191 pub fn stream_with<F>(
197 &self,
198 prompt: &str,
199 options: GenerationOptions,
200 on_chunk: F,
201 ) -> Result<(), FMError>
202 where
203 F: FnMut(StreamEvent<'_>) + Send + 'static,
204 {
205 let prompt_c = CString::new(prompt)
206 .map_err(|e| FMError::InvalidArgument(format!("prompt contains NUL byte: {e}")))?;
207 let opts = options.to_ffi();
208
209 let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
213 let state = Arc::new(StreamState {
214 on_chunk: Mutex::new(Box::new(on_chunk)),
215 done_tx: Mutex::new(Some(done_tx)),
216 });
217 let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
218
219 unsafe {
220 ffi::fm_session_stream_response(
221 self.ptr,
222 prompt_c.as_ptr(),
223 opts.temperature,
224 opts.maximum_response_tokens,
225 opts.sampling_mode,
226 opts.top_k,
227 opts.top_p,
228 context,
229 stream_trampoline,
230 );
231 }
232
233 done_rx.recv().map_err(|_| FMError::Unknown {
234 code: ffi::status::UNKNOWN,
235 message: "Swift bridge dropped the stream channel".into(),
236 })?
237 }
238}
239
240impl Default for LanguageModelSession {
241 fn default() -> Self {
242 Self::new()
243 }
244}
245
246impl Drop for LanguageModelSession {
247 fn drop(&mut self) {
248 if !self.ptr.is_null() {
249 unsafe { ffi::fm_object_release(self.ptr) };
250 }
251 }
252}
253
254impl core::fmt::Debug for LanguageModelSession {
255 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
256 f.debug_struct("LanguageModelSession")
257 .field("ptr", &self.ptr)
258 .finish()
259 }
260}
261
262#[derive(Debug)]
264#[non_exhaustive]
265pub enum StreamEvent<'a> {
266 Chunk(&'a str),
268 Done,
270 Error(FMError),
272}
273
274unsafe extern "C" fn respond_trampoline(
277 context: *mut c_void,
278 response: *mut c_char,
279 error: *mut c_char,
280 status: i32,
281) {
282 let tx = Box::from_raw(context.cast::<mpsc::Sender<Result<String, FMError>>>());
283 let result = if status == ffi::status::OK && !response.is_null() {
284 let s = core::ffi::CStr::from_ptr(response)
285 .to_string_lossy()
286 .into_owned();
287 ffi::fm_string_free(response);
288 Ok(s)
289 } else {
290 Err(crate::error::from_swift(status, error))
291 };
292 let _ = tx.send(result);
293}
294
295type StreamCallback = Box<dyn FnMut(StreamEvent<'_>) + Send>;
296
297struct StreamState {
298 on_chunk: Mutex<StreamCallback>,
299 done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
300}
301
302unsafe extern "C" fn stream_trampoline(
303 context: *mut c_void,
304 chunk: *mut c_char,
305 done: bool,
306 status: i32,
307) {
308 let state = Arc::from_raw(context.cast::<StreamState>());
309 let state_for_swift = state.clone();
312 core::mem::forget(state_for_swift);
313
314 let chunk_str: Option<String> = if chunk.is_null() {
315 None
316 } else {
317 let s = core::ffi::CStr::from_ptr(chunk)
318 .to_string_lossy()
319 .into_owned();
320 ffi::fm_string_free(chunk);
321 Some(s)
322 };
323
324 if status != ffi::status::OK {
325 let err = crate::error::from_swift(status, ptr::null_mut());
326 let err_for_callback = chunk_str
327 .map(|m| match err.clone() {
328 FMError::Unknown { code, .. } => FMError::Unknown { code, message: m },
329 other => other,
330 })
331 .unwrap_or(err);
332 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
333 cb(StreamEvent::Error(err_for_callback.clone()));
334 drop(cb);
335 let pending_tx = state.done_tx.lock().expect("done_tx mutex poisoned").take();
336 if let Some(tx) = pending_tx {
337 let _ = tx.send(Err(err_for_callback));
338 }
339 drop(Arc::from_raw(Arc::as_ptr(&state)));
341 drop(state);
342 return;
343 }
344
345 if let Some(s) = chunk_str.as_deref() {
346 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
347 cb(StreamEvent::Chunk(s));
348 }
349
350 if done {
351 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
352 cb(StreamEvent::Done);
353 drop(cb);
354 let pending_tx = state.done_tx.lock().expect("done_tx mutex poisoned").take();
355 if let Some(tx) = pending_tx {
356 let _ = tx.send(Ok(()));
357 }
358 drop(Arc::from_raw(Arc::as_ptr(&state)));
360 }
361 drop(state);
362}