Skip to main content

foundation_models/session/
mod.rs

1//! [`LanguageModelSession`] — a stateful conversation with the on-device model.
2
3use core::ffi::{c_char, c_void};
4use core::ptr;
5use std::ffi::CString;
6use std::sync::mpsc;
7use std::sync::{Arc, Mutex};
8
9use crate::error::FMError;
10use crate::ffi;
11use crate::generation::GenerationOptions;
12
13/// A stateful conversation with the on-device language model.
14///
15/// Sessions retain their conversation history; subsequent calls to
16/// [`respond`](Self::respond) build on the previous turns.
17///
18/// # Examples
19///
20/// ```rust,no_run
21/// use foundation_models::LanguageModelSession;
22///
23/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
24/// let session = LanguageModelSession::new();
25/// let answer = session.respond("Name three Norse gods.")?;
26/// println!("{answer}");
27/// # Ok(())
28/// # }
29/// ```
30pub struct LanguageModelSession {
31    ptr: *mut c_void,
32}
33
34// SAFETY: The underlying Swift LanguageModelSession is reference-counted via
35// Unmanaged.passRetained on the Swift side; sending the opaque pointer between
36// threads is safe as long as we don't dereference it from Rust (we never do —
37// it only travels through extern "C" calls that internally hop to the
38// Swift concurrency executor).
39unsafe impl Send for LanguageModelSession {}
40unsafe impl Sync for LanguageModelSession {}
41
42impl LanguageModelSession {
43    /// Create a session with the model's default behaviour.
44    ///
45    /// # Panics
46    ///
47    /// Panics if `FoundationModels` is not available on this OS. Check
48    /// [`crate::SystemLanguageModel::is_available`] first if you need to
49    /// handle that gracefully.
50    #[must_use]
51    pub fn new() -> Self {
52        Self::try_new(None).expect("FoundationModels is not available on this OS")
53    }
54
55    /// Create a session with custom system instructions ("system prompt").
56    ///
57    /// # Panics
58    ///
59    /// Panics if `FoundationModels` is not available, or if `instructions`
60    /// contains an interior NUL byte.
61    #[must_use]
62    pub fn with_instructions(instructions: &str) -> Self {
63        Self::try_new(Some(instructions)).expect("FoundationModels is not available on this OS")
64    }
65
66    /// Fallible constructor. Returns `None` when `FoundationModels` is not
67    /// available (OS too old, model not enabled, etc.) or when `instructions`
68    /// contains an interior NUL byte.
69    #[must_use]
70    pub fn try_new(instructions: Option<&str>) -> Option<Self> {
71        let cstring = match instructions {
72            Some(s) => Some(CString::new(s).ok()?),
73            None => None,
74        };
75        let ptr =
76            unsafe { ffi::fm_session_create(cstring.as_ref().map_or(ptr::null(), |s| s.as_ptr())) };
77        if ptr.is_null() {
78            return None;
79        }
80        Some(Self { ptr })
81    }
82
83    /// Send a prompt and block until the full response is available.
84    ///
85    /// # Errors
86    ///
87    /// Returns an [`FMError`] if the model rejects the prompt, the context
88    /// window is exceeded, the session is cancelled, or the prompt contains
89    /// an interior NUL byte.
90    pub fn respond(&self, prompt: &str) -> Result<String, FMError> {
91        self.respond_with(prompt, GenerationOptions::new())
92    }
93
94    /// Pre-warm the model. Apple loads the weights + initialises the
95    /// inference engine so the next `respond` call is faster. Returns
96    /// immediately; the warm-up runs in the background.
97    pub fn prewarm(&self) {
98        unsafe { ffi::fm_session_prewarm(self.ptr) };
99    }
100
101    /// True if this session is currently producing a response (i.e. an
102    /// earlier `respond` / `stream` is still in flight on Apple's queue).
103    #[must_use]
104    pub fn is_responding(&self) -> bool {
105        unsafe { ffi::fm_session_is_responding(self.ptr) }
106    }
107
108    /// Return a best-effort JSON serialisation of the session's
109    /// `Transcript` — the full history of user prompts and model
110    /// responses. Useful for persisting a chat session across
111    /// process boundaries.
112    #[must_use]
113    pub fn transcript_json(&self) -> String {
114        let p = unsafe { ffi::fm_session_transcript_json(self.ptr) };
115        if p.is_null() {
116            return String::from("{}");
117        }
118        let s = unsafe { core::ffi::CStr::from_ptr(p) }
119            .to_string_lossy()
120            .into_owned();
121        unsafe { ffi::fm_string_free(p) };
122        s
123    }
124
125    /// Log feedback on the most recent response for diagnostic /
126    /// fine-tuning purposes. `sentiment`:
127    /// `1` positive, `0` neutral, `-1` negative.
128    pub fn log_feedback(&self, sentiment: i32, description: Option<&str>) {
129        let cstr = description.and_then(|s| CString::new(s).ok());
130        let p = cstr.as_ref().map_or(core::ptr::null(), |c| c.as_ptr());
131        unsafe { ffi::fm_session_log_feedback(self.ptr, sentiment, p) };
132    }
133
134    /// Prompt-engineered JSON-shape response.
135    ///
136    /// Wraps the prompt with a "respond with valid JSON matching this schema"
137    /// instruction and parses the response. The schema is a
138    /// `serde_json::Value`-style JSON string (passed as text).
139    ///
140    /// Useful for getting structured data out of the model without the
141    /// full Generable macro machinery. The model still returns plain
142    /// text — the caller must parse with `serde_json` / `serde` after.
143    ///
144    /// # Errors
145    ///
146    /// See [`respond`](Self::respond).
147    pub fn respond_with_json_schema(
148        &self,
149        prompt: &str,
150        schema_description: &str,
151    ) -> Result<String, FMError> {
152        let wrapped = format!(
153            "{prompt}\n\n\
154             IMPORTANT: respond with VALID JSON ONLY (no prose, no markdown \
155             fences) that matches this schema:\n\n{schema_description}\n\n\
156             Your entire response must be parseable by JSON.parse()."
157        );
158        self.respond(&wrapped)
159    }
160
161    /// Like [`respond`](Self::respond), but with explicit generation options.
162    ///
163    /// # Errors
164    ///
165    /// See [`respond`](Self::respond).
166    pub fn respond_with(
167        &self,
168        prompt: &str,
169        options: GenerationOptions,
170    ) -> Result<String, FMError> {
171        let prompt_c = CString::new(prompt)
172            .map_err(|e| FMError::InvalidArgument(format!("prompt contains NUL byte: {e}")))?;
173        let opts = options.to_ffi();
174        let (tx, rx) = mpsc::channel();
175        let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
176        let context = Box::into_raw(tx_box).cast::<c_void>();
177
178        unsafe {
179            ffi::fm_session_respond(
180                self.ptr,
181                prompt_c.as_ptr(),
182                opts.temperature,
183                opts.maximum_response_tokens,
184                opts.sampling_mode,
185                opts.top_k,
186                opts.top_p,
187                context,
188                respond_trampoline,
189            );
190        }
191
192        // The Swift side dispatches the callback on its own Task executor;
193        // it is guaranteed to fire exactly once.
194        rx.recv().map_err(|_| FMError::Unknown {
195            code: ffi::status::UNKNOWN,
196            message: "Swift bridge dropped the callback channel".into(),
197        })?
198    }
199
200    /// Schema-driven structured response.
201    ///
202    /// Builds a `DynamicGenerationSchema` from the provided JSON
203    /// schema, runs `LanguageModelSession.respond(schema:prompt:)`,
204    /// and returns the model's `GeneratedContent.jsonString` — a
205    /// well-formed JSON string matching the requested shape.
206    ///
207    /// Supported `schema` shape (strict subset of JSON Schema):
208    ///
209    /// ```json
210    /// {
211    ///   "type": "object",
212    ///   "name": "Movie",
213    ///   "properties": {
214    ///     "title":  { "type": "string", "description": "Movie title" },
215    ///     "year":   { "type": "integer" },
216    ///     "rating": { "type": "number", "optional": true },
217    ///     "tags":   { "type": "array", "items": { "type": "string" }, "min": 1, "max": 5 }
218    ///   }
219    /// }
220    /// ```
221    ///
222    /// Primitive types: `"string"`, `"integer"`, `"number"`,
223    /// `"boolean"`, `"array"`, `"object"`. Each property may set
224    /// `"description"` and `"optional"`. Array schemas accept
225    /// `"items"` plus optional `"min"` / `"max"` element counts.
226    ///
227    /// # Errors
228    ///
229    /// See [`respond`](Self::respond) for general errors, plus a
230    /// "schema build failed" / "schema JSON is not valid" error
231    /// returned as [`FMError::Unknown`] if the schema is malformed.
232    pub fn respond_with_schema(
233        &self,
234        prompt: &str,
235        schema: &str,
236        include_schema_in_prompt: bool,
237    ) -> Result<String, FMError> {
238        self.respond_with_schema_options(prompt, schema, include_schema_in_prompt, GenerationOptions::new())
239    }
240
241    /// [`respond_with_schema`](Self::respond_with_schema) with
242    /// explicit generation options.
243    ///
244    /// # Errors
245    ///
246    /// See [`respond_with_schema`](Self::respond_with_schema).
247    pub fn respond_with_schema_options(
248        &self,
249        prompt: &str,
250        schema: &str,
251        include_schema_in_prompt: bool,
252        options: GenerationOptions,
253    ) -> Result<String, FMError> {
254        let prompt_c = CString::new(prompt)
255            .map_err(|e| FMError::InvalidArgument(format!("prompt NUL byte: {e}")))?;
256        let schema_c = CString::new(schema)
257            .map_err(|e| FMError::InvalidArgument(format!("schema NUL byte: {e}")))?;
258        let opts = options.to_ffi();
259        let (tx, rx) = mpsc::channel();
260        let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
261        let context = Box::into_raw(tx_box).cast::<c_void>();
262
263        unsafe {
264            ffi::fm_session_respond_with_schema(
265                self.ptr,
266                prompt_c.as_ptr(),
267                schema_c.as_ptr(),
268                include_schema_in_prompt,
269                opts.temperature,
270                opts.maximum_response_tokens,
271                opts.sampling_mode,
272                opts.top_k,
273                opts.top_p,
274                context,
275                respond_trampoline,
276            );
277        }
278
279        rx.recv().map_err(|_| FMError::Unknown {
280            code: ffi::status::UNKNOWN,
281            message: "Swift bridge dropped the callback channel".into(),
282        })?
283    }
284
285    /// Stream the response as the model generates it. The callback is invoked
286    /// with each delta and a final invocation with `done == true`.
287    ///
288    /// # Errors
289    ///
290    /// Returns an [`FMError`] mirroring [`respond`](Self::respond). The
291    /// callback may also receive a chunk *and* an error if the stream fails
292    /// midway.
293    pub fn stream<F>(&self, prompt: &str, mut on_chunk: F) -> Result<(), FMError>
294    where
295        F: FnMut(StreamEvent<'_>) + Send + 'static,
296    {
297        self.stream_with(prompt, GenerationOptions::new(), move |event| {
298            on_chunk(event);
299        })
300    }
301
302    /// Like [`stream`](Self::stream), but with explicit generation options.
303    ///
304    /// # Errors
305    ///
306    /// See [`stream`](Self::stream).
307    pub fn stream_with<F>(
308        &self,
309        prompt: &str,
310        options: GenerationOptions,
311        on_chunk: F,
312    ) -> Result<(), FMError>
313    where
314        F: FnMut(StreamEvent<'_>) + Send + 'static,
315    {
316        let prompt_c = CString::new(prompt)
317            .map_err(|e| FMError::InvalidArgument(format!("prompt contains NUL byte: {e}")))?;
318        let opts = options.to_ffi();
319
320        // The callback may be invoked many times before completion. We pair
321        // the user closure with a oneshot channel that signals "stream
322        // finished" so this function can block until the Swift Task ends.
323        let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
324        let state = Arc::new(StreamState {
325            on_chunk: Mutex::new(Box::new(on_chunk)),
326            done_tx: Mutex::new(Some(done_tx)),
327        });
328        let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
329
330        unsafe {
331            ffi::fm_session_stream_response(
332                self.ptr,
333                prompt_c.as_ptr(),
334                opts.temperature,
335                opts.maximum_response_tokens,
336                opts.sampling_mode,
337                opts.top_k,
338                opts.top_p,
339                context,
340                stream_trampoline,
341            );
342        }
343
344        done_rx.recv().map_err(|_| FMError::Unknown {
345            code: ffi::status::UNKNOWN,
346            message: "Swift bridge dropped the stream channel".into(),
347        })?
348    }
349}
350
351impl Default for LanguageModelSession {
352    fn default() -> Self {
353        Self::new()
354    }
355}
356
357impl Drop for LanguageModelSession {
358    fn drop(&mut self) {
359        if !self.ptr.is_null() {
360            unsafe { ffi::fm_object_release(self.ptr) };
361        }
362    }
363}
364
365impl core::fmt::Debug for LanguageModelSession {
366    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
367        f.debug_struct("LanguageModelSession")
368            .field("ptr", &self.ptr)
369            .finish()
370    }
371}
372
373/// One event from a streaming generation.
374#[derive(Debug)]
375#[non_exhaustive]
376pub enum StreamEvent<'a> {
377    /// Incremental text delta. Concatenate these to reconstruct the full reply.
378    Chunk(&'a str),
379    /// Stream finished successfully.
380    Done,
381    /// Stream failed; the inner error describes why.
382    Error(FMError),
383}
384
385// ---------- internal callback plumbing ----------
386
387unsafe extern "C" fn respond_trampoline(
388    context: *mut c_void,
389    response: *mut c_char,
390    error: *mut c_char,
391    status: i32,
392) {
393    let tx = Box::from_raw(context.cast::<mpsc::Sender<Result<String, FMError>>>());
394    let result = if status == ffi::status::OK && !response.is_null() {
395        let s = core::ffi::CStr::from_ptr(response)
396            .to_string_lossy()
397            .into_owned();
398        ffi::fm_string_free(response);
399        Ok(s)
400    } else {
401        Err(crate::error::from_swift(status, error))
402    };
403    let _ = tx.send(result);
404}
405
406type StreamCallback = Box<dyn FnMut(StreamEvent<'_>) + Send>;
407
408struct StreamState {
409    on_chunk: Mutex<StreamCallback>,
410    done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
411}
412
413unsafe extern "C" fn stream_trampoline(
414    context: *mut c_void,
415    chunk: *mut c_char,
416    done: bool,
417    status: i32,
418) {
419    let state = Arc::from_raw(context.cast::<StreamState>());
420    // Bump the count back up because Swift may invoke us again before
421    // `done == true` (Arc::from_raw consumed our refcount).
422    let state_for_swift = state.clone();
423    core::mem::forget(state_for_swift);
424
425    let chunk_str: Option<String> = if chunk.is_null() {
426        None
427    } else {
428        let s = core::ffi::CStr::from_ptr(chunk)
429            .to_string_lossy()
430            .into_owned();
431        ffi::fm_string_free(chunk);
432        Some(s)
433    };
434
435    if status != ffi::status::OK {
436        let err = crate::error::from_swift(status, ptr::null_mut());
437        let err_for_callback = chunk_str
438            .map(|m| match err.clone() {
439                FMError::Unknown { code, .. } => FMError::Unknown { code, message: m },
440                other => other,
441            })
442            .unwrap_or(err);
443        let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
444        cb(StreamEvent::Error(err_for_callback.clone()));
445        drop(cb);
446        let pending_tx = state.done_tx.lock().expect("done_tx mutex poisoned").take();
447        if let Some(tx) = pending_tx {
448            let _ = tx.send(Err(err_for_callback));
449        }
450        // This was the final invocation: drop the extra ref we forgot above.
451        drop(Arc::from_raw(Arc::as_ptr(&state)));
452        drop(state);
453        return;
454    }
455
456    if let Some(s) = chunk_str.as_deref() {
457        let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
458        cb(StreamEvent::Chunk(s));
459    }
460
461    if done {
462        let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
463        cb(StreamEvent::Done);
464        drop(cb);
465        let pending_tx = state.done_tx.lock().expect("done_tx mutex poisoned").take();
466        if let Some(tx) = pending_tx {
467            let _ = tx.send(Ok(()));
468        }
469        // Final invocation: release the extra ref we forgot above.
470        drop(Arc::from_raw(Arc::as_ptr(&state)));
471    }
472    drop(state);
473}