1use std::collections::HashMap;
6use std::ffi::{CStr, CString, c_char, c_int, c_void};
7use std::ptr::{self, NonNull};
8use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
9use std::sync::{Arc, Mutex};
10use std::time::Duration;
11
12use crate::context::{ContextLimit, ContextUsage, context_usage_from_transcript};
13use crate::error::{Error, Result};
14use crate::ffi::{self, SwiftPtr};
15use crate::model::{SystemLanguageModel, error_from_swift};
16use crate::options::GenerationOptions;
17use crate::tool::{Tool, ToolResult, tools_to_json};
18
19type ToolMapInner = HashMap<String, Arc<dyn Tool>>;
21
22struct ToolCallbackData {
27 tools: Mutex<ToolMapInner>,
28 dropping: AtomicBool,
30 active_callbacks: AtomicUsize,
32}
33
34struct CallbackGuard<'a>(&'a AtomicUsize);
36
37impl Drop for CallbackGuard<'_> {
38 fn drop(&mut self) {
39 self.0.fetch_sub(1, Ordering::SeqCst);
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct Response {
46 content: String,
47}
48
49impl Response {
50 pub(crate) fn new(content: String) -> Self {
52 Self { content }
53 }
54
55 pub fn content(&self) -> &str {
57 &self.content
58 }
59
60 pub fn into_content(self) -> String {
62 self.content
63 }
64}
65
66impl AsRef<str> for Response {
67 fn as_ref(&self) -> &str {
68 &self.content
69 }
70}
71
72impl std::fmt::Display for Response {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 f.write_str(&self.content)
75 }
76}
77
78pub struct Session {
96 ptr: NonNull<c_void>,
97 tool_callback_data: Option<Arc<ToolCallbackData>>,
100}
101
102impl Session {
103 pub fn new(model: &SystemLanguageModel) -> Result<Self> {
105 Self::create_internal(model, None, &[])
106 }
107
108 pub fn with_instructions(model: &SystemLanguageModel, instructions: &str) -> Result<Self> {
112 Self::create_internal(model, Some(instructions), &[])
113 }
114
115 pub fn with_tools(model: &SystemLanguageModel, tools: &[Arc<dyn Tool>]) -> Result<Self> {
119 Self::create_internal(model, None, tools)
120 }
121
122 pub fn with_instructions_and_tools(
124 model: &SystemLanguageModel,
125 instructions: &str,
126 tools: &[Arc<dyn Tool>],
127 ) -> Result<Self> {
128 Self::create_internal(model, Some(instructions), tools)
129 }
130
131 pub fn from_transcript(model: &SystemLanguageModel, transcript_json: &str) -> Result<Self> {
136 let transcript_c = CString::new(transcript_json)?;
137 let mut error: SwiftPtr = ptr::null_mut();
138
139 let ptr = unsafe {
140 ffi::fm_session_from_transcript(model.as_ptr(), transcript_c.as_ptr(), &raw mut error)
141 };
142
143 if !error.is_null() {
144 return Err(error_from_swift(error));
145 }
146
147 NonNull::new(ptr)
148 .map(|ptr| Self {
149 ptr,
150 tool_callback_data: None,
151 })
152 .ok_or_else(|| {
153 Error::InternalError(
154 "Session creation from transcript returned null without error. \
155 The transcript JSON may be malformed or incompatible."
156 .to_string(),
157 )
158 })
159 }
160
161 fn create_internal(
163 model: &SystemLanguageModel,
164 instructions: Option<&str>,
165 tools: &[Arc<dyn Tool>],
166 ) -> Result<Self> {
167 let instructions_c = instructions.map(CString::new).transpose()?;
168 let instructions_ptr = instructions_c.as_ref().map_or(ptr::null(), |s| s.as_ptr());
169
170 let mut tool_map = HashMap::new();
172 let tools_json = if tools.is_empty() {
173 None
174 } else {
175 let tool_refs: Vec<&dyn Tool> = tools.iter().map(std::convert::AsRef::as_ref).collect();
176 for tool in tools {
177 tool_map.insert(tool.name().to_string(), Arc::clone(tool));
178 }
179 let json_str = tools_to_json(&tool_refs)?;
180 Some(CString::new(json_str)?)
181 };
182 let tools_ptr = tools_json.as_ref().map_or(ptr::null(), |s| s.as_ptr());
183
184 let callback_data = if tools.is_empty() {
186 None
187 } else {
188 Some(Arc::new(ToolCallbackData {
189 tools: Mutex::new(tool_map),
190 dropping: AtomicBool::new(false),
191 active_callbacks: AtomicUsize::new(0),
192 }))
193 };
194
195 let user_data = callback_data.as_ref().map_or(ptr::null_mut(), |arc| {
197 Arc::into_raw(Arc::clone(arc)) as *mut c_void
198 });
199
200 let mut error: SwiftPtr = ptr::null_mut();
201
202 let ptr = unsafe {
203 ffi::fm_session_create(
204 model.as_ptr(),
205 instructions_ptr,
206 tools_ptr,
207 user_data,
208 session_tool_callback,
209 &raw mut error,
210 )
211 };
212
213 if !error.is_null() {
214 if !user_data.is_null() {
216 let _ = unsafe { Arc::from_raw(user_data as *const ToolCallbackData) };
217 }
218 return Err(error_from_swift(error));
219 }
220
221 NonNull::new(ptr)
222 .map(|ptr| Self {
223 ptr,
224 tool_callback_data: callback_data,
225 })
226 .ok_or_else(|| {
227 if !user_data.is_null() {
229 let _ = unsafe { Arc::from_raw(user_data as *const ToolCallbackData) };
230 }
231 Error::InternalError(
232 "Session creation returned null without error. \
233 Check model availability and instructions validity."
234 .to_string(),
235 )
236 })
237 }
238
239 pub fn respond(&self, prompt: &str, options: &GenerationOptions) -> Result<Response> {
243 let prompt_c = CString::new(prompt)?;
244 let options_json = options.to_json();
245 let options_c = CString::new(options_json)?;
246
247 let mut error: SwiftPtr = ptr::null_mut();
248
249 let response_ptr = unsafe {
250 ffi::fm_session_respond(
251 self.ptr.as_ptr(),
252 prompt_c.as_ptr(),
253 options_c.as_ptr(),
254 &raw mut error,
255 )
256 };
257
258 if !error.is_null() {
259 return Err(error_from_swift(error));
260 }
261
262 if response_ptr.is_null() {
263 return Err(Error::GenerationError("Received null response".to_string()));
264 }
265
266 let content = unsafe {
267 let cstr = CStr::from_ptr(response_ptr);
268 let s = cstr
269 .to_str()
270 .map_err(|e| Error::GenerationError(format!("Invalid UTF-8 in response: {e}")))?
271 .to_owned();
272 ffi::fm_string_free(response_ptr);
273 s
274 };
275
276 Ok(Response::new(content))
277 }
278
279 pub fn respond_with_timeout(
283 &self,
284 prompt: &str,
285 options: &GenerationOptions,
286 timeout: Duration,
287 ) -> Result<Response> {
288 if timeout.is_zero() {
289 return self.respond(prompt, options);
290 }
291
292 let timeout_ms = u64::try_from(timeout.as_millis()).map_err(|_| {
293 Error::InvalidInput("Timeout is too large to represent in milliseconds".to_string())
294 })?;
295
296 let prompt_c = CString::new(prompt)?;
297 let options_json = options.to_json();
298 let options_c = CString::new(options_json)?;
299
300 let mut error: SwiftPtr = ptr::null_mut();
301
302 let response_ptr = unsafe {
303 ffi::fm_session_respond_with_timeout(
304 self.ptr.as_ptr(),
305 prompt_c.as_ptr(),
306 options_c.as_ptr(),
307 timeout_ms,
308 &raw mut error,
309 )
310 };
311
312 if !error.is_null() {
313 return Err(error_from_swift(error));
314 }
315
316 if response_ptr.is_null() {
317 return Err(Error::GenerationError("Received null response".to_string()));
318 }
319
320 let content = unsafe {
321 let cstr = CStr::from_ptr(response_ptr);
322 let s = cstr
323 .to_str()
324 .map_err(|e| Error::GenerationError(format!("Invalid UTF-8 in response: {e}")))?
325 .to_owned();
326 ffi::fm_string_free(response_ptr);
327 s
328 };
329
330 Ok(Response::new(content))
331 }
332
333 pub fn stream_response<F>(
352 &self,
353 prompt: &str,
354 options: &GenerationOptions,
355 on_chunk: F,
356 ) -> Result<()>
357 where
358 F: FnMut(&str) + Send + 'static,
359 {
360 let prompt_c = CString::new(prompt)?;
361 let options_json = options.to_json();
362 let options_c = CString::new(options_json)?;
363
364 let state = Box::new(StreamState {
366 on_chunk: Mutex::new(Box::new(on_chunk)),
367 error: Mutex::new(None),
368 });
369 let state_ptr = Box::into_raw(state).cast::<c_void>();
370
371 unsafe {
372 ffi::fm_session_stream(
373 self.ptr.as_ptr(),
374 prompt_c.as_ptr(),
375 options_c.as_ptr(),
376 state_ptr,
377 stream_chunk_callback,
378 stream_done_callback,
379 stream_error_callback,
380 );
381 }
382
383 let state = unsafe { Box::from_raw(state_ptr.cast::<StreamState>()) };
385 let error = state.error.lock().map_err(|_| Error::PoisonError)?;
386 if let Some(err) = error.as_ref() {
387 return Err(Error::GenerationError(err.clone()));
388 }
389
390 Ok(())
391 }
392
393 pub fn cancel(&self) {
395 unsafe {
396 ffi::fm_session_cancel(self.ptr.as_ptr());
397 }
398 }
399
400 pub fn is_responding(&self) -> bool {
402 unsafe { ffi::fm_session_is_responding(self.ptr.as_ptr()) }
403 }
404
405 pub fn transcript_json(&self) -> Result<String> {
409 let mut error: SwiftPtr = ptr::null_mut();
410 let ptr = unsafe { ffi::fm_session_get_transcript(self.ptr.as_ptr(), &raw mut error) };
411
412 if !error.is_null() {
413 return Err(error_from_swift(error));
414 }
415
416 if ptr.is_null() {
417 return Err(Error::InternalError(
418 "Transcript retrieval returned null without error. \
419 The session may be in an invalid state."
420 .to_string(),
421 ));
422 }
423
424 let json = unsafe {
425 let cstr = CStr::from_ptr(ptr);
426 let s = cstr
427 .to_str()
428 .map_err(|e| Error::InternalError(format!("Invalid UTF-8 in transcript: {e}")))?
429 .to_owned();
430 ffi::fm_string_free(ptr);
431 s
432 };
433
434 Ok(json)
435 }
436
437 pub fn context_usage(&self, limit: &ContextLimit) -> Result<ContextUsage> {
439 let transcript_json = self.transcript_json()?;
440 context_usage_from_transcript(&transcript_json, limit)
441 }
442
443 pub fn ensure_context_within(&self, limit: &ContextLimit) -> Result<()> {
445 let usage = self.context_usage(limit)?;
446 if usage.over_limit {
447 return Err(Error::InvalidInput(format!(
448 "Estimated context usage {} exceeds configured limit {} (reserved: {})",
449 usage.estimated_tokens, usage.max_tokens, usage.reserved_response_tokens
450 )));
451 }
452 Ok(())
453 }
454
455 pub fn prewarm(&self, prompt_prefix: Option<&str>) -> Result<()> {
459 let prefix_c = prompt_prefix.map(CString::new).transpose()?;
460 let prefix_ptr = prefix_c.as_ref().map_or(ptr::null(), |s| s.as_ptr());
461
462 unsafe {
463 ffi::fm_session_prewarm(self.ptr.as_ptr(), prefix_ptr);
464 }
465
466 Ok(())
467 }
468
469 pub fn respond_json(
509 &self,
510 prompt: &str,
511 schema: &serde_json::Value,
512 options: &GenerationOptions,
513 ) -> Result<String> {
514 let prompt_c = CString::new(prompt)?;
515 let schema_json = serde_json::to_string(schema)?;
516 let schema_c = CString::new(schema_json)?;
517 let options_json = options.to_json();
518 let options_c = CString::new(options_json)?;
519
520 let mut error: SwiftPtr = ptr::null_mut();
521
522 let response_ptr = unsafe {
523 ffi::fm_session_respond_json(
524 self.ptr.as_ptr(),
525 prompt_c.as_ptr(),
526 schema_c.as_ptr(),
527 options_c.as_ptr(),
528 &raw mut error,
529 )
530 };
531
532 if !error.is_null() {
533 return Err(error_from_swift(error));
534 }
535
536 if response_ptr.is_null() {
537 return Err(Error::GenerationError(
538 "Received null response from JSON generation".to_string(),
539 ));
540 }
541
542 let content = unsafe {
543 let cstr = CStr::from_ptr(response_ptr);
544 let s = cstr
545 .to_str()
546 .map_err(|e| {
547 Error::GenerationError(format!("Invalid UTF-8 in JSON response: {e}"))
548 })?
549 .to_owned();
550 ffi::fm_string_free(response_ptr);
551 s
552 };
553
554 Ok(content)
555 }
556
557 pub fn respond_structured<T: serde::de::DeserializeOwned>(
595 &self,
596 prompt: &str,
597 schema: &serde_json::Value,
598 options: &GenerationOptions,
599 ) -> Result<T> {
600 let json_str = self.respond_json(prompt, schema, options)?;
601 serde_json::from_str(&json_str)
602 .map_err(|e| Error::InvalidInput(format!("Failed to deserialize response: {e}")))
603 }
604
605 pub fn respond_structured_gen<T>(&self, prompt: &str, options: &GenerationOptions) -> Result<T>
609 where
610 T: crate::Generable + serde::de::DeserializeOwned,
611 {
612 self.respond_structured(prompt, &T::schema(), options)
613 }
614
615 pub fn stream_json<F>(
647 &self,
648 prompt: &str,
649 schema: &serde_json::Value,
650 options: &GenerationOptions,
651 on_chunk: F,
652 ) -> Result<()>
653 where
654 F: FnMut(&str) + Send + 'static,
655 {
656 let prompt_c = CString::new(prompt)?;
657 let schema_json = serde_json::to_string(schema)?;
658 let schema_c = CString::new(schema_json)?;
659 let options_json = options.to_json();
660 let options_c = CString::new(options_json)?;
661
662 let state = Box::new(StreamState {
664 on_chunk: Mutex::new(Box::new(on_chunk)),
665 error: Mutex::new(None),
666 });
667 let state_ptr = Box::into_raw(state).cast::<c_void>();
668
669 unsafe {
670 ffi::fm_session_stream_json(
671 self.ptr.as_ptr(),
672 prompt_c.as_ptr(),
673 schema_c.as_ptr(),
674 options_c.as_ptr(),
675 state_ptr,
676 stream_chunk_callback,
677 stream_done_callback,
678 stream_error_callback,
679 );
680 }
681
682 let state = unsafe { Box::from_raw(state_ptr.cast::<StreamState>()) };
684 let error = state.error.lock().map_err(|_| Error::PoisonError)?;
685 if let Some(err) = error.as_ref() {
686 return Err(Error::GenerationError(err.clone()));
687 }
688
689 Ok(())
690 }
691}
692
693impl Drop for Session {
694 fn drop(&mut self) {
695 if let Some(ref callback_data) = self.tool_callback_data {
697 callback_data.dropping.store(true, Ordering::SeqCst);
698
699 let mut attempts = 0;
701 while callback_data.active_callbacks.load(Ordering::SeqCst) > 0 && attempts < 100 {
702 std::thread::sleep(std::time::Duration::from_millis(10));
703 attempts += 1;
704 }
705 }
706
707 unsafe {
709 ffi::fm_session_free(self.ptr.as_ptr());
710 }
711
712 }
716}
717
718unsafe impl Send for Session {}
721
722type ChunkCallbackFn = dyn FnMut(&str) + Send;
727
728struct StreamState {
730 on_chunk: Mutex<Box<ChunkCallbackFn>>,
731 error: Mutex<Option<String>>,
732}
733
734extern "C" fn stream_chunk_callback(user_data: *mut c_void, chunk: *const c_char) {
736 if user_data.is_null() || chunk.is_null() {
737 return;
738 }
739
740 let state = unsafe { &*(user_data as *const StreamState) };
741 let chunk_str = unsafe { CStr::from_ptr(chunk).to_string_lossy() };
742
743 if let Ok(mut on_chunk) = state.on_chunk.lock() {
744 on_chunk(&chunk_str);
745 }
746}
747
748extern "C" fn stream_done_callback(_user_data: *mut c_void) {
750 }
752
753extern "C" fn stream_error_callback(user_data: *mut c_void, _code: c_int, message: *const c_char) {
755 if user_data.is_null() {
756 return;
757 }
758
759 let state = unsafe { &*(user_data as *const StreamState) };
760 let msg = if message.is_null() {
761 "Streaming error occurred (no message provided by Swift)".to_string()
762 } else {
763 unsafe { CStr::from_ptr(message).to_string_lossy().into_owned() }
764 };
765
766 if let Ok(mut error) = state.error.lock() {
767 *error = Some(msg);
768 }
769}
770
771extern "C" fn session_tool_callback(
774 user_data: *mut c_void,
775 tool_name: *const c_char,
776 arguments_json: *const c_char,
777) -> *mut c_char {
778 if user_data.is_null() || tool_name.is_null() {
779 let result = ToolResult::error("Invalid callback parameters");
780 return string_to_c(result.to_json());
781 }
782
783 let callback_data = unsafe { &*(user_data as *const ToolCallbackData) };
787
788 if callback_data.dropping.load(Ordering::SeqCst) {
790 let result = ToolResult::error("Session is being dropped");
791 return string_to_c(result.to_json());
792 }
793
794 callback_data
796 .active_callbacks
797 .fetch_add(1, Ordering::SeqCst);
798 let _guard = CallbackGuard(&callback_data.active_callbacks);
799
800 let name = unsafe { CStr::from_ptr(tool_name).to_string_lossy().into_owned() };
801 let args_str = if arguments_json.is_null() {
802 "{}".to_string()
803 } else {
804 unsafe {
805 CStr::from_ptr(arguments_json)
806 .to_string_lossy()
807 .into_owned()
808 }
809 };
810
811 let arguments: serde_json::Value = match parse_tool_arguments(&args_str) {
813 Ok(v) => v,
814 Err(message) => {
815 let result = ToolResult::error(message);
816 return string_to_c(result.to_json());
817 }
818 };
819
820 let Ok(tools) = callback_data.tools.lock() else {
822 let result = ToolResult::error("Failed to acquire tool lock");
823 return string_to_c(result.to_json());
824 };
825
826 let Some(tool) = tools.get(&name).map(Arc::clone) else {
827 let result = ToolResult::error(format!("Unknown tool: {name}"));
828 return string_to_c(result.to_json());
829 };
830
831 drop(tools);
833
834 let result = match tool.call(arguments) {
836 Ok(output) => ToolResult::success(output),
837 Err(e) => ToolResult::error(e.to_string()),
838 };
839
840 string_to_c(result.to_json())
841}
842
843fn string_to_c(s: String) -> *mut c_char {
845 match CString::new(s) {
846 Ok(cs) => cs.into_raw(),
847 Err(_) => ptr::null_mut(),
848 }
849}
850
851fn parse_tool_arguments(input: &str) -> std::result::Result<serde_json::Value, String> {
852 match serde_json::from_str(input) {
853 Ok(value) => Ok(value),
854 Err(err) => {
855 if let Some(fixed) = autoclose_json(input) {
856 match serde_json::from_str(&fixed) {
857 Ok(value) => {
858 #[cfg(debug_assertions)]
860 eprintln!(
861 "[fm-rs] autoclose_json repaired truncated tool arguments: {input:?} -> {fixed:?}"
862 );
863 Ok(value)
864 }
865 Err(fixed_err) => Err(format!(
866 "Failed to parse arguments: {err}; attempted fix: {fixed_err}"
867 )),
868 }
869 } else {
870 Err(format!("Failed to parse arguments: {err}"))
871 }
872 }
873 }
874}
875
876const AUTOCLOSE_JSON_MAX_SIZE: usize = 1024 * 1024;
878
879fn autoclose_json(input: &str) -> Option<String> {
880 if input.len() > AUTOCLOSE_JSON_MAX_SIZE {
882 return None;
883 }
884
885 let mut stack: Vec<char> = Vec::new();
886 let mut in_string = false;
887 let mut escape = false;
888
889 for ch in input.chars() {
890 if in_string {
891 if escape {
892 escape = false;
893 continue;
894 }
895 if ch == '\\' {
896 escape = true;
897 continue;
898 }
899 if ch == '"' {
900 in_string = false;
901 }
902 continue;
903 }
904
905 match ch {
906 '"' => in_string = true,
907 '{' => stack.push('}'),
908 '[' => stack.push(']'),
909 '}' => {
910 if stack.pop() != Some('}') {
911 return None;
912 }
913 }
914 ']' => {
915 if stack.pop() != Some(']') {
916 return None;
917 }
918 }
919 _ => {}
920 }
921 }
922
923 if in_string || stack.is_empty() {
924 return None;
925 }
926
927 let mut out = input.to_string();
928 while let Some(close) = stack.pop() {
929 out.push(close);
930 }
931 Some(out)
932}
933
934#[cfg(test)]
935mod tests {
936 use super::*;
937
938 #[test]
939 fn test_response() {
940 let response = Response::new("Hello, world!".to_string());
941 assert_eq!(response.content(), "Hello, world!");
942 assert_eq!(response.as_ref(), "Hello, world!");
943 assert_eq!(format!("{response}"), "Hello, world!");
944 assert_eq!(response.into_content(), "Hello, world!");
945 }
946}