Skip to main content

foundation_models/session/
mod.rs

1//! [`LanguageModelSession`] — a stateful conversation with the on-device model.
2
3use core::ffi::{c_char, c_void};
4use core::ptr;
5use std::ffi::CString;
6use std::sync::mpsc;
7use std::sync::{Arc, Mutex};
8
9use serde::Deserialize;
10use serde_json::json;
11
12use crate::content::GeneratedContent;
13use crate::error::FMError;
14use crate::ffi;
15use crate::generation::{GenerationOptions, SamplingMode};
16use crate::model::ConfiguredSystemLanguageModel;
17use crate::prompt::{Instructions, Prompt, ToInstructions, ToPrompt};
18use crate::schema::GenerationSchema;
19use crate::tool::{tool_callback_trampoline, Tool, ToolRegistry};
20use crate::transcript::Transcript;
21
22/// A stateful conversation with the on-device language model.
23///
24/// Sessions retain their conversation history; subsequent calls to
25/// [`respond`](Self::respond) build on the previous turns.
26///
27/// # Examples
28///
29/// ```rust,no_run
30/// use foundation_models::LanguageModelSession;
31///
32/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
33/// let session = LanguageModelSession::new();
34/// let answer = session.respond("Name three Norse gods.")?;
35/// println!("{answer}");
36/// # Ok(())
37/// # }
38/// ```
39pub struct LanguageModelSession {
40    ptr: *mut c_void,
41    _tool_registry: Option<Arc<ToolRegistry>>,
42}
43
44// SAFETY: The underlying Swift LanguageModelSession is reference-counted via
45// Unmanaged.passRetained on the Swift side; sending the opaque pointer between
46// threads is safe as long as we don't dereference it from Rust (we never do —
47// it only travels through extern "C" calls that internally hop to the
48// Swift concurrency executor).
49unsafe impl Send for LanguageModelSession {}
50unsafe impl Sync for LanguageModelSession {}
51
52impl LanguageModelSession {
53    /// Create a session with the model's default behaviour.
54    ///
55    /// # Panics
56    ///
57    /// Panics if `FoundationModels` is not available on this OS. Check
58    /// [`crate::SystemLanguageModel::is_available`] first if you need to
59    /// handle that gracefully.
60    #[must_use]
61    pub fn new() -> Self {
62        Self::try_new(None).expect("FoundationModels is not available on this OS")
63    }
64
65    /// Create a session with custom system instructions ("system prompt").
66    ///
67    /// # Panics
68    ///
69    /// Panics if `FoundationModels` is not available, or if `instructions`
70    /// contains an interior NUL byte.
71    #[must_use]
72    pub fn with_instructions(instructions: &str) -> Self {
73        Self::try_new(Some(instructions)).expect("FoundationModels is not available on this OS")
74    }
75
76    /// Fallible constructor. Returns `None` when `FoundationModels` is not
77    /// available (OS too old, model not enabled, etc.) or when `instructions`
78    /// contains an interior NUL byte.
79    #[must_use]
80    pub fn try_new(instructions: Option<&str>) -> Option<Self> {
81        let cstring = match instructions {
82            Some(s) => Some(CString::new(s).ok()?),
83            None => None,
84        };
85        let ptr =
86            unsafe { ffi::fm_session_create(cstring.as_ref().map_or(ptr::null(), |s| s.as_ptr())) };
87        if ptr.is_null() {
88            return None;
89        }
90        Some(Self {
91            ptr,
92            _tool_registry: None,
93        })
94    }
95
96    /// Send a prompt and block until the full response is available.
97    ///
98    /// # Errors
99    ///
100    /// Returns an [`FMError`] if the model rejects the prompt, the context
101    /// window is exceeded, the session is cancelled, or the prompt contains
102    /// an interior NUL byte.
103    pub fn respond(&self, prompt: &str) -> Result<String, FMError> {
104        self.respond_with(prompt, GenerationOptions::new())
105    }
106
107    /// Pre-warm the model. Apple loads the weights + initialises the
108    /// inference engine so the next `respond` call is faster. Returns
109    /// immediately; the warm-up runs in the background.
110    pub fn prewarm(&self) {
111        unsafe { ffi::fm_session_prewarm(self.ptr) };
112    }
113
114    /// True if this session is currently producing a response (i.e. an
115    /// earlier `respond` / `stream` is still in flight on Apple's queue).
116    #[must_use]
117    pub fn is_responding(&self) -> bool {
118        unsafe { ffi::fm_session_is_responding(self.ptr) }
119    }
120
121    /// Return a best-effort JSON serialisation of the session's
122    /// `Transcript` — the full history of user prompts and model
123    /// responses. Useful for persisting a chat session across
124    /// process boundaries.
125    #[must_use]
126    pub fn transcript_json(&self) -> String {
127        let p = unsafe { ffi::fm_session_transcript_json(self.ptr) };
128        if p.is_null() {
129            return String::from("{}");
130        }
131        let s = unsafe { core::ffi::CStr::from_ptr(p) }
132            .to_string_lossy()
133            .into_owned();
134        unsafe { ffi::fm_string_free(p) };
135        s
136    }
137
138    /// Log feedback on the most recent response for diagnostic /
139    /// fine-tuning purposes. `sentiment`:
140    /// `1` positive, `0` neutral, `-1` negative.
141    pub fn log_feedback(&self, sentiment: i32, description: Option<&str>) {
142        let cstr = description.and_then(|s| CString::new(s).ok());
143        let p = cstr.as_ref().map_or(core::ptr::null(), |c| c.as_ptr());
144        unsafe { ffi::fm_session_log_feedback(self.ptr, sentiment, p) };
145    }
146
147    /// Prompt-engineered JSON-shape response.
148    ///
149    /// Wraps the prompt with a "respond with valid JSON matching this schema"
150    /// instruction and parses the response. The schema is a
151    /// `serde_json::Value`-style JSON string (passed as text).
152    ///
153    /// Useful for getting structured data out of the model without the
154    /// full Generable macro machinery. The model still returns plain
155    /// text — the caller must parse with `serde_json` / `serde` after.
156    ///
157    /// # Errors
158    ///
159    /// See [`respond`](Self::respond).
160    pub fn respond_with_json_schema(
161        &self,
162        prompt: &str,
163        schema_description: &str,
164    ) -> Result<String, FMError> {
165        let wrapped = format!(
166            "{prompt}\n\n\
167             IMPORTANT: respond with VALID JSON ONLY (no prose, no markdown \
168             fences) that matches this schema:\n\n{schema_description}\n\n\
169             Your entire response must be parseable by JSON.parse()."
170        );
171        self.respond(&wrapped)
172    }
173
174    /// Like [`respond`](Self::respond), but with explicit generation options.
175    ///
176    /// # Errors
177    ///
178    /// See [`respond`](Self::respond).
179    pub fn respond_with(
180        &self,
181        prompt: &str,
182        options: GenerationOptions,
183    ) -> Result<String, FMError> {
184        self.respond_prompt_with(prompt, options)
185    }
186
187    /// Schema-driven structured response.
188    ///
189    /// Builds a `DynamicGenerationSchema` from the provided JSON
190    /// schema, runs `LanguageModelSession.respond(schema:prompt:)`,
191    /// and returns the model's `GeneratedContent.jsonString` — a
192    /// well-formed JSON string matching the requested shape.
193    ///
194    /// Supported `schema` shape (strict subset of JSON Schema):
195    ///
196    /// ```json
197    /// {
198    ///   "type": "object",
199    ///   "name": "Movie",
200    ///   "properties": {
201    ///     "title":  { "type": "string", "description": "Movie title" },
202    ///     "year":   { "type": "integer" },
203    ///     "rating": { "type": "number", "optional": true },
204    ///     "tags":   { "type": "array", "items": { "type": "string" }, "min": 1, "max": 5 }
205    ///   }
206    /// }
207    /// ```
208    ///
209    /// Primitive types: `"string"`, `"integer"`, `"number"`,
210    /// `"boolean"`, `"array"`, `"object"`. Each property may set
211    /// `"description"` and `"optional"`. Array schemas accept
212    /// `"items"` plus optional `"min"` / `"max"` element counts.
213    ///
214    /// # Errors
215    ///
216    /// See [`respond`](Self::respond) for general errors, plus a
217    /// "schema build failed" / "schema JSON is not valid" error
218    /// returned as [`FMError::Unknown`] if the schema is malformed.
219    pub fn respond_with_schema(
220        &self,
221        prompt: &str,
222        schema: &str,
223        include_schema_in_prompt: bool,
224    ) -> Result<String, FMError> {
225        self.respond_with_schema_options(
226            prompt,
227            schema,
228            include_schema_in_prompt,
229            GenerationOptions::new(),
230        )
231    }
232
233    /// [`respond_with_schema`](Self::respond_with_schema) with
234    /// explicit generation options.
235    ///
236    /// # Errors
237    ///
238    /// See [`respond_with_schema`](Self::respond_with_schema).
239    pub fn respond_with_schema_options(
240        &self,
241        prompt: &str,
242        schema: &str,
243        include_schema_in_prompt: bool,
244        options: GenerationOptions,
245    ) -> Result<String, FMError> {
246        let prompt_c = CString::new(prompt)
247            .map_err(|e| FMError::InvalidArgument(format!("prompt NUL byte: {e}")))?;
248        let schema_c = CString::new(schema)
249            .map_err(|e| FMError::InvalidArgument(format!("schema NUL byte: {e}")))?;
250        let opts = options.to_ffi();
251        let (tx, rx) = mpsc::channel();
252        let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
253        let context = Box::into_raw(tx_box).cast::<c_void>();
254
255        unsafe {
256            ffi::fm_session_respond_with_schema(
257                self.ptr,
258                prompt_c.as_ptr(),
259                schema_c.as_ptr(),
260                include_schema_in_prompt,
261                opts.temperature,
262                opts.maximum_response_tokens,
263                opts.sampling_mode,
264                opts.top_k,
265                opts.top_p,
266                context,
267                respond_trampoline,
268            );
269        }
270
271        rx.recv().map_err(|_| FMError::Unknown {
272            code: ffi::status::UNKNOWN,
273            message: "Swift bridge dropped the callback channel".into(),
274        })?
275    }
276
277    /// Stream the response as the model generates it. The callback is invoked
278    /// with each delta and a final invocation with `done == true`.
279    ///
280    /// # Errors
281    ///
282    /// Returns an [`FMError`] mirroring [`respond`](Self::respond). The
283    /// callback may also receive a chunk *and* an error if the stream fails
284    /// midway.
285    pub fn stream<F>(&self, prompt: &str, mut on_chunk: F) -> Result<(), FMError>
286    where
287        F: FnMut(StreamEvent<'_>) + Send + 'static,
288    {
289        self.stream_with(prompt, GenerationOptions::new(), move |event| {
290            on_chunk(event);
291        })
292    }
293
294    /// Like [`stream`](Self::stream), but with explicit generation options.
295    ///
296    /// # Errors
297    ///
298    /// See [`stream`](Self::stream).
299    pub fn stream_with<F>(
300        &self,
301        prompt: &str,
302        options: GenerationOptions,
303        on_chunk: F,
304    ) -> Result<(), FMError>
305    where
306        F: FnMut(StreamEvent<'_>) + Send + 'static,
307    {
308        let payload = respond_request_json(&Prompt::from(prompt), options, None, true)?;
309
310        let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
311        let state = Arc::new(StreamState {
312            on_chunk: Mutex::new(Box::new(on_chunk)),
313            done_tx: Mutex::new(Some(done_tx)),
314        });
315        let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
316
317        unsafe {
318            ffi::fm_session_stream_request_json(
319                self.ptr,
320                payload.as_ptr(),
321                context,
322                json_text_stream_trampoline,
323            )
324        };
325
326        done_rx.recv().map_err(|_| FMError::Unknown {
327            code: ffi::status::UNKNOWN,
328            message: "Swift bridge dropped the stream channel".into(),
329        })?
330    }
331}
332
333impl LanguageModelSession {
334    /// Create a configurable session builder.
335    #[must_use]
336    pub fn builder<'a>() -> SessionBuilder<'a> {
337        SessionBuilder::new()
338    }
339
340    /// Restore a session from a transcript.
341    ///
342    /// # Errors
343    ///
344    /// Returns an [`FMError`] if the transcript cannot be encoded for Swift.
345    pub fn from_transcript(transcript: Transcript) -> Result<Self, FMError> {
346        Self::builder().transcript(transcript).build()
347    }
348
349    /// Return the typed transcript for this session.
350    ///
351    /// # Errors
352    ///
353    /// Returns an [`FMError`] if the transcript JSON returned by Swift could not
354    /// be decoded.
355    pub fn transcript(&self) -> Result<Transcript, FMError> {
356        Transcript::from_json_str(&self.transcript_json())
357    }
358
359    /// Pre-warm the model using a prompt prefix.
360    ///
361    /// # Errors
362    ///
363    /// Returns an [`FMError`] if the prompt cannot be encoded for Swift.
364    pub fn prewarm_with_prompt<P>(&self, prompt: P) -> Result<(), FMError>
365    where
366        P: ToPrompt,
367    {
368        let prompt = prompt.to_prompt()?;
369        let prompt_json = CString::new(prompt.to_bridge_json()?).map_err(|error| {
370            FMError::InvalidArgument(format!("prompt JSON contains a NUL byte: {error}"))
371        })?;
372        let mut error: *mut c_char = ptr::null_mut();
373        let status = unsafe {
374            ffi::fm_session_prewarm_prompt_json(self.ptr, prompt_json.as_ptr(), &mut error)
375        };
376        if status != ffi::status::OK {
377            return Err(crate::error::from_swift(status, error));
378        }
379        Ok(())
380    }
381
382    /// Respond to a structured prompt and return only the generated text.
383    ///
384    /// # Errors
385    ///
386    /// Returns an [`FMError`] if generation fails.
387    pub fn respond_prompt<P>(&self, prompt: P) -> Result<String, FMError>
388    where
389        P: ToPrompt,
390    {
391        self.respond_prompt_with(prompt, GenerationOptions::new())
392    }
393
394    /// Like [`respond_prompt`](Self::respond_prompt), but with explicit options.
395    ///
396    /// # Errors
397    ///
398    /// Returns an [`FMError`] if generation fails.
399    pub fn respond_prompt_with<P>(
400        &self,
401        prompt: P,
402        options: GenerationOptions,
403    ) -> Result<String, FMError>
404    where
405        P: ToPrompt,
406    {
407        self.respond_prompt_detailed(prompt, options)
408            .map(|response| response.content)
409    }
410
411    /// Respond to a structured prompt and keep the full response metadata.
412    ///
413    /// # Errors
414    ///
415    /// Returns an [`FMError`] if generation fails.
416    pub fn respond_prompt_detailed<P>(
417        &self,
418        prompt: P,
419        options: GenerationOptions,
420    ) -> Result<SessionResponse<String>, FMError>
421    where
422        P: ToPrompt,
423    {
424        let prompt = prompt.to_prompt()?;
425        let payload = respond_request_json(&prompt, options, None, true)?;
426        let payload = request_response(self.ptr, &payload)?;
427        let response: BridgeTextResponse = serde_json::from_str(&payload)
428            .map_err(|error| FMError::DecodingFailure(error.to_string()))?;
429        Ok(SessionResponse {
430            content: response.content,
431            raw_content: GeneratedContent::from_bridge_json(
432                &response.raw_content_json,
433                true,
434                None,
435            )?,
436            transcript: Transcript::from_json_str(&response.transcript_json)?,
437        })
438    }
439
440    /// Generate structured content using an explicit schema.
441    ///
442    /// # Errors
443    ///
444    /// Returns an [`FMError`] if generation fails or the schema is invalid.
445    pub fn respond_generated<P>(
446        &self,
447        prompt: P,
448        schema: &GenerationSchema,
449        include_schema_in_prompt: bool,
450    ) -> Result<GeneratedContent, FMError>
451    where
452        P: ToPrompt,
453    {
454        self.respond_generated_with(
455            prompt,
456            schema,
457            include_schema_in_prompt,
458            GenerationOptions::new(),
459        )
460        .map(|response| response.content)
461    }
462
463    /// Like [`respond_generated`](Self::respond_generated), but with explicit options.
464    ///
465    /// # Errors
466    ///
467    /// Returns an [`FMError`] if generation fails or the schema is invalid.
468    pub fn respond_generated_with<P>(
469        &self,
470        prompt: P,
471        schema: &GenerationSchema,
472        include_schema_in_prompt: bool,
473        options: GenerationOptions,
474    ) -> Result<SessionResponse<GeneratedContent>, FMError>
475    where
476        P: ToPrompt,
477    {
478        let prompt = prompt.to_prompt()?;
479        let payload =
480            respond_request_json(&prompt, options, Some(schema), include_schema_in_prompt)?;
481        let payload = request_response(self.ptr, &payload)?;
482        let response: BridgeStructuredResponse = serde_json::from_str(&payload)
483            .map_err(|error| FMError::DecodingFailure(error.to_string()))?;
484        Ok(SessionResponse {
485            content: GeneratedContent::from_bridge_json(&response.content_json, true, None)?,
486            raw_content: GeneratedContent::from_bridge_json(
487                &response.raw_content_json,
488                true,
489                None,
490            )?,
491            transcript: Transcript::from_json_str(&response.transcript_json)?,
492        })
493    }
494
495    /// Generate a typed Rust value using a [`crate::schema::Generable`] implementation.
496    ///
497    /// # Errors
498    ///
499    /// Returns an [`FMError`] if generation fails or the generated JSON cannot
500    /// be decoded as `T`.
501    pub fn respond_generating<P, T>(
502        &self,
503        prompt: P,
504        include_schema_in_prompt: bool,
505        options: GenerationOptions,
506    ) -> Result<SessionResponse<T>, FMError>
507    where
508        P: ToPrompt,
509        T: crate::schema::Generable,
510    {
511        let response = self.respond_generated_with(
512            prompt,
513            &T::generation_schema()?,
514            include_schema_in_prompt,
515            options,
516        )?;
517        Ok(SessionResponse {
518            content: T::from_generated_content(&response.content)?,
519            raw_content: response.raw_content,
520            transcript: response.transcript,
521        })
522    }
523
524    /// Stream a structured prompt token-by-token.
525    ///
526    /// # Errors
527    ///
528    /// Returns an [`FMError`] if the prompt cannot be encoded or generation fails.
529    pub fn stream_prompt<P, F>(&self, prompt: P, on_chunk: F) -> Result<(), FMError>
530    where
531        P: ToPrompt,
532        F: FnMut(StreamEvent<'_>) + Send + 'static,
533    {
534        let prompt = prompt.to_prompt()?;
535        let prompt_text = prompt_to_plain_text(&prompt).ok_or_else(|| {
536            FMError::InvalidArgument(
537                "text streaming only supports prompts composed of text segments".into(),
538            )
539        })?;
540        self.stream_with(&prompt_text, GenerationOptions::new(), on_chunk)
541    }
542
543    /// Stream structured generation snapshots.
544    ///
545    /// # Errors
546    ///
547    /// Returns an [`FMError`] if the prompt cannot be encoded or generation fails.
548    pub fn stream_generated<P, F>(
549        &self,
550        prompt: P,
551        schema: &GenerationSchema,
552        include_schema_in_prompt: bool,
553        options: GenerationOptions,
554        on_event: F,
555    ) -> Result<(), FMError>
556    where
557        P: ToPrompt,
558        F: FnMut(StructuredStreamEvent) + Send + 'static,
559    {
560        let prompt = prompt.to_prompt()?;
561        let payload =
562            respond_request_json(&prompt, options, Some(schema), include_schema_in_prompt)?;
563        let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
564        let state = Arc::new(StructuredStreamState {
565            on_event: Mutex::new(Box::new(on_event)),
566            done_tx: Mutex::new(Some(done_tx)),
567        });
568        let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
569        unsafe {
570            ffi::fm_session_stream_request_json(
571                self.ptr,
572                payload.as_ptr(),
573                context,
574                structured_stream_trampoline,
575            )
576        };
577        done_rx.recv().map_err(|_| FMError::Unknown {
578            code: ffi::status::UNKNOWN,
579            message: "Swift bridge dropped the structured stream channel".into(),
580        })?
581    }
582
583    /// Log a feedback attachment and return the raw bytes Apple produced.
584    ///
585    /// # Errors
586    ///
587    /// Returns an [`FMError`] if the attachment request is invalid.
588    pub fn log_feedback_attachment(
589        &self,
590        request: FeedbackAttachmentRequest,
591    ) -> Result<Vec<u8>, FMError> {
592        let request_json = CString::new(request.to_bridge_json()?).map_err(|error| {
593            FMError::InvalidArgument(format!("feedback request contains a NUL byte: {error}"))
594        })?;
595        let mut length = 0usize;
596        let mut error: *mut c_char = ptr::null_mut();
597        let ptr = unsafe {
598            ffi::fm_session_log_feedback_attachment_json(
599                self.ptr,
600                request_json.as_ptr(),
601                &mut length,
602                &mut error,
603            )
604        };
605        if ptr.is_null() && !error.is_null() {
606            return Err(crate::error::from_swift(
607                ffi::status::INVALID_ARGUMENT,
608                error,
609            ));
610        }
611        if ptr.is_null() || length == 0 {
612            return Ok(Vec::new());
613        }
614        let bytes = unsafe { std::slice::from_raw_parts(ptr.cast::<u8>(), length) }.to_vec();
615        unsafe { ffi::fm_bytes_free(ptr) };
616        Ok(bytes)
617    }
618}
619
620/// Builder for [`LanguageModelSession`].
621pub struct SessionBuilder<'a> {
622    model: Option<&'a ConfiguredSystemLanguageModel>,
623    instructions: Option<Instructions>,
624    transcript: Option<Transcript>,
625    tools: Vec<Tool>,
626}
627
628impl<'a> SessionBuilder<'a> {
629    const fn new() -> Self {
630        Self {
631            model: None,
632            instructions: None,
633            transcript: None,
634            tools: Vec::new(),
635        }
636    }
637
638    /// Use a configured system model.
639    #[must_use]
640    pub const fn model(mut self, model: &'a ConfiguredSystemLanguageModel) -> Self {
641        self.model = Some(model);
642        self
643    }
644
645    /// Set system instructions.
646    pub fn instructions<I>(mut self, instructions: I) -> Result<Self, FMError>
647    where
648        I: ToInstructions,
649    {
650        self.instructions = Some(instructions.to_instructions()?);
651        Ok(self)
652    }
653
654    /// Restore the session from a transcript.
655    #[must_use]
656    pub fn transcript(mut self, transcript: Transcript) -> Self {
657        self.transcript = Some(transcript);
658        self
659    }
660
661    /// Add one tool.
662    #[must_use]
663    pub fn tool(mut self, tool: Tool) -> Self {
664        self.tools.push(tool);
665        self
666    }
667
668    /// Add many tools.
669    #[must_use]
670    pub fn tools(mut self, tools: impl IntoIterator<Item = Tool>) -> Self {
671        self.tools.extend(tools);
672        self
673    }
674
675    /// Build the session.
676    ///
677    /// # Errors
678    ///
679    /// Returns an [`FMError`] if the configuration cannot be encoded for Swift.
680    pub fn build(self) -> Result<LanguageModelSession, FMError> {
681        if self.instructions.is_some() && self.transcript.is_some() {
682            return Err(FMError::InvalidArgument(
683                "session builder accepts either instructions or a transcript, not both".into(),
684            ));
685        }
686
687        let instructions_json = self
688            .instructions
689            .as_ref()
690            .map(Instructions::to_bridge_json)
691            .transpose()?;
692        let transcript_json = self
693            .transcript
694            .as_ref()
695            .map(Transcript::to_json_string)
696            .transpose()?;
697        let tool_registry = if self.tools.is_empty() {
698            None
699        } else {
700            Some(Arc::new(ToolRegistry::new(self.tools)))
701        };
702        let tools_json = tool_registry
703            .as_ref()
704            .map(|registry| registry.specs_json())
705            .transpose()?;
706
707        let instructions_c = instructions_json
708            .as_deref()
709            .map(CString::new)
710            .transpose()
711            .map_err(|error| {
712                FMError::InvalidArgument(format!("instructions JSON contains a NUL byte: {error}"))
713            })?;
714        let transcript_c = transcript_json
715            .as_deref()
716            .map(CString::new)
717            .transpose()
718            .map_err(|error| {
719                FMError::InvalidArgument(format!("transcript JSON contains a NUL byte: {error}"))
720            })?;
721        let tools_c = tools_json
722            .as_deref()
723            .map(CString::new)
724            .transpose()
725            .map_err(|error| {
726                FMError::InvalidArgument(format!("tool JSON contains a NUL byte: {error}"))
727            })?;
728
729        let tool_context = tool_registry.as_ref().map_or(ptr::null_mut(), |registry| {
730            Arc::as_ptr(registry).cast_mut().cast::<c_void>()
731        });
732        let mut error: *mut c_char = ptr::null_mut();
733        let ptr = unsafe {
734            ffi::fm_session_create_ex(
735                self.model.map_or(ptr::null_mut(), |model| model.ptr),
736                instructions_c
737                    .as_ref()
738                    .map_or(ptr::null(), |json| json.as_ptr()),
739                transcript_c
740                    .as_ref()
741                    .map_or(ptr::null(), |json| json.as_ptr()),
742                tools_c.as_ref().map_or(ptr::null(), |json| json.as_ptr()),
743                tool_context,
744                tool_registry
745                    .as_ref()
746                    .map(|_| tool_callback_trampoline as ffi::FmToolCallback),
747                &mut error,
748            )
749        };
750        if ptr.is_null() {
751            return Err(crate::error::from_swift(
752                ffi::status::MODEL_UNAVAILABLE,
753                error,
754            ));
755        }
756        Ok(LanguageModelSession {
757            ptr,
758            _tool_registry: tool_registry,
759        })
760    }
761}
762
763/// A detailed generation response.
764#[derive(Debug, Clone, PartialEq)]
765pub struct SessionResponse<T> {
766    pub content: T,
767    pub raw_content: GeneratedContent,
768    pub transcript: Transcript,
769}
770
771/// One structured-generation stream snapshot.
772#[derive(Debug, Clone, PartialEq, Eq)]
773pub struct StructuredStreamSnapshot {
774    pub content_json: String,
775    pub raw_content_json: String,
776    pub is_complete: bool,
777}
778
779/// One structured stream event.
780#[derive(Debug, Clone, PartialEq)]
781#[non_exhaustive]
782pub enum StructuredStreamEvent {
783    Snapshot(StructuredStreamSnapshot),
784    Done,
785    Error(FMError),
786}
787
788/// One feedback issue category.
789#[derive(Debug, Clone, Copy, PartialEq, Eq)]
790pub enum FeedbackIssueCategory {
791    Unhelpful,
792    TooVerbose,
793    DidNotFollowInstructions,
794    Incorrect,
795    StereotypeOrBias,
796    SuggestiveOrSexual,
797    VulgarOrOffensive,
798    TriggeredGuardrailUnexpectedly,
799}
800
801impl FeedbackIssueCategory {
802    const fn as_str(self) -> &'static str {
803        match self {
804            Self::Unhelpful => "unhelpful",
805            Self::TooVerbose => "too_verbose",
806            Self::DidNotFollowInstructions => "did_not_follow_instructions",
807            Self::Incorrect => "incorrect",
808            Self::StereotypeOrBias => "stereotype_or_bias",
809            Self::SuggestiveOrSexual => "suggestive_or_sexual",
810            Self::VulgarOrOffensive => "vulgar_or_offensive",
811            Self::TriggeredGuardrailUnexpectedly => "triggered_guardrail_unexpectedly",
812        }
813    }
814}
815
816/// One feedback issue.
817#[derive(Debug, Clone, PartialEq, Eq)]
818pub struct FeedbackIssue {
819    pub category: FeedbackIssueCategory,
820    pub explanation: Option<String>,
821}
822
823/// Feedback sentiment.
824#[derive(Debug, Clone, Copy, PartialEq, Eq)]
825pub enum FeedbackSentiment {
826    Positive,
827    Negative,
828    Neutral,
829}
830
831impl FeedbackSentiment {
832    const fn as_str(self) -> &'static str {
833        match self {
834            Self::Positive => "positive",
835            Self::Negative => "negative",
836            Self::Neutral => "neutral",
837        }
838    }
839}
840
841/// A full feedback attachment request.
842#[derive(Debug, Clone, PartialEq)]
843pub struct FeedbackAttachmentRequest {
844    pub sentiment: Option<FeedbackSentiment>,
845    pub issues: Vec<FeedbackIssue>,
846    pub desired_response_text: Option<String>,
847    pub desired_response_content: Option<GeneratedContent>,
848    pub desired_output: Option<crate::transcript::Entry>,
849}
850
851impl FeedbackAttachmentRequest {
852    /// Create an empty feedback request.
853    #[must_use]
854    pub const fn new() -> Self {
855        Self {
856            sentiment: None,
857            issues: Vec::new(),
858            desired_response_text: None,
859            desired_response_content: None,
860            desired_output: None,
861        }
862    }
863
864    fn to_bridge_json(&self) -> Result<String, FMError> {
865        let issues = self
866            .issues
867            .iter()
868            .map(|issue| {
869                json!({
870                    "category": issue.category.as_str(),
871                    "explanation": issue.explanation,
872                })
873            })
874            .collect::<Vec<_>>();
875        let desired_output_json = self
876            .desired_output
877            .as_ref()
878            .map(|entry| Transcript::from(vec![entry.clone()]).to_json_string())
879            .transpose()?;
880        let desired_response_content_json = self
881            .desired_response_content
882            .as_ref()
883            .map(GeneratedContent::json_string)
884            .transpose()?;
885        serde_json::to_string(&json!({
886            "sentiment": self.sentiment.map(FeedbackSentiment::as_str),
887            "issues": issues,
888            "desiredResponseText": self.desired_response_text,
889            "desiredResponseContentJSON": desired_response_content_json,
890            "desiredOutputTranscriptJSON": desired_output_json,
891        }))
892        .map_err(|error| {
893            FMError::InvalidArgument(format!(
894                "feedback request is not JSON-serializable: {error}"
895            ))
896        })
897    }
898}
899
900#[derive(Debug, Deserialize)]
901struct BridgeTextResponse {
902    content: String,
903    #[serde(rename = "rawContentJSON")]
904    raw_content_json: String,
905    #[serde(rename = "transcriptJSON")]
906    transcript_json: String,
907}
908
909#[derive(Debug, Deserialize)]
910struct BridgeStructuredResponse {
911    #[serde(rename = "contentJSON")]
912    content_json: String,
913    #[serde(rename = "rawContentJSON")]
914    raw_content_json: String,
915    #[serde(rename = "transcriptJSON")]
916    transcript_json: String,
917}
918
919#[derive(Debug, Deserialize)]
920struct BridgeStructuredSnapshot {
921    #[serde(rename = "contentJSON")]
922    content_json: String,
923    #[serde(rename = "rawContentJSON")]
924    raw_content_json: String,
925    #[serde(rename = "isComplete")]
926    is_complete: bool,
927}
928
929#[derive(Debug, Deserialize)]
930struct BridgeTextStreamSnapshot {
931    delta: String,
932}
933
934fn respond_request_json(
935    prompt: &Prompt,
936    options: GenerationOptions,
937    schema: Option<&GenerationSchema>,
938    include_schema_in_prompt: bool,
939) -> Result<CString, FMError> {
940    let sampling = match options.sampling() {
941        SamplingMode::Default => json!({ "mode": "default" }),
942        SamplingMode::Greedy => json!({ "mode": "greedy" }),
943        SamplingMode::TopK(k) => json!({
944            "mode": "top_k",
945            "topK": k,
946            "seed": options.sampling_seed(),
947        }),
948        SamplingMode::TopP(p) => json!({
949            "mode": "top_p",
950            "topP": p,
951            "seed": options.sampling_seed(),
952        }),
953    };
954    let payload = serde_json::to_string(&json!({
955        "prompt": prompt.to_bridge_value(),
956        "options": {
957            "temperature": options.temperature(),
958            "maximumResponseTokens": options.maximum_response_tokens(),
959            "sampling": sampling,
960        },
961        "schemaJSON": schema.map(GenerationSchema::json_schema),
962        "includeSchemaInPrompt": include_schema_in_prompt,
963    }))
964    .map_err(|error| {
965        FMError::InvalidArgument(format!("request is not JSON-serializable: {error}"))
966    })?;
967    CString::new(payload).map_err(|error| {
968        FMError::InvalidArgument(format!("request JSON contains a NUL byte: {error}"))
969    })
970}
971
972fn request_response(session: *mut c_void, payload: &CString) -> Result<String, FMError> {
973    let (tx, rx) = mpsc::channel();
974    let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
975    let context = Box::into_raw(tx_box).cast::<c_void>();
976    unsafe {
977        ffi::fm_session_respond_request_json(session, payload.as_ptr(), context, respond_trampoline)
978    };
979    rx.recv().map_err(|_| FMError::Unknown {
980        code: ffi::status::UNKNOWN,
981        message: "Swift bridge dropped the JSON response channel".into(),
982    })?
983}
984
985fn prompt_to_plain_text(prompt: &Prompt) -> Option<String> {
986    let mut text = String::new();
987    for segment in prompt.segments() {
988        match segment {
989            crate::prompt::Segment::Text(segment) => text.push_str(&segment.text),
990            crate::prompt::Segment::Structure(_) => return None,
991        }
992    }
993    Some(text)
994}
995
996impl Default for LanguageModelSession {
997    fn default() -> Self {
998        Self::new()
999    }
1000}
1001
1002impl Drop for LanguageModelSession {
1003    fn drop(&mut self) {
1004        if !self.ptr.is_null() {
1005            unsafe { ffi::fm_object_release(self.ptr) };
1006        }
1007    }
1008}
1009
1010impl core::fmt::Debug for LanguageModelSession {
1011    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1012        f.debug_struct("LanguageModelSession")
1013            .field("ptr", &self.ptr)
1014            .finish()
1015    }
1016}
1017
1018/// One event from a streaming generation.
1019#[derive(Debug)]
1020#[non_exhaustive]
1021pub enum StreamEvent<'a> {
1022    /// Incremental text delta. Concatenate these to reconstruct the full reply.
1023    Chunk(&'a str),
1024    /// Stream finished successfully.
1025    Done,
1026    /// Stream failed; the inner error describes why.
1027    Error(FMError),
1028}
1029
1030// ---------- internal callback plumbing ----------
1031
1032unsafe extern "C" fn respond_trampoline(
1033    context: *mut c_void,
1034    response: *mut c_char,
1035    error: *mut c_char,
1036    status: i32,
1037) {
1038    let tx = Box::from_raw(context.cast::<mpsc::Sender<Result<String, FMError>>>());
1039    let result = if status == ffi::status::OK && !response.is_null() {
1040        let s = core::ffi::CStr::from_ptr(response)
1041            .to_string_lossy()
1042            .into_owned();
1043        ffi::fm_string_free(response);
1044        Ok(s)
1045    } else {
1046        Err(crate::error::from_swift(status, error))
1047    };
1048    let _ = tx.send(result);
1049}
1050
1051type StreamCallback = Box<dyn FnMut(StreamEvent<'_>) + Send>;
1052
1053struct StreamState {
1054    on_chunk: Mutex<StreamCallback>,
1055    done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
1056}
1057
1058unsafe extern "C" fn json_text_stream_trampoline(
1059    context: *mut c_void,
1060    chunk: *mut c_char,
1061    done: bool,
1062    status: i32,
1063) {
1064    let state = Arc::from_raw(context.cast::<StreamState>());
1065    let state_for_swift = state.clone();
1066    core::mem::forget(state_for_swift);
1067
1068    let payload: Option<String> = if chunk.is_null() {
1069        None
1070    } else {
1071        let value = core::ffi::CStr::from_ptr(chunk)
1072            .to_string_lossy()
1073            .into_owned();
1074        ffi::fm_string_free(chunk);
1075        Some(value)
1076    };
1077
1078    if status != ffi::status::OK {
1079        let err = payload
1080            .map(|message| {
1081                crate::error::from_swift(
1082                    status,
1083                    ffi::fm_string_dup(
1084                        CString::new(message)
1085                            .expect("stream errors must not contain NUL bytes")
1086                            .as_ptr(),
1087                    ),
1088                )
1089            })
1090            .unwrap_or_else(|| crate::error::from_swift(status, ptr::null_mut()));
1091        let mut callback = state.on_chunk.lock().expect("user callback mutex poisoned");
1092        callback(StreamEvent::Error(err.clone()));
1093        drop(callback);
1094        if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1095            let _ = tx.send(Err(err));
1096        }
1097        drop(Arc::from_raw(Arc::as_ptr(&state)));
1098        drop(state);
1099        return;
1100    }
1101
1102    if let Some(payload) = payload {
1103        match serde_json::from_str::<BridgeTextStreamSnapshot>(&payload) {
1104            Ok(snapshot) if !snapshot.delta.is_empty() => {
1105                let mut callback = state.on_chunk.lock().expect("user callback mutex poisoned");
1106                callback(StreamEvent::Chunk(&snapshot.delta));
1107            }
1108            Ok(_) => {}
1109            Err(error) => {
1110                let err = FMError::DecodingFailure(error.to_string());
1111                let mut callback = state.on_chunk.lock().expect("user callback mutex poisoned");
1112                callback(StreamEvent::Error(err.clone()));
1113                drop(callback);
1114                if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1115                    let _ = tx.send(Err(err));
1116                }
1117                drop(Arc::from_raw(Arc::as_ptr(&state)));
1118                drop(state);
1119                return;
1120            }
1121        }
1122    }
1123
1124    if done {
1125        let mut callback = state.on_chunk.lock().expect("user callback mutex poisoned");
1126        callback(StreamEvent::Done);
1127        drop(callback);
1128        if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1129            let _ = tx.send(Ok(()));
1130        }
1131        drop(Arc::from_raw(Arc::as_ptr(&state)));
1132    }
1133    drop(state);
1134}
1135
1136type StructuredStreamCallback = Box<dyn FnMut(StructuredStreamEvent) + Send>;
1137
1138struct StructuredStreamState {
1139    on_event: Mutex<StructuredStreamCallback>,
1140    done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
1141}
1142
1143unsafe extern "C" fn structured_stream_trampoline(
1144    context: *mut c_void,
1145    chunk: *mut c_char,
1146    done: bool,
1147    status: i32,
1148) {
1149    let state = Arc::from_raw(context.cast::<StructuredStreamState>());
1150    let state_for_swift = state.clone();
1151    core::mem::forget(state_for_swift);
1152
1153    let payload: Option<String> = if chunk.is_null() {
1154        None
1155    } else {
1156        let value = core::ffi::CStr::from_ptr(chunk)
1157            .to_string_lossy()
1158            .into_owned();
1159        ffi::fm_string_free(chunk);
1160        Some(value)
1161    };
1162
1163    if status != ffi::status::OK {
1164        let err = payload
1165            .map(|message| {
1166                crate::error::from_swift(
1167                    status,
1168                    ffi::fm_string_dup(
1169                        CString::new(message)
1170                            .expect("stream errors must not contain NUL bytes")
1171                            .as_ptr(),
1172                    ),
1173                )
1174            })
1175            .unwrap_or_else(|| crate::error::from_swift(status, ptr::null_mut()));
1176        let mut callback = state
1177            .on_event
1178            .lock()
1179            .expect("structured callback mutex poisoned");
1180        callback(StructuredStreamEvent::Error(err.clone()));
1181        drop(callback);
1182        if let Some(tx) = state
1183            .done_tx
1184            .lock()
1185            .expect("structured done_tx mutex poisoned")
1186            .take()
1187        {
1188            let _ = tx.send(Err(err));
1189        }
1190        drop(Arc::from_raw(Arc::as_ptr(&state)));
1191        drop(state);
1192        return;
1193    }
1194
1195    if let Some(payload) = payload {
1196        let snapshot: BridgeStructuredSnapshot = match serde_json::from_str(&payload) {
1197            Ok(snapshot) => snapshot,
1198            Err(error) => {
1199                let err = FMError::DecodingFailure(error.to_string());
1200                let mut callback = state
1201                    .on_event
1202                    .lock()
1203                    .expect("structured callback mutex poisoned");
1204                callback(StructuredStreamEvent::Error(err.clone()));
1205                drop(callback);
1206                if let Some(tx) = state
1207                    .done_tx
1208                    .lock()
1209                    .expect("structured done_tx mutex poisoned")
1210                    .take()
1211                {
1212                    let _ = tx.send(Err(err));
1213                }
1214                drop(Arc::from_raw(Arc::as_ptr(&state)));
1215                drop(state);
1216                return;
1217            }
1218        };
1219        let mut callback = state
1220            .on_event
1221            .lock()
1222            .expect("structured callback mutex poisoned");
1223        callback(StructuredStreamEvent::Snapshot(StructuredStreamSnapshot {
1224            content_json: snapshot.content_json,
1225            raw_content_json: snapshot.raw_content_json,
1226            is_complete: snapshot.is_complete,
1227        }));
1228    }
1229
1230    if done {
1231        let mut callback = state
1232            .on_event
1233            .lock()
1234            .expect("structured callback mutex poisoned");
1235        callback(StructuredStreamEvent::Done);
1236        drop(callback);
1237        if let Some(tx) = state
1238            .done_tx
1239            .lock()
1240            .expect("structured done_tx mutex poisoned")
1241            .take()
1242        {
1243            let _ = tx.send(Ok(()));
1244        }
1245        drop(Arc::from_raw(Arc::as_ptr(&state)));
1246    }
1247    drop(state);
1248}