Skip to main content

foundation_models/session/
mod.rs

1//! [`LanguageModelSession`] — a stateful conversation with the on-device model.
2
3use core::ffi::{c_char, c_void};
4use core::ptr;
5use std::ffi::CString;
6use std::panic::{catch_unwind, AssertUnwindSafe};
7use std::sync::mpsc;
8use std::sync::{Arc, Mutex};
9
10use serde::Deserialize;
11use serde_json::json;
12
13use crate::content::{BridgeGeneratedContent, GeneratedContent};
14use crate::error::FMError;
15use crate::ffi;
16use crate::generation::{GenerationOptions, SamplingMode};
17use crate::model::ConfiguredSystemLanguageModel;
18use crate::prompt::{Instructions, Prompt, ToInstructions, ToPrompt};
19use crate::schema::GenerationSchema;
20use crate::tool::{tool_callback_trampoline, Tool, ToolRegistry};
21use crate::transcript::Transcript;
22
23/// A stateful conversation with the on-device language model.
24///
25/// Sessions retain their conversation history; subsequent calls to
26/// [`respond`](Self::respond) build on the previous turns.
27///
28/// # Examples
29///
30/// ```rust,no_run
31/// use foundation_models::LanguageModelSession;
32///
33/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
34/// let session = LanguageModelSession::new();
35/// let answer = session.respond("Name three Norse gods.")?;
36/// println!("{answer}");
37/// # Ok(())
38/// # }
39/// ```
40pub struct LanguageModelSession {
41    ptr: *mut c_void,
42    _tool_registry: Option<Arc<ToolRegistry>>,
43}
44
45// SAFETY: The underlying Swift LanguageModelSession is reference-counted via
46// Unmanaged.passRetained on the Swift side; sending the opaque pointer between
47// threads is safe as long as we don't dereference it from Rust (we never do —
48// it only travels through extern "C" calls that internally hop to the
49// Swift concurrency executor).
50unsafe impl Send for LanguageModelSession {}
51unsafe impl Sync for LanguageModelSession {}
52
53impl LanguageModelSession {
54    /// Return the raw opaque pointer to the underlying Swift session object.
55    ///
56    /// Used internally by `async_api` to pass the session pointer to FFI
57    /// callbacks without exposing `ptr` as a public field.
58    pub(crate) fn as_ptr(&self) -> *mut c_void {
59        self.ptr
60    }
61
62    /// Create a session with the model's default behaviour.
63    ///
64    /// # Panics
65    ///
66    /// Panics if `FoundationModels` is not available on this OS. Check
67    /// [`crate::SystemLanguageModel::is_available`] first if you need to
68    /// handle that gracefully.
69    #[must_use]
70    pub fn new() -> Self {
71        Self::try_new(None).expect("FoundationModels is not available on this OS")
72    }
73
74    /// Create a session with custom system instructions ("system prompt").
75    ///
76    /// # Panics
77    ///
78    /// Panics if `FoundationModels` is not available, or if `instructions`
79    /// contains an interior NUL byte.
80    #[must_use]
81    pub fn with_instructions(instructions: &str) -> Self {
82        Self::try_new(Some(instructions)).expect("FoundationModels is not available on this OS")
83    }
84
85    /// Fallible constructor. Returns `None` when `FoundationModels` is not
86    /// available (OS too old, model not enabled, etc.) or when `instructions`
87    /// contains an interior NUL byte.
88    #[must_use]
89    pub fn try_new(instructions: Option<&str>) -> Option<Self> {
90        let cstring = match instructions {
91            Some(s) => Some(CString::new(s).ok()?),
92            None => None,
93        };
94        let ptr =
95            unsafe { ffi::fm_session_create(cstring.as_ref().map_or(ptr::null(), |s| s.as_ptr())) };
96        if ptr.is_null() {
97            return None;
98        }
99        Some(Self {
100            ptr,
101            _tool_registry: None,
102        })
103    }
104
105    /// Send a prompt and block until the full response is available.
106    ///
107    /// # Errors
108    ///
109    /// Returns an [`FMError`] if the model rejects the prompt, the context
110    /// window is exceeded, the session is cancelled, or the prompt contains
111    /// an interior NUL byte.
112    pub fn respond(&self, prompt: &str) -> Result<String, FMError> {
113        self.respond_with(prompt, GenerationOptions::new())
114    }
115
116    /// Pre-warm the model. Apple loads the weights + initialises the
117    /// inference engine so the next `respond` call is faster. Returns
118    /// immediately; the warm-up runs in the background.
119    pub fn prewarm(&self) {
120        unsafe { ffi::fm_session_prewarm(self.ptr) };
121    }
122
123    /// True if this session is currently producing a response (i.e. an
124    /// earlier `respond` / `stream` is still in flight on Apple's queue).
125    #[must_use]
126    pub fn is_responding(&self) -> bool {
127        unsafe { ffi::fm_session_is_responding(self.ptr) }
128    }
129
130    /// Return a best-effort JSON serialisation of the session's
131    /// `Transcript` — the full history of user prompts and model
132    /// responses. Useful for persisting a chat session across
133    /// process boundaries.
134    #[must_use]
135    pub fn transcript_json(&self) -> String {
136        let p = unsafe { ffi::fm_session_transcript_json(self.ptr) };
137        if p.is_null() {
138            return String::from("{}");
139        }
140        let s = unsafe { core::ffi::CStr::from_ptr(p) }
141            .to_string_lossy()
142            .into_owned();
143        unsafe { ffi::fm_string_free(p) };
144        s
145    }
146
147    /// Log feedback on the most recent response for diagnostic /
148    /// fine-tuning purposes. `sentiment`:
149    /// `1` positive, `0` neutral, `-1` negative.
150    pub fn log_feedback(&self, sentiment: i32, description: Option<&str>) {
151        let cstr = description.and_then(|s| CString::new(s).ok());
152        let p = cstr.as_ref().map_or(core::ptr::null(), |c| c.as_ptr());
153        unsafe { ffi::fm_session_log_feedback(self.ptr, sentiment, p) };
154    }
155
156    /// Prompt-engineered JSON-shape response.
157    ///
158    /// Wraps the prompt with a "respond with valid JSON matching this schema"
159    /// instruction and parses the response. The schema is a
160    /// `serde_json::Value`-style JSON string (passed as text).
161    ///
162    /// Useful for getting structured data out of the model without the
163    /// full Generable macro machinery. The model still returns plain
164    /// text — the caller must parse with `serde_json` / `serde` after.
165    ///
166    /// # Errors
167    ///
168    /// See [`respond`](Self::respond).
169    pub fn respond_with_json_schema(
170        &self,
171        prompt: &str,
172        schema_description: &str,
173    ) -> Result<String, FMError> {
174        let wrapped = format!(
175            "{prompt}\n\n\
176             IMPORTANT: respond with VALID JSON ONLY (no prose, no markdown \
177             fences) that matches this schema:\n\n{schema_description}\n\n\
178             Your entire response must be parseable by JSON.parse()."
179        );
180        self.respond(&wrapped)
181    }
182
183    /// Like [`respond`](Self::respond), but with explicit generation options.
184    ///
185    /// # Errors
186    ///
187    /// See [`respond`](Self::respond).
188    pub fn respond_with(
189        &self,
190        prompt: &str,
191        options: GenerationOptions,
192    ) -> Result<String, FMError> {
193        self.respond_prompt_with(prompt, options)
194    }
195
196    /// Schema-driven structured response.
197    ///
198    /// Builds a `DynamicGenerationSchema` from the provided JSON
199    /// schema, runs `LanguageModelSession.respond(schema:prompt:)`,
200    /// and returns the model's `GeneratedContent.jsonString` — a
201    /// well-formed JSON string matching the requested shape.
202    ///
203    /// Supported `schema` shape (strict subset of JSON Schema):
204    ///
205    /// ```json
206    /// {
207    ///   "type": "object",
208    ///   "name": "Movie",
209    ///   "properties": {
210    ///     "title":  { "type": "string", "description": "Movie title" },
211    ///     "year":   { "type": "integer" },
212    ///     "rating": { "type": "number", "optional": true },
213    ///     "tags":   { "type": "array", "items": { "type": "string" }, "min": 1, "max": 5 }
214    ///   }
215    /// }
216    /// ```
217    ///
218    /// Primitive types: `"string"`, `"integer"`, `"number"`,
219    /// `"boolean"`, `"array"`, `"object"`. Each property may set
220    /// `"description"` and `"optional"`. Array schemas accept
221    /// `"items"` plus optional `"min"` / `"max"` element counts.
222    ///
223    /// # Errors
224    ///
225    /// See [`respond`](Self::respond) for general errors, plus a
226    /// "schema build failed" / "schema JSON is not valid" error
227    /// returned as [`FMError::Unknown`] if the schema is malformed.
228    pub fn respond_with_schema(
229        &self,
230        prompt: &str,
231        schema: &str,
232        include_schema_in_prompt: bool,
233    ) -> Result<String, FMError> {
234        self.respond_with_schema_options(
235            prompt,
236            schema,
237            include_schema_in_prompt,
238            GenerationOptions::new(),
239        )
240    }
241
242    /// [`respond_with_schema`](Self::respond_with_schema) with
243    /// explicit generation options.
244    ///
245    /// # Errors
246    ///
247    /// See [`respond_with_schema`](Self::respond_with_schema).
248    pub fn respond_with_schema_options(
249        &self,
250        prompt: &str,
251        schema: &str,
252        include_schema_in_prompt: bool,
253        options: GenerationOptions,
254    ) -> Result<String, FMError> {
255        let prompt_c = CString::new(prompt)
256            .map_err(|e| FMError::InvalidArgument(format!("prompt NUL byte: {e}")))?;
257        let schema_c = CString::new(schema)
258            .map_err(|e| FMError::InvalidArgument(format!("schema NUL byte: {e}")))?;
259        let opts = options.to_ffi();
260        let (tx, rx) = mpsc::channel();
261        let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
262        let context = Box::into_raw(tx_box).cast::<c_void>();
263
264        unsafe {
265            ffi::fm_session_respond_with_schema(
266                self.ptr,
267                prompt_c.as_ptr(),
268                schema_c.as_ptr(),
269                include_schema_in_prompt,
270                opts.temperature,
271                opts.maximum_response_tokens,
272                opts.sampling_mode,
273                opts.top_k,
274                opts.top_p,
275                context,
276                respond_trampoline,
277            );
278        }
279
280        rx.recv().map_err(|_| FMError::Unknown {
281            code: ffi::status::UNKNOWN,
282            message: "Swift bridge dropped the callback channel".into(),
283        })?
284    }
285
286    /// Stream the response as the model generates it. The callback is invoked
287    /// with each delta and a final invocation with `done == true`.
288    ///
289    /// # Errors
290    ///
291    /// Returns an [`FMError`] mirroring [`respond`](Self::respond). The
292    /// callback may also receive a chunk *and* an error if the stream fails
293    /// midway.
294    pub fn stream<F>(&self, prompt: &str, mut on_chunk: F) -> Result<(), FMError>
295    where
296        F: FnMut(StreamEvent<'_>) + Send + 'static,
297    {
298        self.stream_with(prompt, GenerationOptions::new(), move |event| {
299            on_chunk(event);
300        })
301    }
302
303    /// Like [`stream`](Self::stream), but with explicit generation options.
304    ///
305    /// # Errors
306    ///
307    /// See [`stream`](Self::stream).
308    pub fn stream_with<F>(
309        &self,
310        prompt: &str,
311        options: GenerationOptions,
312        on_chunk: F,
313    ) -> Result<(), FMError>
314    where
315        F: FnMut(StreamEvent<'_>) + Send + 'static,
316    {
317        let payload = respond_request_json(&Prompt::from(prompt), options, None, true)?;
318
319        let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
320        let state = Arc::new(StreamState {
321            on_chunk: Mutex::new(Box::new(on_chunk)),
322            done_tx: Mutex::new(Some(done_tx)),
323        });
324        let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
325
326        unsafe {
327            ffi::fm_session_stream_request_json(
328                self.ptr,
329                payload.as_ptr(),
330                context,
331                json_text_stream_trampoline,
332            )
333        };
334
335        done_rx.recv().map_err(|_| FMError::Unknown {
336            code: ffi::status::UNKNOWN,
337            message: "Swift bridge dropped the stream channel".into(),
338        })?
339    }
340}
341
342impl LanguageModelSession {
343    /// Create a configurable session builder.
344    #[must_use]
345    pub fn builder<'a>() -> SessionBuilder<'a> {
346        SessionBuilder::new()
347    }
348
349    /// Restore a session from a transcript.
350    ///
351    /// # Errors
352    ///
353    /// Returns an [`FMError`] if the transcript cannot be encoded for Swift.
354    pub fn from_transcript(transcript: Transcript) -> Result<Self, FMError> {
355        Self::builder().transcript(transcript).build()
356    }
357
358    /// Return the typed transcript for this session.
359    ///
360    /// # Errors
361    ///
362    /// Returns an [`FMError`] if the transcript JSON returned by Swift could not
363    /// be decoded.
364    pub fn transcript(&self) -> Result<Transcript, FMError> {
365        Transcript::from_json_str(&self.transcript_json())
366    }
367
368    /// Pre-warm the model using a prompt prefix.
369    ///
370    /// # Errors
371    ///
372    /// Returns an [`FMError`] if the prompt cannot be encoded for Swift.
373    pub fn prewarm_with_prompt<P>(&self, prompt: P) -> Result<(), FMError>
374    where
375        P: ToPrompt,
376    {
377        let prompt = prompt.to_prompt()?;
378        let prompt_json = CString::new(prompt.to_bridge_json()?).map_err(|error| {
379            FMError::InvalidArgument(format!("prompt JSON contains a NUL byte: {error}"))
380        })?;
381        let mut error: *mut c_char = ptr::null_mut();
382        let status = unsafe {
383            ffi::fm_session_prewarm_prompt_json(self.ptr, prompt_json.as_ptr(), &mut error)
384        };
385        if status != ffi::status::OK {
386            return Err(crate::error::from_swift(status, error));
387        }
388        Ok(())
389    }
390
391    /// Respond to a structured prompt and return only the generated text.
392    ///
393    /// # Errors
394    ///
395    /// Returns an [`FMError`] if generation fails.
396    pub fn respond_prompt<P>(&self, prompt: P) -> Result<String, FMError>
397    where
398        P: ToPrompt,
399    {
400        self.respond_prompt_with(prompt, GenerationOptions::new())
401    }
402
403    /// Like [`respond_prompt`](Self::respond_prompt), but with explicit options.
404    ///
405    /// # Errors
406    ///
407    /// Returns an [`FMError`] if generation fails.
408    pub fn respond_prompt_with<P>(
409        &self,
410        prompt: P,
411        options: GenerationOptions,
412    ) -> Result<String, FMError>
413    where
414        P: ToPrompt,
415    {
416        self.respond_prompt_detailed(prompt, options)
417            .map(|response| response.content)
418    }
419
420    /// Respond to a structured prompt and keep the full response metadata.
421    ///
422    /// # Errors
423    ///
424    /// Returns an [`FMError`] if generation fails.
425    pub fn respond_prompt_detailed<P>(
426        &self,
427        prompt: P,
428        options: GenerationOptions,
429    ) -> Result<SessionResponse<String>, FMError>
430    where
431        P: ToPrompt,
432    {
433        let prompt = prompt.to_prompt()?;
434        let payload = respond_request_json(&prompt, options, None, true)?;
435        let payload = request_response(self.ptr, &payload)?;
436        let response: BridgeTextResponse = serde_json::from_str(&payload)
437            .map_err(|error| FMError::DecodingFailure(error.to_string()))?;
438        Ok(SessionResponse {
439            content: response.content,
440            raw_content: GeneratedContent::from_bridge_payload(response.raw_content, true)?,
441            transcript: Transcript::from_json_str(&response.transcript_json)?,
442        })
443    }
444
445    /// Generate structured content using an explicit schema.
446    ///
447    /// # Errors
448    ///
449    /// Returns an [`FMError`] if generation fails or the schema is invalid.
450    pub fn respond_generated<P>(
451        &self,
452        prompt: P,
453        schema: &GenerationSchema,
454        include_schema_in_prompt: bool,
455    ) -> Result<GeneratedContent, FMError>
456    where
457        P: ToPrompt,
458    {
459        self.respond_generated_with(
460            prompt,
461            schema,
462            include_schema_in_prompt,
463            GenerationOptions::new(),
464        )
465        .map(|response| response.content)
466    }
467
468    /// Like [`respond_generated`](Self::respond_generated), but with explicit options.
469    ///
470    /// # Errors
471    ///
472    /// Returns an [`FMError`] if generation fails or the schema is invalid.
473    pub fn respond_generated_with<P>(
474        &self,
475        prompt: P,
476        schema: &GenerationSchema,
477        include_schema_in_prompt: bool,
478        options: GenerationOptions,
479    ) -> Result<SessionResponse<GeneratedContent>, FMError>
480    where
481        P: ToPrompt,
482    {
483        let prompt = prompt.to_prompt()?;
484        let payload =
485            respond_request_json(&prompt, options, Some(schema), include_schema_in_prompt)?;
486        let payload = request_response(self.ptr, &payload)?;
487        let response: BridgeStructuredResponse = serde_json::from_str(&payload)
488            .map_err(|error| FMError::DecodingFailure(error.to_string()))?;
489        Ok(SessionResponse {
490            content: GeneratedContent::from_bridge_payload(response.content, true)?,
491            raw_content: GeneratedContent::from_bridge_payload(response.raw_content, true)?,
492            transcript: Transcript::from_json_str(&response.transcript_json)?,
493        })
494    }
495
496    /// Generate a typed Rust value using a [`crate::schema::Generable`] implementation.
497    ///
498    /// # Errors
499    ///
500    /// Returns an [`FMError`] if generation fails or the generated JSON cannot
501    /// be decoded as `T`.
502    pub fn respond_generating<P, T>(
503        &self,
504        prompt: P,
505        include_schema_in_prompt: bool,
506        options: GenerationOptions,
507    ) -> Result<SessionResponse<T>, FMError>
508    where
509        P: ToPrompt,
510        T: crate::schema::Generable,
511    {
512        let response = self.respond_generated_with(
513            prompt,
514            &T::generation_schema()?,
515            include_schema_in_prompt,
516            options,
517        )?;
518        Ok(SessionResponse {
519            content: T::from_generated_content(&response.content)?,
520            raw_content: response.raw_content,
521            transcript: response.transcript,
522        })
523    }
524
525    /// Stream a structured prompt token-by-token.
526    ///
527    /// # Errors
528    ///
529    /// Returns an [`FMError`] if the prompt cannot be encoded or generation fails.
530    pub fn stream_prompt<P, F>(&self, prompt: P, on_chunk: F) -> Result<(), FMError>
531    where
532        P: ToPrompt,
533        F: FnMut(StreamEvent<'_>) + Send + 'static,
534    {
535        let prompt = prompt.to_prompt()?;
536        let prompt_text = prompt_to_plain_text(&prompt).ok_or_else(|| {
537            FMError::InvalidArgument(
538                "text streaming only supports prompts composed of text segments".into(),
539            )
540        })?;
541        self.stream_with(&prompt_text, GenerationOptions::new(), on_chunk)
542    }
543
544    /// Stream structured generation snapshots.
545    ///
546    /// # Errors
547    ///
548    /// Returns an [`FMError`] if the prompt cannot be encoded or generation fails.
549    pub fn stream_generated<P, F>(
550        &self,
551        prompt: P,
552        schema: &GenerationSchema,
553        include_schema_in_prompt: bool,
554        options: GenerationOptions,
555        on_event: F,
556    ) -> Result<(), FMError>
557    where
558        P: ToPrompt,
559        F: FnMut(StructuredStreamEvent) + Send + 'static,
560    {
561        let prompt = prompt.to_prompt()?;
562        let payload =
563            respond_request_json(&prompt, options, Some(schema), include_schema_in_prompt)?;
564        let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
565        let state = Arc::new(StructuredStreamState {
566            on_event: Mutex::new(Box::new(on_event)),
567            done_tx: Mutex::new(Some(done_tx)),
568        });
569        let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
570        unsafe {
571            ffi::fm_session_stream_request_json(
572                self.ptr,
573                payload.as_ptr(),
574                context,
575                structured_stream_trampoline,
576            )
577        };
578        done_rx.recv().map_err(|_| FMError::Unknown {
579            code: ffi::status::UNKNOWN,
580            message: "Swift bridge dropped the structured stream channel".into(),
581        })?
582    }
583
584    /// Log a feedback attachment and return the raw bytes Apple produced.
585    ///
586    /// # Errors
587    ///
588    /// Returns an [`FMError`] if the attachment request is invalid.
589    pub fn log_feedback_attachment(
590        &self,
591        request: FeedbackAttachmentRequest,
592    ) -> Result<Vec<u8>, FMError> {
593        let request_json = CString::new(request.to_bridge_json()?).map_err(|error| {
594            FMError::InvalidArgument(format!("feedback request contains a NUL byte: {error}"))
595        })?;
596        let mut length = 0usize;
597        let mut error: *mut c_char = ptr::null_mut();
598        let ptr = unsafe {
599            ffi::fm_session_log_feedback_attachment_json(
600                self.ptr,
601                request_json.as_ptr(),
602                &mut length,
603                &mut error,
604            )
605        };
606        if ptr.is_null() && !error.is_null() {
607            return Err(crate::error::from_swift(
608                ffi::status::INVALID_ARGUMENT,
609                error,
610            ));
611        }
612        if ptr.is_null() || length == 0 {
613            return Ok(Vec::new());
614        }
615        let bytes = unsafe { std::slice::from_raw_parts(ptr.cast::<u8>(), length) }.to_vec();
616        unsafe { ffi::fm_bytes_free(ptr) };
617        Ok(bytes)
618    }
619}
620
621/// Builder for [`LanguageModelSession`].
622pub struct SessionBuilder<'a> {
623    model: Option<&'a ConfiguredSystemLanguageModel>,
624    instructions: Option<Instructions>,
625    transcript: Option<Transcript>,
626    tools: Vec<Tool>,
627}
628
629impl<'a> SessionBuilder<'a> {
630    const fn new() -> Self {
631        Self {
632            model: None,
633            instructions: None,
634            transcript: None,
635            tools: Vec::new(),
636        }
637    }
638
639    /// Use a configured system model.
640    #[must_use]
641    pub const fn model(mut self, model: &'a ConfiguredSystemLanguageModel) -> Self {
642        self.model = Some(model);
643        self
644    }
645
646    /// Set system instructions.
647    pub fn instructions<I>(mut self, instructions: I) -> Result<Self, FMError>
648    where
649        I: ToInstructions,
650    {
651        self.instructions = Some(instructions.to_instructions()?);
652        Ok(self)
653    }
654
655    /// Restore the session from a transcript.
656    #[must_use]
657    pub fn transcript(mut self, transcript: Transcript) -> Self {
658        self.transcript = Some(transcript);
659        self
660    }
661
662    /// Add one tool.
663    #[must_use]
664    pub fn tool(mut self, tool: Tool) -> Self {
665        self.tools.push(tool);
666        self
667    }
668
669    /// Add many tools.
670    #[must_use]
671    pub fn tools(mut self, tools: impl IntoIterator<Item = Tool>) -> Self {
672        self.tools.extend(tools);
673        self
674    }
675
676    /// Build the session.
677    ///
678    /// # Errors
679    ///
680    /// Returns an [`FMError`] if the configuration cannot be encoded for Swift.
681    pub fn build(self) -> Result<LanguageModelSession, FMError> {
682        if self.instructions.is_some() && self.transcript.is_some() {
683            return Err(FMError::InvalidArgument(
684                "session builder accepts either instructions or a transcript, not both".into(),
685            ));
686        }
687
688        let instructions_json = self
689            .instructions
690            .as_ref()
691            .map(Instructions::to_bridge_json)
692            .transpose()?;
693        let transcript_json = self
694            .transcript
695            .as_ref()
696            .map(Transcript::to_json_string)
697            .transpose()?;
698        let tool_registry = if self.tools.is_empty() {
699            None
700        } else {
701            Some(Arc::new(ToolRegistry::new(self.tools)))
702        };
703        let tools_json = tool_registry
704            .as_ref()
705            .map(|registry| registry.specs_json())
706            .transpose()?;
707
708        let instructions_c = instructions_json
709            .as_deref()
710            .map(CString::new)
711            .transpose()
712            .map_err(|error| {
713                FMError::InvalidArgument(format!("instructions JSON contains a NUL byte: {error}"))
714            })?;
715        let transcript_c = transcript_json
716            .as_deref()
717            .map(CString::new)
718            .transpose()
719            .map_err(|error| {
720                FMError::InvalidArgument(format!("transcript JSON contains a NUL byte: {error}"))
721            })?;
722        let tools_c = tools_json
723            .as_deref()
724            .map(CString::new)
725            .transpose()
726            .map_err(|error| {
727                FMError::InvalidArgument(format!("tool JSON contains a NUL byte: {error}"))
728            })?;
729
730        let tool_context = tool_registry.as_ref().map_or(ptr::null_mut(), |registry| {
731            Arc::as_ptr(registry).cast_mut().cast::<c_void>()
732        });
733        let mut error: *mut c_char = ptr::null_mut();
734        let ptr = unsafe {
735            ffi::fm_session_create_ex(
736                self.model.map_or(ptr::null_mut(), |model| model.ptr),
737                instructions_c
738                    .as_ref()
739                    .map_or(ptr::null(), |json| json.as_ptr()),
740                transcript_c
741                    .as_ref()
742                    .map_or(ptr::null(), |json| json.as_ptr()),
743                tools_c.as_ref().map_or(ptr::null(), |json| json.as_ptr()),
744                tool_context,
745                tool_registry
746                    .as_ref()
747                    .map(|_| tool_callback_trampoline as ffi::FmToolCallback),
748                &mut error,
749            )
750        };
751        if ptr.is_null() {
752            return Err(crate::error::from_swift(
753                ffi::status::MODEL_UNAVAILABLE,
754                error,
755            ));
756        }
757        Ok(LanguageModelSession {
758            ptr,
759            _tool_registry: tool_registry,
760        })
761    }
762}
763
764/// A detailed generation response.
765#[derive(Debug, Clone, PartialEq)]
766pub struct SessionResponse<T> {
767    pub content: T,
768    pub raw_content: GeneratedContent,
769    pub transcript: Transcript,
770}
771
772/// One structured-generation stream snapshot.
773#[derive(Debug, Clone, PartialEq, Eq)]
774pub struct StructuredStreamSnapshot {
775    pub content_json: String,
776    pub raw_content_json: String,
777    pub is_complete: bool,
778}
779
780/// One structured stream event.
781#[derive(Debug, Clone, PartialEq)]
782#[non_exhaustive]
783pub enum StructuredStreamEvent {
784    Snapshot(StructuredStreamSnapshot),
785    Done,
786    Error(FMError),
787}
788
789/// One feedback issue category.
790#[derive(Debug, Clone, Copy, PartialEq, Eq)]
791pub enum FeedbackIssueCategory {
792    Unhelpful,
793    TooVerbose,
794    DidNotFollowInstructions,
795    Incorrect,
796    StereotypeOrBias,
797    SuggestiveOrSexual,
798    VulgarOrOffensive,
799    TriggeredGuardrailUnexpectedly,
800}
801
802impl FeedbackIssueCategory {
803    const fn as_str(self) -> &'static str {
804        match self {
805            Self::Unhelpful => "unhelpful",
806            Self::TooVerbose => "too_verbose",
807            Self::DidNotFollowInstructions => "did_not_follow_instructions",
808            Self::Incorrect => "incorrect",
809            Self::StereotypeOrBias => "stereotype_or_bias",
810            Self::SuggestiveOrSexual => "suggestive_or_sexual",
811            Self::VulgarOrOffensive => "vulgar_or_offensive",
812            Self::TriggeredGuardrailUnexpectedly => "triggered_guardrail_unexpectedly",
813        }
814    }
815}
816
817/// One feedback issue.
818#[derive(Debug, Clone, PartialEq, Eq)]
819pub struct FeedbackIssue {
820    pub category: FeedbackIssueCategory,
821    pub explanation: Option<String>,
822}
823
824/// Feedback sentiment.
825#[derive(Debug, Clone, Copy, PartialEq, Eq)]
826pub enum FeedbackSentiment {
827    Positive,
828    Negative,
829    Neutral,
830}
831
832impl FeedbackSentiment {
833    const fn as_str(self) -> &'static str {
834        match self {
835            Self::Positive => "positive",
836            Self::Negative => "negative",
837            Self::Neutral => "neutral",
838        }
839    }
840}
841
842/// A full feedback attachment request.
843#[derive(Debug, Clone, PartialEq)]
844pub struct FeedbackAttachmentRequest {
845    pub sentiment: Option<FeedbackSentiment>,
846    pub issues: Vec<FeedbackIssue>,
847    pub desired_response_text: Option<String>,
848    pub desired_response_content: Option<GeneratedContent>,
849    pub desired_output: Option<crate::transcript::Entry>,
850}
851
852impl FeedbackAttachmentRequest {
853    /// Create an empty feedback request.
854    #[must_use]
855    pub const fn new() -> Self {
856        Self {
857            sentiment: None,
858            issues: Vec::new(),
859            desired_response_text: None,
860            desired_response_content: None,
861            desired_output: None,
862        }
863    }
864
865    fn to_bridge_json(&self) -> Result<String, FMError> {
866        let issues = self
867            .issues
868            .iter()
869            .map(|issue| {
870                json!({
871                    "category": issue.category.as_str(),
872                    "explanation": issue.explanation,
873                })
874            })
875            .collect::<Vec<_>>();
876        let desired_output_json = self
877            .desired_output
878            .as_ref()
879            .map(|entry| Transcript::from(vec![entry.clone()]).to_json_string())
880            .transpose()?;
881        let desired_response_content = self
882            .desired_response_content
883            .as_ref()
884            .map(GeneratedContent::to_bridge_value)
885            .transpose()?;
886        serde_json::to_string(&json!({
887            "sentiment": self.sentiment.map(FeedbackSentiment::as_str),
888            "issues": issues,
889            "desiredResponseText": self.desired_response_text,
890            "desiredResponseContent": desired_response_content,
891            "desiredOutputTranscriptJSON": desired_output_json,
892        }))
893        .map_err(|error| {
894            FMError::InvalidArgument(format!(
895                "feedback request is not JSON-serializable: {error}"
896            ))
897        })
898    }
899}
900
901#[derive(Debug, Deserialize)]
902struct BridgeTextResponse {
903    content: String,
904    #[serde(rename = "rawContent")]
905    raw_content: BridgeGeneratedContent,
906    #[serde(rename = "transcriptJSON")]
907    transcript_json: String,
908}
909
910#[derive(Debug, Deserialize)]
911struct BridgeStructuredResponse {
912    content: BridgeGeneratedContent,
913    #[serde(rename = "rawContent")]
914    raw_content: BridgeGeneratedContent,
915    #[serde(rename = "transcriptJSON")]
916    transcript_json: String,
917}
918
919#[derive(Debug, Deserialize)]
920struct BridgeStructuredSnapshot {
921    content: BridgeGeneratedContent,
922    #[serde(rename = "rawContent")]
923    raw_content: BridgeGeneratedContent,
924    #[serde(rename = "isComplete")]
925    is_complete: bool,
926}
927
928#[derive(Debug, Deserialize)]
929struct BridgeTextStreamSnapshot {
930    delta: String,
931}
932
933fn respond_request_json(
934    prompt: &Prompt,
935    options: GenerationOptions,
936    schema: Option<&GenerationSchema>,
937    include_schema_in_prompt: bool,
938) -> Result<CString, FMError> {
939    let sampling = match options.sampling() {
940        SamplingMode::Default => json!({ "mode": "default" }),
941        SamplingMode::Greedy => json!({ "mode": "greedy" }),
942        SamplingMode::TopK(k) => json!({
943            "mode": "top_k",
944            "topK": k,
945            "seed": options.sampling_seed(),
946        }),
947        SamplingMode::TopP(p) => json!({
948            "mode": "top_p",
949            "topP": p,
950            "seed": options.sampling_seed(),
951        }),
952    };
953    let include_schema_in_prompt = schema.map_or(include_schema_in_prompt, |schema| {
954        schema.effective_include_schema_in_prompt(include_schema_in_prompt)
955    });
956    let payload = serde_json::to_string(&json!({
957        "prompt": prompt.to_bridge_value(),
958        "options": {
959            "temperature": options.temperature(),
960            "maximumResponseTokens": options.maximum_response_tokens(),
961            "sampling": sampling,
962        },
963        "schemaJSON": schema.map(GenerationSchema::bridge_request_json),
964        "includeSchemaInPrompt": include_schema_in_prompt,
965    }))
966    .map_err(|error| {
967        FMError::InvalidArgument(format!("request is not JSON-serializable: {error}"))
968    })?;
969    CString::new(payload).map_err(|error| {
970        FMError::InvalidArgument(format!("request JSON contains a NUL byte: {error}"))
971    })
972}
973
974fn request_response(session: *mut c_void, payload: &CString) -> Result<String, FMError> {
975    let (tx, rx) = mpsc::channel();
976    let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
977    let context = Box::into_raw(tx_box).cast::<c_void>();
978    unsafe {
979        ffi::fm_session_respond_request_json(session, payload.as_ptr(), context, respond_trampoline)
980    };
981    rx.recv().map_err(|_| FMError::Unknown {
982        code: ffi::status::UNKNOWN,
983        message: "Swift bridge dropped the JSON response channel".into(),
984    })?
985}
986
987pub(crate) fn decode_bridge_text_response(
988    payload: &str,
989) -> Result<SessionResponse<String>, FMError> {
990    let response: BridgeTextResponse = serde_json::from_str(payload)
991        .map_err(|error| FMError::DecodingFailure(error.to_string()))?;
992    Ok(SessionResponse {
993        content: response.content,
994        raw_content: GeneratedContent::from_bridge_payload(response.raw_content, true)?,
995        transcript: Transcript::from_json_str(&response.transcript_json)?,
996    })
997}
998
999pub(crate) fn request_text_response_with<F>(invoke: F) -> Result<SessionResponse<String>, FMError>
1000where
1001    F: FnOnce(*mut c_void, ffi::FmRespondCallback),
1002{
1003    let (tx, rx) = mpsc::channel();
1004    let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
1005    let context = Box::into_raw(tx_box).cast::<c_void>();
1006    invoke(context, respond_trampoline);
1007    let payload = rx.recv().map_err(|_| FMError::Unknown {
1008        code: ffi::status::UNKNOWN,
1009        message: "Swift bridge dropped the JSON response channel".into(),
1010    })??;
1011    decode_bridge_text_response(&payload)
1012}
1013
1014pub(crate) fn run_text_stream_with<F, C>(invoke: F, on_chunk: C) -> Result<(), FMError>
1015where
1016    F: FnOnce(*mut c_void, ffi::FmStreamCallback),
1017    C: FnMut(StreamEvent<'_>) + Send + 'static,
1018{
1019    let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
1020    let state = Arc::new(StreamState {
1021        on_chunk: Mutex::new(Box::new(on_chunk)),
1022        done_tx: Mutex::new(Some(done_tx)),
1023    });
1024    let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
1025    invoke(context, json_text_stream_trampoline);
1026    done_rx.recv().map_err(|_| FMError::Unknown {
1027        code: ffi::status::UNKNOWN,
1028        message: "Swift bridge dropped the stream channel".into(),
1029    })?
1030}
1031
1032fn prompt_to_plain_text(prompt: &Prompt) -> Option<String> {
1033    let mut text = String::new();
1034    for segment in prompt.segments() {
1035        match segment {
1036            crate::prompt::Segment::Text(segment) => text.push_str(&segment.text),
1037            crate::prompt::Segment::Structure(_) => return None,
1038        }
1039    }
1040    Some(text)
1041}
1042
1043impl Default for LanguageModelSession {
1044    fn default() -> Self {
1045        Self::new()
1046    }
1047}
1048
1049impl Drop for LanguageModelSession {
1050    fn drop(&mut self) {
1051        if !self.ptr.is_null() {
1052            unsafe { ffi::fm_object_release(self.ptr) };
1053        }
1054    }
1055}
1056
1057impl core::fmt::Debug for LanguageModelSession {
1058    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1059        f.debug_struct("LanguageModelSession")
1060            .field("ptr", &self.ptr)
1061            .finish()
1062    }
1063}
1064
1065/// One event from a streaming generation.
1066#[derive(Debug)]
1067#[non_exhaustive]
1068pub enum StreamEvent<'a> {
1069    /// Incremental text delta. Concatenate these to reconstruct the full reply.
1070    Chunk(&'a str),
1071    /// Stream finished successfully.
1072    Done,
1073    /// Stream failed; the inner error describes why.
1074    Error(FMError),
1075}
1076
1077// ---------- internal callback plumbing ----------
1078
1079// SAFETY: `context` is a `Box<mpsc::Sender<...>>` raw pointer created by
1080// `request_response` / `request_text_response_with`. Swift calls this callback
1081// exactly once, so there is no double-free risk. `response` and `error` are
1082// C strings owned by the Swift bridge and only valid for this call.
1083unsafe extern "C" fn respond_trampoline(
1084    context: *mut c_void,
1085    response: *mut c_char,
1086    error: *mut c_char,
1087    status: i32,
1088) {
1089    let tx = Box::from_raw(context.cast::<mpsc::Sender<Result<String, FMError>>>());
1090    let result = if status == ffi::status::OK && !response.is_null() {
1091        let s = core::ffi::CStr::from_ptr(response)
1092            .to_string_lossy()
1093            .into_owned();
1094        ffi::fm_string_free(response);
1095        Ok(s)
1096    } else {
1097        Err(crate::error::from_swift(status, error))
1098    };
1099    let _ = tx.send(result);
1100}
1101
1102type StreamCallback = Box<dyn FnMut(StreamEvent<'_>) + Send>;
1103
1104struct StreamState {
1105    on_chunk: Mutex<StreamCallback>,
1106    done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
1107}
1108
1109// SAFETY: `context` is a `Arc<StreamState>` raw pointer passed via
1110// `Arc::into_raw`. We reconstruct it with `Arc::from_raw` on every call and
1111// immediately `mem::forget` a clone so the count stays ≥ 1 until the
1112// terminal call (done=true or error). `chunk` is a Swift-owned C string valid
1113// only for the duration of this call.
1114unsafe extern "C" fn json_text_stream_trampoline(
1115    context: *mut c_void,
1116    chunk: *mut c_char,
1117    done: bool,
1118    status: i32,
1119) {
1120    let state = Arc::from_raw(context.cast::<StreamState>());
1121    let state_for_swift = state.clone();
1122    core::mem::forget(state_for_swift);
1123
1124    let payload: Option<String> = if chunk.is_null() {
1125        None
1126    } else {
1127        let value = core::ffi::CStr::from_ptr(chunk)
1128            .to_string_lossy()
1129            .into_owned();
1130        ffi::fm_string_free(chunk);
1131        Some(value)
1132    };
1133
1134    if status != ffi::status::OK {
1135        let err = payload
1136            .map(|message| {
1137                crate::error::from_swift(
1138                    status,
1139                    ffi::fm_string_dup(
1140                        CString::new(message)
1141                            .expect("stream errors must not contain NUL bytes")
1142                            .as_ptr(),
1143                    ),
1144                )
1145            })
1146            .unwrap_or_else(|| crate::error::from_swift(status, ptr::null_mut()));
1147        {
1148            let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1149            // Catch panics so they don't unwind across the FFI boundary (UB).
1150            let _ = catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Error(err.clone()))));
1151        }
1152        if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1153            let _ = tx.send(Err(err));
1154        }
1155        drop(Arc::from_raw(Arc::as_ptr(&state)));
1156        drop(state);
1157        return;
1158    }
1159
1160    if let Some(payload) = payload {
1161        match serde_json::from_str::<BridgeTextStreamSnapshot>(&payload) {
1162            Ok(snapshot) if !snapshot.delta.is_empty() => {
1163                let chunk_panicked = {
1164                    let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1165                    // Catch panics so they don't unwind across the FFI boundary.
1166                    catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Chunk(&snapshot.delta))))
1167                        .is_err()
1168                };
1169                if chunk_panicked {
1170                    if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1171                        let _ = tx.send(Err(FMError::Unknown {
1172                            code: ffi::status::UNKNOWN,
1173                            message: "stream callback panicked".into(),
1174                        }));
1175                    }
1176                    drop(Arc::from_raw(Arc::as_ptr(&state)));
1177                    drop(state);
1178                    return;
1179                }
1180            }
1181            Ok(_) => {}
1182            Err(error) => {
1183                let err = FMError::DecodingFailure(error.to_string());
1184                {
1185                    let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1186                    let _ = catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Error(err.clone()))));
1187                }
1188                if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1189                    let _ = tx.send(Err(err));
1190                }
1191                drop(Arc::from_raw(Arc::as_ptr(&state)));
1192                drop(state);
1193                return;
1194            }
1195        }
1196    }
1197
1198    if done {
1199        {
1200            let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1201            let _ = catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Done)));
1202        }
1203        if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1204            let _ = tx.send(Ok(()));
1205        }
1206        drop(Arc::from_raw(Arc::as_ptr(&state)));
1207    }
1208    drop(state);
1209}
1210
1211type StructuredStreamCallback = Box<dyn FnMut(StructuredStreamEvent) + Send>;
1212
1213struct StructuredStreamState {
1214    on_event: Mutex<StructuredStreamCallback>,
1215    done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
1216}
1217
1218// SAFETY: Same invariants as `json_text_stream_trampoline` above, but for
1219// `StructuredStreamState`.
1220#[allow(clippy::too_many_lines)]
1221unsafe extern "C" fn structured_stream_trampoline(
1222    context: *mut c_void,
1223    chunk: *mut c_char,
1224    done: bool,
1225    status: i32,
1226) {
1227    let state = Arc::from_raw(context.cast::<StructuredStreamState>());
1228    let state_for_swift = state.clone();
1229    core::mem::forget(state_for_swift);
1230
1231    let payload: Option<String> = if chunk.is_null() {
1232        None
1233    } else {
1234        let value = core::ffi::CStr::from_ptr(chunk)
1235            .to_string_lossy()
1236            .into_owned();
1237        ffi::fm_string_free(chunk);
1238        Some(value)
1239    };
1240
1241    if status != ffi::status::OK {
1242        let err = payload
1243            .map(|message| {
1244                crate::error::from_swift(
1245                    status,
1246                    ffi::fm_string_dup(
1247                        CString::new(message)
1248                            .expect("stream errors must not contain NUL bytes")
1249                            .as_ptr(),
1250                    ),
1251                )
1252            })
1253            .unwrap_or_else(|| crate::error::from_swift(status, ptr::null_mut()));
1254        {
1255            let mut cb = state
1256                .on_event
1257                .lock()
1258                .expect("structured callback mutex poisoned");
1259            // Catch panics so they don't unwind across the FFI boundary (UB).
1260            let _ = catch_unwind(AssertUnwindSafe(|| {
1261                cb(StructuredStreamEvent::Error(err.clone()));
1262            }));
1263        }
1264        if let Some(tx) = state
1265            .done_tx
1266            .lock()
1267            .expect("structured done_tx mutex poisoned")
1268            .take()
1269        {
1270            let _ = tx.send(Err(err));
1271        }
1272        drop(Arc::from_raw(Arc::as_ptr(&state)));
1273        drop(state);
1274        return;
1275    }
1276
1277    if let Some(payload) = payload {
1278        let snapshot: BridgeStructuredSnapshot = match serde_json::from_str(&payload) {
1279            Ok(snapshot) => snapshot,
1280            Err(error) => {
1281                let err = FMError::DecodingFailure(error.to_string());
1282                {
1283                    let mut cb = state
1284                        .on_event
1285                        .lock()
1286                        .expect("structured callback mutex poisoned");
1287                    let _ = catch_unwind(AssertUnwindSafe(|| {
1288                        cb(StructuredStreamEvent::Error(err.clone()));
1289                    }));
1290                }
1291                if let Some(tx) = state
1292                    .done_tx
1293                    .lock()
1294                    .expect("structured done_tx mutex poisoned")
1295                    .take()
1296                {
1297                    let _ = tx.send(Err(err));
1298                }
1299                drop(Arc::from_raw(Arc::as_ptr(&state)));
1300                drop(state);
1301                return;
1302            }
1303        };
1304        let snapshot_event = StructuredStreamEvent::Snapshot(StructuredStreamSnapshot {
1305            content_json: snapshot.content.json,
1306            raw_content_json: snapshot.raw_content.json,
1307            is_complete: snapshot.is_complete,
1308        });
1309        let snapshot_panicked = {
1310            let mut cb = state
1311                .on_event
1312                .lock()
1313                .expect("structured callback mutex poisoned");
1314            // Catch panics so they don't unwind across the FFI boundary.
1315            catch_unwind(AssertUnwindSafe(|| cb(snapshot_event))).is_err()
1316        };
1317        if snapshot_panicked {
1318            if let Some(tx) = state
1319                .done_tx
1320                .lock()
1321                .expect("structured done_tx mutex poisoned")
1322                .take()
1323            {
1324                let _ = tx.send(Err(FMError::Unknown {
1325                    code: ffi::status::UNKNOWN,
1326                    message: "stream callback panicked".into(),
1327                }));
1328            }
1329            drop(Arc::from_raw(Arc::as_ptr(&state)));
1330            drop(state);
1331            return;
1332        }
1333    }
1334
1335    if done {
1336        {
1337            let mut cb = state
1338                .on_event
1339                .lock()
1340                .expect("structured callback mutex poisoned");
1341            let _ = catch_unwind(AssertUnwindSafe(|| cb(StructuredStreamEvent::Done)));
1342        }
1343        if let Some(tx) = state
1344            .done_tx
1345            .lock()
1346            .expect("structured done_tx mutex poisoned")
1347            .take()
1348        {
1349            let _ = tx.send(Ok(()));
1350        }
1351        drop(Arc::from_raw(Arc::as_ptr(&state)));
1352    }
1353    drop(state);
1354}