Skip to main content

foundation_models/session/
mod.rs

1//! [`LanguageModelSession`] — a stateful conversation with the on-device model.
2
3use core::ffi::{c_char, c_void};
4use core::ptr;
5use std::ffi::CString;
6use std::panic::{catch_unwind, AssertUnwindSafe};
7use std::sync::mpsc;
8use std::sync::{Arc, Mutex};
9
10use serde::Deserialize;
11use serde_json::json;
12
13use crate::content::{BridgeGeneratedContent, GeneratedContent};
14use crate::error::FMError;
15use crate::ffi;
16use crate::generation::{GenerationOptions, SamplingMode};
17use crate::model::ConfiguredSystemLanguageModel;
18use crate::prompt::{Instructions, Prompt, ToInstructions, ToPrompt};
19use crate::schema::GenerationSchema;
20use crate::tool::{tool_callback_trampoline, Tool, ToolRegistry};
21use crate::transcript::Transcript;
22
23/// A stateful conversation with the on-device language model.
24///
25/// Sessions retain their conversation history; subsequent calls to
26/// [`respond`](Self::respond) build on the previous turns.
27///
28/// # Examples
29///
30/// ```rust,no_run
31/// use foundation_models::LanguageModelSession;
32///
33/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
34/// let session = LanguageModelSession::new();
35/// let answer = session.respond("Name three Norse gods.")?;
36/// println!("{answer}");
37/// # Ok(())
38/// # }
39/// ```
40pub struct LanguageModelSession {
41    ptr: *mut c_void,
42    _tool_registry: Option<Arc<ToolRegistry>>,
43}
44
45// SAFETY: The underlying Swift LanguageModelSession is reference-counted via
46// Unmanaged.passRetained on the Swift side; sending the opaque pointer between
47// threads is safe as long as we don't dereference it from Rust (we never do —
48// it only travels through extern "C" calls that internally hop to the
49// Swift concurrency executor).
50unsafe impl Send for LanguageModelSession {}
51unsafe impl Sync for LanguageModelSession {}
52
53impl LanguageModelSession {
54    /// Return the raw opaque pointer to the underlying Swift session object.
55    ///
56    /// Used internally by `async_api` to pass the session pointer to FFI
57    /// callbacks without exposing `ptr` as a public field.
58    pub(crate) fn as_ptr(&self) -> *mut c_void {
59        self.ptr
60    }
61
62    /// Create a session with the model's default behaviour.
63    ///
64    /// # Panics
65    ///
66    /// Panics if `FoundationModels` is not available on this OS. Check
67    /// [`crate::SystemLanguageModel::is_available`] first if you need to
68    /// handle that gracefully.
69    #[must_use]
70    pub fn new() -> Self {
71        Self::try_new(None).expect("FoundationModels is not available on this OS")
72    }
73
74    /// Create a session with custom system instructions ("system prompt").
75    ///
76    /// # Panics
77    ///
78    /// Panics if `FoundationModels` is not available, or if `instructions`
79    /// contains an interior NUL byte.
80    #[must_use]
81    pub fn with_instructions(instructions: &str) -> Self {
82        Self::try_new(Some(instructions)).expect("FoundationModels is not available on this OS")
83    }
84
85    /// Fallible constructor. Returns `None` when `FoundationModels` is not
86    /// available (OS too old, model not enabled, etc.) or when `instructions`
87    /// contains an interior NUL byte.
88    #[must_use]
89    pub fn try_new(instructions: Option<&str>) -> Option<Self> {
90        let cstring = match instructions {
91            Some(s) => Some(CString::new(s).ok()?),
92            None => None,
93        };
94        let ptr =
95            unsafe { ffi::fm_session_create(cstring.as_ref().map_or(ptr::null(), |s| s.as_ptr())) };
96        if ptr.is_null() {
97            return None;
98        }
99        Some(Self {
100            ptr,
101            _tool_registry: None,
102        })
103    }
104
105    /// Send a prompt and block until the full response is available.
106    ///
107    /// # Errors
108    ///
109    /// Returns an [`FMError`] if the model rejects the prompt, the context
110    /// window is exceeded, the session is cancelled, or the prompt contains
111    /// an interior NUL byte.
112    pub fn respond(&self, prompt: &str) -> Result<String, FMError> {
113        self.respond_with(prompt, GenerationOptions::new())
114    }
115
116    /// Pre-warm the model. Apple loads the weights + initialises the
117    /// inference engine so the next `respond` call is faster. Returns
118    /// immediately; the warm-up runs in the background.
119    pub fn prewarm(&self) {
120        unsafe { ffi::fm_session_prewarm(self.ptr) };
121    }
122
123    /// True if this session is currently producing a response (i.e. an
124    /// earlier `respond` / `stream` is still in flight on Apple's queue).
125    #[must_use]
126    pub fn is_responding(&self) -> bool {
127        unsafe { ffi::fm_session_is_responding(self.ptr) }
128    }
129
130    /// Return a best-effort JSON serialisation of the session's
131    /// `Transcript` — the full history of user prompts and model
132    /// responses. Useful for persisting a chat session across
133    /// process boundaries.
134    #[must_use]
135    pub fn transcript_json(&self) -> String {
136        let p = unsafe { ffi::fm_session_transcript_json(self.ptr) };
137        if p.is_null() {
138            return String::from("{}");
139        }
140        let s = unsafe { core::ffi::CStr::from_ptr(p) }
141            .to_string_lossy()
142            .into_owned();
143        unsafe { ffi::fm_string_free(p) };
144        s
145    }
146
147    /// Log feedback on the most recent response for diagnostic /
148    /// fine-tuning purposes. `sentiment`:
149    /// `1` positive, `0` neutral, `-1` negative.
150    pub fn log_feedback(&self, sentiment: i32, description: Option<&str>) {
151        let cstr = description.and_then(|s| CString::new(s).ok());
152        let p = cstr.as_ref().map_or(core::ptr::null(), |c| c.as_ptr());
153        unsafe { ffi::fm_session_log_feedback(self.ptr, sentiment, p) };
154    }
155
156    /// Prompt-engineered JSON-shape response.
157    ///
158    /// Wraps the prompt with a "respond with valid JSON matching this schema"
159    /// instruction and parses the response. The schema is a
160    /// `serde_json::Value`-style JSON string (passed as text).
161    ///
162    /// Useful for getting structured data out of the model without the
163    /// full Generable macro machinery. The model still returns plain
164    /// text — the caller must parse with `serde_json` / `serde` after.
165    ///
166    /// # Errors
167    ///
168    /// See [`respond`](Self::respond).
169    pub fn respond_with_json_schema(
170        &self,
171        prompt: &str,
172        schema_description: &str,
173    ) -> Result<String, FMError> {
174        let wrapped = format!(
175            "{prompt}\n\n\
176             IMPORTANT: respond with VALID JSON ONLY (no prose, no markdown \
177             fences) that matches this schema:\n\n{schema_description}\n\n\
178             Your entire response must be parseable by JSON.parse()."
179        );
180        self.respond(&wrapped)
181    }
182
183    /// Like [`respond`](Self::respond), but with explicit generation options.
184    ///
185    /// # Errors
186    ///
187    /// See [`respond`](Self::respond).
188    pub fn respond_with(
189        &self,
190        prompt: &str,
191        options: GenerationOptions,
192    ) -> Result<String, FMError> {
193        self.respond_prompt_with(prompt, options)
194    }
195
196    /// Schema-driven structured response.
197    ///
198    /// Builds a `DynamicGenerationSchema` from the provided JSON
199    /// schema, runs `LanguageModelSession.respond(schema:prompt:)`,
200    /// and returns the model's `GeneratedContent.jsonString` — a
201    /// well-formed JSON string matching the requested shape.
202    ///
203    /// Supported `schema` shape (strict subset of JSON Schema):
204    ///
205    /// ```json
206    /// {
207    ///   "type": "object",
208    ///   "name": "Movie",
209    ///   "properties": {
210    ///     "title":  { "type": "string", "description": "Movie title" },
211    ///     "year":   { "type": "integer" },
212    ///     "rating": { "type": "number", "optional": true },
213    ///     "tags":   { "type": "array", "items": { "type": "string" }, "min": 1, "max": 5 }
214    ///   }
215    /// }
216    /// ```
217    ///
218    /// Primitive types: `"string"`, `"integer"`, `"number"`,
219    /// `"boolean"`, `"array"`, `"object"`. Each property may set
220    /// `"description"` and `"optional"`. Array schemas accept
221    /// `"items"` plus optional `"min"` / `"max"` element counts.
222    ///
223    /// # Errors
224    ///
225    /// See [`respond`](Self::respond) for general errors, plus a
226    /// "schema build failed" / "schema JSON is not valid" error
227    /// returned as [`FMError::Unknown`] if the schema is malformed.
228    pub fn respond_with_schema(
229        &self,
230        prompt: &str,
231        schema: &str,
232        include_schema_in_prompt: bool,
233    ) -> Result<String, FMError> {
234        self.respond_with_schema_options(
235            prompt,
236            schema,
237            include_schema_in_prompt,
238            GenerationOptions::new(),
239        )
240    }
241
242    /// [`respond_with_schema`](Self::respond_with_schema) with
243    /// explicit generation options.
244    ///
245    /// # Errors
246    ///
247    /// See [`respond_with_schema`](Self::respond_with_schema).
248    pub fn respond_with_schema_options(
249        &self,
250        prompt: &str,
251        schema: &str,
252        include_schema_in_prompt: bool,
253        options: GenerationOptions,
254    ) -> Result<String, FMError> {
255        let prompt_c = CString::new(prompt)
256            .map_err(|e| FMError::InvalidArgument(format!("prompt NUL byte: {e}")))?;
257        let schema_c = CString::new(schema)
258            .map_err(|e| FMError::InvalidArgument(format!("schema NUL byte: {e}")))?;
259        let opts = options.to_ffi();
260        let (tx, rx) = mpsc::channel();
261        let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
262        let context = Box::into_raw(tx_box).cast::<c_void>();
263
264        unsafe {
265            ffi::fm_session_respond_with_schema(
266                self.ptr,
267                prompt_c.as_ptr(),
268                schema_c.as_ptr(),
269                include_schema_in_prompt,
270                opts.temperature,
271                opts.maximum_response_tokens,
272                opts.sampling_mode,
273                opts.top_k,
274                opts.top_p,
275                context,
276                respond_trampoline,
277            );
278        }
279
280        rx.recv().map_err(|_| FMError::Unknown {
281            code: ffi::status::UNKNOWN,
282            message: "Swift bridge dropped the callback channel".into(),
283        })?
284    }
285
286    /// Stream the response as the model generates it. The callback is invoked
287    /// with each delta and a final invocation with `done == true`.
288    ///
289    /// # Errors
290    ///
291    /// Returns an [`FMError`] mirroring [`respond`](Self::respond). The
292    /// callback may also receive a chunk *and* an error if the stream fails
293    /// midway.
294    pub fn stream<F>(&self, prompt: &str, mut on_chunk: F) -> Result<(), FMError>
295    where
296        F: FnMut(StreamEvent<'_>) + Send + 'static,
297    {
298        self.stream_with(prompt, GenerationOptions::new(), move |event| {
299            on_chunk(event);
300        })
301    }
302
303    /// Like [`stream`](Self::stream), but with explicit generation options.
304    ///
305    /// # Errors
306    ///
307    /// See [`stream`](Self::stream).
308    pub fn stream_with<F>(
309        &self,
310        prompt: &str,
311        options: GenerationOptions,
312        on_chunk: F,
313    ) -> Result<(), FMError>
314    where
315        F: FnMut(StreamEvent<'_>) + Send + 'static,
316    {
317        let payload = respond_request_json(&Prompt::from(prompt), options, None, true)?;
318
319        let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
320        let state = Arc::new(StreamState {
321            on_chunk: Mutex::new(Box::new(on_chunk)),
322            done_tx: Mutex::new(Some(done_tx)),
323        });
324        let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
325
326        unsafe {
327            ffi::fm_session_stream_request_json(
328                self.ptr,
329                payload.as_ptr(),
330                context,
331                json_text_stream_trampoline,
332            )
333        };
334
335        done_rx.recv().map_err(|_| FMError::Unknown {
336            code: ffi::status::UNKNOWN,
337            message: "Swift bridge dropped the stream channel".into(),
338        })?
339    }
340}
341
342impl LanguageModelSession {
343    /// Create a configurable session builder.
344    #[must_use]
345    pub fn builder<'a>() -> SessionBuilder<'a> {
346        SessionBuilder::new()
347    }
348
349    /// Restore a session from a transcript.
350    ///
351    /// # Errors
352    ///
353    /// Returns an [`FMError`] if the transcript cannot be encoded for Swift.
354    pub fn from_transcript(transcript: Transcript) -> Result<Self, FMError> {
355        Self::builder().transcript(transcript).build()
356    }
357
358    /// Return the typed transcript for this session.
359    ///
360    /// # Errors
361    ///
362    /// Returns an [`FMError`] if the transcript JSON returned by Swift could not
363    /// be decoded.
364    pub fn transcript(&self) -> Result<Transcript, FMError> {
365        Transcript::from_json_str(&self.transcript_json())
366    }
367
368    /// Pre-warm the model using a prompt prefix.
369    ///
370    /// # Errors
371    ///
372    /// Returns an [`FMError`] if the prompt cannot be encoded for Swift.
373    pub fn prewarm_with_prompt<P>(&self, prompt: P) -> Result<(), FMError>
374    where
375        P: ToPrompt,
376    {
377        let prompt = prompt.to_prompt()?;
378        let prompt_json = CString::new(prompt.to_bridge_json()?).map_err(|error| {
379            FMError::InvalidArgument(format!("prompt JSON contains a NUL byte: {error}"))
380        })?;
381        let mut error: *mut c_char = ptr::null_mut();
382        let status = unsafe {
383            ffi::fm_session_prewarm_prompt_json(self.ptr, prompt_json.as_ptr(), &mut error)
384        };
385        if status != ffi::status::OK {
386            return Err(crate::error::from_swift(status, error));
387        }
388        Ok(())
389    }
390
391    /// Respond to a structured prompt and return only the generated text.
392    ///
393    /// # Errors
394    ///
395    /// Returns an [`FMError`] if generation fails.
396    pub fn respond_prompt<P>(&self, prompt: P) -> Result<String, FMError>
397    where
398        P: ToPrompt,
399    {
400        self.respond_prompt_with(prompt, GenerationOptions::new())
401    }
402
403    /// Like [`respond_prompt`](Self::respond_prompt), but with explicit options.
404    ///
405    /// # Errors
406    ///
407    /// Returns an [`FMError`] if generation fails.
408    pub fn respond_prompt_with<P>(
409        &self,
410        prompt: P,
411        options: GenerationOptions,
412    ) -> Result<String, FMError>
413    where
414        P: ToPrompt,
415    {
416        self.respond_prompt_detailed(prompt, options)
417            .map(|response| response.content)
418    }
419
420    /// Respond to a structured prompt and keep the full response metadata.
421    ///
422    /// # Errors
423    ///
424    /// Returns an [`FMError`] if generation fails.
425    pub fn respond_prompt_detailed<P>(
426        &self,
427        prompt: P,
428        options: GenerationOptions,
429    ) -> Result<SessionResponse<String>, FMError>
430    where
431        P: ToPrompt,
432    {
433        let prompt = prompt.to_prompt()?;
434        let payload = respond_request_json(&prompt, options, None, true)?;
435        let payload = request_response(self.ptr, &payload)?;
436        let response: BridgeTextResponse = serde_json::from_str(&payload)
437            .map_err(|error| FMError::DecodingFailure(error.to_string()))?;
438        Ok(SessionResponse {
439            content: response.content,
440            raw_content: GeneratedContent::from_bridge_payload(response.raw_content, true)?,
441            transcript: Transcript::from_json_str(&response.transcript_json)?,
442        })
443    }
444
445    /// Generate structured content using an explicit schema.
446    ///
447    /// # Errors
448    ///
449    /// Returns an [`FMError`] if generation fails or the schema is invalid.
450    pub fn respond_generated<P>(
451        &self,
452        prompt: P,
453        schema: &GenerationSchema,
454        include_schema_in_prompt: bool,
455    ) -> Result<GeneratedContent, FMError>
456    where
457        P: ToPrompt,
458    {
459        self.respond_generated_with(
460            prompt,
461            schema,
462            include_schema_in_prompt,
463            GenerationOptions::new(),
464        )
465        .map(|response| response.content)
466    }
467
468    /// Like [`respond_generated`](Self::respond_generated), but with explicit options.
469    ///
470    /// # Errors
471    ///
472    /// Returns an [`FMError`] if generation fails or the schema is invalid.
473    pub fn respond_generated_with<P>(
474        &self,
475        prompt: P,
476        schema: &GenerationSchema,
477        include_schema_in_prompt: bool,
478        options: GenerationOptions,
479    ) -> Result<SessionResponse<GeneratedContent>, FMError>
480    where
481        P: ToPrompt,
482    {
483        let prompt = prompt.to_prompt()?;
484        let payload =
485            respond_request_json(&prompt, options, Some(schema), include_schema_in_prompt)?;
486        let payload = request_response(self.ptr, &payload)?;
487        let response: BridgeStructuredResponse = serde_json::from_str(&payload)
488            .map_err(|error| FMError::DecodingFailure(error.to_string()))?;
489        Ok(SessionResponse {
490            content: GeneratedContent::from_bridge_payload(response.content, true)?,
491            raw_content: GeneratedContent::from_bridge_payload(response.raw_content, true)?,
492            transcript: Transcript::from_json_str(&response.transcript_json)?,
493        })
494    }
495
496    /// Generate a typed Rust value using a [`crate::schema::Generable`] implementation.
497    ///
498    /// # Errors
499    ///
500    /// Returns an [`FMError`] if generation fails or the generated JSON cannot
501    /// be decoded as `T`.
502    pub fn respond_generating<P, T>(
503        &self,
504        prompt: P,
505        include_schema_in_prompt: bool,
506        options: GenerationOptions,
507    ) -> Result<SessionResponse<T>, FMError>
508    where
509        P: ToPrompt,
510        T: crate::schema::Generable,
511    {
512        let response = self.respond_generated_with(
513            prompt,
514            &T::generation_schema()?,
515            include_schema_in_prompt,
516            options,
517        )?;
518        Ok(SessionResponse {
519            content: T::from_generated_content(&response.content)?,
520            raw_content: response.raw_content,
521            transcript: response.transcript,
522        })
523    }
524
525    /// Stream a structured prompt token-by-token.
526    ///
527    /// # Errors
528    ///
529    /// Returns an [`FMError`] if the prompt cannot be encoded or generation fails.
530    pub fn stream_prompt<P, F>(&self, prompt: P, on_chunk: F) -> Result<(), FMError>
531    where
532        P: ToPrompt,
533        F: FnMut(StreamEvent<'_>) + Send + 'static,
534    {
535        let prompt = prompt.to_prompt()?;
536        let prompt_text = prompt_to_plain_text(&prompt).ok_or_else(|| {
537            FMError::InvalidArgument(
538                "text streaming only supports prompts composed of text segments".into(),
539            )
540        })?;
541        self.stream_with(&prompt_text, GenerationOptions::new(), on_chunk)
542    }
543
544    /// Stream structured generation snapshots.
545    ///
546    /// # Errors
547    ///
548    /// Returns an [`FMError`] if the prompt cannot be encoded or generation fails.
549    pub fn stream_generated<P, F>(
550        &self,
551        prompt: P,
552        schema: &GenerationSchema,
553        include_schema_in_prompt: bool,
554        options: GenerationOptions,
555        on_event: F,
556    ) -> Result<(), FMError>
557    where
558        P: ToPrompt,
559        F: FnMut(StructuredStreamEvent) + Send + 'static,
560    {
561        let prompt = prompt.to_prompt()?;
562        let payload =
563            respond_request_json(&prompt, options, Some(schema), include_schema_in_prompt)?;
564        let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
565        let state = Arc::new(StructuredStreamState {
566            on_event: Mutex::new(Box::new(on_event)),
567            done_tx: Mutex::new(Some(done_tx)),
568        });
569        let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
570        unsafe {
571            ffi::fm_session_stream_request_json(
572                self.ptr,
573                payload.as_ptr(),
574                context,
575                structured_stream_trampoline,
576            )
577        };
578        done_rx.recv().map_err(|_| FMError::Unknown {
579            code: ffi::status::UNKNOWN,
580            message: "Swift bridge dropped the structured stream channel".into(),
581        })?
582    }
583
584    /// Log a feedback attachment and return the raw bytes Apple produced.
585    ///
586    /// # Errors
587    ///
588    /// Returns an [`FMError`] if the attachment request is invalid.
589    pub fn log_feedback_attachment(
590        &self,
591        request: FeedbackAttachmentRequest,
592    ) -> Result<Vec<u8>, FMError> {
593        let request_json = CString::new(request.to_bridge_json()?).map_err(|error| {
594            FMError::InvalidArgument(format!("feedback request contains a NUL byte: {error}"))
595        })?;
596        let mut length = 0usize;
597        let mut error: *mut c_char = ptr::null_mut();
598        let ptr = unsafe {
599            ffi::fm_session_log_feedback_attachment_json(
600                self.ptr,
601                request_json.as_ptr(),
602                &mut length,
603                &mut error,
604            )
605        };
606        if ptr.is_null() && !error.is_null() {
607            return Err(crate::error::from_swift(
608                ffi::status::INVALID_ARGUMENT,
609                error,
610            ));
611        }
612        if ptr.is_null() || length == 0 {
613            return Ok(Vec::new());
614        }
615        let bytes = unsafe { std::slice::from_raw_parts(ptr.cast::<u8>(), length) }.to_vec();
616        unsafe { ffi::fm_bytes_free(ptr) };
617        Ok(bytes)
618    }
619}
620
621/// Builder for [`LanguageModelSession`].
622pub struct SessionBuilder<'a> {
623    model: Option<&'a ConfiguredSystemLanguageModel>,
624    instructions: Option<Instructions>,
625    transcript: Option<Transcript>,
626    tools: Vec<Tool>,
627}
628
629impl<'a> SessionBuilder<'a> {
630    const fn new() -> Self {
631        Self {
632            model: None,
633            instructions: None,
634            transcript: None,
635            tools: Vec::new(),
636        }
637    }
638
639    /// Use a configured system model.
640    #[must_use]
641    pub const fn model(mut self, model: &'a ConfiguredSystemLanguageModel) -> Self {
642        self.model = Some(model);
643        self
644    }
645
646    /// Set system instructions.
647    pub fn instructions<I>(mut self, instructions: I) -> Result<Self, FMError>
648    where
649        I: ToInstructions,
650    {
651        self.instructions = Some(instructions.to_instructions()?);
652        Ok(self)
653    }
654
655    /// Restore the session from a transcript.
656    #[must_use]
657    pub fn transcript(mut self, transcript: Transcript) -> Self {
658        self.transcript = Some(transcript);
659        self
660    }
661
662    /// Add one tool.
663    #[must_use]
664    pub fn tool(mut self, tool: Tool) -> Self {
665        self.tools.push(tool);
666        self
667    }
668
669    /// Add many tools.
670    #[must_use]
671    pub fn tools(mut self, tools: impl IntoIterator<Item = Tool>) -> Self {
672        self.tools.extend(tools);
673        self
674    }
675
676    /// Build the session.
677    ///
678    /// # Errors
679    ///
680    /// Returns an [`FMError`] if the configuration cannot be encoded for Swift.
681    pub fn build(self) -> Result<LanguageModelSession, FMError> {
682        if self.instructions.is_some() && self.transcript.is_some() {
683            return Err(FMError::InvalidArgument(
684                "session builder accepts either instructions or a transcript, not both".into(),
685            ));
686        }
687
688        let instructions_json = self
689            .instructions
690            .as_ref()
691            .map(Instructions::to_bridge_json)
692            .transpose()?;
693        let transcript_json = self
694            .transcript
695            .as_ref()
696            .map(Transcript::to_json_string)
697            .transpose()?;
698        let tool_registry = if self.tools.is_empty() {
699            None
700        } else {
701            Some(Arc::new(ToolRegistry::new(self.tools)))
702        };
703        let tools_json = tool_registry
704            .as_ref()
705            .map(|registry| registry.specs_json())
706            .transpose()?;
707
708        let instructions_c = instructions_json
709            .as_deref()
710            .map(CString::new)
711            .transpose()
712            .map_err(|error| {
713                FMError::InvalidArgument(format!("instructions JSON contains a NUL byte: {error}"))
714            })?;
715        let transcript_c = transcript_json
716            .as_deref()
717            .map(CString::new)
718            .transpose()
719            .map_err(|error| {
720                FMError::InvalidArgument(format!("transcript JSON contains a NUL byte: {error}"))
721            })?;
722        let tools_c = tools_json
723            .as_deref()
724            .map(CString::new)
725            .transpose()
726            .map_err(|error| {
727                FMError::InvalidArgument(format!("tool JSON contains a NUL byte: {error}"))
728            })?;
729
730        let tool_context = tool_registry.as_ref().map_or(ptr::null_mut(), |registry| {
731            Arc::as_ptr(registry).cast_mut().cast::<c_void>()
732        });
733        let mut error: *mut c_char = ptr::null_mut();
734        let ptr = unsafe {
735            ffi::fm_session_create_ex(
736                self.model.map_or(ptr::null_mut(), |model| model.ptr),
737                instructions_c
738                    .as_ref()
739                    .map_or(ptr::null(), |json| json.as_ptr()),
740                transcript_c
741                    .as_ref()
742                    .map_or(ptr::null(), |json| json.as_ptr()),
743                tools_c.as_ref().map_or(ptr::null(), |json| json.as_ptr()),
744                tool_context,
745                tool_registry
746                    .as_ref()
747                    .map(|_| tool_callback_trampoline as ffi::FmToolCallback),
748                &mut error,
749            )
750        };
751        if ptr.is_null() {
752            return Err(crate::error::from_swift(
753                ffi::status::MODEL_UNAVAILABLE,
754                error,
755            ));
756        }
757        Ok(LanguageModelSession {
758            ptr,
759            _tool_registry: tool_registry,
760        })
761    }
762}
763
764/// A detailed generation response.
765#[derive(Debug, Clone, PartialEq)]
766pub struct SessionResponse<T> {
767    pub content: T,
768    pub raw_content: GeneratedContent,
769    pub transcript: Transcript,
770}
771
772/// One structured-generation stream snapshot.
773#[derive(Debug, Clone, PartialEq, Eq)]
774pub struct StructuredStreamSnapshot {
775    pub content_json: String,
776    pub raw_content_json: String,
777    pub is_complete: bool,
778}
779
780/// One structured stream event.
781#[derive(Debug, Clone, PartialEq)]
782#[non_exhaustive]
783pub enum StructuredStreamEvent {
784    Snapshot(StructuredStreamSnapshot),
785    Done,
786    Error(FMError),
787}
788
789/// One feedback issue category.
790#[derive(Debug, Clone, Copy, PartialEq, Eq)]
791pub enum FeedbackIssueCategory {
792    Unhelpful,
793    TooVerbose,
794    DidNotFollowInstructions,
795    Incorrect,
796    StereotypeOrBias,
797    SuggestiveOrSexual,
798    VulgarOrOffensive,
799    TriggeredGuardrailUnexpectedly,
800}
801
802impl FeedbackIssueCategory {
803    const fn as_str(self) -> &'static str {
804        match self {
805            Self::Unhelpful => "unhelpful",
806            Self::TooVerbose => "too_verbose",
807            Self::DidNotFollowInstructions => "did_not_follow_instructions",
808            Self::Incorrect => "incorrect",
809            Self::StereotypeOrBias => "stereotype_or_bias",
810            Self::SuggestiveOrSexual => "suggestive_or_sexual",
811            Self::VulgarOrOffensive => "vulgar_or_offensive",
812            Self::TriggeredGuardrailUnexpectedly => "triggered_guardrail_unexpectedly",
813        }
814    }
815}
816
817/// One feedback issue.
818#[derive(Debug, Clone, PartialEq, Eq)]
819pub struct FeedbackIssue {
820    pub category: FeedbackIssueCategory,
821    pub explanation: Option<String>,
822}
823
824/// Feedback sentiment.
825#[derive(Debug, Clone, Copy, PartialEq, Eq)]
826pub enum FeedbackSentiment {
827    Positive,
828    Negative,
829    Neutral,
830}
831
832impl FeedbackSentiment {
833    const fn as_str(self) -> &'static str {
834        match self {
835            Self::Positive => "positive",
836            Self::Negative => "negative",
837            Self::Neutral => "neutral",
838        }
839    }
840}
841
842/// A full feedback attachment request.
843#[derive(Debug, Clone, PartialEq)]
844pub struct FeedbackAttachmentRequest {
845    pub sentiment: Option<FeedbackSentiment>,
846    pub issues: Vec<FeedbackIssue>,
847    pub desired_response_text: Option<String>,
848    pub desired_response_content: Option<GeneratedContent>,
849    pub desired_output: Option<crate::transcript::Entry>,
850}
851
852impl FeedbackAttachmentRequest {
853    /// Create an empty feedback request.
854    #[must_use]
855    pub const fn new() -> Self {
856        Self {
857            sentiment: None,
858            issues: Vec::new(),
859            desired_response_text: None,
860            desired_response_content: None,
861            desired_output: None,
862        }
863    }
864
865    fn to_bridge_json(&self) -> Result<String, FMError> {
866        let issues = self
867            .issues
868            .iter()
869            .map(|issue| {
870                json!({
871                    "category": issue.category.as_str(),
872                    "explanation": issue.explanation,
873                })
874            })
875            .collect::<Vec<_>>();
876        let desired_output_json = self
877            .desired_output
878            .as_ref()
879            .map(|entry| Transcript::from(vec![entry.clone()]).to_json_string())
880            .transpose()?;
881        let desired_response_content = self
882            .desired_response_content
883            .as_ref()
884            .map(GeneratedContent::to_bridge_value)
885            .transpose()?;
886        serde_json::to_string(&json!({
887            "sentiment": self.sentiment.map(FeedbackSentiment::as_str),
888            "issues": issues,
889            "desiredResponseText": self.desired_response_text,
890            "desiredResponseContent": desired_response_content,
891            "desiredOutputTranscriptJSON": desired_output_json,
892        }))
893        .map_err(|error| {
894            FMError::InvalidArgument(format!(
895                "feedback request is not JSON-serializable: {error}"
896            ))
897        })
898    }
899}
900
901#[derive(Debug, Deserialize)]
902struct BridgeTextResponse {
903    content: String,
904    #[serde(rename = "rawContent")]
905    raw_content: BridgeGeneratedContent,
906    #[serde(rename = "transcriptJSON")]
907    transcript_json: String,
908}
909
910#[derive(Debug, Deserialize)]
911struct BridgeStructuredResponse {
912    content: BridgeGeneratedContent,
913    #[serde(rename = "rawContent")]
914    raw_content: BridgeGeneratedContent,
915    #[serde(rename = "transcriptJSON")]
916    transcript_json: String,
917}
918
919#[derive(Debug, Deserialize)]
920struct BridgeStructuredSnapshot {
921    content: BridgeGeneratedContent,
922    #[serde(rename = "rawContent")]
923    raw_content: BridgeGeneratedContent,
924    #[serde(rename = "isComplete")]
925    is_complete: bool,
926}
927
928#[derive(Debug, Deserialize)]
929struct BridgeTextStreamSnapshot {
930    delta: String,
931}
932
933fn respond_request_json(
934    prompt: &Prompt,
935    options: GenerationOptions,
936    schema: Option<&GenerationSchema>,
937    include_schema_in_prompt: bool,
938) -> Result<CString, FMError> {
939    let sampling = match options.sampling() {
940        SamplingMode::Default => json!({ "mode": "default" }),
941        SamplingMode::Greedy => json!({ "mode": "greedy" }),
942        SamplingMode::TopK(k) => json!({
943            "mode": "top_k",
944            "topK": k,
945            "seed": options.sampling_seed(),
946        }),
947        SamplingMode::TopP(p) => json!({
948            "mode": "top_p",
949            "topP": p,
950            "seed": options.sampling_seed(),
951        }),
952    };
953    let payload = serde_json::to_string(&json!({
954        "prompt": prompt.to_bridge_value(),
955        "options": {
956            "temperature": options.temperature(),
957            "maximumResponseTokens": options.maximum_response_tokens(),
958            "sampling": sampling,
959        },
960        "schemaJSON": schema.map(GenerationSchema::json_schema),
961        "includeSchemaInPrompt": include_schema_in_prompt,
962    }))
963    .map_err(|error| {
964        FMError::InvalidArgument(format!("request is not JSON-serializable: {error}"))
965    })?;
966    CString::new(payload).map_err(|error| {
967        FMError::InvalidArgument(format!("request JSON contains a NUL byte: {error}"))
968    })
969}
970
971fn request_response(session: *mut c_void, payload: &CString) -> Result<String, FMError> {
972    let (tx, rx) = mpsc::channel();
973    let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
974    let context = Box::into_raw(tx_box).cast::<c_void>();
975    unsafe {
976        ffi::fm_session_respond_request_json(session, payload.as_ptr(), context, respond_trampoline)
977    };
978    rx.recv().map_err(|_| FMError::Unknown {
979        code: ffi::status::UNKNOWN,
980        message: "Swift bridge dropped the JSON response channel".into(),
981    })?
982}
983
984pub(crate) fn decode_bridge_text_response(
985    payload: &str,
986) -> Result<SessionResponse<String>, FMError> {
987    let response: BridgeTextResponse = serde_json::from_str(payload)
988        .map_err(|error| FMError::DecodingFailure(error.to_string()))?;
989    Ok(SessionResponse {
990        content: response.content,
991        raw_content: GeneratedContent::from_bridge_payload(response.raw_content, true)?,
992        transcript: Transcript::from_json_str(&response.transcript_json)?,
993    })
994}
995
996pub(crate) fn request_text_response_with<F>(invoke: F) -> Result<SessionResponse<String>, FMError>
997where
998    F: FnOnce(*mut c_void, ffi::FmRespondCallback),
999{
1000    let (tx, rx) = mpsc::channel();
1001    let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
1002    let context = Box::into_raw(tx_box).cast::<c_void>();
1003    invoke(context, respond_trampoline);
1004    let payload = rx.recv().map_err(|_| FMError::Unknown {
1005        code: ffi::status::UNKNOWN,
1006        message: "Swift bridge dropped the JSON response channel".into(),
1007    })??;
1008    decode_bridge_text_response(&payload)
1009}
1010
1011pub(crate) fn run_text_stream_with<F, C>(invoke: F, on_chunk: C) -> Result<(), FMError>
1012where
1013    F: FnOnce(*mut c_void, ffi::FmStreamCallback),
1014    C: FnMut(StreamEvent<'_>) + Send + 'static,
1015{
1016    let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
1017    let state = Arc::new(StreamState {
1018        on_chunk: Mutex::new(Box::new(on_chunk)),
1019        done_tx: Mutex::new(Some(done_tx)),
1020    });
1021    let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
1022    invoke(context, json_text_stream_trampoline);
1023    done_rx.recv().map_err(|_| FMError::Unknown {
1024        code: ffi::status::UNKNOWN,
1025        message: "Swift bridge dropped the stream channel".into(),
1026    })?
1027}
1028
1029fn prompt_to_plain_text(prompt: &Prompt) -> Option<String> {
1030    let mut text = String::new();
1031    for segment in prompt.segments() {
1032        match segment {
1033            crate::prompt::Segment::Text(segment) => text.push_str(&segment.text),
1034            crate::prompt::Segment::Structure(_) => return None,
1035        }
1036    }
1037    Some(text)
1038}
1039
1040impl Default for LanguageModelSession {
1041    fn default() -> Self {
1042        Self::new()
1043    }
1044}
1045
1046impl Drop for LanguageModelSession {
1047    fn drop(&mut self) {
1048        if !self.ptr.is_null() {
1049            unsafe { ffi::fm_object_release(self.ptr) };
1050        }
1051    }
1052}
1053
1054impl core::fmt::Debug for LanguageModelSession {
1055    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1056        f.debug_struct("LanguageModelSession")
1057            .field("ptr", &self.ptr)
1058            .finish()
1059    }
1060}
1061
1062/// One event from a streaming generation.
1063#[derive(Debug)]
1064#[non_exhaustive]
1065pub enum StreamEvent<'a> {
1066    /// Incremental text delta. Concatenate these to reconstruct the full reply.
1067    Chunk(&'a str),
1068    /// Stream finished successfully.
1069    Done,
1070    /// Stream failed; the inner error describes why.
1071    Error(FMError),
1072}
1073
1074// ---------- internal callback plumbing ----------
1075
1076// SAFETY: `context` is a `Box<mpsc::Sender<...>>` raw pointer created by
1077// `request_response` / `request_text_response_with`. Swift calls this callback
1078// exactly once, so there is no double-free risk. `response` and `error` are
1079// C strings owned by the Swift bridge and only valid for this call.
1080unsafe extern "C" fn respond_trampoline(
1081    context: *mut c_void,
1082    response: *mut c_char,
1083    error: *mut c_char,
1084    status: i32,
1085) {
1086    let tx = Box::from_raw(context.cast::<mpsc::Sender<Result<String, FMError>>>());
1087    let result = if status == ffi::status::OK && !response.is_null() {
1088        let s = core::ffi::CStr::from_ptr(response)
1089            .to_string_lossy()
1090            .into_owned();
1091        ffi::fm_string_free(response);
1092        Ok(s)
1093    } else {
1094        Err(crate::error::from_swift(status, error))
1095    };
1096    let _ = tx.send(result);
1097}
1098
1099type StreamCallback = Box<dyn FnMut(StreamEvent<'_>) + Send>;
1100
1101struct StreamState {
1102    on_chunk: Mutex<StreamCallback>,
1103    done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
1104}
1105
1106// SAFETY: `context` is a `Arc<StreamState>` raw pointer passed via
1107// `Arc::into_raw`. We reconstruct it with `Arc::from_raw` on every call and
1108// immediately `mem::forget` a clone so the count stays ≥ 1 until the
1109// terminal call (done=true or error). `chunk` is a Swift-owned C string valid
1110// only for the duration of this call.
1111unsafe extern "C" fn json_text_stream_trampoline(
1112    context: *mut c_void,
1113    chunk: *mut c_char,
1114    done: bool,
1115    status: i32,
1116) {
1117    let state = Arc::from_raw(context.cast::<StreamState>());
1118    let state_for_swift = state.clone();
1119    core::mem::forget(state_for_swift);
1120
1121    let payload: Option<String> = if chunk.is_null() {
1122        None
1123    } else {
1124        let value = core::ffi::CStr::from_ptr(chunk)
1125            .to_string_lossy()
1126            .into_owned();
1127        ffi::fm_string_free(chunk);
1128        Some(value)
1129    };
1130
1131    if status != ffi::status::OK {
1132        let err = payload
1133            .map(|message| {
1134                crate::error::from_swift(
1135                    status,
1136                    ffi::fm_string_dup(
1137                        CString::new(message)
1138                            .expect("stream errors must not contain NUL bytes")
1139                            .as_ptr(),
1140                    ),
1141                )
1142            })
1143            .unwrap_or_else(|| crate::error::from_swift(status, ptr::null_mut()));
1144        {
1145            let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1146            // Catch panics so they don't unwind across the FFI boundary (UB).
1147            let _ = catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Error(err.clone()))));
1148        }
1149        if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1150            let _ = tx.send(Err(err));
1151        }
1152        drop(Arc::from_raw(Arc::as_ptr(&state)));
1153        drop(state);
1154        return;
1155    }
1156
1157    if let Some(payload) = payload {
1158        match serde_json::from_str::<BridgeTextStreamSnapshot>(&payload) {
1159            Ok(snapshot) if !snapshot.delta.is_empty() => {
1160                let chunk_panicked = {
1161                    let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1162                    // Catch panics so they don't unwind across the FFI boundary.
1163                    catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Chunk(&snapshot.delta))))
1164                        .is_err()
1165                };
1166                if chunk_panicked {
1167                    if let Some(tx) =
1168                        state.done_tx.lock().expect("done_tx mutex poisoned").take()
1169                    {
1170                        let _ = tx.send(Err(FMError::Unknown {
1171                            code: ffi::status::UNKNOWN,
1172                            message: "stream callback panicked".into(),
1173                        }));
1174                    }
1175                    drop(Arc::from_raw(Arc::as_ptr(&state)));
1176                    drop(state);
1177                    return;
1178                }
1179            }
1180            Ok(_) => {}
1181            Err(error) => {
1182                let err = FMError::DecodingFailure(error.to_string());
1183                {
1184                    let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1185                    let _ = catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Error(err.clone()))));
1186                }
1187                if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1188                    let _ = tx.send(Err(err));
1189                }
1190                drop(Arc::from_raw(Arc::as_ptr(&state)));
1191                drop(state);
1192                return;
1193            }
1194        }
1195    }
1196
1197    if done {
1198        {
1199            let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1200            let _ = catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Done)));
1201        }
1202        if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1203            let _ = tx.send(Ok(()));
1204        }
1205        drop(Arc::from_raw(Arc::as_ptr(&state)));
1206    }
1207    drop(state);
1208}
1209
1210type StructuredStreamCallback = Box<dyn FnMut(StructuredStreamEvent) + Send>;
1211
1212struct StructuredStreamState {
1213    on_event: Mutex<StructuredStreamCallback>,
1214    done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
1215}
1216
1217// SAFETY: Same invariants as `json_text_stream_trampoline` above, but for
1218// `StructuredStreamState`.
1219#[allow(clippy::too_many_lines)]
1220unsafe extern "C" fn structured_stream_trampoline(
1221    context: *mut c_void,
1222    chunk: *mut c_char,
1223    done: bool,
1224    status: i32,
1225) {
1226    let state = Arc::from_raw(context.cast::<StructuredStreamState>());
1227    let state_for_swift = state.clone();
1228    core::mem::forget(state_for_swift);
1229
1230    let payload: Option<String> = if chunk.is_null() {
1231        None
1232    } else {
1233        let value = core::ffi::CStr::from_ptr(chunk)
1234            .to_string_lossy()
1235            .into_owned();
1236        ffi::fm_string_free(chunk);
1237        Some(value)
1238    };
1239
1240    if status != ffi::status::OK {
1241        let err = payload
1242            .map(|message| {
1243                crate::error::from_swift(
1244                    status,
1245                    ffi::fm_string_dup(
1246                        CString::new(message)
1247                            .expect("stream errors must not contain NUL bytes")
1248                            .as_ptr(),
1249                    ),
1250                )
1251            })
1252            .unwrap_or_else(|| crate::error::from_swift(status, ptr::null_mut()));
1253        {
1254            let mut cb = state
1255                .on_event
1256                .lock()
1257                .expect("structured callback mutex poisoned");
1258            // Catch panics so they don't unwind across the FFI boundary (UB).
1259            let _ = catch_unwind(AssertUnwindSafe(|| {
1260                cb(StructuredStreamEvent::Error(err.clone()));
1261            }));
1262        }
1263        if let Some(tx) = state
1264            .done_tx
1265            .lock()
1266            .expect("structured done_tx mutex poisoned")
1267            .take()
1268        {
1269            let _ = tx.send(Err(err));
1270        }
1271        drop(Arc::from_raw(Arc::as_ptr(&state)));
1272        drop(state);
1273        return;
1274    }
1275
1276    if let Some(payload) = payload {
1277        let snapshot: BridgeStructuredSnapshot = match serde_json::from_str(&payload) {
1278            Ok(snapshot) => snapshot,
1279            Err(error) => {
1280                let err = FMError::DecodingFailure(error.to_string());
1281                {
1282                    let mut cb = state
1283                        .on_event
1284                        .lock()
1285                        .expect("structured callback mutex poisoned");
1286                    let _ = catch_unwind(AssertUnwindSafe(|| {
1287                        cb(StructuredStreamEvent::Error(err.clone()));
1288                    }));
1289                }
1290                if let Some(tx) = state
1291                    .done_tx
1292                    .lock()
1293                    .expect("structured done_tx mutex poisoned")
1294                    .take()
1295                {
1296                    let _ = tx.send(Err(err));
1297                }
1298                drop(Arc::from_raw(Arc::as_ptr(&state)));
1299                drop(state);
1300                return;
1301            }
1302        };
1303        let snapshot_event = StructuredStreamEvent::Snapshot(StructuredStreamSnapshot {
1304            content_json: snapshot.content.json,
1305            raw_content_json: snapshot.raw_content.json,
1306            is_complete: snapshot.is_complete,
1307        });
1308        let snapshot_panicked = {
1309            let mut cb = state
1310                .on_event
1311                .lock()
1312                .expect("structured callback mutex poisoned");
1313            // Catch panics so they don't unwind across the FFI boundary.
1314            catch_unwind(AssertUnwindSafe(|| cb(snapshot_event))).is_err()
1315        };
1316        if snapshot_panicked {
1317            if let Some(tx) = state
1318                .done_tx
1319                .lock()
1320                .expect("structured done_tx mutex poisoned")
1321                .take()
1322            {
1323                let _ = tx.send(Err(FMError::Unknown {
1324                    code: ffi::status::UNKNOWN,
1325                    message: "stream callback panicked".into(),
1326                }));
1327            }
1328            drop(Arc::from_raw(Arc::as_ptr(&state)));
1329            drop(state);
1330            return;
1331        }
1332    }
1333
1334    if done {
1335        {
1336            let mut cb = state
1337                .on_event
1338                .lock()
1339                .expect("structured callback mutex poisoned");
1340            let _ = catch_unwind(AssertUnwindSafe(|| cb(StructuredStreamEvent::Done)));
1341        }
1342        if let Some(tx) = state
1343            .done_tx
1344            .lock()
1345            .expect("structured done_tx mutex poisoned")
1346            .take()
1347        {
1348            let _ = tx.send(Ok(()));
1349        }
1350        drop(Arc::from_raw(Arc::as_ptr(&state)));
1351    }
1352    drop(state);
1353}