Skip to main content

foundation_models/session/
mod.rs

1//! [`LanguageModelSession`] — a stateful conversation with the on-device model.
2
3use core::ffi::{c_char, c_void};
4use core::ptr;
5use std::ffi::CString;
6use std::sync::mpsc;
7use std::sync::{Arc, Mutex};
8
9use crate::error::FMError;
10use crate::ffi;
11use crate::generation::GenerationOptions;
12
13/// A stateful conversation with the on-device language model.
14///
15/// Sessions retain their conversation history; subsequent calls to
16/// [`respond`](Self::respond) build on the previous turns.
17///
18/// # Examples
19///
20/// ```rust,no_run
21/// use foundation_models::LanguageModelSession;
22///
23/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
24/// let session = LanguageModelSession::new();
25/// let answer = session.respond("Name three Norse gods.")?;
26/// println!("{answer}");
27/// # Ok(())
28/// # }
29/// ```
30pub struct LanguageModelSession {
31    ptr: *mut c_void,
32}
33
34// SAFETY: The underlying Swift LanguageModelSession is reference-counted via
35// Unmanaged.passRetained on the Swift side; sending the opaque pointer between
36// threads is safe as long as we don't dereference it from Rust (we never do —
37// it only travels through extern "C" calls that internally hop to the
38// Swift concurrency executor).
39unsafe impl Send for LanguageModelSession {}
40unsafe impl Sync for LanguageModelSession {}
41
42impl LanguageModelSession {
43    /// Create a session with the model's default behaviour.
44    ///
45    /// # Panics
46    ///
47    /// Panics if `FoundationModels` is not available on this OS. Check
48    /// [`crate::SystemLanguageModel::is_available`] first if you need to
49    /// handle that gracefully.
50    #[must_use]
51    pub fn new() -> Self {
52        Self::try_new(None).expect("FoundationModels is not available on this OS")
53    }
54
55    /// Create a session with custom system instructions ("system prompt").
56    ///
57    /// # Panics
58    ///
59    /// Panics if `FoundationModels` is not available, or if `instructions`
60    /// contains an interior NUL byte.
61    #[must_use]
62    pub fn with_instructions(instructions: &str) -> Self {
63        Self::try_new(Some(instructions)).expect("FoundationModels is not available on this OS")
64    }
65
66    /// Fallible constructor. Returns `None` when `FoundationModels` is not
67    /// available (OS too old, model not enabled, etc.) or when `instructions`
68    /// contains an interior NUL byte.
69    #[must_use]
70    pub fn try_new(instructions: Option<&str>) -> Option<Self> {
71        let cstring = match instructions {
72            Some(s) => Some(CString::new(s).ok()?),
73            None => None,
74        };
75        let ptr =
76            unsafe { ffi::fm_session_create(cstring.as_ref().map_or(ptr::null(), |s| s.as_ptr())) };
77        if ptr.is_null() {
78            return None;
79        }
80        Some(Self { ptr })
81    }
82
83    /// Send a prompt and block until the full response is available.
84    ///
85    /// # Errors
86    ///
87    /// Returns an [`FMError`] if the model rejects the prompt, the context
88    /// window is exceeded, the session is cancelled, or the prompt contains
89    /// an interior NUL byte.
90    pub fn respond(&self, prompt: &str) -> Result<String, FMError> {
91        self.respond_with(prompt, GenerationOptions::new())
92    }
93
94    /// Pre-warm the model. Apple loads the weights + initialises the
95    /// inference engine so the next `respond` call is faster. Returns
96    /// immediately; the warm-up runs in the background.
97    pub fn prewarm(&self) {
98        unsafe { ffi::fm_session_prewarm(self.ptr) };
99    }
100
101    /// True if this session is currently producing a response (i.e. an
102    /// earlier `respond` / `stream` is still in flight on Apple's queue).
103    #[must_use]
104    pub fn is_responding(&self) -> bool {
105        unsafe { ffi::fm_session_is_responding(self.ptr) }
106    }
107
108    /// Prompt-engineered JSON-shape response.
109    ///
110    /// Wraps the prompt with a "respond with valid JSON matching this schema"
111    /// instruction and parses the response. The schema is a
112    /// `serde_json::Value`-style JSON string (passed as text).
113    ///
114    /// Useful for getting structured data out of the model without the
115    /// full Generable macro machinery. The model still returns plain
116    /// text — the caller must parse with `serde_json` / `serde` after.
117    ///
118    /// # Errors
119    ///
120    /// See [`respond`](Self::respond).
121    pub fn respond_with_json_schema(
122        &self,
123        prompt: &str,
124        schema_description: &str,
125    ) -> Result<String, FMError> {
126        let wrapped = format!(
127            "{prompt}\n\n\
128             IMPORTANT: respond with VALID JSON ONLY (no prose, no markdown \
129             fences) that matches this schema:\n\n{schema_description}\n\n\
130             Your entire response must be parseable by JSON.parse()."
131        );
132        self.respond(&wrapped)
133    }
134
135    /// Like [`respond`](Self::respond), but with explicit generation options.
136    ///
137    /// # Errors
138    ///
139    /// See [`respond`](Self::respond).
140    pub fn respond_with(
141        &self,
142        prompt: &str,
143        options: GenerationOptions,
144    ) -> Result<String, FMError> {
145        let prompt_c = CString::new(prompt)
146            .map_err(|e| FMError::InvalidArgument(format!("prompt contains NUL byte: {e}")))?;
147        let opts = options.to_ffi();
148        let (tx, rx) = mpsc::channel();
149        let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
150        let context = Box::into_raw(tx_box).cast::<c_void>();
151
152        unsafe {
153            ffi::fm_session_respond(
154                self.ptr,
155                prompt_c.as_ptr(),
156                opts.temperature,
157                opts.maximum_response_tokens,
158                opts.sampling_mode,
159                opts.top_k,
160                opts.top_p,
161                context,
162                respond_trampoline,
163            );
164        }
165
166        // The Swift side dispatches the callback on its own Task executor;
167        // it is guaranteed to fire exactly once.
168        rx.recv().map_err(|_| FMError::Unknown {
169            code: ffi::status::UNKNOWN,
170            message: "Swift bridge dropped the callback channel".into(),
171        })?
172    }
173
174    /// Schema-driven structured response.
175    ///
176    /// Builds a `DynamicGenerationSchema` from the provided JSON
177    /// schema, runs `LanguageModelSession.respond(schema:prompt:)`,
178    /// and returns the model's `GeneratedContent.jsonString` — a
179    /// well-formed JSON string matching the requested shape.
180    ///
181    /// Supported `schema` shape (strict subset of JSON Schema):
182    ///
183    /// ```json
184    /// {
185    ///   "type": "object",
186    ///   "name": "Movie",
187    ///   "properties": {
188    ///     "title":  { "type": "string", "description": "Movie title" },
189    ///     "year":   { "type": "integer" },
190    ///     "rating": { "type": "number", "optional": true },
191    ///     "tags":   { "type": "array", "items": { "type": "string" }, "min": 1, "max": 5 }
192    ///   }
193    /// }
194    /// ```
195    ///
196    /// Primitive types: `"string"`, `"integer"`, `"number"`,
197    /// `"boolean"`, `"array"`, `"object"`. Each property may set
198    /// `"description"` and `"optional"`. Array schemas accept
199    /// `"items"` plus optional `"min"` / `"max"` element counts.
200    ///
201    /// # Errors
202    ///
203    /// See [`respond`](Self::respond) for general errors, plus a
204    /// "schema build failed" / "schema JSON is not valid" error
205    /// returned as [`FMError::Unknown`] if the schema is malformed.
206    pub fn respond_with_schema(
207        &self,
208        prompt: &str,
209        schema: &str,
210        include_schema_in_prompt: bool,
211    ) -> Result<String, FMError> {
212        self.respond_with_schema_options(prompt, schema, include_schema_in_prompt, GenerationOptions::new())
213    }
214
215    /// [`respond_with_schema`](Self::respond_with_schema) with
216    /// explicit generation options.
217    ///
218    /// # Errors
219    ///
220    /// See [`respond_with_schema`](Self::respond_with_schema).
221    pub fn respond_with_schema_options(
222        &self,
223        prompt: &str,
224        schema: &str,
225        include_schema_in_prompt: bool,
226        options: GenerationOptions,
227    ) -> Result<String, FMError> {
228        let prompt_c = CString::new(prompt)
229            .map_err(|e| FMError::InvalidArgument(format!("prompt NUL byte: {e}")))?;
230        let schema_c = CString::new(schema)
231            .map_err(|e| FMError::InvalidArgument(format!("schema NUL byte: {e}")))?;
232        let opts = options.to_ffi();
233        let (tx, rx) = mpsc::channel();
234        let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
235        let context = Box::into_raw(tx_box).cast::<c_void>();
236
237        unsafe {
238            ffi::fm_session_respond_with_schema(
239                self.ptr,
240                prompt_c.as_ptr(),
241                schema_c.as_ptr(),
242                include_schema_in_prompt,
243                opts.temperature,
244                opts.maximum_response_tokens,
245                opts.sampling_mode,
246                opts.top_k,
247                opts.top_p,
248                context,
249                respond_trampoline,
250            );
251        }
252
253        rx.recv().map_err(|_| FMError::Unknown {
254            code: ffi::status::UNKNOWN,
255            message: "Swift bridge dropped the callback channel".into(),
256        })?
257    }
258
259    /// Stream the response as the model generates it. The callback is invoked
260    /// with each delta and a final invocation with `done == true`.
261    ///
262    /// # Errors
263    ///
264    /// Returns an [`FMError`] mirroring [`respond`](Self::respond). The
265    /// callback may also receive a chunk *and* an error if the stream fails
266    /// midway.
267    pub fn stream<F>(&self, prompt: &str, mut on_chunk: F) -> Result<(), FMError>
268    where
269        F: FnMut(StreamEvent<'_>) + Send + 'static,
270    {
271        self.stream_with(prompt, GenerationOptions::new(), move |event| {
272            on_chunk(event);
273        })
274    }
275
276    /// Like [`stream`](Self::stream), but with explicit generation options.
277    ///
278    /// # Errors
279    ///
280    /// See [`stream`](Self::stream).
281    pub fn stream_with<F>(
282        &self,
283        prompt: &str,
284        options: GenerationOptions,
285        on_chunk: F,
286    ) -> Result<(), FMError>
287    where
288        F: FnMut(StreamEvent<'_>) + Send + 'static,
289    {
290        let prompt_c = CString::new(prompt)
291            .map_err(|e| FMError::InvalidArgument(format!("prompt contains NUL byte: {e}")))?;
292        let opts = options.to_ffi();
293
294        // The callback may be invoked many times before completion. We pair
295        // the user closure with a oneshot channel that signals "stream
296        // finished" so this function can block until the Swift Task ends.
297        let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
298        let state = Arc::new(StreamState {
299            on_chunk: Mutex::new(Box::new(on_chunk)),
300            done_tx: Mutex::new(Some(done_tx)),
301        });
302        let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
303
304        unsafe {
305            ffi::fm_session_stream_response(
306                self.ptr,
307                prompt_c.as_ptr(),
308                opts.temperature,
309                opts.maximum_response_tokens,
310                opts.sampling_mode,
311                opts.top_k,
312                opts.top_p,
313                context,
314                stream_trampoline,
315            );
316        }
317
318        done_rx.recv().map_err(|_| FMError::Unknown {
319            code: ffi::status::UNKNOWN,
320            message: "Swift bridge dropped the stream channel".into(),
321        })?
322    }
323}
324
325impl Default for LanguageModelSession {
326    fn default() -> Self {
327        Self::new()
328    }
329}
330
331impl Drop for LanguageModelSession {
332    fn drop(&mut self) {
333        if !self.ptr.is_null() {
334            unsafe { ffi::fm_object_release(self.ptr) };
335        }
336    }
337}
338
339impl core::fmt::Debug for LanguageModelSession {
340    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
341        f.debug_struct("LanguageModelSession")
342            .field("ptr", &self.ptr)
343            .finish()
344    }
345}
346
347/// One event from a streaming generation.
348#[derive(Debug)]
349#[non_exhaustive]
350pub enum StreamEvent<'a> {
351    /// Incremental text delta. Concatenate these to reconstruct the full reply.
352    Chunk(&'a str),
353    /// Stream finished successfully.
354    Done,
355    /// Stream failed; the inner error describes why.
356    Error(FMError),
357}
358
359// ---------- internal callback plumbing ----------
360
361unsafe extern "C" fn respond_trampoline(
362    context: *mut c_void,
363    response: *mut c_char,
364    error: *mut c_char,
365    status: i32,
366) {
367    let tx = Box::from_raw(context.cast::<mpsc::Sender<Result<String, FMError>>>());
368    let result = if status == ffi::status::OK && !response.is_null() {
369        let s = core::ffi::CStr::from_ptr(response)
370            .to_string_lossy()
371            .into_owned();
372        ffi::fm_string_free(response);
373        Ok(s)
374    } else {
375        Err(crate::error::from_swift(status, error))
376    };
377    let _ = tx.send(result);
378}
379
380type StreamCallback = Box<dyn FnMut(StreamEvent<'_>) + Send>;
381
382struct StreamState {
383    on_chunk: Mutex<StreamCallback>,
384    done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
385}
386
387unsafe extern "C" fn stream_trampoline(
388    context: *mut c_void,
389    chunk: *mut c_char,
390    done: bool,
391    status: i32,
392) {
393    let state = Arc::from_raw(context.cast::<StreamState>());
394    // Bump the count back up because Swift may invoke us again before
395    // `done == true` (Arc::from_raw consumed our refcount).
396    let state_for_swift = state.clone();
397    core::mem::forget(state_for_swift);
398
399    let chunk_str: Option<String> = if chunk.is_null() {
400        None
401    } else {
402        let s = core::ffi::CStr::from_ptr(chunk)
403            .to_string_lossy()
404            .into_owned();
405        ffi::fm_string_free(chunk);
406        Some(s)
407    };
408
409    if status != ffi::status::OK {
410        let err = crate::error::from_swift(status, ptr::null_mut());
411        let err_for_callback = chunk_str
412            .map(|m| match err.clone() {
413                FMError::Unknown { code, .. } => FMError::Unknown { code, message: m },
414                other => other,
415            })
416            .unwrap_or(err);
417        let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
418        cb(StreamEvent::Error(err_for_callback.clone()));
419        drop(cb);
420        let pending_tx = state.done_tx.lock().expect("done_tx mutex poisoned").take();
421        if let Some(tx) = pending_tx {
422            let _ = tx.send(Err(err_for_callback));
423        }
424        // This was the final invocation: drop the extra ref we forgot above.
425        drop(Arc::from_raw(Arc::as_ptr(&state)));
426        drop(state);
427        return;
428    }
429
430    if let Some(s) = chunk_str.as_deref() {
431        let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
432        cb(StreamEvent::Chunk(s));
433    }
434
435    if done {
436        let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
437        cb(StreamEvent::Done);
438        drop(cb);
439        let pending_tx = state.done_tx.lock().expect("done_tx mutex poisoned").take();
440        if let Some(tx) = pending_tx {
441            let _ = tx.send(Ok(()));
442        }
443        // Final invocation: release the extra ref we forgot above.
444        drop(Arc::from_raw(Arc::as_ptr(&state)));
445    }
446    drop(state);
447}