Skip to main content

foundation_models/session/
mod.rs

1//! [`LanguageModelSession`] — a stateful conversation with the on-device model.
2
3use core::ffi::{c_char, c_void};
4use core::ptr;
5use std::ffi::CString;
6use std::panic::{catch_unwind, AssertUnwindSafe};
7use std::sync::mpsc;
8use std::sync::{Arc, Mutex};
9
10use serde::Deserialize;
11use serde_json::json;
12
13use crate::content::{BridgeGeneratedContent, GeneratedContent};
14use crate::error::FMError;
15use crate::ffi;
16use crate::generation::{GenerationOptions, SamplingMode};
17use crate::model::ConfiguredSystemLanguageModel;
18use crate::prompt::{Instructions, Prompt, ToInstructions, ToPrompt};
19use crate::schema::GenerationSchema;
20use crate::tool::{tool_callback_trampoline, Tool, ToolRegistry};
21use crate::transcript::Transcript;
22
23/// A stateful conversation with the on-device language model.
24///
25/// Sessions retain their conversation history; subsequent calls to
26/// [`respond`](Self::respond) build on the previous turns.
27///
28/// # Examples
29///
30/// ```rust,no_run
31/// use foundation_models::LanguageModelSession;
32///
33/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
34/// let session = LanguageModelSession::new();
35/// let answer = session.respond("Name three Norse gods.")?;
36/// println!("{answer}");
37/// # Ok(())
38/// # }
39/// ```
40pub struct LanguageModelSession {
41    ptr: *mut c_void,
42    _tool_registry: Option<Arc<ToolRegistry>>,
43}
44
45// SAFETY: The underlying Swift LanguageModelSession is reference-counted via
46// Unmanaged.passRetained on the Swift side; sending the opaque pointer between
47// threads is safe as long as we don't dereference it from Rust (we never do —
48// it only travels through extern "C" calls that internally hop to the
49// Swift concurrency executor).
50unsafe impl Send for LanguageModelSession {}
51unsafe impl Sync for LanguageModelSession {}
52
53impl LanguageModelSession {
54    /// Return the raw opaque pointer to the underlying Swift session object.
55    ///
56    /// Used internally by `async_api` to pass the session pointer to FFI
57    /// callbacks without exposing `ptr` as a public field.
58    #[cfg(feature = "async")]
59    pub(crate) fn as_ptr(&self) -> *mut c_void {
60        self.ptr
61    }
62
63    /// Create a session with the model's default behaviour.
64    ///
65    /// # Panics
66    ///
67    /// Panics if `FoundationModels` is not available on this OS. Check
68    /// [`crate::SystemLanguageModel::is_available`] first if you need to
69    /// handle that gracefully.
70    #[must_use]
71    pub fn new() -> Self {
72        Self::try_new(None).expect("FoundationModels is not available on this OS")
73    }
74
75    /// Create a session with custom system instructions ("system prompt").
76    ///
77    /// # Panics
78    ///
79    /// Panics if `FoundationModels` is not available, or if `instructions`
80    /// contains an interior NUL byte.
81    #[must_use]
82    pub fn with_instructions(instructions: &str) -> Self {
83        Self::try_new(Some(instructions)).expect("FoundationModels is not available on this OS")
84    }
85
86    /// Fallible constructor. Returns `None` when `FoundationModels` is not
87    /// available (OS too old, model not enabled, etc.) or when `instructions`
88    /// contains an interior NUL byte.
89    #[must_use]
90    pub fn try_new(instructions: Option<&str>) -> Option<Self> {
91        let cstring = match instructions {
92            Some(s) => Some(CString::new(s).ok()?),
93            None => None,
94        };
95        let ptr =
96            unsafe { ffi::fm_session_create(cstring.as_ref().map_or(ptr::null(), |s| s.as_ptr())) };
97        if ptr.is_null() {
98            return None;
99        }
100        Some(Self {
101            ptr,
102            _tool_registry: None,
103        })
104    }
105
106    /// Send a prompt and block until the full response is available.
107    ///
108    /// # Errors
109    ///
110    /// Returns an [`FMError`] if the model rejects the prompt, the context
111    /// window is exceeded, the session is cancelled, or the prompt contains
112    /// an interior NUL byte.
113    pub fn respond(&self, prompt: &str) -> Result<String, FMError> {
114        self.respond_with(prompt, GenerationOptions::new())
115    }
116
117    /// Pre-warm the model. Apple loads the weights + initialises the
118    /// inference engine so the next `respond` call is faster. Returns
119    /// immediately; the warm-up runs in the background.
120    pub fn prewarm(&self) {
121        unsafe { ffi::fm_session_prewarm(self.ptr) };
122    }
123
124    /// True if this session is currently producing a response (i.e. an
125    /// earlier `respond` / `stream` is still in flight on Apple's queue).
126    #[must_use]
127    pub fn is_responding(&self) -> bool {
128        unsafe { ffi::fm_session_is_responding(self.ptr) }
129    }
130
131    /// Return a best-effort JSON serialisation of the session's
132    /// `Transcript` — the full history of user prompts and model
133    /// responses. Useful for persisting a chat session across
134    /// process boundaries.
135    #[must_use]
136    pub fn transcript_json(&self) -> String {
137        let p = unsafe { ffi::fm_session_transcript_json(self.ptr) };
138        if p.is_null() {
139            return String::from("{}");
140        }
141        let s = unsafe { core::ffi::CStr::from_ptr(p) }
142            .to_string_lossy()
143            .into_owned();
144        unsafe { ffi::fm_string_free(p) };
145        s
146    }
147
148    /// Log feedback on the most recent response for diagnostic /
149    /// fine-tuning purposes. `sentiment`:
150    /// `1` positive, `0` neutral, `-1` negative.
151    pub fn log_feedback(&self, sentiment: i32, description: Option<&str>) {
152        let cstr = description.and_then(|s| CString::new(s).ok());
153        let p = cstr.as_ref().map_or(core::ptr::null(), |c| c.as_ptr());
154        unsafe { ffi::fm_session_log_feedback(self.ptr, sentiment, p) };
155    }
156
157    /// Prompt-engineered JSON-shape response.
158    ///
159    /// Wraps the prompt with a "respond with valid JSON matching this schema"
160    /// instruction and parses the response. The schema is a
161    /// `serde_json::Value`-style JSON string (passed as text).
162    ///
163    /// Useful for getting structured data out of the model without the
164    /// full Generable macro machinery. The model still returns plain
165    /// text — the caller must parse with `serde_json` / `serde` after.
166    ///
167    /// # Errors
168    ///
169    /// See [`respond`](Self::respond).
170    pub fn respond_with_json_schema(
171        &self,
172        prompt: &str,
173        schema_description: &str,
174    ) -> Result<String, FMError> {
175        let wrapped = format!(
176            "{prompt}\n\n\
177             IMPORTANT: respond with VALID JSON ONLY (no prose, no markdown \
178             fences) that matches this schema:\n\n{schema_description}\n\n\
179             Your entire response must be parseable by JSON.parse()."
180        );
181        self.respond(&wrapped)
182    }
183
184    /// Like [`respond`](Self::respond), but with explicit generation options.
185    ///
186    /// # Errors
187    ///
188    /// See [`respond`](Self::respond).
189    pub fn respond_with(
190        &self,
191        prompt: &str,
192        options: GenerationOptions,
193    ) -> Result<String, FMError> {
194        self.respond_prompt_with(prompt, options)
195    }
196
197    /// Schema-driven structured response.
198    ///
199    /// Builds a `DynamicGenerationSchema` from the provided JSON
200    /// schema, runs `LanguageModelSession.respond(schema:prompt:)`,
201    /// and returns the model's `GeneratedContent.jsonString` — a
202    /// well-formed JSON string matching the requested shape.
203    ///
204    /// Supported `schema` shape (strict subset of JSON Schema):
205    ///
206    /// ```json
207    /// {
208    ///   "type": "object",
209    ///   "name": "Movie",
210    ///   "properties": {
211    ///     "title":  { "type": "string", "description": "Movie title" },
212    ///     "year":   { "type": "integer" },
213    ///     "rating": { "type": "number", "optional": true },
214    ///     "tags":   { "type": "array", "items": { "type": "string" }, "min": 1, "max": 5 }
215    ///   }
216    /// }
217    /// ```
218    ///
219    /// Primitive types: `"string"`, `"integer"`, `"number"`,
220    /// `"boolean"`, `"array"`, `"object"`. Each property may set
221    /// `"description"` and `"optional"`. Array schemas accept
222    /// `"items"` plus optional `"min"` / `"max"` element counts.
223    ///
224    /// # Errors
225    ///
226    /// See [`respond`](Self::respond) for general errors, plus a
227    /// "schema build failed" / "schema JSON is not valid" error
228    /// returned as [`FMError::Unknown`] if the schema is malformed.
229    pub fn respond_with_schema(
230        &self,
231        prompt: &str,
232        schema: &str,
233        include_schema_in_prompt: bool,
234    ) -> Result<String, FMError> {
235        self.respond_with_schema_options(
236            prompt,
237            schema,
238            include_schema_in_prompt,
239            GenerationOptions::new(),
240        )
241    }
242
243    /// [`respond_with_schema`](Self::respond_with_schema) with
244    /// explicit generation options.
245    ///
246    /// # Errors
247    ///
248    /// See [`respond_with_schema`](Self::respond_with_schema).
249    pub fn respond_with_schema_options(
250        &self,
251        prompt: &str,
252        schema: &str,
253        include_schema_in_prompt: bool,
254        options: GenerationOptions,
255    ) -> Result<String, FMError> {
256        let prompt_c = CString::new(prompt)
257            .map_err(|e| FMError::InvalidArgument(format!("prompt NUL byte: {e}")))?;
258        let schema_c = CString::new(schema)
259            .map_err(|e| FMError::InvalidArgument(format!("schema NUL byte: {e}")))?;
260        let opts = options.to_ffi();
261        let (tx, rx) = mpsc::channel();
262        let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
263        let context = Box::into_raw(tx_box).cast::<c_void>();
264
265        unsafe {
266            ffi::fm_session_respond_with_schema(
267                self.ptr,
268                prompt_c.as_ptr(),
269                schema_c.as_ptr(),
270                include_schema_in_prompt,
271                opts.temperature,
272                opts.maximum_response_tokens,
273                opts.sampling_mode,
274                opts.top_k,
275                opts.top_p,
276                context,
277                respond_trampoline,
278            );
279        }
280
281        rx.recv().map_err(|_| FMError::Unknown {
282            code: ffi::status::UNKNOWN,
283            message: "Swift bridge dropped the callback channel".into(),
284        })?
285    }
286
287    /// Stream the response as the model generates it. The callback is invoked
288    /// with each delta and a final invocation with `done == true`.
289    ///
290    /// # Errors
291    ///
292    /// Returns an [`FMError`] mirroring [`respond`](Self::respond). The
293    /// callback may also receive a chunk *and* an error if the stream fails
294    /// midway.
295    pub fn stream<F>(&self, prompt: &str, mut on_chunk: F) -> Result<(), FMError>
296    where
297        F: FnMut(StreamEvent<'_>) + Send + 'static,
298    {
299        self.stream_with(prompt, GenerationOptions::new(), move |event| {
300            on_chunk(event);
301        })
302    }
303
304    /// Like [`stream`](Self::stream), but with explicit generation options.
305    ///
306    /// # Errors
307    ///
308    /// See [`stream`](Self::stream).
309    pub fn stream_with<F>(
310        &self,
311        prompt: &str,
312        options: GenerationOptions,
313        on_chunk: F,
314    ) -> Result<(), FMError>
315    where
316        F: FnMut(StreamEvent<'_>) + Send + 'static,
317    {
318        let payload = respond_request_json(&Prompt::from(prompt), options, None, true)?;
319
320        let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
321        let state = Arc::new(StreamState {
322            on_chunk: Mutex::new(Box::new(on_chunk)),
323            done_tx: Mutex::new(Some(done_tx)),
324        });
325        let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
326
327        unsafe {
328            ffi::fm_session_stream_request_json(
329                self.ptr,
330                payload.as_ptr(),
331                context,
332                json_text_stream_trampoline,
333            )
334        };
335
336        done_rx.recv().map_err(|_| FMError::Unknown {
337            code: ffi::status::UNKNOWN,
338            message: "Swift bridge dropped the stream channel".into(),
339        })?
340    }
341}
342
343impl LanguageModelSession {
344    /// Create a configurable session builder.
345    #[must_use]
346    pub fn builder<'a>() -> SessionBuilder<'a> {
347        SessionBuilder::new()
348    }
349
350    /// Restore a session from a transcript.
351    ///
352    /// # Errors
353    ///
354    /// Returns an [`FMError`] if the transcript cannot be encoded for Swift.
355    pub fn from_transcript(transcript: Transcript) -> Result<Self, FMError> {
356        Self::builder().transcript(transcript).build()
357    }
358
359    /// Return the typed transcript for this session.
360    ///
361    /// # Errors
362    ///
363    /// Returns an [`FMError`] if the transcript JSON returned by Swift could not
364    /// be decoded.
365    pub fn transcript(&self) -> Result<Transcript, FMError> {
366        Transcript::from_json_str(&self.transcript_json())
367    }
368
369    /// Pre-warm the model using a prompt prefix.
370    ///
371    /// # Errors
372    ///
373    /// Returns an [`FMError`] if the prompt cannot be encoded for Swift.
374    pub fn prewarm_with_prompt<P>(&self, prompt: P) -> Result<(), FMError>
375    where
376        P: ToPrompt,
377    {
378        let prompt = prompt.to_prompt()?;
379        let prompt_json = CString::new(prompt.to_bridge_json()?).map_err(|error| {
380            FMError::InvalidArgument(format!("prompt JSON contains a NUL byte: {error}"))
381        })?;
382        let mut error: *mut c_char = ptr::null_mut();
383        let status = unsafe {
384            ffi::fm_session_prewarm_prompt_json(self.ptr, prompt_json.as_ptr(), &mut error)
385        };
386        if status != ffi::status::OK {
387            return Err(crate::error::from_swift(status, error));
388        }
389        Ok(())
390    }
391
392    /// Respond to a structured prompt and return only the generated text.
393    ///
394    /// # Errors
395    ///
396    /// Returns an [`FMError`] if generation fails.
397    pub fn respond_prompt<P>(&self, prompt: P) -> Result<String, FMError>
398    where
399        P: ToPrompt,
400    {
401        self.respond_prompt_with(prompt, GenerationOptions::new())
402    }
403
404    /// Like [`respond_prompt`](Self::respond_prompt), but with explicit options.
405    ///
406    /// # Errors
407    ///
408    /// Returns an [`FMError`] if generation fails.
409    pub fn respond_prompt_with<P>(
410        &self,
411        prompt: P,
412        options: GenerationOptions,
413    ) -> Result<String, FMError>
414    where
415        P: ToPrompt,
416    {
417        self.respond_prompt_detailed(prompt, options)
418            .map(|response| response.content)
419    }
420
421    /// Respond to a structured prompt and keep the full response metadata.
422    ///
423    /// # Errors
424    ///
425    /// Returns an [`FMError`] if generation fails.
426    pub fn respond_prompt_detailed<P>(
427        &self,
428        prompt: P,
429        options: GenerationOptions,
430    ) -> Result<SessionResponse<String>, FMError>
431    where
432        P: ToPrompt,
433    {
434        let prompt = prompt.to_prompt()?;
435        let payload = respond_request_json(&prompt, options, None, true)?;
436        let payload = request_response(self.ptr, &payload)?;
437        let response: BridgeTextResponse = serde_json::from_str(&payload)
438            .map_err(|error| FMError::DecodingFailure(error.to_string()))?;
439        Ok(SessionResponse {
440            content: response.content,
441            raw_content: GeneratedContent::from_bridge_payload(response.raw_content, true)?,
442            transcript: Transcript::from_json_str(&response.transcript_json)?,
443        })
444    }
445
446    /// Generate structured content using an explicit schema.
447    ///
448    /// # Errors
449    ///
450    /// Returns an [`FMError`] if generation fails or the schema is invalid.
451    pub fn respond_generated<P>(
452        &self,
453        prompt: P,
454        schema: &GenerationSchema,
455        include_schema_in_prompt: bool,
456    ) -> Result<GeneratedContent, FMError>
457    where
458        P: ToPrompt,
459    {
460        self.respond_generated_with(
461            prompt,
462            schema,
463            include_schema_in_prompt,
464            GenerationOptions::new(),
465        )
466        .map(|response| response.content)
467    }
468
469    /// Like [`respond_generated`](Self::respond_generated), but with explicit options.
470    ///
471    /// # Errors
472    ///
473    /// Returns an [`FMError`] if generation fails or the schema is invalid.
474    pub fn respond_generated_with<P>(
475        &self,
476        prompt: P,
477        schema: &GenerationSchema,
478        include_schema_in_prompt: bool,
479        options: GenerationOptions,
480    ) -> Result<SessionResponse<GeneratedContent>, FMError>
481    where
482        P: ToPrompt,
483    {
484        let prompt = prompt.to_prompt()?;
485        let payload =
486            respond_request_json(&prompt, options, Some(schema), include_schema_in_prompt)?;
487        let payload = request_response(self.ptr, &payload)?;
488        let response: BridgeStructuredResponse = serde_json::from_str(&payload)
489            .map_err(|error| FMError::DecodingFailure(error.to_string()))?;
490        Ok(SessionResponse {
491            content: GeneratedContent::from_bridge_payload(response.content, true)?,
492            raw_content: GeneratedContent::from_bridge_payload(response.raw_content, true)?,
493            transcript: Transcript::from_json_str(&response.transcript_json)?,
494        })
495    }
496
497    /// Generate a typed Rust value using a [`crate::schema::Generable`] implementation.
498    ///
499    /// # Errors
500    ///
501    /// Returns an [`FMError`] if generation fails or the generated JSON cannot
502    /// be decoded as `T`.
503    pub fn respond_generating<P, T>(
504        &self,
505        prompt: P,
506        include_schema_in_prompt: bool,
507        options: GenerationOptions,
508    ) -> Result<SessionResponse<T>, FMError>
509    where
510        P: ToPrompt,
511        T: crate::schema::Generable,
512    {
513        let response = self.respond_generated_with(
514            prompt,
515            &T::generation_schema()?,
516            include_schema_in_prompt,
517            options,
518        )?;
519        Ok(SessionResponse {
520            content: T::from_generated_content(&response.content)?,
521            raw_content: response.raw_content,
522            transcript: response.transcript,
523        })
524    }
525
526    /// Stream a structured prompt token-by-token.
527    ///
528    /// # Errors
529    ///
530    /// Returns an [`FMError`] if the prompt cannot be encoded or generation fails.
531    pub fn stream_prompt<P, F>(&self, prompt: P, on_chunk: F) -> Result<(), FMError>
532    where
533        P: ToPrompt,
534        F: FnMut(StreamEvent<'_>) + Send + 'static,
535    {
536        let prompt = prompt.to_prompt()?;
537        let prompt_text = prompt_to_plain_text(&prompt).ok_or_else(|| {
538            FMError::InvalidArgument(
539                "text streaming only supports prompts composed of text segments".into(),
540            )
541        })?;
542        self.stream_with(&prompt_text, GenerationOptions::new(), on_chunk)
543    }
544
545    /// Stream structured generation snapshots.
546    ///
547    /// # Errors
548    ///
549    /// Returns an [`FMError`] if the prompt cannot be encoded or generation fails.
550    pub fn stream_generated<P, F>(
551        &self,
552        prompt: P,
553        schema: &GenerationSchema,
554        include_schema_in_prompt: bool,
555        options: GenerationOptions,
556        on_event: F,
557    ) -> Result<(), FMError>
558    where
559        P: ToPrompt,
560        F: FnMut(StructuredStreamEvent) + Send + 'static,
561    {
562        let prompt = prompt.to_prompt()?;
563        let payload =
564            respond_request_json(&prompt, options, Some(schema), include_schema_in_prompt)?;
565        let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
566        let state = Arc::new(StructuredStreamState {
567            on_event: Mutex::new(Box::new(on_event)),
568            done_tx: Mutex::new(Some(done_tx)),
569        });
570        let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
571        unsafe {
572            ffi::fm_session_stream_request_json(
573                self.ptr,
574                payload.as_ptr(),
575                context,
576                structured_stream_trampoline,
577            )
578        };
579        done_rx.recv().map_err(|_| FMError::Unknown {
580            code: ffi::status::UNKNOWN,
581            message: "Swift bridge dropped the structured stream channel".into(),
582        })?
583    }
584
585    /// Log a feedback attachment and return the raw bytes Apple produced.
586    ///
587    /// # Errors
588    ///
589    /// Returns an [`FMError`] if the attachment request is invalid.
590    pub fn log_feedback_attachment(
591        &self,
592        request: FeedbackAttachmentRequest,
593    ) -> Result<Vec<u8>, FMError> {
594        let request_json = CString::new(request.to_bridge_json()?).map_err(|error| {
595            FMError::InvalidArgument(format!("feedback request contains a NUL byte: {error}"))
596        })?;
597        let mut length = 0usize;
598        let mut error: *mut c_char = ptr::null_mut();
599        let ptr = unsafe {
600            ffi::fm_session_log_feedback_attachment_json(
601                self.ptr,
602                request_json.as_ptr(),
603                &mut length,
604                &mut error,
605            )
606        };
607        if ptr.is_null() && !error.is_null() {
608            return Err(crate::error::from_swift(
609                ffi::status::INVALID_ARGUMENT,
610                error,
611            ));
612        }
613        if ptr.is_null() || length == 0 {
614            return Ok(Vec::new());
615        }
616        let bytes = unsafe { std::slice::from_raw_parts(ptr.cast::<u8>(), length) }.to_vec();
617        unsafe { ffi::fm_bytes_free(ptr) };
618        Ok(bytes)
619    }
620}
621
622/// Builder for [`LanguageModelSession`].
623pub struct SessionBuilder<'a> {
624    model: Option<&'a ConfiguredSystemLanguageModel>,
625    instructions: Option<Instructions>,
626    transcript: Option<Transcript>,
627    tools: Vec<Tool>,
628}
629
630impl<'a> SessionBuilder<'a> {
631    const fn new() -> Self {
632        Self {
633            model: None,
634            instructions: None,
635            transcript: None,
636            tools: Vec::new(),
637        }
638    }
639
640    /// Use a configured system model.
641    #[must_use]
642    pub const fn model(mut self, model: &'a ConfiguredSystemLanguageModel) -> Self {
643        self.model = Some(model);
644        self
645    }
646
647    /// Set system instructions.
648    pub fn instructions<I>(mut self, instructions: I) -> Result<Self, FMError>
649    where
650        I: ToInstructions,
651    {
652        self.instructions = Some(instructions.to_instructions()?);
653        Ok(self)
654    }
655
656    /// Restore the session from a transcript.
657    #[must_use]
658    pub fn transcript(mut self, transcript: Transcript) -> Self {
659        self.transcript = Some(transcript);
660        self
661    }
662
663    /// Add one tool.
664    #[must_use]
665    pub fn tool(mut self, tool: Tool) -> Self {
666        self.tools.push(tool);
667        self
668    }
669
670    /// Add many tools.
671    #[must_use]
672    pub fn tools(mut self, tools: impl IntoIterator<Item = Tool>) -> Self {
673        self.tools.extend(tools);
674        self
675    }
676
677    /// Build the session.
678    ///
679    /// # Errors
680    ///
681    /// Returns an [`FMError`] if the configuration cannot be encoded for Swift.
682    pub fn build(self) -> Result<LanguageModelSession, FMError> {
683        if self.instructions.is_some() && self.transcript.is_some() {
684            return Err(FMError::InvalidArgument(
685                "session builder accepts either instructions or a transcript, not both".into(),
686            ));
687        }
688
689        let instructions_json = self
690            .instructions
691            .as_ref()
692            .map(Instructions::to_bridge_json)
693            .transpose()?;
694        let transcript_json = self
695            .transcript
696            .as_ref()
697            .map(Transcript::to_json_string)
698            .transpose()?;
699        let tool_registry = if self.tools.is_empty() {
700            None
701        } else {
702            Some(Arc::new(ToolRegistry::new(self.tools)))
703        };
704        let tools_json = tool_registry
705            .as_ref()
706            .map(|registry| registry.specs_json())
707            .transpose()?;
708
709        let instructions_c = instructions_json
710            .as_deref()
711            .map(CString::new)
712            .transpose()
713            .map_err(|error| {
714                FMError::InvalidArgument(format!("instructions JSON contains a NUL byte: {error}"))
715            })?;
716        let transcript_c = transcript_json
717            .as_deref()
718            .map(CString::new)
719            .transpose()
720            .map_err(|error| {
721                FMError::InvalidArgument(format!("transcript JSON contains a NUL byte: {error}"))
722            })?;
723        let tools_c = tools_json
724            .as_deref()
725            .map(CString::new)
726            .transpose()
727            .map_err(|error| {
728                FMError::InvalidArgument(format!("tool JSON contains a NUL byte: {error}"))
729            })?;
730
731        let tool_context = tool_registry.as_ref().map_or(ptr::null_mut(), |registry| {
732            Arc::as_ptr(registry).cast_mut().cast::<c_void>()
733        });
734        let mut error: *mut c_char = ptr::null_mut();
735        let ptr = unsafe {
736            ffi::fm_session_create_ex(
737                self.model.map_or(ptr::null_mut(), |model| model.ptr),
738                instructions_c
739                    .as_ref()
740                    .map_or(ptr::null(), |json| json.as_ptr()),
741                transcript_c
742                    .as_ref()
743                    .map_or(ptr::null(), |json| json.as_ptr()),
744                tools_c.as_ref().map_or(ptr::null(), |json| json.as_ptr()),
745                tool_context,
746                tool_registry
747                    .as_ref()
748                    .map(|_| tool_callback_trampoline as ffi::FmToolCallback),
749                &mut error,
750            )
751        };
752        if ptr.is_null() {
753            return Err(crate::error::from_swift(
754                ffi::status::MODEL_UNAVAILABLE,
755                error,
756            ));
757        }
758        Ok(LanguageModelSession {
759            ptr,
760            _tool_registry: tool_registry,
761        })
762    }
763}
764
765/// A detailed generation response.
766#[derive(Debug, Clone, PartialEq)]
767pub struct SessionResponse<T> {
768    pub content: T,
769    pub raw_content: GeneratedContent,
770    pub transcript: Transcript,
771}
772
773/// One structured-generation stream snapshot.
774#[derive(Debug, Clone, PartialEq, Eq)]
775pub struct StructuredStreamSnapshot {
776    pub content_json: String,
777    pub raw_content_json: String,
778    pub is_complete: bool,
779}
780
781/// One structured stream event.
782#[derive(Debug, Clone, PartialEq)]
783#[non_exhaustive]
784pub enum StructuredStreamEvent {
785    Snapshot(StructuredStreamSnapshot),
786    Done,
787    Error(FMError),
788}
789
790/// One feedback issue category.
791#[derive(Debug, Clone, Copy, PartialEq, Eq)]
792pub enum FeedbackIssueCategory {
793    Unhelpful,
794    TooVerbose,
795    DidNotFollowInstructions,
796    Incorrect,
797    StereotypeOrBias,
798    SuggestiveOrSexual,
799    VulgarOrOffensive,
800    TriggeredGuardrailUnexpectedly,
801}
802
803impl FeedbackIssueCategory {
804    const fn as_str(self) -> &'static str {
805        match self {
806            Self::Unhelpful => "unhelpful",
807            Self::TooVerbose => "too_verbose",
808            Self::DidNotFollowInstructions => "did_not_follow_instructions",
809            Self::Incorrect => "incorrect",
810            Self::StereotypeOrBias => "stereotype_or_bias",
811            Self::SuggestiveOrSexual => "suggestive_or_sexual",
812            Self::VulgarOrOffensive => "vulgar_or_offensive",
813            Self::TriggeredGuardrailUnexpectedly => "triggered_guardrail_unexpectedly",
814        }
815    }
816}
817
818/// One feedback issue.
819#[derive(Debug, Clone, PartialEq, Eq)]
820pub struct FeedbackIssue {
821    pub category: FeedbackIssueCategory,
822    pub explanation: Option<String>,
823}
824
825/// Feedback sentiment.
826#[derive(Debug, Clone, Copy, PartialEq, Eq)]
827pub enum FeedbackSentiment {
828    Positive,
829    Negative,
830    Neutral,
831}
832
833impl FeedbackSentiment {
834    const fn as_str(self) -> &'static str {
835        match self {
836            Self::Positive => "positive",
837            Self::Negative => "negative",
838            Self::Neutral => "neutral",
839        }
840    }
841}
842
843/// A full feedback attachment request.
844#[derive(Debug, Clone, PartialEq)]
845pub struct FeedbackAttachmentRequest {
846    pub sentiment: Option<FeedbackSentiment>,
847    pub issues: Vec<FeedbackIssue>,
848    pub desired_response_text: Option<String>,
849    pub desired_response_content: Option<GeneratedContent>,
850    pub desired_output: Option<crate::transcript::Entry>,
851}
852
853impl FeedbackAttachmentRequest {
854    /// Create an empty feedback request.
855    #[must_use]
856    pub const fn new() -> Self {
857        Self {
858            sentiment: None,
859            issues: Vec::new(),
860            desired_response_text: None,
861            desired_response_content: None,
862            desired_output: None,
863        }
864    }
865
866    fn to_bridge_json(&self) -> Result<String, FMError> {
867        let issues = self
868            .issues
869            .iter()
870            .map(|issue| {
871                json!({
872                    "category": issue.category.as_str(),
873                    "explanation": issue.explanation,
874                })
875            })
876            .collect::<Vec<_>>();
877        let desired_output_json = self
878            .desired_output
879            .as_ref()
880            .map(|entry| Transcript::from(vec![entry.clone()]).to_json_string())
881            .transpose()?;
882        let desired_response_content = self
883            .desired_response_content
884            .as_ref()
885            .map(GeneratedContent::to_bridge_value)
886            .transpose()?;
887        serde_json::to_string(&json!({
888            "sentiment": self.sentiment.map(FeedbackSentiment::as_str),
889            "issues": issues,
890            "desiredResponseText": self.desired_response_text,
891            "desiredResponseContent": desired_response_content,
892            "desiredOutputTranscriptJSON": desired_output_json,
893        }))
894        .map_err(|error| {
895            FMError::InvalidArgument(format!(
896                "feedback request is not JSON-serializable: {error}"
897            ))
898        })
899    }
900}
901
902#[derive(Debug, Deserialize)]
903struct BridgeTextResponse {
904    content: String,
905    #[serde(rename = "rawContent")]
906    raw_content: BridgeGeneratedContent,
907    #[serde(rename = "transcriptJSON")]
908    transcript_json: String,
909}
910
911#[derive(Debug, Deserialize)]
912struct BridgeStructuredResponse {
913    content: BridgeGeneratedContent,
914    #[serde(rename = "rawContent")]
915    raw_content: BridgeGeneratedContent,
916    #[serde(rename = "transcriptJSON")]
917    transcript_json: String,
918}
919
920#[derive(Debug, Deserialize)]
921struct BridgeStructuredSnapshot {
922    content: BridgeGeneratedContent,
923    #[serde(rename = "rawContent")]
924    raw_content: BridgeGeneratedContent,
925    #[serde(rename = "isComplete")]
926    is_complete: bool,
927}
928
929#[derive(Debug, Deserialize)]
930struct BridgeTextStreamSnapshot {
931    delta: String,
932}
933
934fn respond_request_json(
935    prompt: &Prompt,
936    options: GenerationOptions,
937    schema: Option<&GenerationSchema>,
938    include_schema_in_prompt: bool,
939) -> Result<CString, FMError> {
940    let sampling = match options.sampling() {
941        SamplingMode::Default => json!({ "mode": "default" }),
942        SamplingMode::Greedy => json!({ "mode": "greedy" }),
943        SamplingMode::TopK(k) => json!({
944            "mode": "top_k",
945            "topK": k,
946            "seed": options.sampling_seed(),
947        }),
948        SamplingMode::TopP(p) => json!({
949            "mode": "top_p",
950            "topP": p,
951            "seed": options.sampling_seed(),
952        }),
953    };
954    let include_schema_in_prompt = schema.map_or(include_schema_in_prompt, |schema| {
955        schema.effective_include_schema_in_prompt(include_schema_in_prompt)
956    });
957    let payload = serde_json::to_string(&json!({
958        "prompt": prompt.to_bridge_value(),
959        "options": {
960            "temperature": options.temperature(),
961            "maximumResponseTokens": options.maximum_response_tokens(),
962            "sampling": sampling,
963        },
964        "schemaJSON": schema.map(GenerationSchema::bridge_request_json),
965        "includeSchemaInPrompt": include_schema_in_prompt,
966    }))
967    .map_err(|error| {
968        FMError::InvalidArgument(format!("request is not JSON-serializable: {error}"))
969    })?;
970    CString::new(payload).map_err(|error| {
971        FMError::InvalidArgument(format!("request JSON contains a NUL byte: {error}"))
972    })
973}
974
975fn request_response(session: *mut c_void, payload: &CString) -> Result<String, FMError> {
976    let (tx, rx) = mpsc::channel();
977    let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
978    let context = Box::into_raw(tx_box).cast::<c_void>();
979    unsafe {
980        ffi::fm_session_respond_request_json(session, payload.as_ptr(), context, respond_trampoline)
981    };
982    rx.recv().map_err(|_| FMError::Unknown {
983        code: ffi::status::UNKNOWN,
984        message: "Swift bridge dropped the JSON response channel".into(),
985    })?
986}
987
988pub(crate) fn decode_bridge_text_response(
989    payload: &str,
990) -> Result<SessionResponse<String>, FMError> {
991    let response: BridgeTextResponse = serde_json::from_str(payload)
992        .map_err(|error| FMError::DecodingFailure(error.to_string()))?;
993    Ok(SessionResponse {
994        content: response.content,
995        raw_content: GeneratedContent::from_bridge_payload(response.raw_content, true)?,
996        transcript: Transcript::from_json_str(&response.transcript_json)?,
997    })
998}
999
1000pub(crate) fn request_text_response_with<F>(invoke: F) -> Result<SessionResponse<String>, FMError>
1001where
1002    F: FnOnce(*mut c_void, ffi::FmRespondCallback),
1003{
1004    let (tx, rx) = mpsc::channel();
1005    let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
1006    let context = Box::into_raw(tx_box).cast::<c_void>();
1007    invoke(context, respond_trampoline);
1008    let payload = rx.recv().map_err(|_| FMError::Unknown {
1009        code: ffi::status::UNKNOWN,
1010        message: "Swift bridge dropped the JSON response channel".into(),
1011    })??;
1012    decode_bridge_text_response(&payload)
1013}
1014
1015pub(crate) fn run_text_stream_with<F, C>(invoke: F, on_chunk: C) -> Result<(), FMError>
1016where
1017    F: FnOnce(*mut c_void, ffi::FmStreamCallback),
1018    C: FnMut(StreamEvent<'_>) + Send + 'static,
1019{
1020    let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
1021    let state = Arc::new(StreamState {
1022        on_chunk: Mutex::new(Box::new(on_chunk)),
1023        done_tx: Mutex::new(Some(done_tx)),
1024    });
1025    let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
1026    invoke(context, json_text_stream_trampoline);
1027    done_rx.recv().map_err(|_| FMError::Unknown {
1028        code: ffi::status::UNKNOWN,
1029        message: "Swift bridge dropped the stream channel".into(),
1030    })?
1031}
1032
1033fn prompt_to_plain_text(prompt: &Prompt) -> Option<String> {
1034    let mut text = String::new();
1035    for segment in prompt.segments() {
1036        match segment {
1037            crate::prompt::Segment::Text(segment) => text.push_str(&segment.text),
1038            crate::prompt::Segment::Structure(_) => return None,
1039        }
1040    }
1041    Some(text)
1042}
1043
1044impl Default for LanguageModelSession {
1045    fn default() -> Self {
1046        Self::new()
1047    }
1048}
1049
1050impl Drop for LanguageModelSession {
1051    fn drop(&mut self) {
1052        if !self.ptr.is_null() {
1053            unsafe { ffi::fm_object_release(self.ptr) };
1054        }
1055    }
1056}
1057
1058impl core::fmt::Debug for LanguageModelSession {
1059    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1060        f.debug_struct("LanguageModelSession")
1061            .field("ptr", &self.ptr)
1062            .finish()
1063    }
1064}
1065
1066/// One event from a streaming generation.
1067#[derive(Debug)]
1068#[non_exhaustive]
1069pub enum StreamEvent<'a> {
1070    /// Incremental text delta. Concatenate these to reconstruct the full reply.
1071    Chunk(&'a str),
1072    /// Stream finished successfully.
1073    Done,
1074    /// Stream failed; the inner error describes why.
1075    Error(FMError),
1076}
1077
1078// ---------- internal callback plumbing ----------
1079
1080// SAFETY: `context` is a `Box<mpsc::Sender<...>>` raw pointer created by
1081// `request_response` / `request_text_response_with`. Swift calls this callback
1082// exactly once, so there is no double-free risk. `response` and `error` are
1083// C strings owned by the Swift bridge and only valid for this call.
1084unsafe extern "C" fn respond_trampoline(
1085    context: *mut c_void,
1086    response: *mut c_char,
1087    error: *mut c_char,
1088    status: i32,
1089) {
1090    let tx = Box::from_raw(context.cast::<mpsc::Sender<Result<String, FMError>>>());
1091    let result = if status == ffi::status::OK && !response.is_null() {
1092        let s = core::ffi::CStr::from_ptr(response)
1093            .to_string_lossy()
1094            .into_owned();
1095        ffi::fm_string_free(response);
1096        Ok(s)
1097    } else {
1098        Err(crate::error::from_swift(status, error))
1099    };
1100    let _ = tx.send(result);
1101}
1102
1103type StreamCallback = Box<dyn FnMut(StreamEvent<'_>) + Send>;
1104
1105struct StreamState {
1106    on_chunk: Mutex<StreamCallback>,
1107    done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
1108}
1109
1110// SAFETY: `context` is a `Arc<StreamState>` raw pointer passed via
1111// `Arc::into_raw`. We reconstruct it with `Arc::from_raw` on every call and
1112// immediately `mem::forget` a clone so the count stays ≥ 1 until the
1113// terminal call (done=true or error). `chunk` is a Swift-owned C string valid
1114// only for the duration of this call.
1115unsafe extern "C" fn json_text_stream_trampoline(
1116    context: *mut c_void,
1117    chunk: *mut c_char,
1118    done: bool,
1119    status: i32,
1120) {
1121    let state = Arc::from_raw(context.cast::<StreamState>());
1122    let state_for_swift = state.clone();
1123    core::mem::forget(state_for_swift);
1124
1125    let payload: Option<String> = if chunk.is_null() {
1126        None
1127    } else {
1128        let value = core::ffi::CStr::from_ptr(chunk)
1129            .to_string_lossy()
1130            .into_owned();
1131        ffi::fm_string_free(chunk);
1132        Some(value)
1133    };
1134
1135    if status != ffi::status::OK {
1136        let err = payload
1137            .map(|message| {
1138                crate::error::from_swift(
1139                    status,
1140                    ffi::fm_string_dup(
1141                        CString::new(message)
1142                            .expect("stream errors must not contain NUL bytes")
1143                            .as_ptr(),
1144                    ),
1145                )
1146            })
1147            .unwrap_or_else(|| crate::error::from_swift(status, ptr::null_mut()));
1148        {
1149            let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1150            // Catch panics so they don't unwind across the FFI boundary (UB).
1151            let _ = catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Error(err.clone()))));
1152        }
1153        if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1154            let _ = tx.send(Err(err));
1155        }
1156        drop(Arc::from_raw(Arc::as_ptr(&state)));
1157        drop(state);
1158        return;
1159    }
1160
1161    if let Some(payload) = payload {
1162        match serde_json::from_str::<BridgeTextStreamSnapshot>(&payload) {
1163            Ok(snapshot) if !snapshot.delta.is_empty() => {
1164                let chunk_panicked = {
1165                    let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1166                    // Catch panics so they don't unwind across the FFI boundary.
1167                    catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Chunk(&snapshot.delta))))
1168                        .is_err()
1169                };
1170                if chunk_panicked {
1171                    if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1172                        let _ = tx.send(Err(FMError::Unknown {
1173                            code: ffi::status::UNKNOWN,
1174                            message: "stream callback panicked".into(),
1175                        }));
1176                    }
1177                    drop(Arc::from_raw(Arc::as_ptr(&state)));
1178                    drop(state);
1179                    return;
1180                }
1181            }
1182            Ok(_) => {}
1183            Err(error) => {
1184                let err = FMError::DecodingFailure(error.to_string());
1185                {
1186                    let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1187                    let _ = catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Error(err.clone()))));
1188                }
1189                if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1190                    let _ = tx.send(Err(err));
1191                }
1192                drop(Arc::from_raw(Arc::as_ptr(&state)));
1193                drop(state);
1194                return;
1195            }
1196        }
1197    }
1198
1199    if done {
1200        {
1201            let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1202            let _ = catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Done)));
1203        }
1204        if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1205            let _ = tx.send(Ok(()));
1206        }
1207        drop(Arc::from_raw(Arc::as_ptr(&state)));
1208    }
1209    drop(state);
1210}
1211
1212type StructuredStreamCallback = Box<dyn FnMut(StructuredStreamEvent) + Send>;
1213
1214struct StructuredStreamState {
1215    on_event: Mutex<StructuredStreamCallback>,
1216    done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
1217}
1218
1219// SAFETY: Same invariants as `json_text_stream_trampoline` above, but for
1220// `StructuredStreamState`.
1221#[allow(clippy::too_many_lines)]
1222unsafe extern "C" fn structured_stream_trampoline(
1223    context: *mut c_void,
1224    chunk: *mut c_char,
1225    done: bool,
1226    status: i32,
1227) {
1228    let state = Arc::from_raw(context.cast::<StructuredStreamState>());
1229    let state_for_swift = state.clone();
1230    core::mem::forget(state_for_swift);
1231
1232    let payload: Option<String> = if chunk.is_null() {
1233        None
1234    } else {
1235        let value = core::ffi::CStr::from_ptr(chunk)
1236            .to_string_lossy()
1237            .into_owned();
1238        ffi::fm_string_free(chunk);
1239        Some(value)
1240    };
1241
1242    if status != ffi::status::OK {
1243        let err = payload
1244            .map(|message| {
1245                crate::error::from_swift(
1246                    status,
1247                    ffi::fm_string_dup(
1248                        CString::new(message)
1249                            .expect("stream errors must not contain NUL bytes")
1250                            .as_ptr(),
1251                    ),
1252                )
1253            })
1254            .unwrap_or_else(|| crate::error::from_swift(status, ptr::null_mut()));
1255        {
1256            let mut cb = state
1257                .on_event
1258                .lock()
1259                .expect("structured callback mutex poisoned");
1260            // Catch panics so they don't unwind across the FFI boundary (UB).
1261            let _ = catch_unwind(AssertUnwindSafe(|| {
1262                cb(StructuredStreamEvent::Error(err.clone()));
1263            }));
1264        }
1265        if let Some(tx) = state
1266            .done_tx
1267            .lock()
1268            .expect("structured done_tx mutex poisoned")
1269            .take()
1270        {
1271            let _ = tx.send(Err(err));
1272        }
1273        drop(Arc::from_raw(Arc::as_ptr(&state)));
1274        drop(state);
1275        return;
1276    }
1277
1278    if let Some(payload) = payload {
1279        let snapshot: BridgeStructuredSnapshot = match serde_json::from_str(&payload) {
1280            Ok(snapshot) => snapshot,
1281            Err(error) => {
1282                let err = FMError::DecodingFailure(error.to_string());
1283                {
1284                    let mut cb = state
1285                        .on_event
1286                        .lock()
1287                        .expect("structured callback mutex poisoned");
1288                    let _ = catch_unwind(AssertUnwindSafe(|| {
1289                        cb(StructuredStreamEvent::Error(err.clone()));
1290                    }));
1291                }
1292                if let Some(tx) = state
1293                    .done_tx
1294                    .lock()
1295                    .expect("structured done_tx mutex poisoned")
1296                    .take()
1297                {
1298                    let _ = tx.send(Err(err));
1299                }
1300                drop(Arc::from_raw(Arc::as_ptr(&state)));
1301                drop(state);
1302                return;
1303            }
1304        };
1305        let snapshot_event = StructuredStreamEvent::Snapshot(StructuredStreamSnapshot {
1306            content_json: snapshot.content.json,
1307            raw_content_json: snapshot.raw_content.json,
1308            is_complete: snapshot.is_complete,
1309        });
1310        let snapshot_panicked = {
1311            let mut cb = state
1312                .on_event
1313                .lock()
1314                .expect("structured callback mutex poisoned");
1315            // Catch panics so they don't unwind across the FFI boundary.
1316            catch_unwind(AssertUnwindSafe(|| cb(snapshot_event))).is_err()
1317        };
1318        if snapshot_panicked {
1319            if let Some(tx) = state
1320                .done_tx
1321                .lock()
1322                .expect("structured done_tx mutex poisoned")
1323                .take()
1324            {
1325                let _ = tx.send(Err(FMError::Unknown {
1326                    code: ffi::status::UNKNOWN,
1327                    message: "stream callback panicked".into(),
1328                }));
1329            }
1330            drop(Arc::from_raw(Arc::as_ptr(&state)));
1331            drop(state);
1332            return;
1333        }
1334    }
1335
1336    if done {
1337        {
1338            let mut cb = state
1339                .on_event
1340                .lock()
1341                .expect("structured callback mutex poisoned");
1342            let _ = catch_unwind(AssertUnwindSafe(|| cb(StructuredStreamEvent::Done)));
1343        }
1344        if let Some(tx) = state
1345            .done_tx
1346            .lock()
1347            .expect("structured done_tx mutex poisoned")
1348            .take()
1349        {
1350            let _ = tx.send(Ok(()));
1351        }
1352        drop(Arc::from_raw(Arc::as_ptr(&state)));
1353    }
1354    drop(state);
1355}