1use core::ffi::{c_char, c_void};
4use core::ptr;
5use std::ffi::CString;
6use std::panic::{catch_unwind, AssertUnwindSafe};
7use std::sync::mpsc;
8use std::sync::{Arc, Mutex};
9
10use serde::Deserialize;
11use serde_json::json;
12
13use crate::content::{BridgeGeneratedContent, GeneratedContent};
14use crate::error::FMError;
15use crate::ffi;
16use crate::generation::{GenerationOptions, SamplingMode};
17use crate::model::ConfiguredSystemLanguageModel;
18use crate::prompt::{Instructions, Prompt, ToInstructions, ToPrompt};
19use crate::schema::GenerationSchema;
20use crate::tool::{tool_callback_trampoline, Tool, ToolRegistry};
21use crate::transcript::Transcript;
22
23pub struct LanguageModelSession {
41 ptr: *mut c_void,
42 _tool_registry: Option<Arc<ToolRegistry>>,
43}
44
45unsafe impl Send for LanguageModelSession {}
51unsafe impl Sync for LanguageModelSession {}
52
53impl LanguageModelSession {
54 pub(crate) fn as_ptr(&self) -> *mut c_void {
59 self.ptr
60 }
61
62 #[must_use]
70 pub fn new() -> Self {
71 Self::try_new(None).expect("FoundationModels is not available on this OS")
72 }
73
74 #[must_use]
81 pub fn with_instructions(instructions: &str) -> Self {
82 Self::try_new(Some(instructions)).expect("FoundationModels is not available on this OS")
83 }
84
85 #[must_use]
89 pub fn try_new(instructions: Option<&str>) -> Option<Self> {
90 let cstring = match instructions {
91 Some(s) => Some(CString::new(s).ok()?),
92 None => None,
93 };
94 let ptr =
95 unsafe { ffi::fm_session_create(cstring.as_ref().map_or(ptr::null(), |s| s.as_ptr())) };
96 if ptr.is_null() {
97 return None;
98 }
99 Some(Self {
100 ptr,
101 _tool_registry: None,
102 })
103 }
104
105 pub fn respond(&self, prompt: &str) -> Result<String, FMError> {
113 self.respond_with(prompt, GenerationOptions::new())
114 }
115
116 pub fn prewarm(&self) {
120 unsafe { ffi::fm_session_prewarm(self.ptr) };
121 }
122
123 #[must_use]
126 pub fn is_responding(&self) -> bool {
127 unsafe { ffi::fm_session_is_responding(self.ptr) }
128 }
129
130 #[must_use]
135 pub fn transcript_json(&self) -> String {
136 let p = unsafe { ffi::fm_session_transcript_json(self.ptr) };
137 if p.is_null() {
138 return String::from("{}");
139 }
140 let s = unsafe { core::ffi::CStr::from_ptr(p) }
141 .to_string_lossy()
142 .into_owned();
143 unsafe { ffi::fm_string_free(p) };
144 s
145 }
146
147 pub fn log_feedback(&self, sentiment: i32, description: Option<&str>) {
151 let cstr = description.and_then(|s| CString::new(s).ok());
152 let p = cstr.as_ref().map_or(core::ptr::null(), |c| c.as_ptr());
153 unsafe { ffi::fm_session_log_feedback(self.ptr, sentiment, p) };
154 }
155
156 pub fn respond_with_json_schema(
170 &self,
171 prompt: &str,
172 schema_description: &str,
173 ) -> Result<String, FMError> {
174 let wrapped = format!(
175 "{prompt}\n\n\
176 IMPORTANT: respond with VALID JSON ONLY (no prose, no markdown \
177 fences) that matches this schema:\n\n{schema_description}\n\n\
178 Your entire response must be parseable by JSON.parse()."
179 );
180 self.respond(&wrapped)
181 }
182
183 pub fn respond_with(
189 &self,
190 prompt: &str,
191 options: GenerationOptions,
192 ) -> Result<String, FMError> {
193 self.respond_prompt_with(prompt, options)
194 }
195
196 pub fn respond_with_schema(
229 &self,
230 prompt: &str,
231 schema: &str,
232 include_schema_in_prompt: bool,
233 ) -> Result<String, FMError> {
234 self.respond_with_schema_options(
235 prompt,
236 schema,
237 include_schema_in_prompt,
238 GenerationOptions::new(),
239 )
240 }
241
242 pub fn respond_with_schema_options(
249 &self,
250 prompt: &str,
251 schema: &str,
252 include_schema_in_prompt: bool,
253 options: GenerationOptions,
254 ) -> Result<String, FMError> {
255 let prompt_c = CString::new(prompt)
256 .map_err(|e| FMError::InvalidArgument(format!("prompt NUL byte: {e}")))?;
257 let schema_c = CString::new(schema)
258 .map_err(|e| FMError::InvalidArgument(format!("schema NUL byte: {e}")))?;
259 let opts = options.to_ffi();
260 let (tx, rx) = mpsc::channel();
261 let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
262 let context = Box::into_raw(tx_box).cast::<c_void>();
263
264 unsafe {
265 ffi::fm_session_respond_with_schema(
266 self.ptr,
267 prompt_c.as_ptr(),
268 schema_c.as_ptr(),
269 include_schema_in_prompt,
270 opts.temperature,
271 opts.maximum_response_tokens,
272 opts.sampling_mode,
273 opts.top_k,
274 opts.top_p,
275 context,
276 respond_trampoline,
277 );
278 }
279
280 rx.recv().map_err(|_| FMError::Unknown {
281 code: ffi::status::UNKNOWN,
282 message: "Swift bridge dropped the callback channel".into(),
283 })?
284 }
285
286 pub fn stream<F>(&self, prompt: &str, mut on_chunk: F) -> Result<(), FMError>
295 where
296 F: FnMut(StreamEvent<'_>) + Send + 'static,
297 {
298 self.stream_with(prompt, GenerationOptions::new(), move |event| {
299 on_chunk(event);
300 })
301 }
302
303 pub fn stream_with<F>(
309 &self,
310 prompt: &str,
311 options: GenerationOptions,
312 on_chunk: F,
313 ) -> Result<(), FMError>
314 where
315 F: FnMut(StreamEvent<'_>) + Send + 'static,
316 {
317 let payload = respond_request_json(&Prompt::from(prompt), options, None, true)?;
318
319 let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
320 let state = Arc::new(StreamState {
321 on_chunk: Mutex::new(Box::new(on_chunk)),
322 done_tx: Mutex::new(Some(done_tx)),
323 });
324 let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
325
326 unsafe {
327 ffi::fm_session_stream_request_json(
328 self.ptr,
329 payload.as_ptr(),
330 context,
331 json_text_stream_trampoline,
332 )
333 };
334
335 done_rx.recv().map_err(|_| FMError::Unknown {
336 code: ffi::status::UNKNOWN,
337 message: "Swift bridge dropped the stream channel".into(),
338 })?
339 }
340}
341
342impl LanguageModelSession {
343 #[must_use]
345 pub fn builder<'a>() -> SessionBuilder<'a> {
346 SessionBuilder::new()
347 }
348
349 pub fn from_transcript(transcript: Transcript) -> Result<Self, FMError> {
355 Self::builder().transcript(transcript).build()
356 }
357
358 pub fn transcript(&self) -> Result<Transcript, FMError> {
365 Transcript::from_json_str(&self.transcript_json())
366 }
367
368 pub fn prewarm_with_prompt<P>(&self, prompt: P) -> Result<(), FMError>
374 where
375 P: ToPrompt,
376 {
377 let prompt = prompt.to_prompt()?;
378 let prompt_json = CString::new(prompt.to_bridge_json()?).map_err(|error| {
379 FMError::InvalidArgument(format!("prompt JSON contains a NUL byte: {error}"))
380 })?;
381 let mut error: *mut c_char = ptr::null_mut();
382 let status = unsafe {
383 ffi::fm_session_prewarm_prompt_json(self.ptr, prompt_json.as_ptr(), &mut error)
384 };
385 if status != ffi::status::OK {
386 return Err(crate::error::from_swift(status, error));
387 }
388 Ok(())
389 }
390
391 pub fn respond_prompt<P>(&self, prompt: P) -> Result<String, FMError>
397 where
398 P: ToPrompt,
399 {
400 self.respond_prompt_with(prompt, GenerationOptions::new())
401 }
402
403 pub fn respond_prompt_with<P>(
409 &self,
410 prompt: P,
411 options: GenerationOptions,
412 ) -> Result<String, FMError>
413 where
414 P: ToPrompt,
415 {
416 self.respond_prompt_detailed(prompt, options)
417 .map(|response| response.content)
418 }
419
420 pub fn respond_prompt_detailed<P>(
426 &self,
427 prompt: P,
428 options: GenerationOptions,
429 ) -> Result<SessionResponse<String>, FMError>
430 where
431 P: ToPrompt,
432 {
433 let prompt = prompt.to_prompt()?;
434 let payload = respond_request_json(&prompt, options, None, true)?;
435 let payload = request_response(self.ptr, &payload)?;
436 let response: BridgeTextResponse = serde_json::from_str(&payload)
437 .map_err(|error| FMError::DecodingFailure(error.to_string()))?;
438 Ok(SessionResponse {
439 content: response.content,
440 raw_content: GeneratedContent::from_bridge_payload(response.raw_content, true)?,
441 transcript: Transcript::from_json_str(&response.transcript_json)?,
442 })
443 }
444
445 pub fn respond_generated<P>(
451 &self,
452 prompt: P,
453 schema: &GenerationSchema,
454 include_schema_in_prompt: bool,
455 ) -> Result<GeneratedContent, FMError>
456 where
457 P: ToPrompt,
458 {
459 self.respond_generated_with(
460 prompt,
461 schema,
462 include_schema_in_prompt,
463 GenerationOptions::new(),
464 )
465 .map(|response| response.content)
466 }
467
468 pub fn respond_generated_with<P>(
474 &self,
475 prompt: P,
476 schema: &GenerationSchema,
477 include_schema_in_prompt: bool,
478 options: GenerationOptions,
479 ) -> Result<SessionResponse<GeneratedContent>, FMError>
480 where
481 P: ToPrompt,
482 {
483 let prompt = prompt.to_prompt()?;
484 let payload =
485 respond_request_json(&prompt, options, Some(schema), include_schema_in_prompt)?;
486 let payload = request_response(self.ptr, &payload)?;
487 let response: BridgeStructuredResponse = serde_json::from_str(&payload)
488 .map_err(|error| FMError::DecodingFailure(error.to_string()))?;
489 Ok(SessionResponse {
490 content: GeneratedContent::from_bridge_payload(response.content, true)?,
491 raw_content: GeneratedContent::from_bridge_payload(response.raw_content, true)?,
492 transcript: Transcript::from_json_str(&response.transcript_json)?,
493 })
494 }
495
496 pub fn respond_generating<P, T>(
503 &self,
504 prompt: P,
505 include_schema_in_prompt: bool,
506 options: GenerationOptions,
507 ) -> Result<SessionResponse<T>, FMError>
508 where
509 P: ToPrompt,
510 T: crate::schema::Generable,
511 {
512 let response = self.respond_generated_with(
513 prompt,
514 &T::generation_schema()?,
515 include_schema_in_prompt,
516 options,
517 )?;
518 Ok(SessionResponse {
519 content: T::from_generated_content(&response.content)?,
520 raw_content: response.raw_content,
521 transcript: response.transcript,
522 })
523 }
524
525 pub fn stream_prompt<P, F>(&self, prompt: P, on_chunk: F) -> Result<(), FMError>
531 where
532 P: ToPrompt,
533 F: FnMut(StreamEvent<'_>) + Send + 'static,
534 {
535 let prompt = prompt.to_prompt()?;
536 let prompt_text = prompt_to_plain_text(&prompt).ok_or_else(|| {
537 FMError::InvalidArgument(
538 "text streaming only supports prompts composed of text segments".into(),
539 )
540 })?;
541 self.stream_with(&prompt_text, GenerationOptions::new(), on_chunk)
542 }
543
544 pub fn stream_generated<P, F>(
550 &self,
551 prompt: P,
552 schema: &GenerationSchema,
553 include_schema_in_prompt: bool,
554 options: GenerationOptions,
555 on_event: F,
556 ) -> Result<(), FMError>
557 where
558 P: ToPrompt,
559 F: FnMut(StructuredStreamEvent) + Send + 'static,
560 {
561 let prompt = prompt.to_prompt()?;
562 let payload =
563 respond_request_json(&prompt, options, Some(schema), include_schema_in_prompt)?;
564 let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
565 let state = Arc::new(StructuredStreamState {
566 on_event: Mutex::new(Box::new(on_event)),
567 done_tx: Mutex::new(Some(done_tx)),
568 });
569 let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
570 unsafe {
571 ffi::fm_session_stream_request_json(
572 self.ptr,
573 payload.as_ptr(),
574 context,
575 structured_stream_trampoline,
576 )
577 };
578 done_rx.recv().map_err(|_| FMError::Unknown {
579 code: ffi::status::UNKNOWN,
580 message: "Swift bridge dropped the structured stream channel".into(),
581 })?
582 }
583
584 pub fn log_feedback_attachment(
590 &self,
591 request: FeedbackAttachmentRequest,
592 ) -> Result<Vec<u8>, FMError> {
593 let request_json = CString::new(request.to_bridge_json()?).map_err(|error| {
594 FMError::InvalidArgument(format!("feedback request contains a NUL byte: {error}"))
595 })?;
596 let mut length = 0usize;
597 let mut error: *mut c_char = ptr::null_mut();
598 let ptr = unsafe {
599 ffi::fm_session_log_feedback_attachment_json(
600 self.ptr,
601 request_json.as_ptr(),
602 &mut length,
603 &mut error,
604 )
605 };
606 if ptr.is_null() && !error.is_null() {
607 return Err(crate::error::from_swift(
608 ffi::status::INVALID_ARGUMENT,
609 error,
610 ));
611 }
612 if ptr.is_null() || length == 0 {
613 return Ok(Vec::new());
614 }
615 let bytes = unsafe { std::slice::from_raw_parts(ptr.cast::<u8>(), length) }.to_vec();
616 unsafe { ffi::fm_bytes_free(ptr) };
617 Ok(bytes)
618 }
619}
620
621pub struct SessionBuilder<'a> {
623 model: Option<&'a ConfiguredSystemLanguageModel>,
624 instructions: Option<Instructions>,
625 transcript: Option<Transcript>,
626 tools: Vec<Tool>,
627}
628
629impl<'a> SessionBuilder<'a> {
630 const fn new() -> Self {
631 Self {
632 model: None,
633 instructions: None,
634 transcript: None,
635 tools: Vec::new(),
636 }
637 }
638
639 #[must_use]
641 pub const fn model(mut self, model: &'a ConfiguredSystemLanguageModel) -> Self {
642 self.model = Some(model);
643 self
644 }
645
646 pub fn instructions<I>(mut self, instructions: I) -> Result<Self, FMError>
648 where
649 I: ToInstructions,
650 {
651 self.instructions = Some(instructions.to_instructions()?);
652 Ok(self)
653 }
654
655 #[must_use]
657 pub fn transcript(mut self, transcript: Transcript) -> Self {
658 self.transcript = Some(transcript);
659 self
660 }
661
662 #[must_use]
664 pub fn tool(mut self, tool: Tool) -> Self {
665 self.tools.push(tool);
666 self
667 }
668
669 #[must_use]
671 pub fn tools(mut self, tools: impl IntoIterator<Item = Tool>) -> Self {
672 self.tools.extend(tools);
673 self
674 }
675
676 pub fn build(self) -> Result<LanguageModelSession, FMError> {
682 if self.instructions.is_some() && self.transcript.is_some() {
683 return Err(FMError::InvalidArgument(
684 "session builder accepts either instructions or a transcript, not both".into(),
685 ));
686 }
687
688 let instructions_json = self
689 .instructions
690 .as_ref()
691 .map(Instructions::to_bridge_json)
692 .transpose()?;
693 let transcript_json = self
694 .transcript
695 .as_ref()
696 .map(Transcript::to_json_string)
697 .transpose()?;
698 let tool_registry = if self.tools.is_empty() {
699 None
700 } else {
701 Some(Arc::new(ToolRegistry::new(self.tools)))
702 };
703 let tools_json = tool_registry
704 .as_ref()
705 .map(|registry| registry.specs_json())
706 .transpose()?;
707
708 let instructions_c = instructions_json
709 .as_deref()
710 .map(CString::new)
711 .transpose()
712 .map_err(|error| {
713 FMError::InvalidArgument(format!("instructions JSON contains a NUL byte: {error}"))
714 })?;
715 let transcript_c = transcript_json
716 .as_deref()
717 .map(CString::new)
718 .transpose()
719 .map_err(|error| {
720 FMError::InvalidArgument(format!("transcript JSON contains a NUL byte: {error}"))
721 })?;
722 let tools_c = tools_json
723 .as_deref()
724 .map(CString::new)
725 .transpose()
726 .map_err(|error| {
727 FMError::InvalidArgument(format!("tool JSON contains a NUL byte: {error}"))
728 })?;
729
730 let tool_context = tool_registry.as_ref().map_or(ptr::null_mut(), |registry| {
731 Arc::as_ptr(registry).cast_mut().cast::<c_void>()
732 });
733 let mut error: *mut c_char = ptr::null_mut();
734 let ptr = unsafe {
735 ffi::fm_session_create_ex(
736 self.model.map_or(ptr::null_mut(), |model| model.ptr),
737 instructions_c
738 .as_ref()
739 .map_or(ptr::null(), |json| json.as_ptr()),
740 transcript_c
741 .as_ref()
742 .map_or(ptr::null(), |json| json.as_ptr()),
743 tools_c.as_ref().map_or(ptr::null(), |json| json.as_ptr()),
744 tool_context,
745 tool_registry
746 .as_ref()
747 .map(|_| tool_callback_trampoline as ffi::FmToolCallback),
748 &mut error,
749 )
750 };
751 if ptr.is_null() {
752 return Err(crate::error::from_swift(
753 ffi::status::MODEL_UNAVAILABLE,
754 error,
755 ));
756 }
757 Ok(LanguageModelSession {
758 ptr,
759 _tool_registry: tool_registry,
760 })
761 }
762}
763
764#[derive(Debug, Clone, PartialEq)]
766pub struct SessionResponse<T> {
767 pub content: T,
768 pub raw_content: GeneratedContent,
769 pub transcript: Transcript,
770}
771
772#[derive(Debug, Clone, PartialEq, Eq)]
774pub struct StructuredStreamSnapshot {
775 pub content_json: String,
776 pub raw_content_json: String,
777 pub is_complete: bool,
778}
779
780#[derive(Debug, Clone, PartialEq)]
782#[non_exhaustive]
783pub enum StructuredStreamEvent {
784 Snapshot(StructuredStreamSnapshot),
785 Done,
786 Error(FMError),
787}
788
789#[derive(Debug, Clone, Copy, PartialEq, Eq)]
791pub enum FeedbackIssueCategory {
792 Unhelpful,
793 TooVerbose,
794 DidNotFollowInstructions,
795 Incorrect,
796 StereotypeOrBias,
797 SuggestiveOrSexual,
798 VulgarOrOffensive,
799 TriggeredGuardrailUnexpectedly,
800}
801
802impl FeedbackIssueCategory {
803 const fn as_str(self) -> &'static str {
804 match self {
805 Self::Unhelpful => "unhelpful",
806 Self::TooVerbose => "too_verbose",
807 Self::DidNotFollowInstructions => "did_not_follow_instructions",
808 Self::Incorrect => "incorrect",
809 Self::StereotypeOrBias => "stereotype_or_bias",
810 Self::SuggestiveOrSexual => "suggestive_or_sexual",
811 Self::VulgarOrOffensive => "vulgar_or_offensive",
812 Self::TriggeredGuardrailUnexpectedly => "triggered_guardrail_unexpectedly",
813 }
814 }
815}
816
817#[derive(Debug, Clone, PartialEq, Eq)]
819pub struct FeedbackIssue {
820 pub category: FeedbackIssueCategory,
821 pub explanation: Option<String>,
822}
823
824#[derive(Debug, Clone, Copy, PartialEq, Eq)]
826pub enum FeedbackSentiment {
827 Positive,
828 Negative,
829 Neutral,
830}
831
832impl FeedbackSentiment {
833 const fn as_str(self) -> &'static str {
834 match self {
835 Self::Positive => "positive",
836 Self::Negative => "negative",
837 Self::Neutral => "neutral",
838 }
839 }
840}
841
842#[derive(Debug, Clone, PartialEq)]
844pub struct FeedbackAttachmentRequest {
845 pub sentiment: Option<FeedbackSentiment>,
846 pub issues: Vec<FeedbackIssue>,
847 pub desired_response_text: Option<String>,
848 pub desired_response_content: Option<GeneratedContent>,
849 pub desired_output: Option<crate::transcript::Entry>,
850}
851
852impl FeedbackAttachmentRequest {
853 #[must_use]
855 pub const fn new() -> Self {
856 Self {
857 sentiment: None,
858 issues: Vec::new(),
859 desired_response_text: None,
860 desired_response_content: None,
861 desired_output: None,
862 }
863 }
864
865 fn to_bridge_json(&self) -> Result<String, FMError> {
866 let issues = self
867 .issues
868 .iter()
869 .map(|issue| {
870 json!({
871 "category": issue.category.as_str(),
872 "explanation": issue.explanation,
873 })
874 })
875 .collect::<Vec<_>>();
876 let desired_output_json = self
877 .desired_output
878 .as_ref()
879 .map(|entry| Transcript::from(vec![entry.clone()]).to_json_string())
880 .transpose()?;
881 let desired_response_content = self
882 .desired_response_content
883 .as_ref()
884 .map(GeneratedContent::to_bridge_value)
885 .transpose()?;
886 serde_json::to_string(&json!({
887 "sentiment": self.sentiment.map(FeedbackSentiment::as_str),
888 "issues": issues,
889 "desiredResponseText": self.desired_response_text,
890 "desiredResponseContent": desired_response_content,
891 "desiredOutputTranscriptJSON": desired_output_json,
892 }))
893 .map_err(|error| {
894 FMError::InvalidArgument(format!(
895 "feedback request is not JSON-serializable: {error}"
896 ))
897 })
898 }
899}
900
901#[derive(Debug, Deserialize)]
902struct BridgeTextResponse {
903 content: String,
904 #[serde(rename = "rawContent")]
905 raw_content: BridgeGeneratedContent,
906 #[serde(rename = "transcriptJSON")]
907 transcript_json: String,
908}
909
910#[derive(Debug, Deserialize)]
911struct BridgeStructuredResponse {
912 content: BridgeGeneratedContent,
913 #[serde(rename = "rawContent")]
914 raw_content: BridgeGeneratedContent,
915 #[serde(rename = "transcriptJSON")]
916 transcript_json: String,
917}
918
919#[derive(Debug, Deserialize)]
920struct BridgeStructuredSnapshot {
921 content: BridgeGeneratedContent,
922 #[serde(rename = "rawContent")]
923 raw_content: BridgeGeneratedContent,
924 #[serde(rename = "isComplete")]
925 is_complete: bool,
926}
927
928#[derive(Debug, Deserialize)]
929struct BridgeTextStreamSnapshot {
930 delta: String,
931}
932
933fn respond_request_json(
934 prompt: &Prompt,
935 options: GenerationOptions,
936 schema: Option<&GenerationSchema>,
937 include_schema_in_prompt: bool,
938) -> Result<CString, FMError> {
939 let sampling = match options.sampling() {
940 SamplingMode::Default => json!({ "mode": "default" }),
941 SamplingMode::Greedy => json!({ "mode": "greedy" }),
942 SamplingMode::TopK(k) => json!({
943 "mode": "top_k",
944 "topK": k,
945 "seed": options.sampling_seed(),
946 }),
947 SamplingMode::TopP(p) => json!({
948 "mode": "top_p",
949 "topP": p,
950 "seed": options.sampling_seed(),
951 }),
952 };
953 let payload = serde_json::to_string(&json!({
954 "prompt": prompt.to_bridge_value(),
955 "options": {
956 "temperature": options.temperature(),
957 "maximumResponseTokens": options.maximum_response_tokens(),
958 "sampling": sampling,
959 },
960 "schemaJSON": schema.map(GenerationSchema::json_schema),
961 "includeSchemaInPrompt": include_schema_in_prompt,
962 }))
963 .map_err(|error| {
964 FMError::InvalidArgument(format!("request is not JSON-serializable: {error}"))
965 })?;
966 CString::new(payload).map_err(|error| {
967 FMError::InvalidArgument(format!("request JSON contains a NUL byte: {error}"))
968 })
969}
970
971fn request_response(session: *mut c_void, payload: &CString) -> Result<String, FMError> {
972 let (tx, rx) = mpsc::channel();
973 let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
974 let context = Box::into_raw(tx_box).cast::<c_void>();
975 unsafe {
976 ffi::fm_session_respond_request_json(session, payload.as_ptr(), context, respond_trampoline)
977 };
978 rx.recv().map_err(|_| FMError::Unknown {
979 code: ffi::status::UNKNOWN,
980 message: "Swift bridge dropped the JSON response channel".into(),
981 })?
982}
983
984pub(crate) fn decode_bridge_text_response(
985 payload: &str,
986) -> Result<SessionResponse<String>, FMError> {
987 let response: BridgeTextResponse = serde_json::from_str(payload)
988 .map_err(|error| FMError::DecodingFailure(error.to_string()))?;
989 Ok(SessionResponse {
990 content: response.content,
991 raw_content: GeneratedContent::from_bridge_payload(response.raw_content, true)?,
992 transcript: Transcript::from_json_str(&response.transcript_json)?,
993 })
994}
995
996pub(crate) fn request_text_response_with<F>(invoke: F) -> Result<SessionResponse<String>, FMError>
997where
998 F: FnOnce(*mut c_void, ffi::FmRespondCallback),
999{
1000 let (tx, rx) = mpsc::channel();
1001 let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
1002 let context = Box::into_raw(tx_box).cast::<c_void>();
1003 invoke(context, respond_trampoline);
1004 let payload = rx.recv().map_err(|_| FMError::Unknown {
1005 code: ffi::status::UNKNOWN,
1006 message: "Swift bridge dropped the JSON response channel".into(),
1007 })??;
1008 decode_bridge_text_response(&payload)
1009}
1010
1011pub(crate) fn run_text_stream_with<F, C>(invoke: F, on_chunk: C) -> Result<(), FMError>
1012where
1013 F: FnOnce(*mut c_void, ffi::FmStreamCallback),
1014 C: FnMut(StreamEvent<'_>) + Send + 'static,
1015{
1016 let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
1017 let state = Arc::new(StreamState {
1018 on_chunk: Mutex::new(Box::new(on_chunk)),
1019 done_tx: Mutex::new(Some(done_tx)),
1020 });
1021 let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
1022 invoke(context, json_text_stream_trampoline);
1023 done_rx.recv().map_err(|_| FMError::Unknown {
1024 code: ffi::status::UNKNOWN,
1025 message: "Swift bridge dropped the stream channel".into(),
1026 })?
1027}
1028
1029fn prompt_to_plain_text(prompt: &Prompt) -> Option<String> {
1030 let mut text = String::new();
1031 for segment in prompt.segments() {
1032 match segment {
1033 crate::prompt::Segment::Text(segment) => text.push_str(&segment.text),
1034 crate::prompt::Segment::Structure(_) => return None,
1035 }
1036 }
1037 Some(text)
1038}
1039
1040impl Default for LanguageModelSession {
1041 fn default() -> Self {
1042 Self::new()
1043 }
1044}
1045
1046impl Drop for LanguageModelSession {
1047 fn drop(&mut self) {
1048 if !self.ptr.is_null() {
1049 unsafe { ffi::fm_object_release(self.ptr) };
1050 }
1051 }
1052}
1053
1054impl core::fmt::Debug for LanguageModelSession {
1055 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1056 f.debug_struct("LanguageModelSession")
1057 .field("ptr", &self.ptr)
1058 .finish()
1059 }
1060}
1061
1062#[derive(Debug)]
1064#[non_exhaustive]
1065pub enum StreamEvent<'a> {
1066 Chunk(&'a str),
1068 Done,
1070 Error(FMError),
1072}
1073
1074unsafe extern "C" fn respond_trampoline(
1081 context: *mut c_void,
1082 response: *mut c_char,
1083 error: *mut c_char,
1084 status: i32,
1085) {
1086 let tx = Box::from_raw(context.cast::<mpsc::Sender<Result<String, FMError>>>());
1087 let result = if status == ffi::status::OK && !response.is_null() {
1088 let s = core::ffi::CStr::from_ptr(response)
1089 .to_string_lossy()
1090 .into_owned();
1091 ffi::fm_string_free(response);
1092 Ok(s)
1093 } else {
1094 Err(crate::error::from_swift(status, error))
1095 };
1096 let _ = tx.send(result);
1097}
1098
1099type StreamCallback = Box<dyn FnMut(StreamEvent<'_>) + Send>;
1100
1101struct StreamState {
1102 on_chunk: Mutex<StreamCallback>,
1103 done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
1104}
1105
1106unsafe extern "C" fn json_text_stream_trampoline(
1112 context: *mut c_void,
1113 chunk: *mut c_char,
1114 done: bool,
1115 status: i32,
1116) {
1117 let state = Arc::from_raw(context.cast::<StreamState>());
1118 let state_for_swift = state.clone();
1119 core::mem::forget(state_for_swift);
1120
1121 let payload: Option<String> = if chunk.is_null() {
1122 None
1123 } else {
1124 let value = core::ffi::CStr::from_ptr(chunk)
1125 .to_string_lossy()
1126 .into_owned();
1127 ffi::fm_string_free(chunk);
1128 Some(value)
1129 };
1130
1131 if status != ffi::status::OK {
1132 let err = payload
1133 .map(|message| {
1134 crate::error::from_swift(
1135 status,
1136 ffi::fm_string_dup(
1137 CString::new(message)
1138 .expect("stream errors must not contain NUL bytes")
1139 .as_ptr(),
1140 ),
1141 )
1142 })
1143 .unwrap_or_else(|| crate::error::from_swift(status, ptr::null_mut()));
1144 {
1145 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1146 let _ = catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Error(err.clone()))));
1148 }
1149 if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1150 let _ = tx.send(Err(err));
1151 }
1152 drop(Arc::from_raw(Arc::as_ptr(&state)));
1153 drop(state);
1154 return;
1155 }
1156
1157 if let Some(payload) = payload {
1158 match serde_json::from_str::<BridgeTextStreamSnapshot>(&payload) {
1159 Ok(snapshot) if !snapshot.delta.is_empty() => {
1160 let chunk_panicked = {
1161 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1162 catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Chunk(&snapshot.delta))))
1164 .is_err()
1165 };
1166 if chunk_panicked {
1167 if let Some(tx) =
1168 state.done_tx.lock().expect("done_tx mutex poisoned").take()
1169 {
1170 let _ = tx.send(Err(FMError::Unknown {
1171 code: ffi::status::UNKNOWN,
1172 message: "stream callback panicked".into(),
1173 }));
1174 }
1175 drop(Arc::from_raw(Arc::as_ptr(&state)));
1176 drop(state);
1177 return;
1178 }
1179 }
1180 Ok(_) => {}
1181 Err(error) => {
1182 let err = FMError::DecodingFailure(error.to_string());
1183 {
1184 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1185 let _ = catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Error(err.clone()))));
1186 }
1187 if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1188 let _ = tx.send(Err(err));
1189 }
1190 drop(Arc::from_raw(Arc::as_ptr(&state)));
1191 drop(state);
1192 return;
1193 }
1194 }
1195 }
1196
1197 if done {
1198 {
1199 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1200 let _ = catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Done)));
1201 }
1202 if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1203 let _ = tx.send(Ok(()));
1204 }
1205 drop(Arc::from_raw(Arc::as_ptr(&state)));
1206 }
1207 drop(state);
1208}
1209
1210type StructuredStreamCallback = Box<dyn FnMut(StructuredStreamEvent) + Send>;
1211
1212struct StructuredStreamState {
1213 on_event: Mutex<StructuredStreamCallback>,
1214 done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
1215}
1216
1217#[allow(clippy::too_many_lines)]
1220unsafe extern "C" fn structured_stream_trampoline(
1221 context: *mut c_void,
1222 chunk: *mut c_char,
1223 done: bool,
1224 status: i32,
1225) {
1226 let state = Arc::from_raw(context.cast::<StructuredStreamState>());
1227 let state_for_swift = state.clone();
1228 core::mem::forget(state_for_swift);
1229
1230 let payload: Option<String> = if chunk.is_null() {
1231 None
1232 } else {
1233 let value = core::ffi::CStr::from_ptr(chunk)
1234 .to_string_lossy()
1235 .into_owned();
1236 ffi::fm_string_free(chunk);
1237 Some(value)
1238 };
1239
1240 if status != ffi::status::OK {
1241 let err = payload
1242 .map(|message| {
1243 crate::error::from_swift(
1244 status,
1245 ffi::fm_string_dup(
1246 CString::new(message)
1247 .expect("stream errors must not contain NUL bytes")
1248 .as_ptr(),
1249 ),
1250 )
1251 })
1252 .unwrap_or_else(|| crate::error::from_swift(status, ptr::null_mut()));
1253 {
1254 let mut cb = state
1255 .on_event
1256 .lock()
1257 .expect("structured callback mutex poisoned");
1258 let _ = catch_unwind(AssertUnwindSafe(|| {
1260 cb(StructuredStreamEvent::Error(err.clone()));
1261 }));
1262 }
1263 if let Some(tx) = state
1264 .done_tx
1265 .lock()
1266 .expect("structured done_tx mutex poisoned")
1267 .take()
1268 {
1269 let _ = tx.send(Err(err));
1270 }
1271 drop(Arc::from_raw(Arc::as_ptr(&state)));
1272 drop(state);
1273 return;
1274 }
1275
1276 if let Some(payload) = payload {
1277 let snapshot: BridgeStructuredSnapshot = match serde_json::from_str(&payload) {
1278 Ok(snapshot) => snapshot,
1279 Err(error) => {
1280 let err = FMError::DecodingFailure(error.to_string());
1281 {
1282 let mut cb = state
1283 .on_event
1284 .lock()
1285 .expect("structured callback mutex poisoned");
1286 let _ = catch_unwind(AssertUnwindSafe(|| {
1287 cb(StructuredStreamEvent::Error(err.clone()));
1288 }));
1289 }
1290 if let Some(tx) = state
1291 .done_tx
1292 .lock()
1293 .expect("structured done_tx mutex poisoned")
1294 .take()
1295 {
1296 let _ = tx.send(Err(err));
1297 }
1298 drop(Arc::from_raw(Arc::as_ptr(&state)));
1299 drop(state);
1300 return;
1301 }
1302 };
1303 let snapshot_event = StructuredStreamEvent::Snapshot(StructuredStreamSnapshot {
1304 content_json: snapshot.content.json,
1305 raw_content_json: snapshot.raw_content.json,
1306 is_complete: snapshot.is_complete,
1307 });
1308 let snapshot_panicked = {
1309 let mut cb = state
1310 .on_event
1311 .lock()
1312 .expect("structured callback mutex poisoned");
1313 catch_unwind(AssertUnwindSafe(|| cb(snapshot_event))).is_err()
1315 };
1316 if snapshot_panicked {
1317 if let Some(tx) = state
1318 .done_tx
1319 .lock()
1320 .expect("structured done_tx mutex poisoned")
1321 .take()
1322 {
1323 let _ = tx.send(Err(FMError::Unknown {
1324 code: ffi::status::UNKNOWN,
1325 message: "stream callback panicked".into(),
1326 }));
1327 }
1328 drop(Arc::from_raw(Arc::as_ptr(&state)));
1329 drop(state);
1330 return;
1331 }
1332 }
1333
1334 if done {
1335 {
1336 let mut cb = state
1337 .on_event
1338 .lock()
1339 .expect("structured callback mutex poisoned");
1340 let _ = catch_unwind(AssertUnwindSafe(|| cb(StructuredStreamEvent::Done)));
1341 }
1342 if let Some(tx) = state
1343 .done_tx
1344 .lock()
1345 .expect("structured done_tx mutex poisoned")
1346 .take()
1347 {
1348 let _ = tx.send(Ok(()));
1349 }
1350 drop(Arc::from_raw(Arc::as_ptr(&state)));
1351 }
1352 drop(state);
1353}