Skip to main content

foundation_models/session/
mod.rs

1//! [`LanguageModelSession`] — a stateful conversation with the on-device model.
2
3use core::ffi::{c_char, c_void};
4use core::ptr;
5use std::ffi::CString;
6use std::sync::mpsc;
7use std::sync::{Arc, Mutex};
8
9use crate::error::FMError;
10use crate::ffi;
11use crate::generation::GenerationOptions;
12
13/// A stateful conversation with the on-device language model.
14///
15/// Sessions retain their conversation history; subsequent calls to
16/// [`respond`](Self::respond) build on the previous turns.
17///
18/// # Examples
19///
20/// ```rust,no_run
21/// use foundation_models::LanguageModelSession;
22///
23/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
24/// let session = LanguageModelSession::new();
25/// let answer = session.respond("Name three Norse gods.")?;
26/// println!("{answer}");
27/// # Ok(())
28/// # }
29/// ```
30pub struct LanguageModelSession {
31    ptr: *mut c_void,
32}
33
34// SAFETY: The underlying Swift LanguageModelSession is reference-counted via
35// Unmanaged.passRetained on the Swift side; sending the opaque pointer between
36// threads is safe as long as we don't dereference it from Rust (we never do —
37// it only travels through extern "C" calls that internally hop to the
38// Swift concurrency executor).
39unsafe impl Send for LanguageModelSession {}
40unsafe impl Sync for LanguageModelSession {}
41
42impl LanguageModelSession {
43    /// Create a session with the model's default behaviour.
44    ///
45    /// # Panics
46    ///
47    /// Panics if `FoundationModels` is not available on this OS. Check
48    /// [`crate::SystemLanguageModel::is_available`] first if you need to
49    /// handle that gracefully.
50    #[must_use]
51    pub fn new() -> Self {
52        Self::try_new(None).expect("FoundationModels is not available on this OS")
53    }
54
55    /// Create a session with custom system instructions ("system prompt").
56    ///
57    /// # Panics
58    ///
59    /// Panics if `FoundationModels` is not available, or if `instructions`
60    /// contains an interior NUL byte.
61    #[must_use]
62    pub fn with_instructions(instructions: &str) -> Self {
63        Self::try_new(Some(instructions)).expect("FoundationModels is not available on this OS")
64    }
65
66    /// Fallible constructor. Returns `None` when `FoundationModels` is not
67    /// available (OS too old, model not enabled, etc.) or when `instructions`
68    /// contains an interior NUL byte.
69    #[must_use]
70    pub fn try_new(instructions: Option<&str>) -> Option<Self> {
71        let cstring = match instructions {
72            Some(s) => Some(CString::new(s).ok()?),
73            None => None,
74        };
75        let ptr =
76            unsafe { ffi::fm_session_create(cstring.as_ref().map_or(ptr::null(), |s| s.as_ptr())) };
77        if ptr.is_null() {
78            return None;
79        }
80        Some(Self { ptr })
81    }
82
83    /// Send a prompt and block until the full response is available.
84    ///
85    /// # Errors
86    ///
87    /// Returns an [`FMError`] if the model rejects the prompt, the context
88    /// window is exceeded, the session is cancelled, or the prompt contains
89    /// an interior NUL byte.
90    pub fn respond(&self, prompt: &str) -> Result<String, FMError> {
91        self.respond_with(prompt, GenerationOptions::new())
92    }
93
94    /// Pre-warm the model. Apple loads the weights + initialises the
95    /// inference engine so the next `respond` call is faster. Returns
96    /// immediately; the warm-up runs in the background.
97    pub fn prewarm(&self) {
98        unsafe { ffi::fm_session_prewarm(self.ptr) };
99    }
100
101    /// True if this session is currently producing a response (i.e. an
102    /// earlier `respond` / `stream` is still in flight on Apple's queue).
103    #[must_use]
104    pub fn is_responding(&self) -> bool {
105        unsafe { ffi::fm_session_is_responding(self.ptr) }
106    }
107
108    /// Prompt-engineered JSON-shape response.
109    ///
110    /// Wraps the prompt with a "respond with valid JSON matching this schema"
111    /// instruction and parses the response. The schema is a
112    /// `serde_json::Value`-style JSON string (passed as text).
113    ///
114    /// Useful for getting structured data out of the model without the
115    /// full Generable macro machinery. The model still returns plain
116    /// text — the caller must parse with `serde_json` / `serde` after.
117    ///
118    /// # Errors
119    ///
120    /// See [`respond`](Self::respond).
121    pub fn respond_with_json_schema(
122        &self,
123        prompt: &str,
124        schema_description: &str,
125    ) -> Result<String, FMError> {
126        let wrapped = format!(
127            "{prompt}\n\n\
128             IMPORTANT: respond with VALID JSON ONLY (no prose, no markdown \
129             fences) that matches this schema:\n\n{schema_description}\n\n\
130             Your entire response must be parseable by JSON.parse()."
131        );
132        self.respond(&wrapped)
133    }
134
135    /// Like [`respond`](Self::respond), but with explicit generation options.
136    ///
137    /// # Errors
138    ///
139    /// See [`respond`](Self::respond).
140    pub fn respond_with(
141        &self,
142        prompt: &str,
143        options: GenerationOptions,
144    ) -> Result<String, FMError> {
145        let prompt_c = CString::new(prompt)
146            .map_err(|e| FMError::InvalidArgument(format!("prompt contains NUL byte: {e}")))?;
147        let opts = options.to_ffi();
148        let (tx, rx) = mpsc::channel();
149        let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
150        let context = Box::into_raw(tx_box).cast::<c_void>();
151
152        unsafe {
153            ffi::fm_session_respond(
154                self.ptr,
155                prompt_c.as_ptr(),
156                opts.temperature,
157                opts.maximum_response_tokens,
158                opts.sampling_mode,
159                opts.top_k,
160                opts.top_p,
161                context,
162                respond_trampoline,
163            );
164        }
165
166        // The Swift side dispatches the callback on its own Task executor;
167        // it is guaranteed to fire exactly once.
168        rx.recv().map_err(|_| FMError::Unknown {
169            code: ffi::status::UNKNOWN,
170            message: "Swift bridge dropped the callback channel".into(),
171        })?
172    }
173
174    /// Stream the response as the model generates it. The callback is invoked
175    /// with each delta and a final invocation with `done == true`.
176    ///
177    /// # Errors
178    ///
179    /// Returns an [`FMError`] mirroring [`respond`](Self::respond). The
180    /// callback may also receive a chunk *and* an error if the stream fails
181    /// midway.
182    pub fn stream<F>(&self, prompt: &str, mut on_chunk: F) -> Result<(), FMError>
183    where
184        F: FnMut(StreamEvent<'_>) + Send + 'static,
185    {
186        self.stream_with(prompt, GenerationOptions::new(), move |event| {
187            on_chunk(event);
188        })
189    }
190
191    /// Like [`stream`](Self::stream), but with explicit generation options.
192    ///
193    /// # Errors
194    ///
195    /// See [`stream`](Self::stream).
196    pub fn stream_with<F>(
197        &self,
198        prompt: &str,
199        options: GenerationOptions,
200        on_chunk: F,
201    ) -> Result<(), FMError>
202    where
203        F: FnMut(StreamEvent<'_>) + Send + 'static,
204    {
205        let prompt_c = CString::new(prompt)
206            .map_err(|e| FMError::InvalidArgument(format!("prompt contains NUL byte: {e}")))?;
207        let opts = options.to_ffi();
208
209        // The callback may be invoked many times before completion. We pair
210        // the user closure with a oneshot channel that signals "stream
211        // finished" so this function can block until the Swift Task ends.
212        let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
213        let state = Arc::new(StreamState {
214            on_chunk: Mutex::new(Box::new(on_chunk)),
215            done_tx: Mutex::new(Some(done_tx)),
216        });
217        let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
218
219        unsafe {
220            ffi::fm_session_stream_response(
221                self.ptr,
222                prompt_c.as_ptr(),
223                opts.temperature,
224                opts.maximum_response_tokens,
225                opts.sampling_mode,
226                opts.top_k,
227                opts.top_p,
228                context,
229                stream_trampoline,
230            );
231        }
232
233        done_rx.recv().map_err(|_| FMError::Unknown {
234            code: ffi::status::UNKNOWN,
235            message: "Swift bridge dropped the stream channel".into(),
236        })?
237    }
238}
239
240impl Default for LanguageModelSession {
241    fn default() -> Self {
242        Self::new()
243    }
244}
245
246impl Drop for LanguageModelSession {
247    fn drop(&mut self) {
248        if !self.ptr.is_null() {
249            unsafe { ffi::fm_object_release(self.ptr) };
250        }
251    }
252}
253
254impl core::fmt::Debug for LanguageModelSession {
255    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
256        f.debug_struct("LanguageModelSession")
257            .field("ptr", &self.ptr)
258            .finish()
259    }
260}
261
262/// One event from a streaming generation.
263#[derive(Debug)]
264#[non_exhaustive]
265pub enum StreamEvent<'a> {
266    /// Incremental text delta. Concatenate these to reconstruct the full reply.
267    Chunk(&'a str),
268    /// Stream finished successfully.
269    Done,
270    /// Stream failed; the inner error describes why.
271    Error(FMError),
272}
273
274// ---------- internal callback plumbing ----------
275
276unsafe extern "C" fn respond_trampoline(
277    context: *mut c_void,
278    response: *mut c_char,
279    error: *mut c_char,
280    status: i32,
281) {
282    let tx = Box::from_raw(context.cast::<mpsc::Sender<Result<String, FMError>>>());
283    let result = if status == ffi::status::OK && !response.is_null() {
284        let s = core::ffi::CStr::from_ptr(response)
285            .to_string_lossy()
286            .into_owned();
287        ffi::fm_string_free(response);
288        Ok(s)
289    } else {
290        Err(crate::error::from_swift(status, error))
291    };
292    let _ = tx.send(result);
293}
294
295type StreamCallback = Box<dyn FnMut(StreamEvent<'_>) + Send>;
296
297struct StreamState {
298    on_chunk: Mutex<StreamCallback>,
299    done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
300}
301
302unsafe extern "C" fn stream_trampoline(
303    context: *mut c_void,
304    chunk: *mut c_char,
305    done: bool,
306    status: i32,
307) {
308    let state = Arc::from_raw(context.cast::<StreamState>());
309    // Bump the count back up because Swift may invoke us again before
310    // `done == true` (Arc::from_raw consumed our refcount).
311    let state_for_swift = state.clone();
312    core::mem::forget(state_for_swift);
313
314    let chunk_str: Option<String> = if chunk.is_null() {
315        None
316    } else {
317        let s = core::ffi::CStr::from_ptr(chunk)
318            .to_string_lossy()
319            .into_owned();
320        ffi::fm_string_free(chunk);
321        Some(s)
322    };
323
324    if status != ffi::status::OK {
325        let err = crate::error::from_swift(status, ptr::null_mut());
326        let err_for_callback = chunk_str
327            .map(|m| match err.clone() {
328                FMError::Unknown { code, .. } => FMError::Unknown { code, message: m },
329                other => other,
330            })
331            .unwrap_or(err);
332        let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
333        cb(StreamEvent::Error(err_for_callback.clone()));
334        drop(cb);
335        let pending_tx = state.done_tx.lock().expect("done_tx mutex poisoned").take();
336        if let Some(tx) = pending_tx {
337            let _ = tx.send(Err(err_for_callback));
338        }
339        // This was the final invocation: drop the extra ref we forgot above.
340        drop(Arc::from_raw(Arc::as_ptr(&state)));
341        drop(state);
342        return;
343    }
344
345    if let Some(s) = chunk_str.as_deref() {
346        let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
347        cb(StreamEvent::Chunk(s));
348    }
349
350    if done {
351        let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
352        cb(StreamEvent::Done);
353        drop(cb);
354        let pending_tx = state.done_tx.lock().expect("done_tx mutex poisoned").take();
355        if let Some(tx) = pending_tx {
356            let _ = tx.send(Ok(()));
357        }
358        // Final invocation: release the extra ref we forgot above.
359        drop(Arc::from_raw(Arc::as_ptr(&state)));
360    }
361    drop(state);
362}