Skip to main content

fm_bindings/
session.rs

1// Language Model Session - the main API for Foundation Models
2
3use crate::error::{Error, Result};
4use crate::ffi;
5use std::ffi::CString;
6use std::ptr::NonNull;
7use std::sync::{Arc, Condvar, Mutex};
8
9/// A session for interacting with Apple's Foundation Models
10///
11/// This provides access to on-device language models via the FoundationModels framework.
12/// Requires macOS 26+ or iOS 26+ with Apple Intelligence enabled.
13///
14/// # Session State
15///
16/// Each session maintains a transcript of all interactions (prompts, responses, etc.).
17/// The transcript can be serialized to JSON for persistence and used to restore
18/// sessions across app launches.
19///
20/// # Creating Sessions
21///
22/// - [`LanguageModelSession::new()`] - Create without instructions
23/// - [`LanguageModelSession::with_instructions()`] - Create with system prompt
24/// - [`LanguageModelSession::from_transcript_json()`] - Restore from saved transcript
25///
26/// # Getting Responses
27///
28/// - [`response()`](Self::response) - Blocking response (waits for completion)
29/// - [`stream_response()`](Self::stream_response) - Streaming response (real-time chunks)
30/// - [`cancel_stream()`](Self::cancel_stream) - Cancel ongoing stream
31///
32/// # Examples
33///
34/// See the method-level documentation for detailed examples:
35/// - [`new()`](Self::new) and [`response()`](Self::response) for basic usage
36/// - [`stream_response()`](Self::stream_response) for streaming
37/// - [`transcript_json()`](Self::transcript_json) and [`from_transcript_json()`](Self::from_transcript_json) for persistence
38pub struct LanguageModelSession {
39    ptr: NonNull<std::ffi::c_void>,
40}
41
42// Safety: The Swift LanguageModelSession is thread-safe (@unchecked Sendable)
43unsafe impl Send for LanguageModelSession {}
44unsafe impl Sync for LanguageModelSession {}
45
46impl LanguageModelSession {
47    /// Creates a new language model session without instructions
48    ///
49    /// This is equivalent to calling `with_instructions(None)`.
50    ///
51    /// # Errors
52    ///
53    /// Returns `Error::ModelNotAvailable` if Apple Intelligence is not enabled
54    /// or the system model is unavailable.
55    pub fn new() -> Result<Self> {
56        Self::with_instructions_opt(None)
57    }
58
59    /// Creates a new language model session with instructions
60    ///
61    /// Instructions define the model's persona, behavior, and guidelines for the
62    /// entire session. They are always the first entry in the session transcript.
63    ///
64    /// # Arguments
65    ///
66    /// * `instructions` - System prompt that guides the model's behavior
67    ///
68    /// # Errors
69    ///
70    /// * `Error::ModelNotAvailable` - If Apple Intelligence is not enabled
71    /// * `Error::InvalidInput` - If instructions contain a null byte
72    ///
73    /// # Examples
74    ///
75    /// ```no_run
76    /// # use fm_bindings::LanguageModelSession;
77    /// let session = LanguageModelSession::with_instructions(
78    ///     "You are a helpful coding assistant. Provide concise answers."
79    /// )?;
80    /// # Ok::<(), fm_bindings::Error>(())
81    /// ```
82    pub fn with_instructions(instructions: &str) -> Result<Self> {
83        Self::with_instructions_opt(Some(instructions))
84    }
85
86    /// Creates a new language model session with optional instructions
87    ///
88    /// # Arguments
89    ///
90    /// * `instructions` - Optional system prompt, or `None` for no instructions
91    ///
92    /// # Errors
93    ///
94    /// * `Error::ModelNotAvailable` - If Apple Intelligence is not enabled
95    /// * `Error::InvalidInput` - If instructions contain a null byte
96    fn with_instructions_opt(instructions: Option<&str>) -> Result<Self> {
97        if !unsafe { ffi::fm_check_availability() } {
98            return Err(Error::ModelNotAvailable);
99        }
100
101        let c_instructions = match instructions {
102            Some(s) => Some(
103                CString::new(s)
104                    .map_err(|_| Error::InvalidInput("Instructions contain null byte".into()))?,
105            ),
106            None => None,
107        };
108
109        let ptr = unsafe {
110            ffi::fm_create_session(
111                c_instructions
112                    .as_ref()
113                    .map_or(std::ptr::null(), |s| s.as_ptr()),
114            )
115        };
116
117        NonNull::new(ptr)
118            .map(|ptr| Self { ptr })
119            .ok_or_else(|| Error::InternalError("Failed to create session".into()))
120    }
121
122    /// Creates a session from a serialized transcript JSON
123    ///
124    /// This restores a previous session state, including the original instructions
125    /// and full conversation history. Use this to resume conversations across
126    /// app launches.
127    ///
128    /// # Arguments
129    ///
130    /// * `transcript_json` - JSON string from `transcript_json()`
131    ///
132    /// # Errors
133    ///
134    /// * `Error::ModelNotAvailable` - If Apple Intelligence is not enabled
135    /// * `Error::InvalidInput` - If JSON contains a null byte or is invalid
136    ///
137    /// # Examples
138    ///
139    /// ```no_run
140    /// # use fm_bindings::LanguageModelSession;
141    /// let json = std::fs::read_to_string("session.json")?;
142    /// let session = LanguageModelSession::from_transcript_json(&json)?;
143    /// # Ok::<(), Box<dyn std::error::Error>>(())
144    /// ```
145    pub fn from_transcript_json(transcript_json: &str) -> Result<Self> {
146        if !unsafe { ffi::fm_check_availability() } {
147            return Err(Error::ModelNotAvailable);
148        }
149
150        let c_json = CString::new(transcript_json)
151            .map_err(|_| Error::InvalidInput("Transcript JSON contains null byte".into()))?;
152
153        let ptr = unsafe { ffi::fm_create_session_from_transcript(c_json.as_ptr()) };
154
155        NonNull::new(ptr)
156            .map(|ptr| Self { ptr })
157            .ok_or_else(|| Error::InternalError("Failed to restore session from transcript".into()))
158    }
159
160    /// Gets the current session transcript as JSON
161    ///
162    /// The returned JSON can be persisted and later passed to `from_transcript_json()`
163    /// to restore the session state.
164    ///
165    /// # Returns
166    ///
167    /// JSON string representing the full transcript (instructions, prompts, responses)
168    ///
169    /// # Errors
170    ///
171    /// * `Error::InternalError` - If transcript serialization fails
172    ///
173    /// # Examples
174    ///
175    /// ```no_run
176    /// # use fm_bindings::LanguageModelSession;
177    /// let session = LanguageModelSession::new()?;
178    /// let _ = session.response("Hello")?;
179    ///
180    /// let json = session.transcript_json()?;
181    /// std::fs::write("session.json", &json)?;
182    /// # Ok::<(), Box<dyn std::error::Error>>(())
183    /// ```
184    pub fn transcript_json(&self) -> Result<String> {
185        let json_ptr = unsafe { ffi::fm_get_transcript_json(self.ptr.as_ptr()) };
186
187        if json_ptr.is_null() {
188            // Empty transcript is valid
189            return Ok("[]".to_string());
190        }
191
192        let json = unsafe {
193            let s = std::ffi::CStr::from_ptr(json_ptr)
194                .to_string_lossy()
195                .into_owned();
196            ffi::fm_free_string(json_ptr);
197            s
198        };
199
200        Ok(json)
201    }
202
203    /// Generates a complete response to the given prompt
204    ///
205    /// This method blocks until the entire response is generated and returned as a String.
206    /// The prompt and response are added to the session transcript.
207    ///
208    /// For a better user experience with incremental updates, use `stream_response` instead.
209    ///
210    /// # Arguments
211    ///
212    /// * `prompt` - The input text to send to the model
213    ///
214    /// # Errors
215    ///
216    /// * `Error::InvalidInput` - If the prompt is empty or contains a null byte
217    /// * `Error::GenerationError` - If an error occurs during generation
218    ///
219    /// # Examples
220    ///
221    /// ```no_run
222    /// # use fm_bindings::LanguageModelSession;
223    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
224    /// let session = LanguageModelSession::new()?;
225    /// let response = session.response("Explain Rust ownership")?;
226    /// println!("Response: {}", response);
227    /// # Ok(())
228    /// # }
229    /// ```
230    pub fn response(&self, prompt: &str) -> Result<String> {
231        if prompt.is_empty() {
232            return Err(Error::InvalidInput("Prompt cannot be empty".into()));
233        }
234
235        let c_prompt = CString::new(prompt)
236            .map_err(|_| Error::InvalidInput("Prompt contains null byte".into()))?;
237
238        // Shared state for collecting response
239        let state = Arc::new((Mutex::new(ResponseState::default()), Condvar::new()));
240        let state_ptr = Box::into_raw(Box::new(Arc::clone(&state)));
241
242        unsafe {
243            ffi::fm_session_response(
244                self.ptr.as_ptr(),
245                c_prompt.as_ptr(),
246                state_ptr as *mut _,
247                response_chunk_callback,
248                response_done_callback,
249                response_error_callback,
250            );
251        }
252
253        // Wait for completion
254        let (mutex, cvar) = &*state;
255        let mut response_state = mutex.lock().map_err(|_| Error::PoisonError)?;
256        while !response_state.finished {
257            response_state = cvar.wait(response_state).map_err(|_| Error::PoisonError)?;
258        }
259
260        // Check for errors
261        if let Some(error) = &response_state.error {
262            if error.contains("not available") {
263                return Err(Error::ModelNotAvailable);
264            }
265            return Err(Error::GenerationError(error.clone()));
266        }
267
268        Ok(response_state.text.clone())
269    }
270
271    /// Generates a streaming response to the given prompt
272    ///
273    /// This method calls the provided callback for each chunk as it's generated,
274    /// providing immediate feedback to the user. The prompt and complete response
275    /// are added to the session transcript.
276    ///
277    /// # Arguments
278    ///
279    /// * `prompt` - The input text to send to the model
280    /// * `on_chunk` - Callback function called for each generated chunk
281    ///
282    /// # Errors
283    ///
284    /// * `Error::InvalidInput` - If the prompt is empty or contains a null byte
285    /// * `Error::GenerationError` - If an error occurs during generation
286    ///
287    /// # Examples
288    ///
289    /// ```no_run
290    /// # use fm_bindings::LanguageModelSession;
291    /// # use std::io::{self, Write};
292    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
293    /// let session = LanguageModelSession::new()?;
294    ///
295    /// session.stream_response("Tell me a story", |chunk| {
296    ///     print!("{}", chunk);
297    ///     let _ = io::stdout().flush();
298    /// })?;
299    ///
300    /// println!(); // newline after stream completes
301    /// # Ok(())
302    /// # }
303    /// ```
304    pub fn stream_response<F>(&self, prompt: &str, on_chunk: F) -> Result<()>
305    where
306        F: FnMut(&str),
307    {
308        if prompt.is_empty() {
309            return Err(Error::InvalidInput("Prompt cannot be empty".into()));
310        }
311
312        let c_prompt = CString::new(prompt)
313            .map_err(|_| Error::InvalidInput("Prompt contains null byte".into()))?;
314
315        // Shared state for streaming
316        let state = Arc::new((Mutex::new(StreamState::default()), Condvar::new()));
317        let user_data = Box::into_raw(Box::new((
318            Arc::clone(&state),
319            Box::new(on_chunk) as Box<dyn FnMut(&str)>,
320        )));
321
322        unsafe {
323            ffi::fm_session_stream(
324                self.ptr.as_ptr(),
325                c_prompt.as_ptr(),
326                user_data as *mut _,
327                stream_chunk_callback,
328                stream_done_callback,
329                stream_error_callback,
330            );
331        }
332
333        // Wait for completion
334        let (mutex, cvar) = &*state;
335        let mut stream_state = mutex.lock().map_err(|_| Error::PoisonError)?;
336        while !stream_state.finished {
337            stream_state = cvar.wait(stream_state).map_err(|_| Error::PoisonError)?;
338        }
339
340        // Check for errors
341        if let Some(error) = &stream_state.error {
342            if error.contains("not available") {
343                return Err(Error::ModelNotAvailable);
344            }
345            return Err(Error::GenerationError(error.clone()));
346        }
347
348        Ok(())
349    }
350
351    /// Cancels the current streaming response
352    ///
353    /// This method immediately cancels any ongoing streaming operation.
354    /// The streaming callback will stop receiving tokens and the stream
355    /// will complete with the tokens received so far.
356    ///
357    /// # Notes
358    ///
359    /// * Safe to call even if no stream is active
360    /// * After cancellation, `stream_response` will return normally
361    ///
362    /// # Examples
363    ///
364    /// ```no_run
365    /// # use fm_bindings::LanguageModelSession;
366    /// # use std::sync::Arc;
367    /// # use std::thread;
368    /// # use std::time::Duration;
369    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
370    /// let session = Arc::new(LanguageModelSession::new()?);
371    /// let session_clone = Arc::clone(&session);
372    ///
373    /// // Start streaming in a thread
374    /// let handle = thread::spawn(move || {
375    ///     session_clone.stream_response("Write a long essay...", |chunk| {
376    ///         print!("{}", chunk);
377    ///     })
378    /// });
379    ///
380    /// // Cancel after a delay
381    /// thread::sleep(Duration::from_secs(2));
382    /// session.cancel_stream();
383    ///
384    /// handle.join().unwrap()?;
385    /// # Ok(())
386    /// # }
387    /// ```
388    pub fn cancel_stream(&self) {
389        unsafe {
390            ffi::fm_session_cancel_stream(self.ptr.as_ptr());
391        }
392    }
393}
394
395impl Drop for LanguageModelSession {
396    fn drop(&mut self) {
397        unsafe {
398            ffi::fm_destroy_session(self.ptr.as_ptr());
399        }
400    }
401}
402
403// Note: We intentionally don't implement Clone.
404// To create a session with the same transcript, use:
405//   let json = session.transcript_json()?;
406//   let new_session = LanguageModelSession::from_transcript_json(&json)?;
407
408// =============================================================================
409// Internal State Types
410// =============================================================================
411
412#[derive(Default)]
413struct ResponseState {
414    text: String,
415    finished: bool,
416    error: Option<String>,
417}
418
419#[derive(Default)]
420struct StreamState {
421    finished: bool,
422    error: Option<String>,
423}
424
425// =============================================================================
426// C Callbacks for response()
427// =============================================================================
428
429extern "C" fn response_chunk_callback(
430    chunk: *const std::os::raw::c_char,
431    user_data: *mut std::os::raw::c_void,
432) {
433    if chunk.is_null() || user_data.is_null() {
434        return;
435    }
436
437    unsafe {
438        let state = &*(user_data as *const Arc<(Mutex<ResponseState>, Condvar)>);
439        let chunk_str = std::ffi::CStr::from_ptr(chunk).to_string_lossy();
440
441        let (mutex, _) = &**state;
442        if let Ok(mut response_state) = mutex.lock() {
443            response_state.text.push_str(&chunk_str);
444        }
445    }
446}
447
448extern "C" fn response_done_callback(user_data: *mut std::os::raw::c_void) {
449    if user_data.is_null() {
450        return;
451    }
452
453    unsafe {
454        // Take ownership back from the raw pointer
455        let state = Box::from_raw(user_data as *mut Arc<(Mutex<ResponseState>, Condvar)>);
456
457        let (mutex, cvar) = &**state;
458        if let Ok(mut response_state) = mutex.lock() {
459            response_state.finished = true;
460            cvar.notify_all();
461        }
462    }
463}
464
465extern "C" fn response_error_callback(
466    error: *const std::os::raw::c_char,
467    user_data: *mut std::os::raw::c_void,
468) {
469    if user_data.is_null() {
470        return;
471    }
472
473    unsafe {
474        // Take ownership back from the raw pointer
475        let state = Box::from_raw(user_data as *mut Arc<(Mutex<ResponseState>, Condvar)>);
476
477        let (mutex, cvar) = &**state;
478        if let Ok(mut response_state) = mutex.lock() {
479            if !error.is_null() {
480                let error_str = std::ffi::CStr::from_ptr(error)
481                    .to_string_lossy()
482                    .into_owned();
483                response_state.error = Some(error_str);
484            }
485            response_state.finished = true;
486            cvar.notify_all();
487        }
488    }
489}
490
491// =============================================================================
492// C Callbacks for stream_response()
493// =============================================================================
494
495type StreamCallback = Box<dyn FnMut(&str)>;
496type StreamUserData = (Arc<(Mutex<StreamState>, Condvar)>, StreamCallback);
497
498extern "C" fn stream_chunk_callback(
499    chunk: *const std::os::raw::c_char,
500    user_data: *mut std::os::raw::c_void,
501) {
502    if chunk.is_null() || user_data.is_null() {
503        return;
504    }
505
506    unsafe {
507        let data = &mut *(user_data as *mut StreamUserData);
508        let chunk_str = std::ffi::CStr::from_ptr(chunk).to_string_lossy();
509        (data.1)(&chunk_str);
510    }
511}
512
513extern "C" fn stream_done_callback(user_data: *mut std::os::raw::c_void) {
514    if user_data.is_null() {
515        return;
516    }
517
518    unsafe {
519        // Take ownership back from the raw pointer
520        let data = Box::from_raw(user_data as *mut StreamUserData);
521
522        let (mutex, cvar) = &*data.0;
523        if let Ok(mut stream_state) = mutex.lock() {
524            stream_state.finished = true;
525            cvar.notify_all();
526        }
527    }
528}
529
530extern "C" fn stream_error_callback(
531    error: *const std::os::raw::c_char,
532    user_data: *mut std::os::raw::c_void,
533) {
534    if user_data.is_null() {
535        return;
536    }
537
538    unsafe {
539        // Take ownership back from the raw pointer
540        let data = Box::from_raw(user_data as *mut StreamUserData);
541
542        let (mutex, cvar) = &*data.0;
543        if let Ok(mut stream_state) = mutex.lock() {
544            if !error.is_null() {
545                let error_str = std::ffi::CStr::from_ptr(error)
546                    .to_string_lossy()
547                    .into_owned();
548                stream_state.error = Some(error_str);
549            }
550            stream_state.finished = true;
551            cvar.notify_all();
552        }
553    }
554}