1use core::ffi::{c_char, c_void};
4use core::ptr;
5use std::ffi::CString;
6use std::panic::{catch_unwind, AssertUnwindSafe};
7use std::sync::mpsc;
8use std::sync::{Arc, Mutex};
9
10use serde::Deserialize;
11use serde_json::json;
12
13use crate::content::{BridgeGeneratedContent, GeneratedContent};
14use crate::error::FMError;
15use crate::ffi;
16use crate::generation::{GenerationOptions, SamplingMode};
17use crate::model::ConfiguredSystemLanguageModel;
18use crate::prompt::{Instructions, Prompt, ToInstructions, ToPrompt};
19use crate::schema::GenerationSchema;
20use crate::tool::{tool_callback_trampoline, Tool, ToolRegistry};
21use crate::transcript::Transcript;
22
23pub struct LanguageModelSession {
41 ptr: *mut c_void,
42 _tool_registry: Option<Arc<ToolRegistry>>,
43}
44
45unsafe impl Send for LanguageModelSession {}
51unsafe impl Sync for LanguageModelSession {}
52
53impl LanguageModelSession {
54 pub(crate) fn as_ptr(&self) -> *mut c_void {
59 self.ptr
60 }
61
62 #[must_use]
70 pub fn new() -> Self {
71 Self::try_new(None).expect("FoundationModels is not available on this OS")
72 }
73
74 #[must_use]
81 pub fn with_instructions(instructions: &str) -> Self {
82 Self::try_new(Some(instructions)).expect("FoundationModels is not available on this OS")
83 }
84
85 #[must_use]
89 pub fn try_new(instructions: Option<&str>) -> Option<Self> {
90 let cstring = match instructions {
91 Some(s) => Some(CString::new(s).ok()?),
92 None => None,
93 };
94 let ptr =
95 unsafe { ffi::fm_session_create(cstring.as_ref().map_or(ptr::null(), |s| s.as_ptr())) };
96 if ptr.is_null() {
97 return None;
98 }
99 Some(Self {
100 ptr,
101 _tool_registry: None,
102 })
103 }
104
105 pub fn respond(&self, prompt: &str) -> Result<String, FMError> {
113 self.respond_with(prompt, GenerationOptions::new())
114 }
115
116 pub fn prewarm(&self) {
120 unsafe { ffi::fm_session_prewarm(self.ptr) };
121 }
122
123 #[must_use]
126 pub fn is_responding(&self) -> bool {
127 unsafe { ffi::fm_session_is_responding(self.ptr) }
128 }
129
130 #[must_use]
135 pub fn transcript_json(&self) -> String {
136 let p = unsafe { ffi::fm_session_transcript_json(self.ptr) };
137 if p.is_null() {
138 return String::from("{}");
139 }
140 let s = unsafe { core::ffi::CStr::from_ptr(p) }
141 .to_string_lossy()
142 .into_owned();
143 unsafe { ffi::fm_string_free(p) };
144 s
145 }
146
147 pub fn log_feedback(&self, sentiment: i32, description: Option<&str>) {
151 let cstr = description.and_then(|s| CString::new(s).ok());
152 let p = cstr.as_ref().map_or(core::ptr::null(), |c| c.as_ptr());
153 unsafe { ffi::fm_session_log_feedback(self.ptr, sentiment, p) };
154 }
155
156 pub fn respond_with_json_schema(
170 &self,
171 prompt: &str,
172 schema_description: &str,
173 ) -> Result<String, FMError> {
174 let wrapped = format!(
175 "{prompt}\n\n\
176 IMPORTANT: respond with VALID JSON ONLY (no prose, no markdown \
177 fences) that matches this schema:\n\n{schema_description}\n\n\
178 Your entire response must be parseable by JSON.parse()."
179 );
180 self.respond(&wrapped)
181 }
182
183 pub fn respond_with(
189 &self,
190 prompt: &str,
191 options: GenerationOptions,
192 ) -> Result<String, FMError> {
193 self.respond_prompt_with(prompt, options)
194 }
195
196 pub fn respond_with_schema(
229 &self,
230 prompt: &str,
231 schema: &str,
232 include_schema_in_prompt: bool,
233 ) -> Result<String, FMError> {
234 self.respond_with_schema_options(
235 prompt,
236 schema,
237 include_schema_in_prompt,
238 GenerationOptions::new(),
239 )
240 }
241
242 pub fn respond_with_schema_options(
249 &self,
250 prompt: &str,
251 schema: &str,
252 include_schema_in_prompt: bool,
253 options: GenerationOptions,
254 ) -> Result<String, FMError> {
255 let prompt_c = CString::new(prompt)
256 .map_err(|e| FMError::InvalidArgument(format!("prompt NUL byte: {e}")))?;
257 let schema_c = CString::new(schema)
258 .map_err(|e| FMError::InvalidArgument(format!("schema NUL byte: {e}")))?;
259 let opts = options.to_ffi();
260 let (tx, rx) = mpsc::channel();
261 let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
262 let context = Box::into_raw(tx_box).cast::<c_void>();
263
264 unsafe {
265 ffi::fm_session_respond_with_schema(
266 self.ptr,
267 prompt_c.as_ptr(),
268 schema_c.as_ptr(),
269 include_schema_in_prompt,
270 opts.temperature,
271 opts.maximum_response_tokens,
272 opts.sampling_mode,
273 opts.top_k,
274 opts.top_p,
275 context,
276 respond_trampoline,
277 );
278 }
279
280 rx.recv().map_err(|_| FMError::Unknown {
281 code: ffi::status::UNKNOWN,
282 message: "Swift bridge dropped the callback channel".into(),
283 })?
284 }
285
286 pub fn stream<F>(&self, prompt: &str, mut on_chunk: F) -> Result<(), FMError>
295 where
296 F: FnMut(StreamEvent<'_>) + Send + 'static,
297 {
298 self.stream_with(prompt, GenerationOptions::new(), move |event| {
299 on_chunk(event);
300 })
301 }
302
303 pub fn stream_with<F>(
309 &self,
310 prompt: &str,
311 options: GenerationOptions,
312 on_chunk: F,
313 ) -> Result<(), FMError>
314 where
315 F: FnMut(StreamEvent<'_>) + Send + 'static,
316 {
317 let payload = respond_request_json(&Prompt::from(prompt), options, None, true)?;
318
319 let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
320 let state = Arc::new(StreamState {
321 on_chunk: Mutex::new(Box::new(on_chunk)),
322 done_tx: Mutex::new(Some(done_tx)),
323 });
324 let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
325
326 unsafe {
327 ffi::fm_session_stream_request_json(
328 self.ptr,
329 payload.as_ptr(),
330 context,
331 json_text_stream_trampoline,
332 )
333 };
334
335 done_rx.recv().map_err(|_| FMError::Unknown {
336 code: ffi::status::UNKNOWN,
337 message: "Swift bridge dropped the stream channel".into(),
338 })?
339 }
340}
341
342impl LanguageModelSession {
343 #[must_use]
345 pub fn builder<'a>() -> SessionBuilder<'a> {
346 SessionBuilder::new()
347 }
348
349 pub fn from_transcript(transcript: Transcript) -> Result<Self, FMError> {
355 Self::builder().transcript(transcript).build()
356 }
357
358 pub fn transcript(&self) -> Result<Transcript, FMError> {
365 Transcript::from_json_str(&self.transcript_json())
366 }
367
368 pub fn prewarm_with_prompt<P>(&self, prompt: P) -> Result<(), FMError>
374 where
375 P: ToPrompt,
376 {
377 let prompt = prompt.to_prompt()?;
378 let prompt_json = CString::new(prompt.to_bridge_json()?).map_err(|error| {
379 FMError::InvalidArgument(format!("prompt JSON contains a NUL byte: {error}"))
380 })?;
381 let mut error: *mut c_char = ptr::null_mut();
382 let status = unsafe {
383 ffi::fm_session_prewarm_prompt_json(self.ptr, prompt_json.as_ptr(), &mut error)
384 };
385 if status != ffi::status::OK {
386 return Err(crate::error::from_swift(status, error));
387 }
388 Ok(())
389 }
390
391 pub fn respond_prompt<P>(&self, prompt: P) -> Result<String, FMError>
397 where
398 P: ToPrompt,
399 {
400 self.respond_prompt_with(prompt, GenerationOptions::new())
401 }
402
403 pub fn respond_prompt_with<P>(
409 &self,
410 prompt: P,
411 options: GenerationOptions,
412 ) -> Result<String, FMError>
413 where
414 P: ToPrompt,
415 {
416 self.respond_prompt_detailed(prompt, options)
417 .map(|response| response.content)
418 }
419
420 pub fn respond_prompt_detailed<P>(
426 &self,
427 prompt: P,
428 options: GenerationOptions,
429 ) -> Result<SessionResponse<String>, FMError>
430 where
431 P: ToPrompt,
432 {
433 let prompt = prompt.to_prompt()?;
434 let payload = respond_request_json(&prompt, options, None, true)?;
435 let payload = request_response(self.ptr, &payload)?;
436 let response: BridgeTextResponse = serde_json::from_str(&payload)
437 .map_err(|error| FMError::DecodingFailure(error.to_string()))?;
438 Ok(SessionResponse {
439 content: response.content,
440 raw_content: GeneratedContent::from_bridge_payload(response.raw_content, true)?,
441 transcript: Transcript::from_json_str(&response.transcript_json)?,
442 })
443 }
444
445 pub fn respond_generated<P>(
451 &self,
452 prompt: P,
453 schema: &GenerationSchema,
454 include_schema_in_prompt: bool,
455 ) -> Result<GeneratedContent, FMError>
456 where
457 P: ToPrompt,
458 {
459 self.respond_generated_with(
460 prompt,
461 schema,
462 include_schema_in_prompt,
463 GenerationOptions::new(),
464 )
465 .map(|response| response.content)
466 }
467
468 pub fn respond_generated_with<P>(
474 &self,
475 prompt: P,
476 schema: &GenerationSchema,
477 include_schema_in_prompt: bool,
478 options: GenerationOptions,
479 ) -> Result<SessionResponse<GeneratedContent>, FMError>
480 where
481 P: ToPrompt,
482 {
483 let prompt = prompt.to_prompt()?;
484 let payload =
485 respond_request_json(&prompt, options, Some(schema), include_schema_in_prompt)?;
486 let payload = request_response(self.ptr, &payload)?;
487 let response: BridgeStructuredResponse = serde_json::from_str(&payload)
488 .map_err(|error| FMError::DecodingFailure(error.to_string()))?;
489 Ok(SessionResponse {
490 content: GeneratedContent::from_bridge_payload(response.content, true)?,
491 raw_content: GeneratedContent::from_bridge_payload(response.raw_content, true)?,
492 transcript: Transcript::from_json_str(&response.transcript_json)?,
493 })
494 }
495
496 pub fn respond_generating<P, T>(
503 &self,
504 prompt: P,
505 include_schema_in_prompt: bool,
506 options: GenerationOptions,
507 ) -> Result<SessionResponse<T>, FMError>
508 where
509 P: ToPrompt,
510 T: crate::schema::Generable,
511 {
512 let response = self.respond_generated_with(
513 prompt,
514 &T::generation_schema()?,
515 include_schema_in_prompt,
516 options,
517 )?;
518 Ok(SessionResponse {
519 content: T::from_generated_content(&response.content)?,
520 raw_content: response.raw_content,
521 transcript: response.transcript,
522 })
523 }
524
525 pub fn stream_prompt<P, F>(&self, prompt: P, on_chunk: F) -> Result<(), FMError>
531 where
532 P: ToPrompt,
533 F: FnMut(StreamEvent<'_>) + Send + 'static,
534 {
535 let prompt = prompt.to_prompt()?;
536 let prompt_text = prompt_to_plain_text(&prompt).ok_or_else(|| {
537 FMError::InvalidArgument(
538 "text streaming only supports prompts composed of text segments".into(),
539 )
540 })?;
541 self.stream_with(&prompt_text, GenerationOptions::new(), on_chunk)
542 }
543
544 pub fn stream_generated<P, F>(
550 &self,
551 prompt: P,
552 schema: &GenerationSchema,
553 include_schema_in_prompt: bool,
554 options: GenerationOptions,
555 on_event: F,
556 ) -> Result<(), FMError>
557 where
558 P: ToPrompt,
559 F: FnMut(StructuredStreamEvent) + Send + 'static,
560 {
561 let prompt = prompt.to_prompt()?;
562 let payload =
563 respond_request_json(&prompt, options, Some(schema), include_schema_in_prompt)?;
564 let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
565 let state = Arc::new(StructuredStreamState {
566 on_event: Mutex::new(Box::new(on_event)),
567 done_tx: Mutex::new(Some(done_tx)),
568 });
569 let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
570 unsafe {
571 ffi::fm_session_stream_request_json(
572 self.ptr,
573 payload.as_ptr(),
574 context,
575 structured_stream_trampoline,
576 )
577 };
578 done_rx.recv().map_err(|_| FMError::Unknown {
579 code: ffi::status::UNKNOWN,
580 message: "Swift bridge dropped the structured stream channel".into(),
581 })?
582 }
583
584 pub fn log_feedback_attachment(
590 &self,
591 request: FeedbackAttachmentRequest,
592 ) -> Result<Vec<u8>, FMError> {
593 let request_json = CString::new(request.to_bridge_json()?).map_err(|error| {
594 FMError::InvalidArgument(format!("feedback request contains a NUL byte: {error}"))
595 })?;
596 let mut length = 0usize;
597 let mut error: *mut c_char = ptr::null_mut();
598 let ptr = unsafe {
599 ffi::fm_session_log_feedback_attachment_json(
600 self.ptr,
601 request_json.as_ptr(),
602 &mut length,
603 &mut error,
604 )
605 };
606 if ptr.is_null() && !error.is_null() {
607 return Err(crate::error::from_swift(
608 ffi::status::INVALID_ARGUMENT,
609 error,
610 ));
611 }
612 if ptr.is_null() || length == 0 {
613 return Ok(Vec::new());
614 }
615 let bytes = unsafe { std::slice::from_raw_parts(ptr.cast::<u8>(), length) }.to_vec();
616 unsafe { ffi::fm_bytes_free(ptr) };
617 Ok(bytes)
618 }
619}
620
621pub struct SessionBuilder<'a> {
623 model: Option<&'a ConfiguredSystemLanguageModel>,
624 instructions: Option<Instructions>,
625 transcript: Option<Transcript>,
626 tools: Vec<Tool>,
627}
628
629impl<'a> SessionBuilder<'a> {
630 const fn new() -> Self {
631 Self {
632 model: None,
633 instructions: None,
634 transcript: None,
635 tools: Vec::new(),
636 }
637 }
638
639 #[must_use]
641 pub const fn model(mut self, model: &'a ConfiguredSystemLanguageModel) -> Self {
642 self.model = Some(model);
643 self
644 }
645
646 pub fn instructions<I>(mut self, instructions: I) -> Result<Self, FMError>
648 where
649 I: ToInstructions,
650 {
651 self.instructions = Some(instructions.to_instructions()?);
652 Ok(self)
653 }
654
655 #[must_use]
657 pub fn transcript(mut self, transcript: Transcript) -> Self {
658 self.transcript = Some(transcript);
659 self
660 }
661
662 #[must_use]
664 pub fn tool(mut self, tool: Tool) -> Self {
665 self.tools.push(tool);
666 self
667 }
668
669 #[must_use]
671 pub fn tools(mut self, tools: impl IntoIterator<Item = Tool>) -> Self {
672 self.tools.extend(tools);
673 self
674 }
675
676 pub fn build(self) -> Result<LanguageModelSession, FMError> {
682 if self.instructions.is_some() && self.transcript.is_some() {
683 return Err(FMError::InvalidArgument(
684 "session builder accepts either instructions or a transcript, not both".into(),
685 ));
686 }
687
688 let instructions_json = self
689 .instructions
690 .as_ref()
691 .map(Instructions::to_bridge_json)
692 .transpose()?;
693 let transcript_json = self
694 .transcript
695 .as_ref()
696 .map(Transcript::to_json_string)
697 .transpose()?;
698 let tool_registry = if self.tools.is_empty() {
699 None
700 } else {
701 Some(Arc::new(ToolRegistry::new(self.tools)))
702 };
703 let tools_json = tool_registry
704 .as_ref()
705 .map(|registry| registry.specs_json())
706 .transpose()?;
707
708 let instructions_c = instructions_json
709 .as_deref()
710 .map(CString::new)
711 .transpose()
712 .map_err(|error| {
713 FMError::InvalidArgument(format!("instructions JSON contains a NUL byte: {error}"))
714 })?;
715 let transcript_c = transcript_json
716 .as_deref()
717 .map(CString::new)
718 .transpose()
719 .map_err(|error| {
720 FMError::InvalidArgument(format!("transcript JSON contains a NUL byte: {error}"))
721 })?;
722 let tools_c = tools_json
723 .as_deref()
724 .map(CString::new)
725 .transpose()
726 .map_err(|error| {
727 FMError::InvalidArgument(format!("tool JSON contains a NUL byte: {error}"))
728 })?;
729
730 let tool_context = tool_registry.as_ref().map_or(ptr::null_mut(), |registry| {
731 Arc::as_ptr(registry).cast_mut().cast::<c_void>()
732 });
733 let mut error: *mut c_char = ptr::null_mut();
734 let ptr = unsafe {
735 ffi::fm_session_create_ex(
736 self.model.map_or(ptr::null_mut(), |model| model.ptr),
737 instructions_c
738 .as_ref()
739 .map_or(ptr::null(), |json| json.as_ptr()),
740 transcript_c
741 .as_ref()
742 .map_or(ptr::null(), |json| json.as_ptr()),
743 tools_c.as_ref().map_or(ptr::null(), |json| json.as_ptr()),
744 tool_context,
745 tool_registry
746 .as_ref()
747 .map(|_| tool_callback_trampoline as ffi::FmToolCallback),
748 &mut error,
749 )
750 };
751 if ptr.is_null() {
752 return Err(crate::error::from_swift(
753 ffi::status::MODEL_UNAVAILABLE,
754 error,
755 ));
756 }
757 Ok(LanguageModelSession {
758 ptr,
759 _tool_registry: tool_registry,
760 })
761 }
762}
763
764#[derive(Debug, Clone, PartialEq)]
766pub struct SessionResponse<T> {
767 pub content: T,
768 pub raw_content: GeneratedContent,
769 pub transcript: Transcript,
770}
771
772#[derive(Debug, Clone, PartialEq, Eq)]
774pub struct StructuredStreamSnapshot {
775 pub content_json: String,
776 pub raw_content_json: String,
777 pub is_complete: bool,
778}
779
780#[derive(Debug, Clone, PartialEq)]
782#[non_exhaustive]
783pub enum StructuredStreamEvent {
784 Snapshot(StructuredStreamSnapshot),
785 Done,
786 Error(FMError),
787}
788
789#[derive(Debug, Clone, Copy, PartialEq, Eq)]
791pub enum FeedbackIssueCategory {
792 Unhelpful,
793 TooVerbose,
794 DidNotFollowInstructions,
795 Incorrect,
796 StereotypeOrBias,
797 SuggestiveOrSexual,
798 VulgarOrOffensive,
799 TriggeredGuardrailUnexpectedly,
800}
801
802impl FeedbackIssueCategory {
803 const fn as_str(self) -> &'static str {
804 match self {
805 Self::Unhelpful => "unhelpful",
806 Self::TooVerbose => "too_verbose",
807 Self::DidNotFollowInstructions => "did_not_follow_instructions",
808 Self::Incorrect => "incorrect",
809 Self::StereotypeOrBias => "stereotype_or_bias",
810 Self::SuggestiveOrSexual => "suggestive_or_sexual",
811 Self::VulgarOrOffensive => "vulgar_or_offensive",
812 Self::TriggeredGuardrailUnexpectedly => "triggered_guardrail_unexpectedly",
813 }
814 }
815}
816
817#[derive(Debug, Clone, PartialEq, Eq)]
819pub struct FeedbackIssue {
820 pub category: FeedbackIssueCategory,
821 pub explanation: Option<String>,
822}
823
824#[derive(Debug, Clone, Copy, PartialEq, Eq)]
826pub enum FeedbackSentiment {
827 Positive,
828 Negative,
829 Neutral,
830}
831
832impl FeedbackSentiment {
833 const fn as_str(self) -> &'static str {
834 match self {
835 Self::Positive => "positive",
836 Self::Negative => "negative",
837 Self::Neutral => "neutral",
838 }
839 }
840}
841
842#[derive(Debug, Clone, PartialEq)]
844pub struct FeedbackAttachmentRequest {
845 pub sentiment: Option<FeedbackSentiment>,
846 pub issues: Vec<FeedbackIssue>,
847 pub desired_response_text: Option<String>,
848 pub desired_response_content: Option<GeneratedContent>,
849 pub desired_output: Option<crate::transcript::Entry>,
850}
851
852impl FeedbackAttachmentRequest {
853 #[must_use]
855 pub const fn new() -> Self {
856 Self {
857 sentiment: None,
858 issues: Vec::new(),
859 desired_response_text: None,
860 desired_response_content: None,
861 desired_output: None,
862 }
863 }
864
865 fn to_bridge_json(&self) -> Result<String, FMError> {
866 let issues = self
867 .issues
868 .iter()
869 .map(|issue| {
870 json!({
871 "category": issue.category.as_str(),
872 "explanation": issue.explanation,
873 })
874 })
875 .collect::<Vec<_>>();
876 let desired_output_json = self
877 .desired_output
878 .as_ref()
879 .map(|entry| Transcript::from(vec![entry.clone()]).to_json_string())
880 .transpose()?;
881 let desired_response_content = self
882 .desired_response_content
883 .as_ref()
884 .map(GeneratedContent::to_bridge_value)
885 .transpose()?;
886 serde_json::to_string(&json!({
887 "sentiment": self.sentiment.map(FeedbackSentiment::as_str),
888 "issues": issues,
889 "desiredResponseText": self.desired_response_text,
890 "desiredResponseContent": desired_response_content,
891 "desiredOutputTranscriptJSON": desired_output_json,
892 }))
893 .map_err(|error| {
894 FMError::InvalidArgument(format!(
895 "feedback request is not JSON-serializable: {error}"
896 ))
897 })
898 }
899}
900
901#[derive(Debug, Deserialize)]
902struct BridgeTextResponse {
903 content: String,
904 #[serde(rename = "rawContent")]
905 raw_content: BridgeGeneratedContent,
906 #[serde(rename = "transcriptJSON")]
907 transcript_json: String,
908}
909
910#[derive(Debug, Deserialize)]
911struct BridgeStructuredResponse {
912 content: BridgeGeneratedContent,
913 #[serde(rename = "rawContent")]
914 raw_content: BridgeGeneratedContent,
915 #[serde(rename = "transcriptJSON")]
916 transcript_json: String,
917}
918
919#[derive(Debug, Deserialize)]
920struct BridgeStructuredSnapshot {
921 content: BridgeGeneratedContent,
922 #[serde(rename = "rawContent")]
923 raw_content: BridgeGeneratedContent,
924 #[serde(rename = "isComplete")]
925 is_complete: bool,
926}
927
928#[derive(Debug, Deserialize)]
929struct BridgeTextStreamSnapshot {
930 delta: String,
931}
932
933fn respond_request_json(
934 prompt: &Prompt,
935 options: GenerationOptions,
936 schema: Option<&GenerationSchema>,
937 include_schema_in_prompt: bool,
938) -> Result<CString, FMError> {
939 let sampling = match options.sampling() {
940 SamplingMode::Default => json!({ "mode": "default" }),
941 SamplingMode::Greedy => json!({ "mode": "greedy" }),
942 SamplingMode::TopK(k) => json!({
943 "mode": "top_k",
944 "topK": k,
945 "seed": options.sampling_seed(),
946 }),
947 SamplingMode::TopP(p) => json!({
948 "mode": "top_p",
949 "topP": p,
950 "seed": options.sampling_seed(),
951 }),
952 };
953 let include_schema_in_prompt = schema.map_or(include_schema_in_prompt, |schema| {
954 schema.effective_include_schema_in_prompt(include_schema_in_prompt)
955 });
956 let payload = serde_json::to_string(&json!({
957 "prompt": prompt.to_bridge_value(),
958 "options": {
959 "temperature": options.temperature(),
960 "maximumResponseTokens": options.maximum_response_tokens(),
961 "sampling": sampling,
962 },
963 "schemaJSON": schema.map(GenerationSchema::bridge_request_json),
964 "includeSchemaInPrompt": include_schema_in_prompt,
965 }))
966 .map_err(|error| {
967 FMError::InvalidArgument(format!("request is not JSON-serializable: {error}"))
968 })?;
969 CString::new(payload).map_err(|error| {
970 FMError::InvalidArgument(format!("request JSON contains a NUL byte: {error}"))
971 })
972}
973
974fn request_response(session: *mut c_void, payload: &CString) -> Result<String, FMError> {
975 let (tx, rx) = mpsc::channel();
976 let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
977 let context = Box::into_raw(tx_box).cast::<c_void>();
978 unsafe {
979 ffi::fm_session_respond_request_json(session, payload.as_ptr(), context, respond_trampoline)
980 };
981 rx.recv().map_err(|_| FMError::Unknown {
982 code: ffi::status::UNKNOWN,
983 message: "Swift bridge dropped the JSON response channel".into(),
984 })?
985}
986
987pub(crate) fn decode_bridge_text_response(
988 payload: &str,
989) -> Result<SessionResponse<String>, FMError> {
990 let response: BridgeTextResponse = serde_json::from_str(payload)
991 .map_err(|error| FMError::DecodingFailure(error.to_string()))?;
992 Ok(SessionResponse {
993 content: response.content,
994 raw_content: GeneratedContent::from_bridge_payload(response.raw_content, true)?,
995 transcript: Transcript::from_json_str(&response.transcript_json)?,
996 })
997}
998
999pub(crate) fn request_text_response_with<F>(invoke: F) -> Result<SessionResponse<String>, FMError>
1000where
1001 F: FnOnce(*mut c_void, ffi::FmRespondCallback),
1002{
1003 let (tx, rx) = mpsc::channel();
1004 let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
1005 let context = Box::into_raw(tx_box).cast::<c_void>();
1006 invoke(context, respond_trampoline);
1007 let payload = rx.recv().map_err(|_| FMError::Unknown {
1008 code: ffi::status::UNKNOWN,
1009 message: "Swift bridge dropped the JSON response channel".into(),
1010 })??;
1011 decode_bridge_text_response(&payload)
1012}
1013
1014pub(crate) fn run_text_stream_with<F, C>(invoke: F, on_chunk: C) -> Result<(), FMError>
1015where
1016 F: FnOnce(*mut c_void, ffi::FmStreamCallback),
1017 C: FnMut(StreamEvent<'_>) + Send + 'static,
1018{
1019 let (done_tx, done_rx) = mpsc::channel::<Result<(), FMError>>();
1020 let state = Arc::new(StreamState {
1021 on_chunk: Mutex::new(Box::new(on_chunk)),
1022 done_tx: Mutex::new(Some(done_tx)),
1023 });
1024 let context = Arc::into_raw(state).cast::<c_void>().cast_mut();
1025 invoke(context, json_text_stream_trampoline);
1026 done_rx.recv().map_err(|_| FMError::Unknown {
1027 code: ffi::status::UNKNOWN,
1028 message: "Swift bridge dropped the stream channel".into(),
1029 })?
1030}
1031
1032fn prompt_to_plain_text(prompt: &Prompt) -> Option<String> {
1033 let mut text = String::new();
1034 for segment in prompt.segments() {
1035 match segment {
1036 crate::prompt::Segment::Text(segment) => text.push_str(&segment.text),
1037 crate::prompt::Segment::Structure(_) => return None,
1038 }
1039 }
1040 Some(text)
1041}
1042
1043impl Default for LanguageModelSession {
1044 fn default() -> Self {
1045 Self::new()
1046 }
1047}
1048
1049impl Drop for LanguageModelSession {
1050 fn drop(&mut self) {
1051 if !self.ptr.is_null() {
1052 unsafe { ffi::fm_object_release(self.ptr) };
1053 }
1054 }
1055}
1056
1057impl core::fmt::Debug for LanguageModelSession {
1058 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1059 f.debug_struct("LanguageModelSession")
1060 .field("ptr", &self.ptr)
1061 .finish()
1062 }
1063}
1064
1065#[derive(Debug)]
1067#[non_exhaustive]
1068pub enum StreamEvent<'a> {
1069 Chunk(&'a str),
1071 Done,
1073 Error(FMError),
1075}
1076
1077unsafe extern "C" fn respond_trampoline(
1084 context: *mut c_void,
1085 response: *mut c_char,
1086 error: *mut c_char,
1087 status: i32,
1088) {
1089 let tx = Box::from_raw(context.cast::<mpsc::Sender<Result<String, FMError>>>());
1090 let result = if status == ffi::status::OK && !response.is_null() {
1091 let s = core::ffi::CStr::from_ptr(response)
1092 .to_string_lossy()
1093 .into_owned();
1094 ffi::fm_string_free(response);
1095 Ok(s)
1096 } else {
1097 Err(crate::error::from_swift(status, error))
1098 };
1099 let _ = tx.send(result);
1100}
1101
1102type StreamCallback = Box<dyn FnMut(StreamEvent<'_>) + Send>;
1103
1104struct StreamState {
1105 on_chunk: Mutex<StreamCallback>,
1106 done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
1107}
1108
1109unsafe extern "C" fn json_text_stream_trampoline(
1115 context: *mut c_void,
1116 chunk: *mut c_char,
1117 done: bool,
1118 status: i32,
1119) {
1120 let state = Arc::from_raw(context.cast::<StreamState>());
1121 let state_for_swift = state.clone();
1122 core::mem::forget(state_for_swift);
1123
1124 let payload: Option<String> = if chunk.is_null() {
1125 None
1126 } else {
1127 let value = core::ffi::CStr::from_ptr(chunk)
1128 .to_string_lossy()
1129 .into_owned();
1130 ffi::fm_string_free(chunk);
1131 Some(value)
1132 };
1133
1134 if status != ffi::status::OK {
1135 let err = payload
1136 .map(|message| {
1137 crate::error::from_swift(
1138 status,
1139 ffi::fm_string_dup(
1140 CString::new(message)
1141 .expect("stream errors must not contain NUL bytes")
1142 .as_ptr(),
1143 ),
1144 )
1145 })
1146 .unwrap_or_else(|| crate::error::from_swift(status, ptr::null_mut()));
1147 {
1148 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1149 let _ = catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Error(err.clone()))));
1151 }
1152 if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1153 let _ = tx.send(Err(err));
1154 }
1155 drop(Arc::from_raw(Arc::as_ptr(&state)));
1156 drop(state);
1157 return;
1158 }
1159
1160 if let Some(payload) = payload {
1161 match serde_json::from_str::<BridgeTextStreamSnapshot>(&payload) {
1162 Ok(snapshot) if !snapshot.delta.is_empty() => {
1163 let chunk_panicked = {
1164 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1165 catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Chunk(&snapshot.delta))))
1167 .is_err()
1168 };
1169 if chunk_panicked {
1170 if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1171 let _ = tx.send(Err(FMError::Unknown {
1172 code: ffi::status::UNKNOWN,
1173 message: "stream callback panicked".into(),
1174 }));
1175 }
1176 drop(Arc::from_raw(Arc::as_ptr(&state)));
1177 drop(state);
1178 return;
1179 }
1180 }
1181 Ok(_) => {}
1182 Err(error) => {
1183 let err = FMError::DecodingFailure(error.to_string());
1184 {
1185 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1186 let _ = catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Error(err.clone()))));
1187 }
1188 if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1189 let _ = tx.send(Err(err));
1190 }
1191 drop(Arc::from_raw(Arc::as_ptr(&state)));
1192 drop(state);
1193 return;
1194 }
1195 }
1196 }
1197
1198 if done {
1199 {
1200 let mut cb = state.on_chunk.lock().expect("user callback mutex poisoned");
1201 let _ = catch_unwind(AssertUnwindSafe(|| cb(StreamEvent::Done)));
1202 }
1203 if let Some(tx) = state.done_tx.lock().expect("done_tx mutex poisoned").take() {
1204 let _ = tx.send(Ok(()));
1205 }
1206 drop(Arc::from_raw(Arc::as_ptr(&state)));
1207 }
1208 drop(state);
1209}
1210
1211type StructuredStreamCallback = Box<dyn FnMut(StructuredStreamEvent) + Send>;
1212
1213struct StructuredStreamState {
1214 on_event: Mutex<StructuredStreamCallback>,
1215 done_tx: Mutex<Option<mpsc::Sender<Result<(), FMError>>>>,
1216}
1217
1218#[allow(clippy::too_many_lines)]
1221unsafe extern "C" fn structured_stream_trampoline(
1222 context: *mut c_void,
1223 chunk: *mut c_char,
1224 done: bool,
1225 status: i32,
1226) {
1227 let state = Arc::from_raw(context.cast::<StructuredStreamState>());
1228 let state_for_swift = state.clone();
1229 core::mem::forget(state_for_swift);
1230
1231 let payload: Option<String> = if chunk.is_null() {
1232 None
1233 } else {
1234 let value = core::ffi::CStr::from_ptr(chunk)
1235 .to_string_lossy()
1236 .into_owned();
1237 ffi::fm_string_free(chunk);
1238 Some(value)
1239 };
1240
1241 if status != ffi::status::OK {
1242 let err = payload
1243 .map(|message| {
1244 crate::error::from_swift(
1245 status,
1246 ffi::fm_string_dup(
1247 CString::new(message)
1248 .expect("stream errors must not contain NUL bytes")
1249 .as_ptr(),
1250 ),
1251 )
1252 })
1253 .unwrap_or_else(|| crate::error::from_swift(status, ptr::null_mut()));
1254 {
1255 let mut cb = state
1256 .on_event
1257 .lock()
1258 .expect("structured callback mutex poisoned");
1259 let _ = catch_unwind(AssertUnwindSafe(|| {
1261 cb(StructuredStreamEvent::Error(err.clone()));
1262 }));
1263 }
1264 if let Some(tx) = state
1265 .done_tx
1266 .lock()
1267 .expect("structured done_tx mutex poisoned")
1268 .take()
1269 {
1270 let _ = tx.send(Err(err));
1271 }
1272 drop(Arc::from_raw(Arc::as_ptr(&state)));
1273 drop(state);
1274 return;
1275 }
1276
1277 if let Some(payload) = payload {
1278 let snapshot: BridgeStructuredSnapshot = match serde_json::from_str(&payload) {
1279 Ok(snapshot) => snapshot,
1280 Err(error) => {
1281 let err = FMError::DecodingFailure(error.to_string());
1282 {
1283 let mut cb = state
1284 .on_event
1285 .lock()
1286 .expect("structured callback mutex poisoned");
1287 let _ = catch_unwind(AssertUnwindSafe(|| {
1288 cb(StructuredStreamEvent::Error(err.clone()));
1289 }));
1290 }
1291 if let Some(tx) = state
1292 .done_tx
1293 .lock()
1294 .expect("structured done_tx mutex poisoned")
1295 .take()
1296 {
1297 let _ = tx.send(Err(err));
1298 }
1299 drop(Arc::from_raw(Arc::as_ptr(&state)));
1300 drop(state);
1301 return;
1302 }
1303 };
1304 let snapshot_event = StructuredStreamEvent::Snapshot(StructuredStreamSnapshot {
1305 content_json: snapshot.content.json,
1306 raw_content_json: snapshot.raw_content.json,
1307 is_complete: snapshot.is_complete,
1308 });
1309 let snapshot_panicked = {
1310 let mut cb = state
1311 .on_event
1312 .lock()
1313 .expect("structured callback mutex poisoned");
1314 catch_unwind(AssertUnwindSafe(|| cb(snapshot_event))).is_err()
1316 };
1317 if snapshot_panicked {
1318 if let Some(tx) = state
1319 .done_tx
1320 .lock()
1321 .expect("structured done_tx mutex poisoned")
1322 .take()
1323 {
1324 let _ = tx.send(Err(FMError::Unknown {
1325 code: ffi::status::UNKNOWN,
1326 message: "stream callback panicked".into(),
1327 }));
1328 }
1329 drop(Arc::from_raw(Arc::as_ptr(&state)));
1330 drop(state);
1331 return;
1332 }
1333 }
1334
1335 if done {
1336 {
1337 let mut cb = state
1338 .on_event
1339 .lock()
1340 .expect("structured callback mutex poisoned");
1341 let _ = catch_unwind(AssertUnwindSafe(|| cb(StructuredStreamEvent::Done)));
1342 }
1343 if let Some(tx) = state
1344 .done_tx
1345 .lock()
1346 .expect("structured done_tx mutex poisoned")
1347 .take()
1348 {
1349 let _ = tx.send(Ok(()));
1350 }
1351 drop(Arc::from_raw(Arc::as_ptr(&state)));
1352 }
1353 drop(state);
1354}