fm_rs/
session.rs

1//! Session management for `FoundationModels`.
2//!
3//! A session maintains conversation context between requests.
4
5use std::collections::HashMap;
6use std::ffi::{CStr, CString, c_char, c_int, c_void};
7use std::ptr::{self, NonNull};
8use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
9use std::sync::{Arc, Mutex};
10use std::time::Duration;
11
12use crate::context::{ContextLimit, ContextUsage, context_usage_from_transcript};
13use crate::error::{Error, Result};
14use crate::ffi::{self, SwiftPtr};
15use crate::model::{SystemLanguageModel, error_from_swift};
16use crate::options::GenerationOptions;
17use crate::tool::{Tool, ToolResult, tools_to_json};
18
19/// Type alias for the tool map used in sessions.
20type ToolMapInner = HashMap<String, Arc<dyn Tool>>;
21
22/// Callback data shared between the session and tool callbacks.
23///
24/// This struct ensures safe cleanup by tracking active callbacks and
25/// preventing new callbacks from starting when the session is being dropped.
26struct ToolCallbackData {
27    tools: Mutex<ToolMapInner>,
28    /// Set to true when the session is being dropped.
29    dropping: AtomicBool,
30    /// Number of callbacks currently in progress.
31    active_callbacks: AtomicUsize,
32}
33
34/// RAII guard to track active callbacks.
35struct CallbackGuard<'a>(&'a AtomicUsize);
36
37impl Drop for CallbackGuard<'_> {
38    fn drop(&mut self) {
39        self.0.fetch_sub(1, Ordering::SeqCst);
40    }
41}
42
43/// Response returned by the model.
44#[derive(Debug, Clone)]
45pub struct Response {
46    content: String,
47}
48
49impl Response {
50    /// Creates a new response with the given content.
51    pub(crate) fn new(content: String) -> Self {
52        Self { content }
53    }
54
55    /// Gets the text content of the response.
56    pub fn content(&self) -> &str {
57        &self.content
58    }
59
60    /// Converts the response into its text content.
61    pub fn into_content(self) -> String {
62        self.content
63    }
64}
65
66impl AsRef<str> for Response {
67    fn as_ref(&self) -> &str {
68        &self.content
69    }
70}
71
72impl std::fmt::Display for Response {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        f.write_str(&self.content)
75    }
76}
77
78/// A session that interacts with a language model.
79///
80/// A session maintains state between requests, allowing for multi-turn conversations.
81/// You can reuse the same session for multiple prompts or create a new one each time.
82///
83/// # Example
84///
85/// ```rust,no_run
86/// use fm_rs::{Session, SystemLanguageModel, GenerationOptions};
87///
88/// let model = SystemLanguageModel::new()?;
89/// let session = Session::new(&model)?;
90///
91/// let response = session.respond("Hello!", &GenerationOptions::default())?;
92/// println!("{}", response.content());
93/// # Ok::<(), fm_rs::Error>(())
94/// ```
95pub struct Session {
96    ptr: NonNull<c_void>,
97    /// Arc to the callback data, shared with the FFI callback.
98    /// Using Arc ensures the data stays alive while callbacks are in flight.
99    tool_callback_data: Option<Arc<ToolCallbackData>>,
100}
101
102impl Session {
103    /// Creates a new session with the given model.
104    pub fn new(model: &SystemLanguageModel) -> Result<Self> {
105        Self::create_internal(model, None, &[])
106    }
107
108    /// Creates a new session with instructions.
109    ///
110    /// Instructions define the model's behavior and role.
111    pub fn with_instructions(model: &SystemLanguageModel, instructions: &str) -> Result<Self> {
112        Self::create_internal(model, Some(instructions), &[])
113    }
114
115    /// Creates a new session with tools.
116    ///
117    /// Tools allow the model to call external functions during generation.
118    pub fn with_tools(model: &SystemLanguageModel, tools: &[Arc<dyn Tool>]) -> Result<Self> {
119        Self::create_internal(model, None, tools)
120    }
121
122    /// Creates a new session with both instructions and tools.
123    pub fn with_instructions_and_tools(
124        model: &SystemLanguageModel,
125        instructions: &str,
126        tools: &[Arc<dyn Tool>],
127    ) -> Result<Self> {
128        Self::create_internal(model, Some(instructions), tools)
129    }
130
131    /// Creates a session from a transcript JSON string.
132    ///
133    /// This allows restoring a previous conversation.
134    /// Note: Restored sessions do not have tools - use `with_tools` for new sessions.
135    pub fn from_transcript(model: &SystemLanguageModel, transcript_json: &str) -> Result<Self> {
136        let transcript_c = CString::new(transcript_json)?;
137        let mut error: SwiftPtr = ptr::null_mut();
138
139        let ptr = unsafe {
140            ffi::fm_session_from_transcript(model.as_ptr(), transcript_c.as_ptr(), &raw mut error)
141        };
142
143        if !error.is_null() {
144            return Err(error_from_swift(error));
145        }
146
147        NonNull::new(ptr)
148            .map(|ptr| Self {
149                ptr,
150                tool_callback_data: None,
151            })
152            .ok_or_else(|| {
153                Error::InternalError(
154                    "Session creation from transcript returned null without error. \
155                     The transcript JSON may be malformed or incompatible."
156                        .to_string(),
157                )
158            })
159    }
160
161    /// Internal helper to create a session.
162    fn create_internal(
163        model: &SystemLanguageModel,
164        instructions: Option<&str>,
165        tools: &[Arc<dyn Tool>],
166    ) -> Result<Self> {
167        let instructions_c = instructions.map(CString::new).transpose()?;
168        let instructions_ptr = instructions_c.as_ref().map_or(ptr::null(), |s| s.as_ptr());
169
170        // Build tool map and serialize for FFI
171        let mut tool_map = HashMap::new();
172        let tools_json = if tools.is_empty() {
173            None
174        } else {
175            let tool_refs: Vec<&dyn Tool> = tools.iter().map(std::convert::AsRef::as_ref).collect();
176            for tool in tools {
177                tool_map.insert(tool.name().to_string(), Arc::clone(tool));
178            }
179            let json_str = tools_to_json(&tool_refs)?;
180            Some(CString::new(json_str)?)
181        };
182        let tools_ptr = tools_json.as_ref().map_or(ptr::null(), |s| s.as_ptr());
183
184        // Create callback data with synchronization primitives
185        let callback_data = if tools.is_empty() {
186            None
187        } else {
188            Some(Arc::new(ToolCallbackData {
189                tools: Mutex::new(tool_map),
190                dropping: AtomicBool::new(false),
191                active_callbacks: AtomicUsize::new(0),
192            }))
193        };
194
195        // Get user_data pointer for FFI (we leak an Arc clone that Swift holds)
196        let user_data = callback_data.as_ref().map_or(ptr::null_mut(), |arc| {
197            Arc::into_raw(Arc::clone(arc)) as *mut c_void
198        });
199
200        let mut error: SwiftPtr = ptr::null_mut();
201
202        let ptr = unsafe {
203            ffi::fm_session_create(
204                model.as_ptr(),
205                instructions_ptr,
206                tools_ptr,
207                user_data,
208                session_tool_callback,
209                &raw mut error,
210            )
211        };
212
213        if !error.is_null() {
214            // Clean up leaked Arc if we allocated it
215            if !user_data.is_null() {
216                let _ = unsafe { Arc::from_raw(user_data as *const ToolCallbackData) };
217            }
218            return Err(error_from_swift(error));
219        }
220
221        NonNull::new(ptr)
222            .map(|ptr| Self {
223                ptr,
224                tool_callback_data: callback_data,
225            })
226            .ok_or_else(|| {
227                // Clean up leaked Arc if we allocated it
228                if !user_data.is_null() {
229                    let _ = unsafe { Arc::from_raw(user_data as *const ToolCallbackData) };
230                }
231                Error::InternalError(
232                    "Session creation returned null without error. \
233                     Check model availability and instructions validity."
234                        .to_string(),
235                )
236            })
237    }
238
239    /// Sends a prompt and waits for the complete response.
240    ///
241    /// This method blocks until the model finishes generating.
242    pub fn respond(&self, prompt: &str, options: &GenerationOptions) -> Result<Response> {
243        let prompt_c = CString::new(prompt)?;
244        let options_json = options.to_json();
245        let options_c = CString::new(options_json)?;
246
247        let mut error: SwiftPtr = ptr::null_mut();
248
249        let response_ptr = unsafe {
250            ffi::fm_session_respond(
251                self.ptr.as_ptr(),
252                prompt_c.as_ptr(),
253                options_c.as_ptr(),
254                &raw mut error,
255            )
256        };
257
258        if !error.is_null() {
259            return Err(error_from_swift(error));
260        }
261
262        if response_ptr.is_null() {
263            return Err(Error::GenerationError("Received null response".to_string()));
264        }
265
266        let content = unsafe {
267            let cstr = CStr::from_ptr(response_ptr);
268            let s = cstr
269                .to_str()
270                .map_err(|e| Error::GenerationError(format!("Invalid UTF-8 in response: {e}")))?
271                .to_owned();
272            ffi::fm_string_free(response_ptr);
273            s
274        };
275
276        Ok(Response::new(content))
277    }
278
279    /// Sends a prompt and waits for the complete response, with a timeout.
280    ///
281    /// If `timeout` is zero, this behaves like [`respond`](Self::respond).
282    pub fn respond_with_timeout(
283        &self,
284        prompt: &str,
285        options: &GenerationOptions,
286        timeout: Duration,
287    ) -> Result<Response> {
288        if timeout.is_zero() {
289            return self.respond(prompt, options);
290        }
291
292        let timeout_ms = u64::try_from(timeout.as_millis()).map_err(|_| {
293            Error::InvalidInput("Timeout is too large to represent in milliseconds".to_string())
294        })?;
295
296        let prompt_c = CString::new(prompt)?;
297        let options_json = options.to_json();
298        let options_c = CString::new(options_json)?;
299
300        let mut error: SwiftPtr = ptr::null_mut();
301
302        let response_ptr = unsafe {
303            ffi::fm_session_respond_with_timeout(
304                self.ptr.as_ptr(),
305                prompt_c.as_ptr(),
306                options_c.as_ptr(),
307                timeout_ms,
308                &raw mut error,
309            )
310        };
311
312        if !error.is_null() {
313            return Err(error_from_swift(error));
314        }
315
316        if response_ptr.is_null() {
317            return Err(Error::GenerationError("Received null response".to_string()));
318        }
319
320        let content = unsafe {
321            let cstr = CStr::from_ptr(response_ptr);
322            let s = cstr
323                .to_str()
324                .map_err(|e| Error::GenerationError(format!("Invalid UTF-8 in response: {e}")))?
325                .to_owned();
326            ffi::fm_string_free(response_ptr);
327            s
328        };
329
330        Ok(Response::new(content))
331    }
332
333    /// Sends a prompt and streams the response.
334    ///
335    /// The `on_chunk` callback is called for each text chunk as it arrives.
336    /// This method blocks until streaming is complete.
337    ///
338    /// # Example
339    ///
340    /// ```rust,no_run
341    /// use fm_rs::{Session, SystemLanguageModel, GenerationOptions};
342    ///
343    /// let model = SystemLanguageModel::new()?;
344    /// let session = Session::new(&model)?;
345    ///
346    /// session.stream_response("Tell me a story", &GenerationOptions::default(), |chunk| {
347    ///     print!("{}", chunk);
348    /// })?;
349    /// # Ok::<(), fm_rs::Error>(())
350    /// ```
351    pub fn stream_response<F>(
352        &self,
353        prompt: &str,
354        options: &GenerationOptions,
355        on_chunk: F,
356    ) -> Result<()>
357    where
358        F: FnMut(&str) + Send + 'static,
359    {
360        let prompt_c = CString::new(prompt)?;
361        let options_json = options.to_json();
362        let options_c = CString::new(options_json)?;
363
364        // Create callback state
365        let state = Box::new(StreamState {
366            on_chunk: Mutex::new(Box::new(on_chunk)),
367            error: Mutex::new(None),
368        });
369        let state_ptr = Box::into_raw(state).cast::<c_void>();
370
371        unsafe {
372            ffi::fm_session_stream(
373                self.ptr.as_ptr(),
374                prompt_c.as_ptr(),
375                options_c.as_ptr(),
376                state_ptr,
377                stream_chunk_callback,
378                stream_done_callback,
379                stream_error_callback,
380            );
381        }
382
383        // Reclaim the state and check for errors
384        let state = unsafe { Box::from_raw(state_ptr.cast::<StreamState>()) };
385        let error = state.error.lock().map_err(|_| Error::PoisonError)?;
386        if let Some(err) = error.as_ref() {
387            return Err(Error::GenerationError(err.clone()));
388        }
389
390        Ok(())
391    }
392
393    /// Cancels an ongoing stream operation.
394    pub fn cancel(&self) {
395        unsafe {
396            ffi::fm_session_cancel(self.ptr.as_ptr());
397        }
398    }
399
400    /// Checks if the session is currently generating a response.
401    pub fn is_responding(&self) -> bool {
402        unsafe { ffi::fm_session_is_responding(self.ptr.as_ptr()) }
403    }
404
405    /// Gets the session transcript as a JSON string.
406    ///
407    /// This can be used to persist and restore conversations.
408    pub fn transcript_json(&self) -> Result<String> {
409        let mut error: SwiftPtr = ptr::null_mut();
410        let ptr = unsafe { ffi::fm_session_get_transcript(self.ptr.as_ptr(), &raw mut error) };
411
412        if !error.is_null() {
413            return Err(error_from_swift(error));
414        }
415
416        if ptr.is_null() {
417            return Err(Error::InternalError(
418                "Transcript retrieval returned null without error. \
419                 The session may be in an invalid state."
420                    .to_string(),
421            ));
422        }
423
424        let json = unsafe {
425            let cstr = CStr::from_ptr(ptr);
426            let s = cstr
427                .to_str()
428                .map_err(|e| Error::InternalError(format!("Invalid UTF-8 in transcript: {e}")))?
429                .to_owned();
430            ffi::fm_string_free(ptr);
431            s
432        };
433
434        Ok(json)
435    }
436
437    /// Estimates current context usage based on the session transcript.
438    pub fn context_usage(&self, limit: &ContextLimit) -> Result<ContextUsage> {
439        let transcript_json = self.transcript_json()?;
440        context_usage_from_transcript(&transcript_json, limit)
441    }
442
443    /// Returns an error if the estimated context usage exceeds the configured limit.
444    pub fn ensure_context_within(&self, limit: &ContextLimit) -> Result<()> {
445        let usage = self.context_usage(limit)?;
446        if usage.over_limit {
447            return Err(Error::InvalidInput(format!(
448                "Estimated context usage {} exceeds configured limit {} (reserved: {})",
449                usage.estimated_tokens, usage.max_tokens, usage.reserved_response_tokens
450            )));
451        }
452        Ok(())
453    }
454
455    /// Prewarms the model with an optional prompt prefix.
456    ///
457    /// This can reduce latency for the first response.
458    pub fn prewarm(&self, prompt_prefix: Option<&str>) -> Result<()> {
459        let prefix_c = prompt_prefix.map(CString::new).transpose()?;
460        let prefix_ptr = prefix_c.as_ref().map_or(ptr::null(), |s| s.as_ptr());
461
462        unsafe {
463            ffi::fm_session_prewarm(self.ptr.as_ptr(), prefix_ptr);
464        }
465
466        Ok(())
467    }
468
469    /// Sends a prompt and returns a structured JSON response.
470    ///
471    /// The schema is a JSON Schema that describes the expected output format.
472    /// The model is instructed to produce JSON that matches the schema.
473    ///
474    /// # Example
475    ///
476    /// ```rust,no_run
477    /// use fm_rs::{Session, SystemLanguageModel, GenerationOptions};
478    /// use serde::Deserialize;
479    /// use serde_json::json;
480    ///
481    /// #[derive(Deserialize)]
482    /// struct Person {
483    ///     name: String,
484    ///     age: u32,
485    /// }
486    ///
487    /// let model = SystemLanguageModel::new()?;
488    /// let session = Session::new(&model)?;
489    ///
490    /// let schema = json!({
491    ///     "type": "object",
492    ///     "properties": {
493    ///         "name": { "type": "string" },
494    ///         "age": { "type": "integer" }
495    ///     },
496    ///     "required": ["name", "age"]
497    /// });
498    ///
499    /// let json_str = session.respond_json(
500    ///     "Generate a fictional person",
501    ///     &schema,
502    ///     &GenerationOptions::default()
503    /// )?;
504    ///
505    /// let person: Person = serde_json::from_str(&json_str)?;
506    /// # Ok::<(), Box<dyn std::error::Error>>(())
507    /// ```
508    pub fn respond_json(
509        &self,
510        prompt: &str,
511        schema: &serde_json::Value,
512        options: &GenerationOptions,
513    ) -> Result<String> {
514        let prompt_c = CString::new(prompt)?;
515        let schema_json = serde_json::to_string(schema)?;
516        let schema_c = CString::new(schema_json)?;
517        let options_json = options.to_json();
518        let options_c = CString::new(options_json)?;
519
520        let mut error: SwiftPtr = ptr::null_mut();
521
522        let response_ptr = unsafe {
523            ffi::fm_session_respond_json(
524                self.ptr.as_ptr(),
525                prompt_c.as_ptr(),
526                schema_c.as_ptr(),
527                options_c.as_ptr(),
528                &raw mut error,
529            )
530        };
531
532        if !error.is_null() {
533            return Err(error_from_swift(error));
534        }
535
536        if response_ptr.is_null() {
537            return Err(Error::GenerationError(
538                "Received null response from JSON generation".to_string(),
539            ));
540        }
541
542        let content = unsafe {
543            let cstr = CStr::from_ptr(response_ptr);
544            let s = cstr
545                .to_str()
546                .map_err(|e| {
547                    Error::GenerationError(format!("Invalid UTF-8 in JSON response: {e}"))
548                })?
549                .to_owned();
550            ffi::fm_string_free(response_ptr);
551            s
552        };
553
554        Ok(content)
555    }
556
557    /// Sends a prompt and returns a deserialized structured response.
558    ///
559    /// This is a convenience method that calls `respond_json` and deserializes
560    /// the result into the specified type.
561    ///
562    /// # Example
563    ///
564    /// ```rust,no_run
565    /// use fm_rs::{Session, SystemLanguageModel, GenerationOptions};
566    /// use serde::Deserialize;
567    /// use serde_json::json;
568    ///
569    /// #[derive(Deserialize)]
570    /// struct Person {
571    ///     name: String,
572    ///     age: u32,
573    /// }
574    ///
575    /// let model = SystemLanguageModel::new()?;
576    /// let session = Session::new(&model)?;
577    ///
578    /// let schema = json!({
579    ///     "type": "object",
580    ///     "properties": {
581    ///         "name": { "type": "string" },
582    ///         "age": { "type": "integer" }
583    ///     },
584    ///     "required": ["name", "age"]
585    /// });
586    ///
587    /// let person: Person = session.respond_structured(
588    ///     "Generate a fictional person",
589    ///     &schema,
590    ///     &GenerationOptions::default()
591    /// )?;
592    /// # Ok::<(), Box<dyn std::error::Error>>(())
593    /// ```
594    pub fn respond_structured<T: serde::de::DeserializeOwned>(
595        &self,
596        prompt: &str,
597        schema: &serde_json::Value,
598        options: &GenerationOptions,
599    ) -> Result<T> {
600        let json_str = self.respond_json(prompt, schema, options)?;
601        serde_json::from_str(&json_str)
602            .map_err(|e| Error::InvalidInput(format!("Failed to deserialize response: {e}")))
603    }
604
605    /// Sends a prompt and returns a deserialized structured response using a derived schema.
606    ///
607    /// This uses the [`crate::Generable`] implementation to obtain the JSON schema.
608    pub fn respond_structured_gen<T>(&self, prompt: &str, options: &GenerationOptions) -> Result<T>
609    where
610        T: crate::Generable + serde::de::DeserializeOwned,
611    {
612        self.respond_structured(prompt, &T::schema(), options)
613    }
614
615    /// Streams a structured JSON response.
616    ///
617    /// The `on_chunk` callback receives partial JSON as it's generated.
618    /// Note that partial chunks may not be valid JSON until streaming completes.
619    ///
620    /// # Example
621    ///
622    /// ```rust,no_run
623    /// use fm_rs::{Session, SystemLanguageModel, GenerationOptions};
624    /// use serde_json::json;
625    ///
626    /// let model = SystemLanguageModel::new()?;
627    /// let session = Session::new(&model)?;
628    ///
629    /// let schema = json!({
630    ///     "type": "object",
631    ///     "properties": {
632    ///         "items": { "type": "array", "items": { "type": "string" } }
633    ///     }
634    /// });
635    ///
636    /// session.stream_json(
637    ///     "List 5 programming languages",
638    ///     &schema,
639    ///     &GenerationOptions::default(),
640    ///     |chunk| {
641    ///         print!("{chunk}");
642    ///     }
643    /// )?;
644    /// # Ok::<(), fm_rs::Error>(())
645    /// ```
646    pub fn stream_json<F>(
647        &self,
648        prompt: &str,
649        schema: &serde_json::Value,
650        options: &GenerationOptions,
651        on_chunk: F,
652    ) -> Result<()>
653    where
654        F: FnMut(&str) + Send + 'static,
655    {
656        let prompt_c = CString::new(prompt)?;
657        let schema_json = serde_json::to_string(schema)?;
658        let schema_c = CString::new(schema_json)?;
659        let options_json = options.to_json();
660        let options_c = CString::new(options_json)?;
661
662        // Create callback state
663        let state = Box::new(StreamState {
664            on_chunk: Mutex::new(Box::new(on_chunk)),
665            error: Mutex::new(None),
666        });
667        let state_ptr = Box::into_raw(state).cast::<c_void>();
668
669        unsafe {
670            ffi::fm_session_stream_json(
671                self.ptr.as_ptr(),
672                prompt_c.as_ptr(),
673                schema_c.as_ptr(),
674                options_c.as_ptr(),
675                state_ptr,
676                stream_chunk_callback,
677                stream_done_callback,
678                stream_error_callback,
679            );
680        }
681
682        // Reclaim the state and check for errors
683        let state = unsafe { Box::from_raw(state_ptr.cast::<StreamState>()) };
684        let error = state.error.lock().map_err(|_| Error::PoisonError)?;
685        if let Some(err) = error.as_ref() {
686            return Err(Error::GenerationError(err.clone()));
687        }
688
689        Ok(())
690    }
691}
692
693impl Drop for Session {
694    fn drop(&mut self) {
695        // Signal that we're dropping - new callbacks will return early
696        if let Some(ref callback_data) = self.tool_callback_data {
697            callback_data.dropping.store(true, Ordering::SeqCst);
698
699            // Wait for any in-flight callbacks to complete (with timeout)
700            let mut attempts = 0;
701            while callback_data.active_callbacks.load(Ordering::SeqCst) > 0 && attempts < 100 {
702                std::thread::sleep(std::time::Duration::from_millis(10));
703                attempts += 1;
704            }
705        }
706
707        // Now safe to free the Swift session
708        unsafe {
709            ffi::fm_session_free(self.ptr.as_ptr());
710        }
711
712        // The Arc in tool_callback_data will be dropped automatically.
713        // Swift also holds an Arc clone (via Arc::into_raw), which will be
714        // reclaimed when Swift's ToolDispatcher is deallocated.
715    }
716}
717
718// SAFETY: Session is a wrapper around a Swift object that uses
719// DispatchQueue for thread safety internally.
720unsafe impl Send for Session {}
721
722// Note: Session is NOT Sync because streaming callbacks use internal mutable state.
723// If you need to share a session across threads, wrap it in Arc<Mutex<Session>>.
724
725/// Type alias for the chunk callback function.
726type ChunkCallbackFn = dyn FnMut(&str) + Send;
727
728/// Internal state for streaming callbacks.
729struct StreamState {
730    on_chunk: Mutex<Box<ChunkCallbackFn>>,
731    error: Mutex<Option<String>>,
732}
733
734/// Callback invoked when a chunk arrives during streaming.
735extern "C" fn stream_chunk_callback(user_data: *mut c_void, chunk: *const c_char) {
736    if user_data.is_null() || chunk.is_null() {
737        return;
738    }
739
740    let state = unsafe { &*(user_data as *const StreamState) };
741    let chunk_str = unsafe { CStr::from_ptr(chunk).to_string_lossy() };
742
743    if let Ok(mut on_chunk) = state.on_chunk.lock() {
744        on_chunk(&chunk_str);
745    }
746}
747
748/// Callback invoked when streaming is done.
749extern "C" fn stream_done_callback(_user_data: *mut c_void) {
750    // Nothing to do - state cleanup happens in stream_response
751}
752
753/// Callback invoked on error during streaming.
754extern "C" fn stream_error_callback(user_data: *mut c_void, _code: c_int, message: *const c_char) {
755    if user_data.is_null() {
756        return;
757    }
758
759    let state = unsafe { &*(user_data as *const StreamState) };
760    let msg = if message.is_null() {
761        "Streaming error occurred (no message provided by Swift)".to_string()
762    } else {
763        unsafe { CStr::from_ptr(message).to_string_lossy().into_owned() }
764    };
765
766    if let Ok(mut error) = state.error.lock() {
767        *error = Some(msg);
768    }
769}
770
771/// Callback invoked when a tool needs to be called during session operations.
772/// This is used by Swift's `FFITool` to call back into Rust.
773extern "C" fn session_tool_callback(
774    user_data: *mut c_void,
775    tool_name: *const c_char,
776    arguments_json: *const c_char,
777) -> *mut c_char {
778    if user_data.is_null() || tool_name.is_null() {
779        let result = ToolResult::error("Invalid callback parameters");
780        return string_to_c(result.to_json());
781    }
782
783    // user_data is a raw pointer to Arc<ToolCallbackData> (from Arc::into_raw)
784    // SAFETY: Swift holds a reference to this Arc, keeping it alive.
785    // We must NOT consume the Arc here - just borrow it.
786    let callback_data = unsafe { &*(user_data as *const ToolCallbackData) };
787
788    // Check if session is being dropped - if so, return early
789    if callback_data.dropping.load(Ordering::SeqCst) {
790        let result = ToolResult::error("Session is being dropped");
791        return string_to_c(result.to_json());
792    }
793
794    // Track that we're in a callback (guard ensures cleanup on all exit paths)
795    callback_data
796        .active_callbacks
797        .fetch_add(1, Ordering::SeqCst);
798    let _guard = CallbackGuard(&callback_data.active_callbacks);
799
800    let name = unsafe { CStr::from_ptr(tool_name).to_string_lossy().into_owned() };
801    let args_str = if arguments_json.is_null() {
802        "{}".to_string()
803    } else {
804        unsafe {
805            CStr::from_ptr(arguments_json)
806                .to_string_lossy()
807                .into_owned()
808        }
809    };
810
811    // Parse arguments (with a best-effort auto-close for truncated JSON)
812    let arguments: serde_json::Value = match parse_tool_arguments(&args_str) {
813        Ok(v) => v,
814        Err(message) => {
815            let result = ToolResult::error(message);
816            return string_to_c(result.to_json());
817        }
818    };
819
820    // Find and call the tool
821    let Ok(tools) = callback_data.tools.lock() else {
822        let result = ToolResult::error("Failed to acquire tool lock");
823        return string_to_c(result.to_json());
824    };
825
826    let Some(tool) = tools.get(&name).map(Arc::clone) else {
827        let result = ToolResult::error(format!("Unknown tool: {name}"));
828        return string_to_c(result.to_json());
829    };
830
831    // Release the lock before calling the tool (it might take a while)
832    drop(tools);
833
834    // Invoke the tool
835    let result = match tool.call(arguments) {
836        Ok(output) => ToolResult::success(output),
837        Err(e) => ToolResult::error(e.to_string()),
838    };
839
840    string_to_c(result.to_json())
841}
842
843/// Helper to convert a Rust string to a C string that can be freed by Swift.
844fn string_to_c(s: String) -> *mut c_char {
845    match CString::new(s) {
846        Ok(cs) => cs.into_raw(),
847        Err(_) => ptr::null_mut(),
848    }
849}
850
851fn parse_tool_arguments(input: &str) -> std::result::Result<serde_json::Value, String> {
852    match serde_json::from_str(input) {
853        Ok(value) => Ok(value),
854        Err(err) => {
855            if let Some(fixed) = autoclose_json(input) {
856                match serde_json::from_str(&fixed) {
857                    Ok(value) => {
858                        // Log when auto-close fixes truncated JSON (debug builds only)
859                        #[cfg(debug_assertions)]
860                        eprintln!(
861                            "[fm-rs] autoclose_json repaired truncated tool arguments: {input:?} -> {fixed:?}"
862                        );
863                        Ok(value)
864                    }
865                    Err(fixed_err) => Err(format!(
866                        "Failed to parse arguments: {err}; attempted fix: {fixed_err}"
867                    )),
868                }
869            } else {
870                Err(format!("Failed to parse arguments: {err}"))
871            }
872        }
873    }
874}
875
876/// Maximum input size for `autoclose_json` to prevent resource exhaustion (1 MB).
877const AUTOCLOSE_JSON_MAX_SIZE: usize = 1024 * 1024;
878
879fn autoclose_json(input: &str) -> Option<String> {
880    // Limit input size to prevent resource exhaustion attacks
881    if input.len() > AUTOCLOSE_JSON_MAX_SIZE {
882        return None;
883    }
884
885    let mut stack: Vec<char> = Vec::new();
886    let mut in_string = false;
887    let mut escape = false;
888
889    for ch in input.chars() {
890        if in_string {
891            if escape {
892                escape = false;
893                continue;
894            }
895            if ch == '\\' {
896                escape = true;
897                continue;
898            }
899            if ch == '"' {
900                in_string = false;
901            }
902            continue;
903        }
904
905        match ch {
906            '"' => in_string = true,
907            '{' => stack.push('}'),
908            '[' => stack.push(']'),
909            '}' => {
910                if stack.pop() != Some('}') {
911                    return None;
912                }
913            }
914            ']' => {
915                if stack.pop() != Some(']') {
916                    return None;
917                }
918            }
919            _ => {}
920        }
921    }
922
923    if in_string || stack.is_empty() {
924        return None;
925    }
926
927    let mut out = input.to_string();
928    while let Some(close) = stack.pop() {
929        out.push(close);
930    }
931    Some(out)
932}
933
934#[cfg(test)]
935mod tests {
936    use super::*;
937
938    #[test]
939    fn test_response() {
940        let response = Response::new("Hello, world!".to_string());
941        assert_eq!(response.content(), "Hello, world!");
942        assert_eq!(response.as_ref(), "Hello, world!");
943        assert_eq!(format!("{response}"), "Hello, world!");
944        assert_eq!(response.into_content(), "Hello, world!");
945    }
946}