Skip to main content

foundation_models/session/
mod.rs

1//! [`LanguageModelSession`] — a stateful conversation with the on-device model.
2
3use core::ffi::{c_char, c_void};
4use core::ptr;
5use std::ffi::CString;
6use std::sync::mpsc;
7use std::sync::{Arc, Mutex};
8
9use crate::error::FMError;
10use crate::ffi;
11use crate::generation::GenerationOptions;
12
13/// A stateful conversation with the on-device language model.
14///
15/// Sessions retain their conversation history; subsequent calls to
16/// [`respond`](Self::respond) build on the previous turns.
17///
18/// # Examples
19///
20/// ```rust,no_run
21/// use foundation_models::LanguageModelSession;
22///
23/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
24/// let session = LanguageModelSession::new();
25/// let answer = session.respond("Name three Norse gods.")?;
26/// println!("{answer}");
27/// # Ok(())
28/// # }
29/// ```
30pub struct LanguageModelSession {
31    ptr: *mut c_void,
32}
33
34// SAFETY: The underlying Swift LanguageModelSession is reference-counted via
35// Unmanaged.passRetained on the Swift side; sending the opaque pointer between
36// threads is safe as long as we don't dereference it from Rust (we never do —
37// it only travels through extern "C" calls that internally hop to the
38// Swift concurrency executor).
39unsafe impl Send for LanguageModelSession {}
40unsafe impl Sync for LanguageModelSession {}
41
42impl LanguageModelSession {
43    /// Create a session with the model's default behaviour.
44    ///
45    /// # Panics
46    ///
47    /// Panics if `FoundationModels` is not available on this OS. Check
48    /// [`crate::SystemLanguageModel::is_available`] first if you need to
49    /// handle that gracefully.
50    #[must_use]
51    pub fn new() -> Self {
52        Self::try_new(None).expect("FoundationModels is not available on this OS")
53    }
54
55    /// Create a session with custom system instructions ("system prompt").
56    ///
57    /// # Panics
58    ///
59    /// Panics if `FoundationModels` is not available, or if `instructions`
60    /// contains an interior NUL byte.
61    #[must_use]
62    pub fn with_instructions(instructions: &str) -> Self {
63        Self::try_new(Some(instructions)).expect("FoundationModels is not available on this OS")
64    }
65
66    /// Fallible constructor. Returns `None` when `FoundationModels` is not
67    /// available (OS too old, model not enabled, etc.) or when `instructions`
68    /// contains an interior NUL byte.
69    #[must_use]
70    pub fn try_new(instructions: Option<&str>) -> Option<Self> {
71        let cstring = match instructions {
72            Some(s) => Some(CString::new(s).ok()?),
73            None => None,
74        };
75        let ptr =
76            unsafe { ffi::fm_session_create(cstring.as_ref().map_or(ptr::null(), |s| s.as_ptr())) };
77        if ptr.is_null() {
78            return None;
79        }
80        Some(Self { ptr })
81    }
82
83    /// Send a prompt and block until the full response is available.
84    ///
85    /// # Errors
86    ///
87    /// Returns an [`FMError`] if the model rejects the prompt, the context
88    /// window is exceeded, the session is cancelled, or the prompt contains
89    /// an interior NUL byte.
90    pub fn respond(&self, prompt: &str) -> Result<String, FMError> {
91        self.respond_with(prompt, GenerationOptions::new())
92    }
93
94    /// Like [`respond`](Self::respond), but with explicit generation options.
95    ///
96    /// # Errors
97    ///
98    /// See [`respond`](Self::respond).
99    pub fn respond_with(
100        &self,
101        prompt: &str,
102        options: GenerationOptions,
103    ) -> Result<String, FMError> {
104        let prompt_c = CString::new(prompt)
105            .map_err(|e| FMError::InvalidArgument(format!("prompt contains NUL byte: {e}")))?;
106        let opts = options.to_ffi();
107        let (tx, rx) = mpsc::channel();
108        let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
109        let context = Box::into_raw(tx_box).cast::<c_void>();
110
111        unsafe {
112            ffi::fm_session_respond(
113                self.ptr,
114                prompt_c.as_ptr(),
115                opts.temperature,
116                opts.maximum_response_tokens,
117                opts.sampling_mode,
118                opts.top_k,
119                opts.top_p,
120                context,
121                respond_trampoline,
122            );
123        }
124
125        // The Swift side dispatches the callback on its own Task executor;
126        // it is guaranteed to fire exactly once.
127        rx.recv().map_err(|_| FMError::Unknown {
128            code: ffi::status::UNKNOWN,
129            message: "Swift bridge dropped the callback channel".into(),
130        })?
131    }
132
133    /// Stream the response as the model generates it. The callback is invoked
134    /// with each delta and a final invocation with `done == true`.
135    ///
136    /// # Errors
137    ///
138    /// Returns an [`FMError`] mirroring [`respond`](Self::respond). The
139    /// callback may also receive a chunk *and* an error if the stream fails
140    /// midway.
141    pub fn stream<F>(&self, prompt: &str, mut on_chunk: F) -> Result<(), FMError>
142    where
143        F: FnMut(StreamEvent<'_>) + Send + 'static,
144    {
145        self.stream_with(prompt, GenerationOptions::new(), move |event| {
146            on_chunk(event);
147        })
148    }
149
150    /// Like [`stream`](Self::stream), but with explicit generation options.
151    ///
152    /// # Errors
153    ///
154    /// See [`stream`](Self::stream).
155    pub fn stream_with<F>(
156        &self,
157        prompt: &str,
158        options: GenerationOptions,
159        on_chunk: F,
160    ) -> Result<(), FMError>
161    where
162        F: FnMut(StreamEvent<'_>) + Send + 'static,
163    {
164        let prompt_c = CString::new(prompt)
165            .map_err(|e| FMError::InvalidArgument(format!("prompt contains NUL byte: {e}")))?;
166        let opts = options.to_ffi();
167
168        // The callback may be invoked many times before completion. We pair
169        // the user closure with a oneshot channel that signals "stream
170        // finished" so this function can block until the Swift Task ends.
171        let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
172        let state = Arc::new(StreamState {
173            on_chunk: Mutex::new(Box::new(on_chunk)),
174            done_tx: Mutex::new(Some(done_tx)),
175        });
176        let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
177
178        unsafe {
179            ffi::fm_session_stream_response(
180                self.ptr,
181                prompt_c.as_ptr(),
182                opts.temperature,
183                opts.maximum_response_tokens,
184                opts.sampling_mode,
185                opts.top_k,
186                opts.top_p,
187                context,
188                stream_trampoline,
189            );
190        }
191
192        done_rx.recv().map_err(|_| FMError::Unknown {
193            code: ffi::status::UNKNOWN,
194            message: "Swift bridge dropped the stream channel".into(),
195        })?
196    }
197}
198
199impl Default for LanguageModelSession {
200    fn default() -> Self {
201        Self::new()
202    }
203}
204
205impl Drop for LanguageModelSession {
206    fn drop(&mut self) {
207        if !self.ptr.is_null() {
208            unsafe { ffi::fm_object_release(self.ptr) };
209        }
210    }
211}
212
213impl core::fmt::Debug for LanguageModelSession {
214    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
215        f.debug_struct("LanguageModelSession")
216            .field("ptr", &self.ptr)
217            .finish()
218    }
219}
220
221/// One event from a streaming generation.
222#[derive(Debug)]
223#[non_exhaustive]
224pub enum StreamEvent<'a> {
225    /// Incremental text delta. Concatenate these to reconstruct the full reply.
226    Chunk(&'a str),
227    /// Stream finished successfully.
228    Done,
229    /// Stream failed; the inner error describes why.
230    Error(FMError),
231}
232
233// ---------- internal callback plumbing ----------
234
235unsafe extern "C" fn respond_trampoline(
236    context: *mut c_void,
237    response: *mut c_char,
238    error: *mut c_char,
239    status: i32,
240) {
241    let tx = Box::from_raw(context.cast::<mpsc::Sender<Result<String, FMError>>>());
242    let result = if status == ffi::status::OK && !response.is_null() {
243        let s = core::ffi::CStr::from_ptr(response)
244            .to_string_lossy()
245            .into_owned();
246        ffi::fm_string_free(response);
247        Ok(s)
248    } else {
249        Err(crate::error::from_swift(status, error))
250    };
251    let _ = tx.send(result);
252}
253
254type StreamCallback = Box<dyn FnMut(StreamEvent<'_>) + Send>;
255
256struct StreamState {
257    on_chunk: Mutex<StreamCallback>,
258    done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
259}
260
261unsafe extern "C" fn stream_trampoline(
262    context: *mut c_void,
263    chunk: *mut c_char,
264    done: bool,
265    status: i32,
266) {
267    let state = Arc::from_raw(context.cast::<StreamState>());
268    // Bump the count back up because Swift may invoke us again before
269    // `done == true` (Arc::from_raw consumed our refcount).
270    let state_for_swift = state.clone();
271    core::mem::forget(state_for_swift);
272
273    let chunk_str: Option<String> = if chunk.is_null() {
274        None
275    } else {
276        let s = core::ffi::CStr::from_ptr(chunk)
277            .to_string_lossy()
278            .into_owned();
279        ffi::fm_string_free(chunk);
280        Some(s)
281    };
282
283    if status != ffi::status::OK {
284        let err = crate::error::from_swift(status, ptr::null_mut());
285        let err_for_callback = chunk_str
286            .map(|m| match err.clone() {
287                FMError::Unknown { code, .. } => FMError::Unknown { code, message: m },
288                other => other,
289            })
290            .unwrap_or(err);
291        let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
292        cb(StreamEvent::Error(err_for_callback.clone()));
293        drop(cb);
294        let pending_tx = state.done_tx.lock().expect("done_tx mutex poisoned").take();
295        if let Some(tx) = pending_tx {
296            let _ = tx.send(Err(err_for_callback));
297        }
298        // This was the final invocation: drop the extra ref we forgot above.
299        drop(Arc::from_raw(Arc::as_ptr(&state)));
300        drop(state);
301        return;
302    }
303
304    if let Some(s) = chunk_str.as_deref() {
305        let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
306        cb(StreamEvent::Chunk(s));
307    }
308
309    if done {
310        let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
311        cb(StreamEvent::Done);
312        drop(cb);
313        let pending_tx = state.done_tx.lock().expect("done_tx mutex poisoned").take();
314        if let Some(tx) = pending_tx {
315            let _ = tx.send(Ok(()));
316        }
317        // Final invocation: release the extra ref we forgot above.
318        drop(Arc::from_raw(Arc::as_ptr(&state)));
319    }
320    drop(state);
321}