foundation_models/session/
mod.rs1use core::ffi::{c_char, c_void};
4use core::ptr;
5use std::ffi::CString;
6use std::sync::mpsc;
7use std::sync::{Arc, Mutex};
8
9use crate::error::FMError;
10use crate::ffi;
11use crate::generation::GenerationOptions;
12
13pub struct LanguageModelSession {
31 ptr: *mut c_void,
32}
33
34unsafe impl Send for LanguageModelSession {}
40unsafe impl Sync for LanguageModelSession {}
41
42impl LanguageModelSession {
43 #[must_use]
51 pub fn new() -> Self {
52 Self::try_new(None).expect("FoundationModels is not available on this OS")
53 }
54
55 #[must_use]
62 pub fn with_instructions(instructions: &str) -> Self {
63 Self::try_new(Some(instructions)).expect("FoundationModels is not available on this OS")
64 }
65
66 #[must_use]
70 pub fn try_new(instructions: Option<&str>) -> Option<Self> {
71 let cstring = match instructions {
72 Some(s) => Some(CString::new(s).ok()?),
73 None => None,
74 };
75 let ptr =
76 unsafe { ffi::fm_session_create(cstring.as_ref().map_or(ptr::null(), |s| s.as_ptr())) };
77 if ptr.is_null() {
78 return None;
79 }
80 Some(Self { ptr })
81 }
82
83 pub fn respond(&self, prompt: &str) -> Result<String, FMError> {
91 self.respond_with(prompt, GenerationOptions::new())
92 }
93
94 pub fn prewarm(&self) {
98 unsafe { ffi::fm_session_prewarm(self.ptr) };
99 }
100
101 #[must_use]
104 pub fn is_responding(&self) -> bool {
105 unsafe { ffi::fm_session_is_responding(self.ptr) }
106 }
107
108 pub fn respond_with(
114 &self,
115 prompt: &str,
116 options: GenerationOptions,
117 ) -> Result<String, FMError> {
118 let prompt_c = CString::new(prompt)
119 .map_err(|e| FMError::InvalidArgument(format!("prompt contains NUL byte: {e}")))?;
120 let opts = options.to_ffi();
121 let (tx, rx) = mpsc::channel();
122 let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
123 let context = Box::into_raw(tx_box).cast::<c_void>();
124
125 unsafe {
126 ffi::fm_session_respond(
127 self.ptr,
128 prompt_c.as_ptr(),
129 opts.temperature,
130 opts.maximum_response_tokens,
131 opts.sampling_mode,
132 opts.top_k,
133 opts.top_p,
134 context,
135 respond_trampoline,
136 );
137 }
138
139 rx.recv().map_err(|_| FMError::Unknown {
142 code: ffi::status::UNKNOWN,
143 message: "Swift bridge dropped the callback channel".into(),
144 })?
145 }
146
147 pub fn stream<F>(&self, prompt: &str, mut on_chunk: F) -> Result<(), FMError>
156 where
157 F: FnMut(StreamEvent<'_>) + Send + 'static,
158 {
159 self.stream_with(prompt, GenerationOptions::new(), move |event| {
160 on_chunk(event);
161 })
162 }
163
164 pub fn stream_with<F>(
170 &self,
171 prompt: &str,
172 options: GenerationOptions,
173 on_chunk: F,
174 ) -> Result<(), FMError>
175 where
176 F: FnMut(StreamEvent<'_>) + Send + 'static,
177 {
178 let prompt_c = CString::new(prompt)
179 .map_err(|e| FMError::InvalidArgument(format!("prompt contains NUL byte: {e}")))?;
180 let opts = options.to_ffi();
181
182 let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
186 let state = Arc::new(StreamState {
187 on_chunk: Mutex::new(Box::new(on_chunk)),
188 done_tx: Mutex::new(Some(done_tx)),
189 });
190 let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
191
192 unsafe {
193 ffi::fm_session_stream_response(
194 self.ptr,
195 prompt_c.as_ptr(),
196 opts.temperature,
197 opts.maximum_response_tokens,
198 opts.sampling_mode,
199 opts.top_k,
200 opts.top_p,
201 context,
202 stream_trampoline,
203 );
204 }
205
206 done_rx.recv().map_err(|_| FMError::Unknown {
207 code: ffi::status::UNKNOWN,
208 message: "Swift bridge dropped the stream channel".into(),
209 })?
210 }
211}
212
213impl Default for LanguageModelSession {
214 fn default() -> Self {
215 Self::new()
216 }
217}
218
219impl Drop for LanguageModelSession {
220 fn drop(&mut self) {
221 if !self.ptr.is_null() {
222 unsafe { ffi::fm_object_release(self.ptr) };
223 }
224 }
225}
226
227impl core::fmt::Debug for LanguageModelSession {
228 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
229 f.debug_struct("LanguageModelSession")
230 .field("ptr", &self.ptr)
231 .finish()
232 }
233}
234
235#[derive(Debug)]
237#[non_exhaustive]
238pub enum StreamEvent<'a> {
239 Chunk(&'a str),
241 Done,
243 Error(FMError),
245}
246
247unsafe extern "C" fn respond_trampoline(
250 context: *mut c_void,
251 response: *mut c_char,
252 error: *mut c_char,
253 status: i32,
254) {
255 let tx = Box::from_raw(context.cast::<mpsc::Sender<Result<String, FMError>>>());
256 let result = if status == ffi::status::OK && !response.is_null() {
257 let s = core::ffi::CStr::from_ptr(response)
258 .to_string_lossy()
259 .into_owned();
260 ffi::fm_string_free(response);
261 Ok(s)
262 } else {
263 Err(crate::error::from_swift(status, error))
264 };
265 let _ = tx.send(result);
266}
267
268type StreamCallback = Box<dyn FnMut(StreamEvent<'_>) + Send>;
269
270struct StreamState {
271 on_chunk: Mutex<StreamCallback>,
272 done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
273}
274
275unsafe extern "C" fn stream_trampoline(
276 context: *mut c_void,
277 chunk: *mut c_char,
278 done: bool,
279 status: i32,
280) {
281 let state = Arc::from_raw(context.cast::<StreamState>());
282 let state_for_swift = state.clone();
285 core::mem::forget(state_for_swift);
286
287 let chunk_str: Option<String> = if chunk.is_null() {
288 None
289 } else {
290 let s = core::ffi::CStr::from_ptr(chunk)
291 .to_string_lossy()
292 .into_owned();
293 ffi::fm_string_free(chunk);
294 Some(s)
295 };
296
297 if status != ffi::status::OK {
298 let err = crate::error::from_swift(status, ptr::null_mut());
299 let err_for_callback = chunk_str
300 .map(|m| match err.clone() {
301 FMError::Unknown { code, .. } => FMError::Unknown { code, message: m },
302 other => other,
303 })
304 .unwrap_or(err);
305 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
306 cb(StreamEvent::Error(err_for_callback.clone()));
307 drop(cb);
308 let pending_tx = state.done_tx.lock().expect("done_tx mutex poisoned").take();
309 if let Some(tx) = pending_tx {
310 let _ = tx.send(Err(err_for_callback));
311 }
312 drop(Arc::from_raw(Arc::as_ptr(&state)));
314 drop(state);
315 return;
316 }
317
318 if let Some(s) = chunk_str.as_deref() {
319 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
320 cb(StreamEvent::Chunk(s));
321 }
322
323 if done {
324 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
325 cb(StreamEvent::Done);
326 drop(cb);
327 let pending_tx = state.done_tx.lock().expect("done_tx mutex poisoned").take();
328 if let Some(tx) = pending_tx {
329 let _ = tx.send(Ok(()));
330 }
331 drop(Arc::from_raw(Arc::as_ptr(&state)));
333 }
334 drop(state);
335}