fm_bindings/
session.rs

1// src/session.rs
2// Language Model Session - the main API for Foundation Models
3
4use super::error::{Error, Result};
5use super::ffi;
6use std::ffi::CString;
7use std::sync::{Arc, Condvar, Mutex};
8
9/// A session for interacting with Apple's Foundation Models
10///
11/// This provides access to on-device language models via the FoundationModels framework.
12/// Requires macOS 26+ or iOS 26+ with Apple Intelligence enabled.
13///
14/// # Examples
15///
16/// ## Blocking response
17/// ```no_run
18/// # use fm_bindings::LanguageModelSession;
19/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
20/// let session = LanguageModelSession::new()?;
21/// let response = session.response("What is Rust?")?;
22/// println!("{}", response);
23/// # Ok(())
24/// # }
25/// ```
26///
27/// ## Streaming response
28/// ```no_run
29/// # use fm_bindings::LanguageModelSession;
30/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
31/// let session = LanguageModelSession::new()?;
32/// session.stream_response("What is Rust?", |chunk| {
33///     print!("{}", chunk);
34/// })?;
35/// # Ok(())
36/// # }
37/// ```
38#[derive(Clone)]
39pub struct LanguageModelSession {
40    _private: (),
41}
42
43impl LanguageModelSession {
44    /// Creates a new language model session
45    ///
46    /// This checks that the Foundation Model is available on the system.
47    ///
48    /// # Errors
49    ///
50    /// Returns `Error::ModelNotAvailable` if Apple Intelligence is not enabled
51    /// or the system model is unavailable.
52    pub fn new() -> Result<Self> {
53        // Check availability before creating the session (fail-fast)
54        let is_available = unsafe { ffi::fm_check_availability() };
55
56        if !is_available {
57            return Err(Error::ModelNotAvailable);
58        }
59
60        Ok(Self { _private: () })
61    }
62
63    /// Generates a complete response to the given prompt
64    ///
65    /// This method blocks until the entire response is generated and returned as a String.
66    /// For a better user experience with incremental updates, use `stream_response` instead.
67    ///
68    /// # Arguments
69    ///
70    /// * `prompt` - The input text to send to the model
71    ///
72    /// # Errors
73    ///
74    /// * `Error::ModelNotAvailable` - If the Foundation Model is not available
75    /// * `Error::InvalidInput` - If the prompt is empty or invalid
76    /// * `Error::GenerationError` - If an error occurs during generation
77    ///
78    /// # Examples
79    ///
80    /// ```no_run
81    /// # use fm_bindings::LanguageModelSession;
82    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
83    /// let session = LanguageModelSession::new()?;
84    /// let response = session.response("Explain Rust ownership")?;
85    /// println!("Response: {}", response);
86    /// # Ok(())
87    /// # }
88    /// ```
89    pub fn response(&self, prompt: &str) -> Result<String> {
90        if prompt.is_empty() {
91            return Err(Error::InvalidInput("Prompt cannot be empty".into()));
92        }
93
94        // Create C string for FFI
95        let c_prompt = CString::new(prompt)
96            .map_err(|_| Error::InvalidInput("Prompt contains null byte".into()))?;
97
98        // Shared state for collecting response
99        let state = Arc::new((Mutex::new(ResponseState::default()), Condvar::new()));
100        let state_clone = Arc::clone(&state);
101
102        // Call Swift FFI with blocking response mode
103        unsafe {
104            ffi::fm_response(
105                c_prompt.as_ptr(),
106                Box::into_raw(Box::new(state_clone)) as *mut _,
107                response_callback,
108                response_done_callback,
109                response_error_callback,
110            );
111        }
112
113        // Wait for completion
114        let (mutex, cvar) = &*state;
115        let mut response_state = mutex.lock().map_err(|_| Error::PoisonError)?;
116        while !response_state.finished {
117            response_state = cvar.wait(response_state).map_err(|_| Error::PoisonError)?;
118        }
119
120        // Check for errors
121        if let Some(error) = &response_state.error {
122            if error.contains("not available") {
123                return Err(Error::ModelNotAvailable);
124            }
125            return Err(Error::GenerationError(error.clone()));
126        }
127
128        Ok(response_state.text.clone())
129    }
130
131    /// Generates a streaming response to the given prompt
132    ///
133    /// This method calls the provided callback for each chunk as it's generated,
134    /// providing immediate feedback to the user. The callback receives string slices
135    /// containing incremental text deltas.
136    ///
137    /// # Arguments
138    ///
139    /// * `prompt` - The input text to send to the model
140    /// * `on_chunk` - Callback function called for each generated chunk
141    ///
142    /// # Errors
143    ///
144    /// * `Error::ModelNotAvailable` - If the Foundation Model is not available
145    /// * `Error::InvalidInput` - If the prompt is empty or invalid
146    /// * `Error::GenerationError` - If an error occurs during generation
147    ///
148    /// # Examples
149    ///
150    /// ```no_run
151    /// # use fm_bindings::LanguageModelSession;
152    /// # use std::io::{self, Write};
153    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
154    /// let session = LanguageModelSession::new()?;
155    ///
156    /// session.stream_response("Tell me a story", |chunk| {
157    ///     print!("{}", chunk);
158    ///     let _ = io::stdout().flush();
159    /// })?;
160    ///
161    /// println!(); // newline after stream completes
162    /// # Ok(())
163    /// # }
164    /// ```
165    pub fn stream_response<F>(&self, prompt: &str, on_chunk: F) -> Result<()>
166    where
167        F: FnMut(&str),
168    {
169        if prompt.is_empty() {
170            return Err(Error::InvalidInput("Prompt cannot be empty".into()));
171        }
172
173        // Create C string for FFI
174        let c_prompt = CString::new(prompt)
175            .map_err(|_| Error::InvalidInput("Prompt contains null byte".into()))?;
176
177        // Shared state for streaming
178        let state = Arc::new((Mutex::new(StreamState::default()), Condvar::new()));
179        let state_clone = Arc::clone(&state);
180
181        // Call Swift FFI with streaming mode
182        unsafe {
183            ffi::fm_start_stream(
184                c_prompt.as_ptr(),
185                Box::into_raw(Box::new((
186                    state_clone,
187                    Box::new(on_chunk) as Box<dyn FnMut(&str)>,
188                ))) as *mut _,
189                stream_chunk_callback,
190                stream_done_callback,
191                stream_error_callback,
192            );
193        }
194
195        // Wait for completion
196        let (mutex, cvar) = &*state;
197        let mut stream_state = mutex.lock().map_err(|_| Error::PoisonError)?;
198        while !stream_state.finished {
199            stream_state = cvar.wait(stream_state).map_err(|_| Error::PoisonError)?;
200        }
201
202        // Check for errors
203        if let Some(error) = &stream_state.error {
204            if error.contains("not available") {
205                return Err(Error::ModelNotAvailable);
206            }
207            return Err(Error::GenerationError(error.clone()));
208        }
209
210        Ok(())
211    }
212
213    /// Cancels the current streaming response
214    ///
215    /// This method immediately cancels any ongoing streaming operation started with
216    /// `stream_response`. The streaming callback will stop receiving tokens and the
217    /// stream will complete with the tokens received so far.
218    ///
219    /// # Notes
220    ///
221    /// * This is a global operation that cancels the current stream
222    /// * Safe to call even if no stream is active
223    /// * After cancellation, the `stream_response` method will return normally
224    ///
225    /// # Examples
226    ///
227    /// ```no_run
228    /// # use fm_bindings::LanguageModelSession;
229    /// # use std::thread;
230    /// # use std::time::Duration;
231    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
232    /// let session = LanguageModelSession::new()?;
233    /// let session_clone = session.clone();
234    ///
235    /// // Start streaming in a thread
236    /// thread::spawn(move || {
237    ///     session_clone.stream_response("Long prompt...", |chunk| {
238    ///         print!("{}", chunk);
239    ///     }).ok();
240    /// });
241    ///
242    /// // Cancel after a delay
243    /// thread::sleep(Duration::from_secs(2));
244    /// session.cancel_stream();
245    /// # Ok(())
246    /// # }
247    /// ```
248    pub fn cancel_stream(&self) {
249        unsafe {
250            ffi::fm_stop_stream();
251        }
252    }
253}
254
255// Internal State Types
256
257#[derive(Default)]
258struct ResponseState {
259    text: String,
260    finished: bool,
261    error: Option<String>,
262}
263
264#[derive(Default)]
265struct StreamState {
266    finished: bool,
267    error: Option<String>,
268}
269
270// C Callbacks for response()
271
272extern "C" fn response_callback(
273    chunk: *const std::os::raw::c_char,
274    user_data: *mut std::os::raw::c_void,
275) {
276    if chunk.is_null() || user_data.is_null() {
277        return;
278    }
279
280    unsafe {
281        let state = &*(user_data as *const Arc<(Mutex<ResponseState>, Condvar)>);
282        let chunk_str = std::ffi::CStr::from_ptr(chunk).to_string_lossy();
283
284        let (mutex, _) = &**state;
285        if let Ok(mut response_state) = mutex.lock() {
286            response_state.text.push_str(&chunk_str);
287        }
288    }
289}
290
291extern "C" fn response_done_callback(user_data: *mut std::os::raw::c_void) {
292    if user_data.is_null() {
293        return;
294    }
295
296    unsafe {
297        let state = Box::from_raw(user_data as *mut Arc<(Mutex<ResponseState>, Condvar)>);
298        let state_arc = (*state).clone();
299        drop(state); // Drop the Box, but Arc is still alive
300
301        let (mutex, cvar) = &*state_arc;
302        if let Ok(mut response_state) = mutex.lock() {
303            response_state.finished = true;
304            cvar.notify_all();
305        }
306    }
307}
308
309extern "C" fn response_error_callback(
310    error: *const std::os::raw::c_char,
311    user_data: *mut std::os::raw::c_void,
312) {
313    if user_data.is_null() {
314        return;
315    }
316
317    unsafe {
318        let state = Box::from_raw(user_data as *mut Arc<(Mutex<ResponseState>, Condvar)>);
319        let state_arc = (*state).clone();
320        drop(state); // Drop the Box, but Arc is still alive
321
322        let (mutex, cvar) = &*state_arc;
323        if let Ok(mut response_state) = mutex.lock() {
324            if !error.is_null() {
325                let error_str = std::ffi::CStr::from_ptr(error)
326                    .to_string_lossy()
327                    .into_owned();
328                response_state.error = Some(error_str);
329            }
330
331            response_state.finished = true;
332            cvar.notify_all();
333        }
334    }
335}
336
337// C Callbacks for stream_response()
338
339type StreamCallback = Box<dyn FnMut(&str)>;
340type StreamUserData = (Arc<(Mutex<StreamState>, Condvar)>, StreamCallback);
341
342extern "C" fn stream_chunk_callback(
343    chunk: *const std::os::raw::c_char,
344    user_data: *mut std::os::raw::c_void,
345) {
346    if chunk.is_null() || user_data.is_null() {
347        return;
348    }
349
350    unsafe {
351        let data = &mut *(user_data as *mut StreamUserData);
352        let chunk_str = std::ffi::CStr::from_ptr(chunk).to_string_lossy();
353        (data.1)(&chunk_str);
354    }
355}
356
357extern "C" fn stream_done_callback(user_data: *mut std::os::raw::c_void) {
358    if user_data.is_null() {
359        return;
360    }
361
362    unsafe {
363        let data = Box::from_raw(user_data as *mut StreamUserData);
364        let (mutex, cvar) = &*data.0;
365        if let Ok(mut stream_state) = mutex.lock() {
366            stream_state.finished = true;
367            cvar.notify_all();
368        }
369    }
370}
371
372extern "C" fn stream_error_callback(
373    error: *const std::os::raw::c_char,
374    user_data: *mut std::os::raw::c_void,
375) {
376    if user_data.is_null() {
377        return;
378    }
379
380    unsafe {
381        let data = Box::from_raw(user_data as *mut StreamUserData);
382        let (mutex, cvar) = &*data.0;
383        if let Ok(mut stream_state) = mutex.lock() {
384            if !error.is_null() {
385                let error_str = std::ffi::CStr::from_ptr(error)
386                    .to_string_lossy()
387                    .into_owned();
388                stream_state.error = Some(error_str);
389            }
390
391            stream_state.finished = true;
392            cvar.notify_all();
393        }
394    }
395}