fm_bindings/session.rs
1// src/session.rs
2// Language Model Session - the main API for Foundation Models
3
4use super::error::{Error, Result};
5use super::ffi;
6use std::ffi::CString;
7use std::sync::{Arc, Condvar, Mutex};
8
9/// A session for interacting with Apple's Foundation Models
10///
11/// This provides access to on-device language models via the FoundationModels framework.
12/// Requires macOS 26+ or iOS 26+ with Apple Intelligence enabled.
13///
14/// # Examples
15///
16/// ## Blocking response
17/// ```no_run
18/// # use fm_bindings::LanguageModelSession;
19/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
20/// let session = LanguageModelSession::new()?;
21/// let response = session.response("What is Rust?")?;
22/// println!("{}", response);
23/// # Ok(())
24/// # }
25/// ```
26///
27/// ## Streaming response
28/// ```no_run
29/// # use fm_bindings::LanguageModelSession;
30/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
31/// let session = LanguageModelSession::new()?;
32/// session.stream_response("What is Rust?", |chunk| {
33/// print!("{}", chunk);
34/// })?;
35/// # Ok(())
36/// # }
37/// ```
38#[derive(Clone)]
39pub struct LanguageModelSession {
40 _private: (),
41}
42
43impl LanguageModelSession {
44 /// Creates a new language model session
45 ///
46 /// This checks that the Foundation Model is available on the system.
47 ///
48 /// # Errors
49 ///
50 /// Returns `Error::ModelNotAvailable` if Apple Intelligence is not enabled
51 /// or the system model is unavailable.
52 pub fn new() -> Result<Self> {
53 // Check availability before creating the session (fail-fast)
54 let is_available = unsafe { ffi::fm_check_availability() };
55
56 if !is_available {
57 return Err(Error::ModelNotAvailable);
58 }
59
60 Ok(Self { _private: () })
61 }
62
63 /// Generates a complete response to the given prompt
64 ///
65 /// This method blocks until the entire response is generated and returned as a String.
66 /// For a better user experience with incremental updates, use `stream_response` instead.
67 ///
68 /// # Arguments
69 ///
70 /// * `prompt` - The input text to send to the model
71 ///
72 /// # Errors
73 ///
74 /// * `Error::ModelNotAvailable` - If the Foundation Model is not available
75 /// * `Error::InvalidInput` - If the prompt is empty or invalid
76 /// * `Error::GenerationError` - If an error occurs during generation
77 ///
78 /// # Examples
79 ///
80 /// ```no_run
81 /// # use fm_bindings::LanguageModelSession;
82 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
83 /// let session = LanguageModelSession::new()?;
84 /// let response = session.response("Explain Rust ownership")?;
85 /// println!("Response: {}", response);
86 /// # Ok(())
87 /// # }
88 /// ```
89 pub fn response(&self, prompt: &str) -> Result<String> {
90 if prompt.is_empty() {
91 return Err(Error::InvalidInput("Prompt cannot be empty".into()));
92 }
93
94 // Create C string for FFI
95 let c_prompt = CString::new(prompt)
96 .map_err(|_| Error::InvalidInput("Prompt contains null byte".into()))?;
97
98 // Shared state for collecting response
99 let state = Arc::new((Mutex::new(ResponseState::default()), Condvar::new()));
100 let state_clone = Arc::clone(&state);
101
102 // Call Swift FFI with blocking response mode
103 unsafe {
104 ffi::fm_response(
105 c_prompt.as_ptr(),
106 Box::into_raw(Box::new(state_clone)) as *mut _,
107 response_callback,
108 response_done_callback,
109 response_error_callback,
110 );
111 }
112
113 // Wait for completion
114 let (mutex, cvar) = &*state;
115 let mut response_state = mutex.lock().map_err(|_| Error::PoisonError)?;
116 while !response_state.finished {
117 response_state = cvar.wait(response_state).map_err(|_| Error::PoisonError)?;
118 }
119
120 // Check for errors
121 if let Some(error) = &response_state.error {
122 if error.contains("not available") {
123 return Err(Error::ModelNotAvailable);
124 }
125 return Err(Error::GenerationError(error.clone()));
126 }
127
128 Ok(response_state.text.clone())
129 }
130
131 /// Generates a streaming response to the given prompt
132 ///
133 /// This method calls the provided callback for each chunk as it's generated,
134 /// providing immediate feedback to the user. The callback receives string slices
135 /// containing incremental text deltas.
136 ///
137 /// # Arguments
138 ///
139 /// * `prompt` - The input text to send to the model
140 /// * `on_chunk` - Callback function called for each generated chunk
141 ///
142 /// # Errors
143 ///
144 /// * `Error::ModelNotAvailable` - If the Foundation Model is not available
145 /// * `Error::InvalidInput` - If the prompt is empty or invalid
146 /// * `Error::GenerationError` - If an error occurs during generation
147 ///
148 /// # Examples
149 ///
150 /// ```no_run
151 /// # use fm_bindings::LanguageModelSession;
152 /// # use std::io::{self, Write};
153 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
154 /// let session = LanguageModelSession::new()?;
155 ///
156 /// session.stream_response("Tell me a story", |chunk| {
157 /// print!("{}", chunk);
158 /// let _ = io::stdout().flush();
159 /// })?;
160 ///
161 /// println!(); // newline after stream completes
162 /// # Ok(())
163 /// # }
164 /// ```
165 pub fn stream_response<F>(&self, prompt: &str, on_chunk: F) -> Result<()>
166 where
167 F: FnMut(&str),
168 {
169 if prompt.is_empty() {
170 return Err(Error::InvalidInput("Prompt cannot be empty".into()));
171 }
172
173 // Create C string for FFI
174 let c_prompt = CString::new(prompt)
175 .map_err(|_| Error::InvalidInput("Prompt contains null byte".into()))?;
176
177 // Shared state for streaming
178 let state = Arc::new((Mutex::new(StreamState::default()), Condvar::new()));
179 let state_clone = Arc::clone(&state);
180
181 // Call Swift FFI with streaming mode
182 unsafe {
183 ffi::fm_start_stream(
184 c_prompt.as_ptr(),
185 Box::into_raw(Box::new((
186 state_clone,
187 Box::new(on_chunk) as Box<dyn FnMut(&str)>,
188 ))) as *mut _,
189 stream_chunk_callback,
190 stream_done_callback,
191 stream_error_callback,
192 );
193 }
194
195 // Wait for completion
196 let (mutex, cvar) = &*state;
197 let mut stream_state = mutex.lock().map_err(|_| Error::PoisonError)?;
198 while !stream_state.finished {
199 stream_state = cvar.wait(stream_state).map_err(|_| Error::PoisonError)?;
200 }
201
202 // Check for errors
203 if let Some(error) = &stream_state.error {
204 if error.contains("not available") {
205 return Err(Error::ModelNotAvailable);
206 }
207 return Err(Error::GenerationError(error.clone()));
208 }
209
210 Ok(())
211 }
212
213 /// Cancels the current streaming response
214 ///
215 /// This method immediately cancels any ongoing streaming operation started with
216 /// `stream_response`. The streaming callback will stop receiving tokens and the
217 /// stream will complete with the tokens received so far.
218 ///
219 /// # Notes
220 ///
221 /// * This is a global operation that cancels the current stream
222 /// * Safe to call even if no stream is active
223 /// * After cancellation, the `stream_response` method will return normally
224 ///
225 /// # Examples
226 ///
227 /// ```no_run
228 /// # use fm_bindings::LanguageModelSession;
229 /// # use std::thread;
230 /// # use std::time::Duration;
231 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
232 /// let session = LanguageModelSession::new()?;
233 /// let session_clone = session.clone();
234 ///
235 /// // Start streaming in a thread
236 /// thread::spawn(move || {
237 /// session_clone.stream_response("Long prompt...", |chunk| {
238 /// print!("{}", chunk);
239 /// }).ok();
240 /// });
241 ///
242 /// // Cancel after a delay
243 /// thread::sleep(Duration::from_secs(2));
244 /// session.cancel_stream();
245 /// # Ok(())
246 /// # }
247 /// ```
248 pub fn cancel_stream(&self) {
249 unsafe {
250 ffi::fm_stop_stream();
251 }
252 }
253}
254
255// Internal State Types
256
257#[derive(Default)]
258struct ResponseState {
259 text: String,
260 finished: bool,
261 error: Option<String>,
262}
263
264#[derive(Default)]
265struct StreamState {
266 finished: bool,
267 error: Option<String>,
268}
269
270// C Callbacks for response()
271
272extern "C" fn response_callback(
273 chunk: *const std::os::raw::c_char,
274 user_data: *mut std::os::raw::c_void,
275) {
276 if chunk.is_null() || user_data.is_null() {
277 return;
278 }
279
280 unsafe {
281 let state = &*(user_data as *const Arc<(Mutex<ResponseState>, Condvar)>);
282 let chunk_str = std::ffi::CStr::from_ptr(chunk).to_string_lossy();
283
284 let (mutex, _) = &**state;
285 if let Ok(mut response_state) = mutex.lock() {
286 response_state.text.push_str(&chunk_str);
287 }
288 }
289}
290
291extern "C" fn response_done_callback(user_data: *mut std::os::raw::c_void) {
292 if user_data.is_null() {
293 return;
294 }
295
296 unsafe {
297 let state = Box::from_raw(user_data as *mut Arc<(Mutex<ResponseState>, Condvar)>);
298 let state_arc = (*state).clone();
299 drop(state); // Drop the Box, but Arc is still alive
300
301 let (mutex, cvar) = &*state_arc;
302 if let Ok(mut response_state) = mutex.lock() {
303 response_state.finished = true;
304 cvar.notify_all();
305 }
306 }
307}
308
309extern "C" fn response_error_callback(
310 error: *const std::os::raw::c_char,
311 user_data: *mut std::os::raw::c_void,
312) {
313 if user_data.is_null() {
314 return;
315 }
316
317 unsafe {
318 let state = Box::from_raw(user_data as *mut Arc<(Mutex<ResponseState>, Condvar)>);
319 let state_arc = (*state).clone();
320 drop(state); // Drop the Box, but Arc is still alive
321
322 let (mutex, cvar) = &*state_arc;
323 if let Ok(mut response_state) = mutex.lock() {
324 if !error.is_null() {
325 let error_str = std::ffi::CStr::from_ptr(error)
326 .to_string_lossy()
327 .into_owned();
328 response_state.error = Some(error_str);
329 }
330
331 response_state.finished = true;
332 cvar.notify_all();
333 }
334 }
335}
336
337// C Callbacks for stream_response()
338
339type StreamCallback = Box<dyn FnMut(&str)>;
340type StreamUserData = (Arc<(Mutex<StreamState>, Condvar)>, StreamCallback);
341
342extern "C" fn stream_chunk_callback(
343 chunk: *const std::os::raw::c_char,
344 user_data: *mut std::os::raw::c_void,
345) {
346 if chunk.is_null() || user_data.is_null() {
347 return;
348 }
349
350 unsafe {
351 let data = &mut *(user_data as *mut StreamUserData);
352 let chunk_str = std::ffi::CStr::from_ptr(chunk).to_string_lossy();
353 (data.1)(&chunk_str);
354 }
355}
356
357extern "C" fn stream_done_callback(user_data: *mut std::os::raw::c_void) {
358 if user_data.is_null() {
359 return;
360 }
361
362 unsafe {
363 let data = Box::from_raw(user_data as *mut StreamUserData);
364 let (mutex, cvar) = &*data.0;
365 if let Ok(mut stream_state) = mutex.lock() {
366 stream_state.finished = true;
367 cvar.notify_all();
368 }
369 }
370}
371
372extern "C" fn stream_error_callback(
373 error: *const std::os::raw::c_char,
374 user_data: *mut std::os::raw::c_void,
375) {
376 if user_data.is_null() {
377 return;
378 }
379
380 unsafe {
381 let data = Box::from_raw(user_data as *mut StreamUserData);
382 let (mutex, cvar) = &*data.0;
383 if let Ok(mut stream_state) = mutex.lock() {
384 if !error.is_null() {
385 let error_str = std::ffi::CStr::from_ptr(error)
386 .to_string_lossy()
387 .into_owned();
388 stream_state.error = Some(error_str);
389 }
390
391 stream_state.finished = true;
392 cvar.notify_all();
393 }
394 }
395}