Skip to main content

foundation_models/session/
mod.rs

1//! [`LanguageModelSession`] — a stateful conversation with the on-device model.
2
3use core::ffi::{c_char, c_void};
4use core::ptr;
5use std::ffi::CString;
6use std::sync::mpsc;
7use std::sync::{Arc, Mutex};
8
9use crate::error::FMError;
10use crate::ffi;
11use crate::generation::GenerationOptions;
12
13/// A stateful conversation with the on-device language model.
14///
15/// Sessions retain their conversation history; subsequent calls to
16/// [`respond`](Self::respond) build on the previous turns.
17///
18/// # Examples
19///
20/// ```rust,no_run
21/// use foundation_models::LanguageModelSession;
22///
23/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
24/// let session = LanguageModelSession::new();
25/// let answer = session.respond("Name three Norse gods.")?;
26/// println!("{answer}");
27/// # Ok(())
28/// # }
29/// ```
30pub struct LanguageModelSession {
31    ptr: *mut c_void,
32}
33
34// SAFETY: The underlying Swift LanguageModelSession is reference-counted via
35// Unmanaged.passRetained on the Swift side; sending the opaque pointer between
36// threads is safe as long as we don't dereference it from Rust (we never do —
37// it only travels through extern "C" calls that internally hop to the
38// Swift concurrency executor).
39unsafe impl Send for LanguageModelSession {}
40unsafe impl Sync for LanguageModelSession {}
41
42impl LanguageModelSession {
43    /// Create a session with the model's default behaviour.
44    ///
45    /// # Panics
46    ///
47    /// Panics if `FoundationModels` is not available on this OS. Check
48    /// [`crate::SystemLanguageModel::is_available`] first if you need to
49    /// handle that gracefully.
50    #[must_use]
51    pub fn new() -> Self {
52        Self::try_new(None).expect("FoundationModels is not available on this OS")
53    }
54
55    /// Create a session with custom system instructions ("system prompt").
56    ///
57    /// # Panics
58    ///
59    /// Panics if `FoundationModels` is not available, or if `instructions`
60    /// contains an interior NUL byte.
61    #[must_use]
62    pub fn with_instructions(instructions: &str) -> Self {
63        Self::try_new(Some(instructions)).expect("FoundationModels is not available on this OS")
64    }
65
66    /// Fallible constructor. Returns `None` when `FoundationModels` is not
67    /// available (OS too old, model not enabled, etc.) or when `instructions`
68    /// contains an interior NUL byte.
69    #[must_use]
70    pub fn try_new(instructions: Option<&str>) -> Option<Self> {
71        let cstring = match instructions {
72            Some(s) => Some(CString::new(s).ok()?),
73            None => None,
74        };
75        let ptr =
76            unsafe { ffi::fm_session_create(cstring.as_ref().map_or(ptr::null(), |s| s.as_ptr())) };
77        if ptr.is_null() {
78            return None;
79        }
80        Some(Self { ptr })
81    }
82
83    /// Send a prompt and block until the full response is available.
84    ///
85    /// # Errors
86    ///
87    /// Returns an [`FMError`] if the model rejects the prompt, the context
88    /// window is exceeded, the session is cancelled, or the prompt contains
89    /// an interior NUL byte.
90    pub fn respond(&self, prompt: &str) -> Result<String, FMError> {
91        self.respond_with(prompt, GenerationOptions::new())
92    }
93
94    /// Pre-warm the model. Apple loads the weights + initialises the
95    /// inference engine so the next `respond` call is faster. Returns
96    /// immediately; the warm-up runs in the background.
97    pub fn prewarm(&self) {
98        unsafe { ffi::fm_session_prewarm(self.ptr) };
99    }
100
101    /// True if this session is currently producing a response (i.e. an
102    /// earlier `respond` / `stream` is still in flight on Apple's queue).
103    #[must_use]
104    pub fn is_responding(&self) -> bool {
105        unsafe { ffi::fm_session_is_responding(self.ptr) }
106    }
107
108    /// Like [`respond`](Self::respond), but with explicit generation options.
109    ///
110    /// # Errors
111    ///
112    /// See [`respond`](Self::respond).
113    pub fn respond_with(
114        &self,
115        prompt: &str,
116        options: GenerationOptions,
117    ) -> Result<String, FMError> {
118        let prompt_c = CString::new(prompt)
119            .map_err(|e| FMError::InvalidArgument(format!("prompt contains NUL byte: {e}")))?;
120        let opts = options.to_ffi();
121        let (tx, rx) = mpsc::channel();
122        let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
123        let context = Box::into_raw(tx_box).cast::<c_void>();
124
125        unsafe {
126            ffi::fm_session_respond(
127                self.ptr,
128                prompt_c.as_ptr(),
129                opts.temperature,
130                opts.maximum_response_tokens,
131                opts.sampling_mode,
132                opts.top_k,
133                opts.top_p,
134                context,
135                respond_trampoline,
136            );
137        }
138
139        // The Swift side dispatches the callback on its own Task executor;
140        // it is guaranteed to fire exactly once.
141        rx.recv().map_err(|_| FMError::Unknown {
142            code: ffi::status::UNKNOWN,
143            message: "Swift bridge dropped the callback channel".into(),
144        })?
145    }
146
147    /// Stream the response as the model generates it. The callback is invoked
148    /// with each delta and a final invocation with `done == true`.
149    ///
150    /// # Errors
151    ///
152    /// Returns an [`FMError`] mirroring [`respond`](Self::respond). The
153    /// callback may also receive a chunk *and* an error if the stream fails
154    /// midway.
155    pub fn stream<F>(&self, prompt: &str, mut on_chunk: F) -> Result<(), FMError>
156    where
157        F: FnMut(StreamEvent<'_>) + Send + 'static,
158    {
159        self.stream_with(prompt, GenerationOptions::new(), move |event| {
160            on_chunk(event);
161        })
162    }
163
164    /// Like [`stream`](Self::stream), but with explicit generation options.
165    ///
166    /// # Errors
167    ///
168    /// See [`stream`](Self::stream).
169    pub fn stream_with<F>(
170        &self,
171        prompt: &str,
172        options: GenerationOptions,
173        on_chunk: F,
174    ) -> Result<(), FMError>
175    where
176        F: FnMut(StreamEvent<'_>) + Send + 'static,
177    {
178        let prompt_c = CString::new(prompt)
179            .map_err(|e| FMError::InvalidArgument(format!("prompt contains NUL byte: {e}")))?;
180        let opts = options.to_ffi();
181
182        // The callback may be invoked many times before completion. We pair
183        // the user closure with a oneshot channel that signals "stream
184        // finished" so this function can block until the Swift Task ends.
185        let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
186        let state = Arc::new(StreamState {
187            on_chunk: Mutex::new(Box::new(on_chunk)),
188            done_tx: Mutex::new(Some(done_tx)),
189        });
190        let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
191
192        unsafe {
193            ffi::fm_session_stream_response(
194                self.ptr,
195                prompt_c.as_ptr(),
196                opts.temperature,
197                opts.maximum_response_tokens,
198                opts.sampling_mode,
199                opts.top_k,
200                opts.top_p,
201                context,
202                stream_trampoline,
203            );
204        }
205
206        done_rx.recv().map_err(|_| FMError::Unknown {
207            code: ffi::status::UNKNOWN,
208            message: "Swift bridge dropped the stream channel".into(),
209        })?
210    }
211}
212
213impl Default for LanguageModelSession {
214    fn default() -> Self {
215        Self::new()
216    }
217}
218
219impl Drop for LanguageModelSession {
220    fn drop(&mut self) {
221        if !self.ptr.is_null() {
222            unsafe { ffi::fm_object_release(self.ptr) };
223        }
224    }
225}
226
227impl core::fmt::Debug for LanguageModelSession {
228    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
229        f.debug_struct("LanguageModelSession")
230            .field("ptr", &self.ptr)
231            .finish()
232    }
233}
234
235/// One event from a streaming generation.
236#[derive(Debug)]
237#[non_exhaustive]
238pub enum StreamEvent<'a> {
239    /// Incremental text delta. Concatenate these to reconstruct the full reply.
240    Chunk(&'a str),
241    /// Stream finished successfully.
242    Done,
243    /// Stream failed; the inner error describes why.
244    Error(FMError),
245}
246
247// ---------- internal callback plumbing ----------
248
249unsafe extern "C" fn respond_trampoline(
250    context: *mut c_void,
251    response: *mut c_char,
252    error: *mut c_char,
253    status: i32,
254) {
255    let tx = Box::from_raw(context.cast::<mpsc::Sender<Result<String, FMError>>>());
256    let result = if status == ffi::status::OK && !response.is_null() {
257        let s = core::ffi::CStr::from_ptr(response)
258            .to_string_lossy()
259            .into_owned();
260        ffi::fm_string_free(response);
261        Ok(s)
262    } else {
263        Err(crate::error::from_swift(status, error))
264    };
265    let _ = tx.send(result);
266}
267
268type StreamCallback = Box<dyn FnMut(StreamEvent<'_>) + Send>;
269
270struct StreamState {
271    on_chunk: Mutex<StreamCallback>,
272    done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
273}
274
275unsafe extern "C" fn stream_trampoline(
276    context: *mut c_void,
277    chunk: *mut c_char,
278    done: bool,
279    status: i32,
280) {
281    let state = Arc::from_raw(context.cast::<StreamState>());
282    // Bump the count back up because Swift may invoke us again before
283    // `done == true` (Arc::from_raw consumed our refcount).
284    let state_for_swift = state.clone();
285    core::mem::forget(state_for_swift);
286
287    let chunk_str: Option<String> = if chunk.is_null() {
288        None
289    } else {
290        let s = core::ffi::CStr::from_ptr(chunk)
291            .to_string_lossy()
292            .into_owned();
293        ffi::fm_string_free(chunk);
294        Some(s)
295    };
296
297    if status != ffi::status::OK {
298        let err = crate::error::from_swift(status, ptr::null_mut());
299        let err_for_callback = chunk_str
300            .map(|m| match err.clone() {
301                FMError::Unknown { code, .. } => FMError::Unknown { code, message: m },
302                other => other,
303            })
304            .unwrap_or(err);
305        let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
306        cb(StreamEvent::Error(err_for_callback.clone()));
307        drop(cb);
308        let pending_tx = state.done_tx.lock().expect("done_tx mutex poisoned").take();
309        if let Some(tx) = pending_tx {
310            let _ = tx.send(Err(err_for_callback));
311        }
312        // This was the final invocation: drop the extra ref we forgot above.
313        drop(Arc::from_raw(Arc::as_ptr(&state)));
314        drop(state);
315        return;
316    }
317
318    if let Some(s) = chunk_str.as_deref() {
319        let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
320        cb(StreamEvent::Chunk(s));
321    }
322
323    if done {
324        let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
325        cb(StreamEvent::Done);
326        drop(cb);
327        let pending_tx = state.done_tx.lock().expect("done_tx mutex poisoned").take();
328        if let Some(tx) = pending_tx {
329            let _ = tx.send(Ok(()));
330        }
331        // Final invocation: release the extra ref we forgot above.
332        drop(Arc::from_raw(Arc::as_ptr(&state)));
333    }
334    drop(state);
335}