1use core::ffi::{c_char, c_void};
4use core::ptr;
5use std::ffi::CString;
6use std::panic::{catch_unwind, AssertUnwindSafe};
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::mpsc;
9use std::sync::{Arc, Mutex};
10
11use serde::Deserialize;
12use serde_json::json;
13
14use crate::content::{BridgeGeneratedContent, GeneratedContent};
15use crate::error::FMError;
16use crate::ffi;
17use crate::generation::{GenerationOptions, SamplingMode};
18use crate::model::ConfiguredSystemLanguageModel;
19use crate::prompt::{Instructions, Prompt, ToInstructions, ToPrompt};
20use crate::schema::GenerationSchema;
21use crate::tool::{tool_callback_trampoline, Tool, ToolRegistry};
22use crate::transcript::Transcript;
23
24pub struct LanguageModelSession {
42 ptr: *mut c_void,
43 _tool_registry: Option<Arc<ToolRegistry>>,
44}
45
46unsafe impl Send for LanguageModelSession {}
52unsafe impl Sync for LanguageModelSession {}
53
54impl LanguageModelSession {
55 #[cfg(feature = "async")]
60 pub(crate) fn as_ptr(&self) -> *mut c_void {
61 self.ptr
62 }
63
64 #[must_use]
72 pub fn new() -> Self {
73 Self::try_new(None).expect("FoundationModels is not available on this OS")
74 }
75
76 #[must_use]
83 pub fn with_instructions(instructions: &str) -> Self {
84 Self::try_new(Some(instructions)).expect("FoundationModels is not available on this OS")
85 }
86
87 #[must_use]
91 pub fn try_new(instructions: Option<&str>) -> Option<Self> {
92 let cstring = match instructions {
93 Some(s) => Some(CString::new(s).ok()?),
94 None => None,
95 };
96 let ptr =
97 unsafe { ffi::fm_session_create(cstring.as_ref().map_or(ptr::null(), |s| s.as_ptr())) };
98 if ptr.is_null() {
99 return None;
100 }
101 Some(Self {
102 ptr,
103 _tool_registry: None,
104 })
105 }
106
107 pub fn respond(&self, prompt: &str) -> Result<String, FMError> {
115 self.respond_with(prompt, GenerationOptions::new())
116 }
117
118 pub fn prewarm(&self) {
122 unsafe { ffi::fm_session_prewarm(self.ptr) };
123 }
124
125 #[must_use]
128 pub fn is_responding(&self) -> bool {
129 unsafe { ffi::fm_session_is_responding(self.ptr) }
130 }
131
132 #[must_use]
137 pub fn transcript_json(&self) -> String {
138 let p = unsafe { ffi::fm_session_transcript_json(self.ptr) };
139 if p.is_null() {
140 return String::from("{}");
141 }
142 let s = unsafe { core::ffi::CStr::from_ptr(p) }
143 .to_string_lossy()
144 .into_owned();
145 unsafe { ffi::fm_string_free(p) };
146 s
147 }
148
149 pub fn log_feedback(&self, sentiment: i32, description: Option<&str>) {
153 let cstr = description.and_then(|s| CString::new(s).ok());
154 let p = cstr.as_ref().map_or(core::ptr::null(), |c| c.as_ptr());
155 unsafe { ffi::fm_session_log_feedback(self.ptr, sentiment, p) };
156 }
157
158 pub fn respond_with_json_schema(
172 &self,
173 prompt: &str,
174 schema_description: &str,
175 ) -> Result<String, FMError> {
176 let wrapped = format!(
177 "{prompt}\n\n\
178 IMPORTANT: respond with VALID JSON ONLY (no prose, no markdown \
179 fences) that matches this schema:\n\n{schema_description}\n\n\
180 Your entire response must be parseable by JSON.parse()."
181 );
182 self.respond(&wrapped)
183 }
184
185 pub fn respond_with(
191 &self,
192 prompt: &str,
193 options: GenerationOptions,
194 ) -> Result<String, FMError> {
195 self.respond_prompt_with(prompt, options)
196 }
197
198 pub fn respond_with_schema(
231 &self,
232 prompt: &str,
233 schema: &str,
234 include_schema_in_prompt: bool,
235 ) -> Result<String, FMError> {
236 self.respond_with_schema_options(
237 prompt,
238 schema,
239 include_schema_in_prompt,
240 GenerationOptions::new(),
241 )
242 }
243
244 pub fn respond_with_schema_options(
251 &self,
252 prompt: &str,
253 schema: &str,
254 include_schema_in_prompt: bool,
255 options: GenerationOptions,
256 ) -> Result<String, FMError> {
257 let prompt_c = CString::new(prompt)
258 .map_err(|e| FMError::InvalidArgument(format!("prompt NUL byte: {e}")))?;
259 let schema_c = CString::new(schema)
260 .map_err(|e| FMError::InvalidArgument(format!("schema NUL byte: {e}")))?;
261 let opts = options.to_ffi();
262 let (tx, rx) = mpsc::channel();
263 let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
264 let context = Box::into_raw(tx_box).cast::<c_void>();
265
266 unsafe {
267 ffi::fm_session_respond_with_schema(
268 self.ptr,
269 prompt_c.as_ptr(),
270 schema_c.as_ptr(),
271 include_schema_in_prompt,
272 opts.temperature,
273 opts.maximum_response_tokens,
274 opts.sampling_mode,
275 opts.top_k,
276 opts.top_p,
277 context,
278 respond_trampoline,
279 );
280 }
281
282 rx.recv().map_err(|_| FMError::Unknown {
283 code: ffi::status::UNKNOWN,
284 message: "Swift bridge dropped the callback channel".into(),
285 })?
286 }
287
288 pub fn stream<F>(&self, prompt: &str, mut on_chunk: F) -> Result<(), FMError>
297 where
298 F: FnMut(StreamEvent<'_>) + Send + 'static,
299 {
300 self.stream_with(prompt, GenerationOptions::new(), move |event| {
301 on_chunk(event);
302 })
303 }
304
305 pub fn stream_with<F>(
311 &self,
312 prompt: &str,
313 options: GenerationOptions,
314 on_chunk: F,
315 ) -> Result<(), FMError>
316 where
317 F: FnMut(StreamEvent<'_>) + Send + 'static,
318 {
319 let payload = respond_request_json(&Prompt::from(prompt), options, None, true)?;
320
321 let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
322 let state = Arc::new(StreamState {
323 on_chunk: Mutex::new(Box::new(on_chunk)),
324 done_tx: Mutex::new(Some(done_tx)),
325 finished: AtomicBool::new(false),
326 });
327 let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
328
329 unsafe {
330 ffi::fm_session_stream_request_json(
331 self.ptr,
332 payload.as_ptr(),
333 context,
334 json_text_stream_trampoline,
335 )
336 };
337
338 done_rx.recv().map_err(|_| FMError::Unknown {
339 code: ffi::status::UNKNOWN,
340 message: "Swift bridge dropped the stream channel".into(),
341 })?
342 }
343}
344
345impl LanguageModelSession {
346 #[must_use]
348 pub fn builder<'a>() -> SessionBuilder<'a> {
349 SessionBuilder::new()
350 }
351
352 pub fn from_transcript(transcript: Transcript) -> Result<Self, FMError> {
358 Self::builder().transcript(transcript).build()
359 }
360
361 pub fn transcript(&self) -> Result<Transcript, FMError> {
368 Transcript::from_json_str(&self.transcript_json())
369 }
370
371 pub fn prewarm_with_prompt<P>(&self, prompt: P) -> Result<(), FMError>
377 where
378 P: ToPrompt,
379 {
380 let prompt = prompt.to_prompt()?;
381 let prompt_json = CString::new(prompt.to_bridge_json()?).map_err(|error| {
382 FMError::InvalidArgument(format!("prompt JSON contains a NUL byte: {error}"))
383 })?;
384 let mut error: *mut c_char = ptr::null_mut();
385 let status = unsafe {
386 ffi::fm_session_prewarm_prompt_json(self.ptr, prompt_json.as_ptr(), &mut error)
387 };
388 if status != ffi::status::OK {
389 return Err(crate::error::from_swift(status, error));
390 }
391 Ok(())
392 }
393
394 pub fn respond_prompt<P>(&self, prompt: P) -> Result<String, FMError>
400 where
401 P: ToPrompt,
402 {
403 self.respond_prompt_with(prompt, GenerationOptions::new())
404 }
405
406 pub fn respond_prompt_with<P>(
412 &self,
413 prompt: P,
414 options: GenerationOptions,
415 ) -> Result<String, FMError>
416 where
417 P: ToPrompt,
418 {
419 self.respond_prompt_detailed(prompt, options)
420 .map(|response| response.content)
421 }
422
423 pub fn respond_prompt_detailed<P>(
429 &self,
430 prompt: P,
431 options: GenerationOptions,
432 ) -> Result<SessionResponse<String>, FMError>
433 where
434 P: ToPrompt,
435 {
436 let prompt = prompt.to_prompt()?;
437 let payload = respond_request_json(&prompt, options, None, true)?;
438 let payload = request_response(self.ptr, &payload)?;
439 let response: BridgeTextResponse = serde_json::from_str(&payload)
440 .map_err(|error| FMError::DecodingFailure(error.to_string()))?;
441 Ok(SessionResponse {
442 content: response.content,
443 raw_content: GeneratedContent::from_bridge_payload(response.raw_content, true)?,
444 transcript: Transcript::from_json_str(&response.transcript_json)?,
445 })
446 }
447
448 pub fn respond_generated<P>(
454 &self,
455 prompt: P,
456 schema: &GenerationSchema,
457 include_schema_in_prompt: bool,
458 ) -> Result<GeneratedContent, FMError>
459 where
460 P: ToPrompt,
461 {
462 self.respond_generated_with(
463 prompt,
464 schema,
465 include_schema_in_prompt,
466 GenerationOptions::new(),
467 )
468 .map(|response| response.content)
469 }
470
471 pub fn respond_generated_with<P>(
477 &self,
478 prompt: P,
479 schema: &GenerationSchema,
480 include_schema_in_prompt: bool,
481 options: GenerationOptions,
482 ) -> Result<SessionResponse<GeneratedContent>, FMError>
483 where
484 P: ToPrompt,
485 {
486 let prompt = prompt.to_prompt()?;
487 let payload =
488 respond_request_json(&prompt, options, Some(schema), include_schema_in_prompt)?;
489 let payload = request_response(self.ptr, &payload)?;
490 let response: BridgeStructuredResponse = serde_json::from_str(&payload)
491 .map_err(|error| FMError::DecodingFailure(error.to_string()))?;
492 Ok(SessionResponse {
493 content: GeneratedContent::from_bridge_payload(response.content, true)?,
494 raw_content: GeneratedContent::from_bridge_payload(response.raw_content, true)?,
495 transcript: Transcript::from_json_str(&response.transcript_json)?,
496 })
497 }
498
499 pub fn respond_generating<P, T>(
506 &self,
507 prompt: P,
508 include_schema_in_prompt: bool,
509 options: GenerationOptions,
510 ) -> Result<SessionResponse<T>, FMError>
511 where
512 P: ToPrompt,
513 T: crate::schema::Generable,
514 {
515 let response = self.respond_generated_with(
516 prompt,
517 &T::generation_schema()?,
518 include_schema_in_prompt,
519 options,
520 )?;
521 Ok(SessionResponse {
522 content: T::from_generated_content(&response.content)?,
523 raw_content: response.raw_content,
524 transcript: response.transcript,
525 })
526 }
527
528 pub fn stream_prompt<P, F>(&self, prompt: P, on_chunk: F) -> Result<(), FMError>
534 where
535 P: ToPrompt,
536 F: FnMut(StreamEvent<'_>) + Send + 'static,
537 {
538 let prompt = prompt.to_prompt()?;
539 let prompt_text = prompt_to_plain_text(&prompt).ok_or_else(|| {
540 FMError::InvalidArgument(
541 "text streaming only supports prompts composed of text segments".into(),
542 )
543 })?;
544 self.stream_with(&prompt_text, GenerationOptions::new(), on_chunk)
545 }
546
547 pub fn stream_generated<P, F>(
553 &self,
554 prompt: P,
555 schema: &GenerationSchema,
556 include_schema_in_prompt: bool,
557 options: GenerationOptions,
558 on_event: F,
559 ) -> Result<(), FMError>
560 where
561 P: ToPrompt,
562 F: FnMut(StructuredStreamEvent) + Send + 'static,
563 {
564 let prompt = prompt.to_prompt()?;
565 let payload =
566 respond_request_json(&prompt, options, Some(schema), include_schema_in_prompt)?;
567 let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
568 let state = Arc::new(StructuredStreamState {
569 on_event: Mutex::new(Box::new(on_event)),
570 done_tx: Mutex::new(Some(done_tx)),
571 finished: AtomicBool::new(false),
572 });
573 let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
574 unsafe {
575 ffi::fm_session_stream_request_json(
576 self.ptr,
577 payload.as_ptr(),
578 context,
579 structured_stream_trampoline,
580 )
581 };
582 done_rx.recv().map_err(|_| FMError::Unknown {
583 code: ffi::status::UNKNOWN,
584 message: "Swift bridge dropped the structured stream channel".into(),
585 })?
586 }
587
588 pub fn log_feedback_attachment(
594 &self,
595 request: FeedbackAttachmentRequest,
596 ) -> Result<Vec<u8>, FMError> {
597 let request_json = CString::new(request.to_bridge_json()?).map_err(|error| {
598 FMError::InvalidArgument(format!("feedback request contains a NUL byte: {error}"))
599 })?;
600 let mut length = 0usize;
601 let mut error: *mut c_char = ptr::null_mut();
602 let ptr = unsafe {
603 ffi::fm_session_log_feedback_attachment_json(
604 self.ptr,
605 request_json.as_ptr(),
606 &mut length,
607 &mut error,
608 )
609 };
610 if ptr.is_null() && !error.is_null() {
611 return Err(crate::error::from_swift(
612 ffi::status::INVALID_ARGUMENT,
613 error,
614 ));
615 }
616 if ptr.is_null() || length == 0 {
617 return Ok(Vec::new());
618 }
619 let bytes = unsafe { std::slice::from_raw_parts(ptr.cast::<u8>(), length) }.to_vec();
620 unsafe { ffi::fm_bytes_free(ptr) };
621 Ok(bytes)
622 }
623}
624
625pub struct SessionBuilder<'a> {
627 model: Option<&'a ConfiguredSystemLanguageModel>,
628 instructions: Option<Instructions>,
629 transcript: Option<Transcript>,
630 tools: Vec<Tool>,
631}
632
633impl<'a> SessionBuilder<'a> {
634 const fn new() -> Self {
635 Self {
636 model: None,
637 instructions: None,
638 transcript: None,
639 tools: Vec::new(),
640 }
641 }
642
643 #[must_use]
645 pub const fn model(mut self, model: &'a ConfiguredSystemLanguageModel) -> Self {
646 self.model = Some(model);
647 self
648 }
649
650 pub fn instructions<I>(mut self, instructions: I) -> Result<Self, FMError>
652 where
653 I: ToInstructions,
654 {
655 self.instructions = Some(instructions.to_instructions()?);
656 Ok(self)
657 }
658
659 #[must_use]
661 pub fn transcript(mut self, transcript: Transcript) -> Self {
662 self.transcript = Some(transcript);
663 self
664 }
665
666 #[must_use]
668 pub fn tool(mut self, tool: Tool) -> Self {
669 self.tools.push(tool);
670 self
671 }
672
673 #[must_use]
675 pub fn tools(mut self, tools: impl IntoIterator<Item = Tool>) -> Self {
676 self.tools.extend(tools);
677 self
678 }
679
680 pub fn build(self) -> Result<LanguageModelSession, FMError> {
686 if self.instructions.is_some() && self.transcript.is_some() {
687 return Err(FMError::InvalidArgument(
688 "session builder accepts either instructions or a transcript, not both".into(),
689 ));
690 }
691
692 let instructions_json = self
693 .instructions
694 .as_ref()
695 .map(Instructions::to_bridge_json)
696 .transpose()?;
697 let transcript_json = self
698 .transcript
699 .as_ref()
700 .map(Transcript::to_json_string)
701 .transpose()?;
702 let tool_registry = if self.tools.is_empty() {
703 None
704 } else {
705 Some(Arc::new(ToolRegistry::new(self.tools)))
706 };
707 let tools_json = tool_registry
708 .as_ref()
709 .map(|registry| registry.specs_json())
710 .transpose()?;
711
712 let instructions_c = instructions_json
713 .as_deref()
714 .map(CString::new)
715 .transpose()
716 .map_err(|error| {
717 FMError::InvalidArgument(format!("instructions JSON contains a NUL byte: {error}"))
718 })?;
719 let transcript_c = transcript_json
720 .as_deref()
721 .map(CString::new)
722 .transpose()
723 .map_err(|error| {
724 FMError::InvalidArgument(format!("transcript JSON contains a NUL byte: {error}"))
725 })?;
726 let tools_c = tools_json
727 .as_deref()
728 .map(CString::new)
729 .transpose()
730 .map_err(|error| {
731 FMError::InvalidArgument(format!("tool JSON contains a NUL byte: {error}"))
732 })?;
733
734 let tool_context = tool_registry.as_ref().map_or(ptr::null_mut(), |registry| {
735 Arc::as_ptr(registry).cast_mut().cast::<c_void>()
736 });
737 let mut error: *mut c_char = ptr::null_mut();
738 let ptr = unsafe {
739 ffi::fm_session_create_ex(
740 self.model.map_or(ptr::null_mut(), |model| model.ptr),
741 instructions_c
742 .as_ref()
743 .map_or(ptr::null(), |json| json.as_ptr()),
744 transcript_c
745 .as_ref()
746 .map_or(ptr::null(), |json| json.as_ptr()),
747 tools_c.as_ref().map_or(ptr::null(), |json| json.as_ptr()),
748 tool_context,
749 tool_registry
750 .as_ref()
751 .map(|_| tool_callback_trampoline as ffi::FmToolCallback),
752 &mut error,
753 )
754 };
755 if ptr.is_null() {
756 return Err(crate::error::from_swift(
757 ffi::status::MODEL_UNAVAILABLE,
758 error,
759 ));
760 }
761 Ok(LanguageModelSession {
762 ptr,
763 _tool_registry: tool_registry,
764 })
765 }
766}
767
768#[derive(Debug, Clone, PartialEq)]
770pub struct SessionResponse<T> {
771 pub content: T,
772 pub raw_content: GeneratedContent,
773 pub transcript: Transcript,
774}
775
776#[derive(Debug, Clone, PartialEq, Eq)]
778pub struct StructuredStreamSnapshot {
779 pub content_json: String,
780 pub raw_content_json: String,
781 pub is_complete: bool,
782}
783
784#[derive(Debug, Clone, PartialEq)]
786#[non_exhaustive]
787pub enum StructuredStreamEvent {
788 Snapshot(StructuredStreamSnapshot),
789 Done,
790 Error(FMError),
791}
792
793#[derive(Debug, Clone, Copy, PartialEq, Eq)]
795pub enum FeedbackIssueCategory {
796 Unhelpful,
797 TooVerbose,
798 DidNotFollowInstructions,
799 Incorrect,
800 StereotypeOrBias,
801 SuggestiveOrSexual,
802 VulgarOrOffensive,
803 TriggeredGuardrailUnexpectedly,
804}
805
806impl FeedbackIssueCategory {
807 const fn as_str(self) -> &'static str {
808 match self {
809 Self::Unhelpful => "unhelpful",
810 Self::TooVerbose => "too_verbose",
811 Self::DidNotFollowInstructions => "did_not_follow_instructions",
812 Self::Incorrect => "incorrect",
813 Self::StereotypeOrBias => "stereotype_or_bias",
814 Self::SuggestiveOrSexual => "suggestive_or_sexual",
815 Self::VulgarOrOffensive => "vulgar_or_offensive",
816 Self::TriggeredGuardrailUnexpectedly => "triggered_guardrail_unexpectedly",
817 }
818 }
819}
820
821#[derive(Debug, Clone, PartialEq, Eq)]
823pub struct FeedbackIssue {
824 pub category: FeedbackIssueCategory,
825 pub explanation: Option<String>,
826}
827
828#[derive(Debug, Clone, Copy, PartialEq, Eq)]
830pub enum FeedbackSentiment {
831 Positive,
832 Negative,
833 Neutral,
834}
835
836impl FeedbackSentiment {
837 const fn as_str(self) -> &'static str {
838 match self {
839 Self::Positive => "positive",
840 Self::Negative => "negative",
841 Self::Neutral => "neutral",
842 }
843 }
844}
845
846#[derive(Debug, Clone, PartialEq)]
848pub struct FeedbackAttachmentRequest {
849 pub sentiment: Option<FeedbackSentiment>,
850 pub issues: Vec<FeedbackIssue>,
851 pub desired_response_text: Option<String>,
852 pub desired_response_content: Option<GeneratedContent>,
853 pub desired_output: Option<crate::transcript::Entry>,
854}
855
856impl FeedbackAttachmentRequest {
857 #[must_use]
859 pub const fn new() -> Self {
860 Self {
861 sentiment: None,
862 issues: Vec::new(),
863 desired_response_text: None,
864 desired_response_content: None,
865 desired_output: None,
866 }
867 }
868
869 fn to_bridge_json(&self) -> Result<String, FMError> {
870 let issues = self
871 .issues
872 .iter()
873 .map(|issue| {
874 json!({
875 "category": issue.category.as_str(),
876 "explanation": issue.explanation,
877 })
878 })
879 .collect::<Vec<_>>();
880 let desired_output_json = self
881 .desired_output
882 .as_ref()
883 .map(|entry| Transcript::from(vec![entry.clone()]).to_json_string())
884 .transpose()?;
885 let desired_response_content = self
886 .desired_response_content
887 .as_ref()
888 .map(GeneratedContent::to_bridge_value)
889 .transpose()?;
890 serde_json::to_string(&json!({
891 "sentiment": self.sentiment.map(FeedbackSentiment::as_str),
892 "issues": issues,
893 "desiredResponseText": self.desired_response_text,
894 "desiredResponseContent": desired_response_content,
895 "desiredOutputTranscriptJSON": desired_output_json,
896 }))
897 .map_err(|error| {
898 FMError::InvalidArgument(format!(
899 "feedback request is not JSON-serializable: {error}"
900 ))
901 })
902 }
903}
904
905#[derive(Debug, Deserialize)]
906struct BridgeTextResponse {
907 content: String,
908 #[serde(rename = "rawContent")]
909 raw_content: BridgeGeneratedContent,
910 #[serde(rename = "transcriptJSON")]
911 transcript_json: String,
912}
913
914#[derive(Debug, Deserialize)]
915struct BridgeStructuredResponse {
916 content: BridgeGeneratedContent,
917 #[serde(rename = "rawContent")]
918 raw_content: BridgeGeneratedContent,
919 #[serde(rename = "transcriptJSON")]
920 transcript_json: String,
921}
922
923#[derive(Debug, Deserialize)]
924struct BridgeStructuredSnapshot {
925 content: BridgeGeneratedContent,
926 #[serde(rename = "rawContent")]
927 raw_content: BridgeGeneratedContent,
928 #[serde(rename = "isComplete")]
929 is_complete: bool,
930}
931
932#[derive(Debug, Deserialize)]
933struct BridgeTextStreamSnapshot {
934 delta: String,
935}
936
937fn respond_request_json(
938 prompt: &Prompt,
939 options: GenerationOptions,
940 schema: Option<&GenerationSchema>,
941 include_schema_in_prompt: bool,
942) -> Result<CString, FMError> {
943 let sampling = match options.sampling() {
944 SamplingMode::Default => json!({ "mode": "default" }),
945 SamplingMode::Greedy => json!({ "mode": "greedy" }),
946 SamplingMode::TopK(k) => json!({
947 "mode": "top_k",
948 "topK": k,
949 "seed": options.sampling_seed(),
950 }),
951 SamplingMode::TopP(p) => json!({
952 "mode": "top_p",
953 "topP": p,
954 "seed": options.sampling_seed(),
955 }),
956 };
957 let include_schema_in_prompt = schema.map_or(include_schema_in_prompt, |schema| {
958 schema.effective_include_schema_in_prompt(include_schema_in_prompt)
959 });
960 let payload = serde_json::to_string(&json!({
961 "prompt": prompt.to_bridge_value(),
962 "options": {
963 "temperature": options.temperature(),
964 "maximumResponseTokens": options.maximum_response_tokens(),
965 "sampling": sampling,
966 },
967 "schemaJSON": schema.map(GenerationSchema::bridge_request_json),
968 "includeSchemaInPrompt": include_schema_in_prompt,
969 }))
970 .map_err(|error| {
971 FMError::InvalidArgument(format!("request is not JSON-serializable: {error}"))
972 })?;
973 CString::new(payload).map_err(|error| {
974 FMError::InvalidArgument(format!("request JSON contains a NUL byte: {error}"))
975 })
976}
977
978fn request_response(session: *mut c_void, payload: &CString) -> Result<String, FMError> {
979 let (tx, rx) = mpsc::channel();
980 let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
981 let context = Box::into_raw(tx_box).cast::<c_void>();
982 unsafe {
983 ffi::fm_session_respond_request_json(session, payload.as_ptr(), context, respond_trampoline)
984 };
985 rx.recv().map_err(|_| FMError::Unknown {
986 code: ffi::status::UNKNOWN,
987 message: "Swift bridge dropped the JSON response channel".into(),
988 })?
989}
990
991pub(crate) fn decode_bridge_text_response(
992 payload: &str,
993) -> Result<SessionResponse<String>, FMError> {
994 let response: BridgeTextResponse = serde_json::from_str(payload)
995 .map_err(|error| FMError::DecodingFailure(error.to_string()))?;
996 Ok(SessionResponse {
997 content: response.content,
998 raw_content: GeneratedContent::from_bridge_payload(response.raw_content, true)?,
999 transcript: Transcript::from_json_str(&response.transcript_json)?,
1000 })
1001}
1002
1003pub(crate) fn request_text_response_with<F>(invoke: F) -> Result<SessionResponse<String>, FMError>
1004where
1005 F: FnOnce(*mut c_void, ffi::FmRespondCallback),
1006{
1007 let (tx, rx) = mpsc::channel();
1008 let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
1009 let context = Box::into_raw(tx_box).cast::<c_void>();
1010 invoke(context, respond_trampoline);
1011 let payload = rx.recv().map_err(|_| FMError::Unknown {
1012 code: ffi::status::UNKNOWN,
1013 message: "Swift bridge dropped the JSON response channel".into(),
1014 })??;
1015 decode_bridge_text_response(&payload)
1016}
1017
1018pub(crate) fn run_text_stream_with<F, C>(invoke: F, on_chunk: C) -> Result<(), FMError>
1019where
1020 F: FnOnce(*mut c_void, ffi::FmStreamCallback),
1021 C: FnMut(StreamEvent<'_>) + Send + 'static,
1022{
1023 let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
1024 let state = Arc::new(StreamState {
1025 on_chunk: Mutex::new(Box::new(on_chunk)),
1026 done_tx: Mutex::new(Some(done_tx)),
1027 finished: AtomicBool::new(false),
1028 });
1029 let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
1030 invoke(context, json_text_stream_trampoline);
1031 done_rx.recv().map_err(|_| FMError::Unknown {
1032 code: ffi::status::UNKNOWN,
1033 message: "Swift bridge dropped the stream channel".into(),
1034 })?
1035}
1036
1037fn prompt_to_plain_text(prompt: &Prompt) -> Option<String> {
1038 let mut text = String::new();
1039 for segment in prompt.segments() {
1040 match segment {
1041 crate::prompt::Segment::Text(segment) => text.push_str(&segment.text),
1042 crate::prompt::Segment::Structure(_) => return None,
1043 }
1044 }
1045 Some(text)
1046}
1047
1048impl Default for LanguageModelSession {
1049 fn default() -> Self {
1050 Self::new()
1051 }
1052}
1053
1054impl Drop for LanguageModelSession {
1055 fn drop(&mut self) {
1056 if !self.ptr.is_null() {
1057 unsafe { ffi::fm_object_release(self.ptr) };
1058 }
1059 }
1060}
1061
1062impl core::fmt::Debug for LanguageModelSession {
1063 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1064 f.debug_struct("LanguageModelSession")
1065 .field("ptr", &self.ptr)
1066 .finish()
1067 }
1068}
1069
1070#[derive(Debug)]
1072#[non_exhaustive]
1073pub enum StreamEvent<'a> {
1074 Chunk(&'a str),
1076 Done,
1078 Error(FMError),
1080}
1081
1082unsafe extern "C" fn respond_trampoline(
1089 context: *mut c_void,
1090 response: *mut c_char,
1091 error: *mut c_char,
1092 status: i32,
1093) {
1094 let tx = Box::from_raw(context.cast::<mpsc::Sender<Result<String, FMError>>>());
1095 let result = if status == ffi::status::OK && !response.is_null() {
1096 let s = core::ffi::CStr::from_ptr(response)
1097 .to_string_lossy()
1098 .into_owned();
1099 ffi::fm_string_free(response);
1100 Ok(s)
1101 } else {
1102 Err(crate::error::from_swift(status, error))
1103 };
1104 let _ = tx.send(result);
1105}
1106
1107type StreamCallback = Box<dyn FnMut(StreamEvent<'_>) + Send>;
1108
1109struct StreamState {
1110 on_chunk: Mutex<StreamCallback>,
1111 done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
1112 finished: AtomicBool,
1118}
1119
1120unsafe extern "C" fn json_text_stream_trampoline(
1131 context: *mut c_void,
1132 chunk: *mut c_char,
1133 done: bool,
1134 status: i32,
1135) {
1136 let state = Arc::from_raw(context.cast::<StreamState>());
1137 let state_for_swift = state.clone();
1138 core::mem::forget(state_for_swift);
1139
1140 let already_finished = state.finished.load(Ordering::Acquire);
1141
1142 let payload: Option<String> = if chunk.is_null() {
1143 None
1144 } else {
1145 let value = core::ffi::CStr::from_ptr(chunk)
1146 .to_string_lossy()
1147 .into_owned();
1148 ffi::fm_string_free(chunk);
1149 Some(value)
1150 };
1151
1152 if status != ffi::status::OK {
1153 if !already_finished {
1155 let err = payload
1156 .map(|message| {
1157 crate::error::from_swift(
1158 status,
1159 ffi::fm_string_dup(
1160 CString::new(message)
1161 .expect("stream errors must not contain NUL bytes")
1162 .as_ptr(),
1163 ),
1164 )
1165 })
1166 .unwrap_or_else(|| crate::error::from_swift(status, ptr::null_mut()));
1167 {
1168 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1169 let _ = catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Error(err.clone()))));
1171 }
1172 if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1173 let _ = tx.send(Err(err));
1174 }
1175 }
1176 drop(Arc::from_raw(Arc::as_ptr(&state)));
1177 drop(state);
1178 return;
1179 }
1180
1181 if !already_finished {
1182 if let Some(payload) = payload {
1183 match serde_json::from_str::<BridgeTextStreamSnapshot>(&payload) {
1184 Ok(snapshot) if !snapshot.delta.is_empty() => {
1185 let chunk_panicked = {
1186 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1187 catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Chunk(&snapshot.delta))))
1189 .is_err()
1190 };
1191 if chunk_panicked {
1192 state.finished.store(true, Ordering::Release);
1195 if let Some(tx) =
1196 state.done_tx.lock().expect("done_tx mutex poisoned").take()
1197 {
1198 let _ = tx.send(Err(FMError::Unknown {
1199 code: ffi::status::UNKNOWN,
1200 message: "stream callback panicked".into(),
1201 }));
1202 }
1203 drop(state);
1204 return;
1205 }
1206 }
1207 Ok(_) => {}
1208 Err(error) => {
1209 let err = FMError::DecodingFailure(error.to_string());
1211 state.finished.store(true, Ordering::Release);
1212 {
1213 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1214 let _ = catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Error(err.clone()))));
1215 }
1216 if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1217 let _ = tx.send(Err(err));
1218 }
1219 drop(state);
1220 return;
1221 }
1222 }
1223 }
1224 }
1225
1226 if done {
1227 if !already_finished {
1228 {
1229 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1230 let _ = catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Done)));
1231 }
1232 if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1233 let _ = tx.send(Ok(()));
1234 }
1235 }
1236 drop(Arc::from_raw(Arc::as_ptr(&state)));
1237 }
1238 drop(state);
1239}
1240
1241type StructuredStreamCallback = Box<dyn FnMut(StructuredStreamEvent) + Send>;
1242
1243struct StructuredStreamState {
1244 on_event: Mutex<StructuredStreamCallback>,
1245 done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
1246 finished: AtomicBool,
1248}
1249
1250#[allow(clippy::too_many_lines)]
1253unsafe extern "C" fn structured_stream_trampoline(
1254 context: *mut c_void,
1255 chunk: *mut c_char,
1256 done: bool,
1257 status: i32,
1258) {
1259 let state = Arc::from_raw(context.cast::<StructuredStreamState>());
1260 let state_for_swift = state.clone();
1261 core::mem::forget(state_for_swift);
1262
1263 let already_finished = state.finished.load(Ordering::Acquire);
1264
1265 let payload: Option<String> = if chunk.is_null() {
1266 None
1267 } else {
1268 let value = core::ffi::CStr::from_ptr(chunk)
1269 .to_string_lossy()
1270 .into_owned();
1271 ffi::fm_string_free(chunk);
1272 Some(value)
1273 };
1274
1275 if status != ffi::status::OK {
1276 if !already_finished {
1278 let err = payload
1279 .map(|message| {
1280 crate::error::from_swift(
1281 status,
1282 ffi::fm_string_dup(
1283 CString::new(message)
1284 .expect("stream errors must not contain NUL bytes")
1285 .as_ptr(),
1286 ),
1287 )
1288 })
1289 .unwrap_or_else(|| crate::error::from_swift(status, ptr::null_mut()));
1290 {
1291 let mut cb = state
1292 .on_event
1293 .lock()
1294 .expect("structured callback mutex poisoned");
1295 let _ = catch_unwind(AssertUnwindSafe(|| {
1297 cb(StructuredStreamEvent::Error(err.clone()));
1298 }));
1299 }
1300 if let Some(tx) = state
1301 .done_tx
1302 .lock()
1303 .expect("structured done_tx mutex poisoned")
1304 .take()
1305 {
1306 let _ = tx.send(Err(err));
1307 }
1308 }
1309 drop(Arc::from_raw(Arc::as_ptr(&state)));
1310 drop(state);
1311 return;
1312 }
1313
1314 if !already_finished {
1315 if let Some(payload) = payload {
1316 let snapshot: BridgeStructuredSnapshot = match serde_json::from_str(&payload) {
1317 Ok(snapshot) => snapshot,
1318 Err(error) => {
1319 let err = FMError::DecodingFailure(error.to_string());
1322 state.finished.store(true, Ordering::Release);
1323 {
1324 let mut cb = state
1325 .on_event
1326 .lock()
1327 .expect("structured callback mutex poisoned");
1328 let _ = catch_unwind(AssertUnwindSafe(|| {
1329 cb(StructuredStreamEvent::Error(err.clone()));
1330 }));
1331 }
1332 if let Some(tx) = state
1333 .done_tx
1334 .lock()
1335 .expect("structured done_tx mutex poisoned")
1336 .take()
1337 {
1338 let _ = tx.send(Err(err));
1339 }
1340 drop(state);
1341 return;
1342 }
1343 };
1344 let snapshot_event = StructuredStreamEvent::Snapshot(StructuredStreamSnapshot {
1345 content_json: snapshot.content.json,
1346 raw_content_json: snapshot.raw_content.json,
1347 is_complete: snapshot.is_complete,
1348 });
1349 let snapshot_panicked = {
1350 let mut cb = state
1351 .on_event
1352 .lock()
1353 .expect("structured callback mutex poisoned");
1354 catch_unwind(AssertUnwindSafe(|| cb(snapshot_event))).is_err()
1356 };
1357 if snapshot_panicked {
1358 state.finished.store(true, Ordering::Release);
1361 if let Some(tx) = state
1362 .done_tx
1363 .lock()
1364 .expect("structured done_tx mutex poisoned")
1365 .take()
1366 {
1367 let _ = tx.send(Err(FMError::Unknown {
1368 code: ffi::status::UNKNOWN,
1369 message: "stream callback panicked".into(),
1370 }));
1371 }
1372 drop(state);
1373 return;
1374 }
1375 }
1376 }
1377
1378 if done {
1379 if !already_finished {
1380 {
1381 let mut cb = state
1382 .on_event
1383 .lock()
1384 .expect("structured callback mutex poisoned");
1385 let _ = catch_unwind(AssertUnwindSafe(|| cb(StructuredStreamEvent::Done)));
1386 }
1387 if let Some(tx) = state
1388 .done_tx
1389 .lock()
1390 .expect("structured done_tx mutex poisoned")
1391 .take()
1392 {
1393 let _ = tx.send(Ok(()));
1394 }
1395 }
1396 drop(Arc::from_raw(Arc::as_ptr(&state)));
1397 }
1398 drop(state);
1399}