1use core::ffi::{c_char, c_void};
4use core::ptr;
5use std::ffi::CString;
6use std::panic::{catch_unwind, AssertUnwindSafe};
7use std::sync::mpsc;
8use std::sync::{Arc, Mutex};
9
10use serde::Deserialize;
11use serde_json::json;
12
13use crate::content::{BridgeGeneratedContent, GeneratedContent};
14use crate::error::FMError;
15use crate::ffi;
16use crate::generation::{GenerationOptions, SamplingMode};
17use crate::model::ConfiguredSystemLanguageModel;
18use crate::prompt::{Instructions, Prompt, ToInstructions, ToPrompt};
19use crate::schema::GenerationSchema;
20use crate::tool::{tool_callback_trampoline, Tool, ToolRegistry};
21use crate::transcript::Transcript;
22
23pub struct LanguageModelSession {
41 ptr: *mut c_void,
42 _tool_registry: Option<Arc<ToolRegistry>>,
43}
44
45unsafe impl Send for LanguageModelSession {}
51unsafe impl Sync for LanguageModelSession {}
52
53impl LanguageModelSession {
54 #[cfg(feature = "async")]
59 pub(crate) fn as_ptr(&self) -> *mut c_void {
60 self.ptr
61 }
62
63 #[must_use]
71 pub fn new() -> Self {
72 Self::try_new(None).expect("FoundationModels is not available on this OS")
73 }
74
75 #[must_use]
82 pub fn with_instructions(instructions: &str) -> Self {
83 Self::try_new(Some(instructions)).expect("FoundationModels is not available on this OS")
84 }
85
86 #[must_use]
90 pub fn try_new(instructions: Option<&str>) -> Option<Self> {
91 let cstring = match instructions {
92 Some(s) => Some(CString::new(s).ok()?),
93 None => None,
94 };
95 let ptr =
96 unsafe { ffi::fm_session_create(cstring.as_ref().map_or(ptr::null(), |s| s.as_ptr())) };
97 if ptr.is_null() {
98 return None;
99 }
100 Some(Self {
101 ptr,
102 _tool_registry: None,
103 })
104 }
105
106 pub fn respond(&self, prompt: &str) -> Result<String, FMError> {
114 self.respond_with(prompt, GenerationOptions::new())
115 }
116
117 pub fn prewarm(&self) {
121 unsafe { ffi::fm_session_prewarm(self.ptr) };
122 }
123
124 #[must_use]
127 pub fn is_responding(&self) -> bool {
128 unsafe { ffi::fm_session_is_responding(self.ptr) }
129 }
130
131 #[must_use]
136 pub fn transcript_json(&self) -> String {
137 let p = unsafe { ffi::fm_session_transcript_json(self.ptr) };
138 if p.is_null() {
139 return String::from("{}");
140 }
141 let s = unsafe { core::ffi::CStr::from_ptr(p) }
142 .to_string_lossy()
143 .into_owned();
144 unsafe { ffi::fm_string_free(p) };
145 s
146 }
147
148 pub fn log_feedback(&self, sentiment: i32, description: Option<&str>) {
152 let cstr = description.and_then(|s| CString::new(s).ok());
153 let p = cstr.as_ref().map_or(core::ptr::null(), |c| c.as_ptr());
154 unsafe { ffi::fm_session_log_feedback(self.ptr, sentiment, p) };
155 }
156
157 pub fn respond_with_json_schema(
171 &self,
172 prompt: &str,
173 schema_description: &str,
174 ) -> Result<String, FMError> {
175 let wrapped = format!(
176 "{prompt}\n\n\
177 IMPORTANT: respond with VALID JSON ONLY (no prose, no markdown \
178 fences) that matches this schema:\n\n{schema_description}\n\n\
179 Your entire response must be parseable by JSON.parse()."
180 );
181 self.respond(&wrapped)
182 }
183
184 pub fn respond_with(
190 &self,
191 prompt: &str,
192 options: GenerationOptions,
193 ) -> Result<String, FMError> {
194 self.respond_prompt_with(prompt, options)
195 }
196
197 pub fn respond_with_schema(
230 &self,
231 prompt: &str,
232 schema: &str,
233 include_schema_in_prompt: bool,
234 ) -> Result<String, FMError> {
235 self.respond_with_schema_options(
236 prompt,
237 schema,
238 include_schema_in_prompt,
239 GenerationOptions::new(),
240 )
241 }
242
243 pub fn respond_with_schema_options(
250 &self,
251 prompt: &str,
252 schema: &str,
253 include_schema_in_prompt: bool,
254 options: GenerationOptions,
255 ) -> Result<String, FMError> {
256 let prompt_c = CString::new(prompt)
257 .map_err(|e| FMError::InvalidArgument(format!("prompt NUL byte: {e}")))?;
258 let schema_c = CString::new(schema)
259 .map_err(|e| FMError::InvalidArgument(format!("schema NUL byte: {e}")))?;
260 let opts = options.to_ffi();
261 let (tx, rx) = mpsc::channel();
262 let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
263 let context = Box::into_raw(tx_box).cast::<c_void>();
264
265 unsafe {
266 ffi::fm_session_respond_with_schema(
267 self.ptr,
268 prompt_c.as_ptr(),
269 schema_c.as_ptr(),
270 include_schema_in_prompt,
271 opts.temperature,
272 opts.maximum_response_tokens,
273 opts.sampling_mode,
274 opts.top_k,
275 opts.top_p,
276 context,
277 respond_trampoline,
278 );
279 }
280
281 rx.recv().map_err(|_| FMError::Unknown {
282 code: ffi::status::UNKNOWN,
283 message: "Swift bridge dropped the callback channel".into(),
284 })?
285 }
286
287 pub fn stream<F>(&self, prompt: &str, mut on_chunk: F) -> Result<(), FMError>
296 where
297 F: FnMut(StreamEvent<'_>) + Send + 'static,
298 {
299 self.stream_with(prompt, GenerationOptions::new(), move |event| {
300 on_chunk(event);
301 })
302 }
303
304 pub fn stream_with<F>(
310 &self,
311 prompt: &str,
312 options: GenerationOptions,
313 on_chunk: F,
314 ) -> Result<(), FMError>
315 where
316 F: FnMut(StreamEvent<'_>) + Send + 'static,
317 {
318 let payload = respond_request_json(&Prompt::from(prompt), options, None, true)?;
319
320 let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
321 let state = Arc::new(StreamState {
322 on_chunk: Mutex::new(Box::new(on_chunk)),
323 done_tx: Mutex::new(Some(done_tx)),
324 });
325 let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
326
327 unsafe {
328 ffi::fm_session_stream_request_json(
329 self.ptr,
330 payload.as_ptr(),
331 context,
332 json_text_stream_trampoline,
333 )
334 };
335
336 done_rx.recv().map_err(|_| FMError::Unknown {
337 code: ffi::status::UNKNOWN,
338 message: "Swift bridge dropped the stream channel".into(),
339 })?
340 }
341}
342
343impl LanguageModelSession {
344 #[must_use]
346 pub fn builder<'a>() -> SessionBuilder<'a> {
347 SessionBuilder::new()
348 }
349
350 pub fn from_transcript(transcript: Transcript) -> Result<Self, FMError> {
356 Self::builder().transcript(transcript).build()
357 }
358
359 pub fn transcript(&self) -> Result<Transcript, FMError> {
366 Transcript::from_json_str(&self.transcript_json())
367 }
368
369 pub fn prewarm_with_prompt<P>(&self, prompt: P) -> Result<(), FMError>
375 where
376 P: ToPrompt,
377 {
378 let prompt = prompt.to_prompt()?;
379 let prompt_json = CString::new(prompt.to_bridge_json()?).map_err(|error| {
380 FMError::InvalidArgument(format!("prompt JSON contains a NUL byte: {error}"))
381 })?;
382 let mut error: *mut c_char = ptr::null_mut();
383 let status = unsafe {
384 ffi::fm_session_prewarm_prompt_json(self.ptr, prompt_json.as_ptr(), &mut error)
385 };
386 if status != ffi::status::OK {
387 return Err(crate::error::from_swift(status, error));
388 }
389 Ok(())
390 }
391
392 pub fn respond_prompt<P>(&self, prompt: P) -> Result<String, FMError>
398 where
399 P: ToPrompt,
400 {
401 self.respond_prompt_with(prompt, GenerationOptions::new())
402 }
403
404 pub fn respond_prompt_with<P>(
410 &self,
411 prompt: P,
412 options: GenerationOptions,
413 ) -> Result<String, FMError>
414 where
415 P: ToPrompt,
416 {
417 self.respond_prompt_detailed(prompt, options)
418 .map(|response| response.content)
419 }
420
421 pub fn respond_prompt_detailed<P>(
427 &self,
428 prompt: P,
429 options: GenerationOptions,
430 ) -> Result<SessionResponse<String>, FMError>
431 where
432 P: ToPrompt,
433 {
434 let prompt = prompt.to_prompt()?;
435 let payload = respond_request_json(&prompt, options, None, true)?;
436 let payload = request_response(self.ptr, &payload)?;
437 let response: BridgeTextResponse = serde_json::from_str(&payload)
438 .map_err(|error| FMError::DecodingFailure(error.to_string()))?;
439 Ok(SessionResponse {
440 content: response.content,
441 raw_content: GeneratedContent::from_bridge_payload(response.raw_content, true)?,
442 transcript: Transcript::from_json_str(&response.transcript_json)?,
443 })
444 }
445
446 pub fn respond_generated<P>(
452 &self,
453 prompt: P,
454 schema: &GenerationSchema,
455 include_schema_in_prompt: bool,
456 ) -> Result<GeneratedContent, FMError>
457 where
458 P: ToPrompt,
459 {
460 self.respond_generated_with(
461 prompt,
462 schema,
463 include_schema_in_prompt,
464 GenerationOptions::new(),
465 )
466 .map(|response| response.content)
467 }
468
469 pub fn respond_generated_with<P>(
475 &self,
476 prompt: P,
477 schema: &GenerationSchema,
478 include_schema_in_prompt: bool,
479 options: GenerationOptions,
480 ) -> Result<SessionResponse<GeneratedContent>, FMError>
481 where
482 P: ToPrompt,
483 {
484 let prompt = prompt.to_prompt()?;
485 let payload =
486 respond_request_json(&prompt, options, Some(schema), include_schema_in_prompt)?;
487 let payload = request_response(self.ptr, &payload)?;
488 let response: BridgeStructuredResponse = serde_json::from_str(&payload)
489 .map_err(|error| FMError::DecodingFailure(error.to_string()))?;
490 Ok(SessionResponse {
491 content: GeneratedContent::from_bridge_payload(response.content, true)?,
492 raw_content: GeneratedContent::from_bridge_payload(response.raw_content, true)?,
493 transcript: Transcript::from_json_str(&response.transcript_json)?,
494 })
495 }
496
497 pub fn respond_generating<P, T>(
504 &self,
505 prompt: P,
506 include_schema_in_prompt: bool,
507 options: GenerationOptions,
508 ) -> Result<SessionResponse<T>, FMError>
509 where
510 P: ToPrompt,
511 T: crate::schema::Generable,
512 {
513 let response = self.respond_generated_with(
514 prompt,
515 &T::generation_schema()?,
516 include_schema_in_prompt,
517 options,
518 )?;
519 Ok(SessionResponse {
520 content: T::from_generated_content(&response.content)?,
521 raw_content: response.raw_content,
522 transcript: response.transcript,
523 })
524 }
525
526 pub fn stream_prompt<P, F>(&self, prompt: P, on_chunk: F) -> Result<(), FMError>
532 where
533 P: ToPrompt,
534 F: FnMut(StreamEvent<'_>) + Send + 'static,
535 {
536 let prompt = prompt.to_prompt()?;
537 let prompt_text = prompt_to_plain_text(&prompt).ok_or_else(|| {
538 FMError::InvalidArgument(
539 "text streaming only supports prompts composed of text segments".into(),
540 )
541 })?;
542 self.stream_with(&prompt_text, GenerationOptions::new(), on_chunk)
543 }
544
545 pub fn stream_generated<P, F>(
551 &self,
552 prompt: P,
553 schema: &GenerationSchema,
554 include_schema_in_prompt: bool,
555 options: GenerationOptions,
556 on_event: F,
557 ) -> Result<(), FMError>
558 where
559 P: ToPrompt,
560 F: FnMut(StructuredStreamEvent) + Send + 'static,
561 {
562 let prompt = prompt.to_prompt()?;
563 let payload =
564 respond_request_json(&prompt, options, Some(schema), include_schema_in_prompt)?;
565 let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
566 let state = Arc::new(StructuredStreamState {
567 on_event: Mutex::new(Box::new(on_event)),
568 done_tx: Mutex::new(Some(done_tx)),
569 });
570 let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
571 unsafe {
572 ffi::fm_session_stream_request_json(
573 self.ptr,
574 payload.as_ptr(),
575 context,
576 structured_stream_trampoline,
577 )
578 };
579 done_rx.recv().map_err(|_| FMError::Unknown {
580 code: ffi::status::UNKNOWN,
581 message: "Swift bridge dropped the structured stream channel".into(),
582 })?
583 }
584
585 pub fn log_feedback_attachment(
591 &self,
592 request: FeedbackAttachmentRequest,
593 ) -> Result<Vec<u8>, FMError> {
594 let request_json = CString::new(request.to_bridge_json()?).map_err(|error| {
595 FMError::InvalidArgument(format!("feedback request contains a NUL byte: {error}"))
596 })?;
597 let mut length = 0usize;
598 let mut error: *mut c_char = ptr::null_mut();
599 let ptr = unsafe {
600 ffi::fm_session_log_feedback_attachment_json(
601 self.ptr,
602 request_json.as_ptr(),
603 &mut length,
604 &mut error,
605 )
606 };
607 if ptr.is_null() && !error.is_null() {
608 return Err(crate::error::from_swift(
609 ffi::status::INVALID_ARGUMENT,
610 error,
611 ));
612 }
613 if ptr.is_null() || length == 0 {
614 return Ok(Vec::new());
615 }
616 let bytes = unsafe { std::slice::from_raw_parts(ptr.cast::<u8>(), length) }.to_vec();
617 unsafe { ffi::fm_bytes_free(ptr) };
618 Ok(bytes)
619 }
620}
621
622pub struct SessionBuilder<'a> {
624 model: Option<&'a ConfiguredSystemLanguageModel>,
625 instructions: Option<Instructions>,
626 transcript: Option<Transcript>,
627 tools: Vec<Tool>,
628}
629
630impl<'a> SessionBuilder<'a> {
631 const fn new() -> Self {
632 Self {
633 model: None,
634 instructions: None,
635 transcript: None,
636 tools: Vec::new(),
637 }
638 }
639
640 #[must_use]
642 pub const fn model(mut self, model: &'a ConfiguredSystemLanguageModel) -> Self {
643 self.model = Some(model);
644 self
645 }
646
647 pub fn instructions<I>(mut self, instructions: I) -> Result<Self, FMError>
649 where
650 I: ToInstructions,
651 {
652 self.instructions = Some(instructions.to_instructions()?);
653 Ok(self)
654 }
655
656 #[must_use]
658 pub fn transcript(mut self, transcript: Transcript) -> Self {
659 self.transcript = Some(transcript);
660 self
661 }
662
663 #[must_use]
665 pub fn tool(mut self, tool: Tool) -> Self {
666 self.tools.push(tool);
667 self
668 }
669
670 #[must_use]
672 pub fn tools(mut self, tools: impl IntoIterator<Item = Tool>) -> Self {
673 self.tools.extend(tools);
674 self
675 }
676
677 pub fn build(self) -> Result<LanguageModelSession, FMError> {
683 if self.instructions.is_some() && self.transcript.is_some() {
684 return Err(FMError::InvalidArgument(
685 "session builder accepts either instructions or a transcript, not both".into(),
686 ));
687 }
688
689 let instructions_json = self
690 .instructions
691 .as_ref()
692 .map(Instructions::to_bridge_json)
693 .transpose()?;
694 let transcript_json = self
695 .transcript
696 .as_ref()
697 .map(Transcript::to_json_string)
698 .transpose()?;
699 let tool_registry = if self.tools.is_empty() {
700 None
701 } else {
702 Some(Arc::new(ToolRegistry::new(self.tools)))
703 };
704 let tools_json = tool_registry
705 .as_ref()
706 .map(|registry| registry.specs_json())
707 .transpose()?;
708
709 let instructions_c = instructions_json
710 .as_deref()
711 .map(CString::new)
712 .transpose()
713 .map_err(|error| {
714 FMError::InvalidArgument(format!("instructions JSON contains a NUL byte: {error}"))
715 })?;
716 let transcript_c = transcript_json
717 .as_deref()
718 .map(CString::new)
719 .transpose()
720 .map_err(|error| {
721 FMError::InvalidArgument(format!("transcript JSON contains a NUL byte: {error}"))
722 })?;
723 let tools_c = tools_json
724 .as_deref()
725 .map(CString::new)
726 .transpose()
727 .map_err(|error| {
728 FMError::InvalidArgument(format!("tool JSON contains a NUL byte: {error}"))
729 })?;
730
731 let tool_context = tool_registry.as_ref().map_or(ptr::null_mut(), |registry| {
732 Arc::as_ptr(registry).cast_mut().cast::<c_void>()
733 });
734 let mut error: *mut c_char = ptr::null_mut();
735 let ptr = unsafe {
736 ffi::fm_session_create_ex(
737 self.model.map_or(ptr::null_mut(), |model| model.ptr),
738 instructions_c
739 .as_ref()
740 .map_or(ptr::null(), |json| json.as_ptr()),
741 transcript_c
742 .as_ref()
743 .map_or(ptr::null(), |json| json.as_ptr()),
744 tools_c.as_ref().map_or(ptr::null(), |json| json.as_ptr()),
745 tool_context,
746 tool_registry
747 .as_ref()
748 .map(|_| tool_callback_trampoline as ffi::FmToolCallback),
749 &mut error,
750 )
751 };
752 if ptr.is_null() {
753 return Err(crate::error::from_swift(
754 ffi::status::MODEL_UNAVAILABLE,
755 error,
756 ));
757 }
758 Ok(LanguageModelSession {
759 ptr,
760 _tool_registry: tool_registry,
761 })
762 }
763}
764
765#[derive(Debug, Clone, PartialEq)]
767pub struct SessionResponse<T> {
768 pub content: T,
769 pub raw_content: GeneratedContent,
770 pub transcript: Transcript,
771}
772
773#[derive(Debug, Clone, PartialEq, Eq)]
775pub struct StructuredStreamSnapshot {
776 pub content_json: String,
777 pub raw_content_json: String,
778 pub is_complete: bool,
779}
780
781#[derive(Debug, Clone, PartialEq)]
783#[non_exhaustive]
784pub enum StructuredStreamEvent {
785 Snapshot(StructuredStreamSnapshot),
786 Done,
787 Error(FMError),
788}
789
790#[derive(Debug, Clone, Copy, PartialEq, Eq)]
792pub enum FeedbackIssueCategory {
793 Unhelpful,
794 TooVerbose,
795 DidNotFollowInstructions,
796 Incorrect,
797 StereotypeOrBias,
798 SuggestiveOrSexual,
799 VulgarOrOffensive,
800 TriggeredGuardrailUnexpectedly,
801}
802
803impl FeedbackIssueCategory {
804 const fn as_str(self) -> &'static str {
805 match self {
806 Self::Unhelpful => "unhelpful",
807 Self::TooVerbose => "too_verbose",
808 Self::DidNotFollowInstructions => "did_not_follow_instructions",
809 Self::Incorrect => "incorrect",
810 Self::StereotypeOrBias => "stereotype_or_bias",
811 Self::SuggestiveOrSexual => "suggestive_or_sexual",
812 Self::VulgarOrOffensive => "vulgar_or_offensive",
813 Self::TriggeredGuardrailUnexpectedly => "triggered_guardrail_unexpectedly",
814 }
815 }
816}
817
818#[derive(Debug, Clone, PartialEq, Eq)]
820pub struct FeedbackIssue {
821 pub category: FeedbackIssueCategory,
822 pub explanation: Option<String>,
823}
824
825#[derive(Debug, Clone, Copy, PartialEq, Eq)]
827pub enum FeedbackSentiment {
828 Positive,
829 Negative,
830 Neutral,
831}
832
833impl FeedbackSentiment {
834 const fn as_str(self) -> &'static str {
835 match self {
836 Self::Positive => "positive",
837 Self::Negative => "negative",
838 Self::Neutral => "neutral",
839 }
840 }
841}
842
843#[derive(Debug, Clone, PartialEq)]
845pub struct FeedbackAttachmentRequest {
846 pub sentiment: Option<FeedbackSentiment>,
847 pub issues: Vec<FeedbackIssue>,
848 pub desired_response_text: Option<String>,
849 pub desired_response_content: Option<GeneratedContent>,
850 pub desired_output: Option<crate::transcript::Entry>,
851}
852
853impl FeedbackAttachmentRequest {
854 #[must_use]
856 pub const fn new() -> Self {
857 Self {
858 sentiment: None,
859 issues: Vec::new(),
860 desired_response_text: None,
861 desired_response_content: None,
862 desired_output: None,
863 }
864 }
865
866 fn to_bridge_json(&self) -> Result<String, FMError> {
867 let issues = self
868 .issues
869 .iter()
870 .map(|issue| {
871 json!({
872 "category": issue.category.as_str(),
873 "explanation": issue.explanation,
874 })
875 })
876 .collect::<Vec<_>>();
877 let desired_output_json = self
878 .desired_output
879 .as_ref()
880 .map(|entry| Transcript::from(vec![entry.clone()]).to_json_string())
881 .transpose()?;
882 let desired_response_content = self
883 .desired_response_content
884 .as_ref()
885 .map(GeneratedContent::to_bridge_value)
886 .transpose()?;
887 serde_json::to_string(&json!({
888 "sentiment": self.sentiment.map(FeedbackSentiment::as_str),
889 "issues": issues,
890 "desiredResponseText": self.desired_response_text,
891 "desiredResponseContent": desired_response_content,
892 "desiredOutputTranscriptJSON": desired_output_json,
893 }))
894 .map_err(|error| {
895 FMError::InvalidArgument(format!(
896 "feedback request is not JSON-serializable: {error}"
897 ))
898 })
899 }
900}
901
902#[derive(Debug, Deserialize)]
903struct BridgeTextResponse {
904 content: String,
905 #[serde(rename = "rawContent")]
906 raw_content: BridgeGeneratedContent,
907 #[serde(rename = "transcriptJSON")]
908 transcript_json: String,
909}
910
911#[derive(Debug, Deserialize)]
912struct BridgeStructuredResponse {
913 content: BridgeGeneratedContent,
914 #[serde(rename = "rawContent")]
915 raw_content: BridgeGeneratedContent,
916 #[serde(rename = "transcriptJSON")]
917 transcript_json: String,
918}
919
920#[derive(Debug, Deserialize)]
921struct BridgeStructuredSnapshot {
922 content: BridgeGeneratedContent,
923 #[serde(rename = "rawContent")]
924 raw_content: BridgeGeneratedContent,
925 #[serde(rename = "isComplete")]
926 is_complete: bool,
927}
928
929#[derive(Debug, Deserialize)]
930struct BridgeTextStreamSnapshot {
931 delta: String,
932}
933
934fn respond_request_json(
935 prompt: &Prompt,
936 options: GenerationOptions,
937 schema: Option<&GenerationSchema>,
938 include_schema_in_prompt: bool,
939) -> Result<CString, FMError> {
940 let sampling = match options.sampling() {
941 SamplingMode::Default => json!({ "mode": "default" }),
942 SamplingMode::Greedy => json!({ "mode": "greedy" }),
943 SamplingMode::TopK(k) => json!({
944 "mode": "top_k",
945 "topK": k,
946 "seed": options.sampling_seed(),
947 }),
948 SamplingMode::TopP(p) => json!({
949 "mode": "top_p",
950 "topP": p,
951 "seed": options.sampling_seed(),
952 }),
953 };
954 let include_schema_in_prompt = schema.map_or(include_schema_in_prompt, |schema| {
955 schema.effective_include_schema_in_prompt(include_schema_in_prompt)
956 });
957 let payload = serde_json::to_string(&json!({
958 "prompt": prompt.to_bridge_value(),
959 "options": {
960 "temperature": options.temperature(),
961 "maximumResponseTokens": options.maximum_response_tokens(),
962 "sampling": sampling,
963 },
964 "schemaJSON": schema.map(GenerationSchema::bridge_request_json),
965 "includeSchemaInPrompt": include_schema_in_prompt,
966 }))
967 .map_err(|error| {
968 FMError::InvalidArgument(format!("request is not JSON-serializable: {error}"))
969 })?;
970 CString::new(payload).map_err(|error| {
971 FMError::InvalidArgument(format!("request JSON contains a NUL byte: {error}"))
972 })
973}
974
975fn request_response(session: *mut c_void, payload: &CString) -> Result<String, FMError> {
976 let (tx, rx) = mpsc::channel();
977 let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
978 let context = Box::into_raw(tx_box).cast::<c_void>();
979 unsafe {
980 ffi::fm_session_respond_request_json(session, payload.as_ptr(), context, respond_trampoline)
981 };
982 rx.recv().map_err(|_| FMError::Unknown {
983 code: ffi::status::UNKNOWN,
984 message: "Swift bridge dropped the JSON response channel".into(),
985 })?
986}
987
988pub(crate) fn decode_bridge_text_response(
989 payload: &str,
990) -> Result<SessionResponse<String>, FMError> {
991 let response: BridgeTextResponse = serde_json::from_str(payload)
992 .map_err(|error| FMError::DecodingFailure(error.to_string()))?;
993 Ok(SessionResponse {
994 content: response.content,
995 raw_content: GeneratedContent::from_bridge_payload(response.raw_content, true)?,
996 transcript: Transcript::from_json_str(&response.transcript_json)?,
997 })
998}
999
1000pub(crate) fn request_text_response_with<F>(invoke: F) -> Result<SessionResponse<String>, FMError>
1001where
1002 F: FnOnce(*mut c_void, ffi::FmRespondCallback),
1003{
1004 let (tx, rx) = mpsc::channel();
1005 let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
1006 let context = Box::into_raw(tx_box).cast::<c_void>();
1007 invoke(context, respond_trampoline);
1008 let payload = rx.recv().map_err(|_| FMError::Unknown {
1009 code: ffi::status::UNKNOWN,
1010 message: "Swift bridge dropped the JSON response channel".into(),
1011 })??;
1012 decode_bridge_text_response(&payload)
1013}
1014
1015pub(crate) fn run_text_stream_with<F, C>(invoke: F, on_chunk: C) -> Result<(), FMError>
1016where
1017 F: FnOnce(*mut c_void, ffi::FmStreamCallback),
1018 C: FnMut(StreamEvent<'_>) + Send + 'static,
1019{
1020 let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
1021 let state = Arc::new(StreamState {
1022 on_chunk: Mutex::new(Box::new(on_chunk)),
1023 done_tx: Mutex::new(Some(done_tx)),
1024 });
1025 let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
1026 invoke(context, json_text_stream_trampoline);
1027 done_rx.recv().map_err(|_| FMError::Unknown {
1028 code: ffi::status::UNKNOWN,
1029 message: "Swift bridge dropped the stream channel".into(),
1030 })?
1031}
1032
1033fn prompt_to_plain_text(prompt: &Prompt) -> Option<String> {
1034 let mut text = String::new();
1035 for segment in prompt.segments() {
1036 match segment {
1037 crate::prompt::Segment::Text(segment) => text.push_str(&segment.text),
1038 crate::prompt::Segment::Structure(_) => return None,
1039 }
1040 }
1041 Some(text)
1042}
1043
1044impl Default for LanguageModelSession {
1045 fn default() -> Self {
1046 Self::new()
1047 }
1048}
1049
1050impl Drop for LanguageModelSession {
1051 fn drop(&mut self) {
1052 if !self.ptr.is_null() {
1053 unsafe { ffi::fm_object_release(self.ptr) };
1054 }
1055 }
1056}
1057
1058impl core::fmt::Debug for LanguageModelSession {
1059 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1060 f.debug_struct("LanguageModelSession")
1061 .field("ptr", &self.ptr)
1062 .finish()
1063 }
1064}
1065
1066#[derive(Debug)]
1068#[non_exhaustive]
1069pub enum StreamEvent<'a> {
1070 Chunk(&'a str),
1072 Done,
1074 Error(FMError),
1076}
1077
1078unsafe extern "C" fn respond_trampoline(
1085 context: *mut c_void,
1086 response: *mut c_char,
1087 error: *mut c_char,
1088 status: i32,
1089) {
1090 let tx = Box::from_raw(context.cast::<mpsc::Sender<Result<String, FMError>>>());
1091 let result = if status == ffi::status::OK && !response.is_null() {
1092 let s = core::ffi::CStr::from_ptr(response)
1093 .to_string_lossy()
1094 .into_owned();
1095 ffi::fm_string_free(response);
1096 Ok(s)
1097 } else {
1098 Err(crate::error::from_swift(status, error))
1099 };
1100 let _ = tx.send(result);
1101}
1102
1103type StreamCallback = Box<dyn FnMut(StreamEvent<'_>) + Send>;
1104
1105struct StreamState {
1106 on_chunk: Mutex<StreamCallback>,
1107 done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
1108}
1109
1110unsafe extern "C" fn json_text_stream_trampoline(
1116 context: *mut c_void,
1117 chunk: *mut c_char,
1118 done: bool,
1119 status: i32,
1120) {
1121 let state = Arc::from_raw(context.cast::<StreamState>());
1122 let state_for_swift = state.clone();
1123 core::mem::forget(state_for_swift);
1124
1125 let payload: Option<String> = if chunk.is_null() {
1126 None
1127 } else {
1128 let value = core::ffi::CStr::from_ptr(chunk)
1129 .to_string_lossy()
1130 .into_owned();
1131 ffi::fm_string_free(chunk);
1132 Some(value)
1133 };
1134
1135 if status != ffi::status::OK {
1136 let err = payload
1137 .map(|message| {
1138 crate::error::from_swift(
1139 status,
1140 ffi::fm_string_dup(
1141 CString::new(message)
1142 .expect("stream errors must not contain NUL bytes")
1143 .as_ptr(),
1144 ),
1145 )
1146 })
1147 .unwrap_or_else(|| crate::error::from_swift(status, ptr::null_mut()));
1148 {
1149 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1150 let _ = catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Error(err.clone()))));
1152 }
1153 if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1154 let _ = tx.send(Err(err));
1155 }
1156 drop(Arc::from_raw(Arc::as_ptr(&state)));
1157 drop(state);
1158 return;
1159 }
1160
1161 if let Some(payload) = payload {
1162 match serde_json::from_str::<BridgeTextStreamSnapshot>(&payload) {
1163 Ok(snapshot) if !snapshot.delta.is_empty() => {
1164 let chunk_panicked = {
1165 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1166 catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Chunk(&snapshot.delta))))
1168 .is_err()
1169 };
1170 if chunk_panicked {
1171 if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1172 let _ = tx.send(Err(FMError::Unknown {
1173 code: ffi::status::UNKNOWN,
1174 message: "stream callback panicked".into(),
1175 }));
1176 }
1177 drop(Arc::from_raw(Arc::as_ptr(&state)));
1178 drop(state);
1179 return;
1180 }
1181 }
1182 Ok(_) => {}
1183 Err(error) => {
1184 let err = FMError::DecodingFailure(error.to_string());
1185 {
1186 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1187 let _ = catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Error(err.clone()))));
1188 }
1189 if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1190 let _ = tx.send(Err(err));
1191 }
1192 drop(Arc::from_raw(Arc::as_ptr(&state)));
1193 drop(state);
1194 return;
1195 }
1196 }
1197 }
1198
1199 if done {
1200 {
1201 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1202 let _ = catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Done)));
1203 }
1204 if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1205 let _ = tx.send(Ok(()));
1206 }
1207 drop(Arc::from_raw(Arc::as_ptr(&state)));
1208 }
1209 drop(state);
1210}
1211
1212type StructuredStreamCallback = Box<dyn FnMut(StructuredStreamEvent) + Send>;
1213
1214struct StructuredStreamState {
1215 on_event: Mutex<StructuredStreamCallback>,
1216 done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
1217}
1218
1219#[allow(clippy::too_many_lines)]
1222unsafe extern "C" fn structured_stream_trampoline(
1223 context: *mut c_void,
1224 chunk: *mut c_char,
1225 done: bool,
1226 status: i32,
1227) {
1228 let state = Arc::from_raw(context.cast::<StructuredStreamState>());
1229 let state_for_swift = state.clone();
1230 core::mem::forget(state_for_swift);
1231
1232 let payload: Option<String> = if chunk.is_null() {
1233 None
1234 } else {
1235 let value = core::ffi::CStr::from_ptr(chunk)
1236 .to_string_lossy()
1237 .into_owned();
1238 ffi::fm_string_free(chunk);
1239 Some(value)
1240 };
1241
1242 if status != ffi::status::OK {
1243 let err = payload
1244 .map(|message| {
1245 crate::error::from_swift(
1246 status,
1247 ffi::fm_string_dup(
1248 CString::new(message)
1249 .expect("stream errors must not contain NUL bytes")
1250 .as_ptr(),
1251 ),
1252 )
1253 })
1254 .unwrap_or_else(|| crate::error::from_swift(status, ptr::null_mut()));
1255 {
1256 let mut cb = state
1257 .on_event
1258 .lock()
1259 .expect("structured callback mutex poisoned");
1260 let _ = catch_unwind(AssertUnwindSafe(|| {
1262 cb(StructuredStreamEvent::Error(err.clone()));
1263 }));
1264 }
1265 if let Some(tx) = state
1266 .done_tx
1267 .lock()
1268 .expect("structured done_tx mutex poisoned")
1269 .take()
1270 {
1271 let _ = tx.send(Err(err));
1272 }
1273 drop(Arc::from_raw(Arc::as_ptr(&state)));
1274 drop(state);
1275 return;
1276 }
1277
1278 if let Some(payload) = payload {
1279 let snapshot: BridgeStructuredSnapshot = match serde_json::from_str(&payload) {
1280 Ok(snapshot) => snapshot,
1281 Err(error) => {
1282 let err = FMError::DecodingFailure(error.to_string());
1283 {
1284 let mut cb = state
1285 .on_event
1286 .lock()
1287 .expect("structured callback mutex poisoned");
1288 let _ = catch_unwind(AssertUnwindSafe(|| {
1289 cb(StructuredStreamEvent::Error(err.clone()));
1290 }));
1291 }
1292 if let Some(tx) = state
1293 .done_tx
1294 .lock()
1295 .expect("structured done_tx mutex poisoned")
1296 .take()
1297 {
1298 let _ = tx.send(Err(err));
1299 }
1300 drop(Arc::from_raw(Arc::as_ptr(&state)));
1301 drop(state);
1302 return;
1303 }
1304 };
1305 let snapshot_event = StructuredStreamEvent::Snapshot(StructuredStreamSnapshot {
1306 content_json: snapshot.content.json,
1307 raw_content_json: snapshot.raw_content.json,
1308 is_complete: snapshot.is_complete,
1309 });
1310 let snapshot_panicked = {
1311 let mut cb = state
1312 .on_event
1313 .lock()
1314 .expect("structured callback mutex poisoned");
1315 catch_unwind(AssertUnwindSafe(|| cb(snapshot_event))).is_err()
1317 };
1318 if snapshot_panicked {
1319 if let Some(tx) = state
1320 .done_tx
1321 .lock()
1322 .expect("structured done_tx mutex poisoned")
1323 .take()
1324 {
1325 let _ = tx.send(Err(FMError::Unknown {
1326 code: ffi::status::UNKNOWN,
1327 message: "stream callback panicked".into(),
1328 }));
1329 }
1330 drop(Arc::from_raw(Arc::as_ptr(&state)));
1331 drop(state);
1332 return;
1333 }
1334 }
1335
1336 if done {
1337 {
1338 let mut cb = state
1339 .on_event
1340 .lock()
1341 .expect("structured callback mutex poisoned");
1342 let _ = catch_unwind(AssertUnwindSafe(|| cb(StructuredStreamEvent::Done)));
1343 }
1344 if let Some(tx) = state
1345 .done_tx
1346 .lock()
1347 .expect("structured done_tx mutex poisoned")
1348 .take()
1349 {
1350 let _ = tx.send(Ok(()));
1351 }
1352 drop(Arc::from_raw(Arc::as_ptr(&state)));
1353 }
1354 drop(state);
1355}