use std::ffi::{c_char, c_void, CStr, CString};
use std::sync::Arc;
use ringo_fm_sys as sys;
use tokio::sync::oneshot;
use crate::error::{Error, Result};
use crate::generated::{GeneratedContent, GeneratedContentTag};
use crate::handle::{check_error, FmString, ManagedRef};
use crate::model::SystemLanguageModel;
use crate::options::GenerationOptions;
use crate::prompt::Prompt;
use crate::schema::GenerationSchema;
use crate::stream::ResponseStream;
use crate::tool::ToolHandle;
use crate::transcript::Transcript;
pub(crate) struct SessionTag;
pub struct LanguageModelSession {
pub(crate) handle: Arc<ManagedRef<SessionTag>>,
_tools: Vec<ToolHandle>,
}
impl LanguageModelSession {
pub fn default() -> Result<Self> {
let ptr = unsafe { sys::FMLanguageModelSessionCreateDefault() };
Ok(Self {
handle: Arc::new(ManagedRef::from_owned(ptr)?),
_tools: Vec::new(),
})
}
pub fn new(
model: Option<&SystemLanguageModel>,
instructions: Option<&str>,
tools: Vec<ToolHandle>,
) -> Result<Self> {
let instr_c = match instructions {
Some(s) => Some(CString::new(s).map_err(|e| Error::Native(e.to_string()))?),
None => None,
};
let instr_ptr = instr_c.as_ref().map_or(std::ptr::null(), |c| c.as_ptr());
let model_ptr = model.map_or(std::ptr::null(), |m| m.handle.as_ptr());
let mut tool_ptrs: Vec<*const c_void> = tools.iter().map(|t| t.as_ptr()).collect();
let (tools_arg, tool_count) = if tool_ptrs.is_empty() {
(std::ptr::null_mut(), 0)
} else {
(tool_ptrs.as_mut_ptr(), tool_ptrs.len() as i32)
};
let ptr = unsafe {
sys::FMLanguageModelSessionCreateFromSystemLanguageModel(
model_ptr,
instr_ptr,
tools_arg,
tool_count,
)
};
Ok(Self {
handle: Arc::new(ManagedRef::from_owned(ptr)?),
_tools: tools,
})
}
pub fn is_responding(&self) -> bool {
unsafe { sys::FMLanguageModelSessionIsResponding(self.handle.as_ptr()) }
}
pub fn reset(&self) {
unsafe { sys::FMLanguageModelSessionReset(self.handle.as_ptr()) };
}
pub fn prewarm(&self, prompt_prefix: Option<&str>) -> Result<()> {
let prefix_c = match prompt_prefix {
Some(s) => Some(CString::new(s).map_err(|e| Error::Native(e.to_string()))?),
None => None,
};
let prefix_ptr = prefix_c.as_ref().map_or(std::ptr::null(), |c| c.as_ptr());
unsafe { sys::FMLanguageModelSessionPrewarm(self.handle.as_ptr(), prefix_ptr) };
Ok(())
}
pub async fn respond<P: Into<Prompt>>(&self, prompt: P) -> Result<String> {
self.respond_with(prompt, &GenerationOptions::default()).await
}
pub async fn respond_with<P: Into<Prompt>>(
&self,
prompt: P,
options: &GenerationOptions,
) -> Result<String> {
let composed = prompt.into().into_composed()?;
let opts_json = options.to_json()?;
let opts_c = match opts_json {
Some(s) => Some(CString::new(s).map_err(|e| Error::Native(e.to_string()))?),
None => None,
};
let opts_ptr = opts_c.as_ref().map_or(std::ptr::null(), |c| c.as_ptr());
let (tx, rx) = oneshot::channel::<Result<String>>();
let user_info = Box::into_raw(Box::new(tx)) as *mut c_void;
let task = unsafe {
sys::FMLanguageModelSessionRespond(
self.handle.as_ptr(),
composed.into_raw(),
opts_ptr,
user_info,
Some(text_trampoline),
)
};
let _cancel = CancelOnDrop::new(task);
match rx.await {
Ok(r) => r,
Err(_) => Err(Error::Native("response callback dropped".into())),
}
}
pub fn stream<P: Into<Prompt>>(&self, prompt: P) -> Result<ResponseStream> {
self.stream_with(prompt, &GenerationOptions::default())
}
pub fn stream_with<P: Into<Prompt>>(
&self,
prompt: P,
options: &GenerationOptions,
) -> Result<ResponseStream> {
let composed = prompt.into().into_composed()?;
let opts_json = options.to_json()?;
let opts_c = match opts_json {
Some(s) => Some(CString::new(s).map_err(|e| Error::Native(e.to_string()))?),
None => None,
};
let opts_ptr = opts_c.as_ref().map_or(std::ptr::null(), |c| c.as_ptr());
let stream_ptr = unsafe {
sys::FMLanguageModelSessionStreamResponse(
self.handle.as_ptr(),
composed.into_raw(),
opts_ptr,
)
};
ResponseStream::start(stream_ptr)
}
pub async fn respond_with_schema<P: Into<Prompt>>(
&self,
prompt: P,
schema: &GenerationSchema,
options: &GenerationOptions,
) -> Result<GeneratedContent> {
let composed = prompt.into().into_composed()?;
let opts_json = options.to_json()?;
let opts_c = match opts_json {
Some(s) => Some(CString::new(s).map_err(|e| Error::Native(e.to_string()))?),
None => None,
};
let opts_ptr = opts_c.as_ref().map_or(std::ptr::null(), |c| c.as_ptr());
let (tx, rx) = oneshot::channel::<Result<GeneratedContent>>();
let user_info = Box::into_raw(Box::new(tx)) as *mut c_void;
let task = unsafe {
sys::FMLanguageModelSessionRespondWithSchema(
self.handle.as_ptr(),
composed.into_raw(),
schema.as_ptr(),
opts_ptr,
user_info,
Some(structured_trampoline),
)
};
let _cancel = CancelOnDrop::new(task);
match rx.await {
Ok(r) => r,
Err(_) => Err(Error::Native("structured response callback dropped".into())),
}
}
pub async fn respond_with_json_schema<P: Into<Prompt>>(
&self,
prompt: P,
schema_json: &str,
options: &GenerationOptions,
) -> Result<GeneratedContent> {
let composed = prompt.into().into_composed()?;
let schema_c = CString::new(schema_json).map_err(|e| Error::Native(e.to_string()))?;
let opts_json = options.to_json()?;
let opts_c = match opts_json {
Some(s) => Some(CString::new(s).map_err(|e| Error::Native(e.to_string()))?),
None => None,
};
let opts_ptr = opts_c.as_ref().map_or(std::ptr::null(), |c| c.as_ptr());
let (tx, rx) = oneshot::channel::<Result<GeneratedContent>>();
let user_info = Box::into_raw(Box::new(tx)) as *mut c_void;
let task = unsafe {
sys::FMLanguageModelSessionRespondWithSchemaFromJSON(
self.handle.as_ptr(),
composed.into_raw(),
schema_c.as_ptr(),
opts_ptr,
user_info,
Some(structured_trampoline),
)
};
let _cancel = CancelOnDrop::new(task);
match rx.await {
Ok(r) => r,
Err(_) => Err(Error::Native("structured response callback dropped".into())),
}
}
pub fn transcript(&self) -> Result<Transcript> {
let mut code: i32 = 0;
let mut desc: *mut c_char = std::ptr::null_mut();
let ptr = unsafe {
sys::FMLanguageModelSessionGetTranscriptJSONString(self.handle.as_ptr(), &mut code, &mut desc)
};
check_error(code, desc)?;
let json = FmString::from_raw(ptr)
.ok_or_else(|| Error::Native("transcript JSON null".into()))?
.to_string()?;
Transcript::from_json(&json)
}
pub fn from_transcript(
transcript_json: &str,
model: Option<&SystemLanguageModel>,
tools: Vec<ToolHandle>,
) -> Result<Self> {
let c = CString::new(transcript_json).map_err(|e| Error::Native(e.to_string()))?;
let mut code: i32 = 0;
let mut desc: *mut c_char = std::ptr::null_mut();
let transcript_session = unsafe {
sys::FMTranscriptCreateFromJSONString(c.as_ptr(), &mut code, &mut desc)
};
check_error(code, desc)?;
if transcript_session.is_null() {
return Err(Error::Native("transcript create returned null".into()));
}
let _drop_intermediate = ManagedRef::<SessionTag>::from_owned(transcript_session)?;
let model_ptr = model.map_or(std::ptr::null(), |m| m.handle.as_ptr());
let mut tool_ptrs: Vec<*const c_void> = tools.iter().map(|t| t.as_ptr()).collect();
let (tools_arg, tool_count) = if tool_ptrs.is_empty() {
(std::ptr::null_mut(), 0)
} else {
(tool_ptrs.as_mut_ptr(), tool_ptrs.len() as i32)
};
let ptr = unsafe {
sys::FMLanguageModelSessionCreateFromTranscript(
_drop_intermediate.as_ptr(),
model_ptr,
tools_arg,
tool_count,
)
};
Ok(Self {
handle: Arc::new(ManagedRef::from_owned(ptr)?),
_tools: tools,
})
}
}
impl std::fmt::Debug for LanguageModelSession {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LanguageModelSession").finish_non_exhaustive()
}
}
pub(crate) struct CancelOnDrop {
task: sys::FMTaskRef,
}
impl CancelOnDrop {
pub(crate) fn new(task: sys::FMTaskRef) -> Self {
Self { task }
}
}
impl Drop for CancelOnDrop {
fn drop(&mut self) {
if !self.task.is_null() {
unsafe { sys::FMTaskCancel(self.task) };
}
}
}
pub(crate) unsafe extern "C" fn text_trampoline(
status: i32,
content: *const c_char,
length: usize,
user_info: *mut c_void,
) {
let tx = unsafe { Box::from_raw(user_info as *mut oneshot::Sender<Result<String>>) };
let result = if status == crate::error::status::SUCCESS {
if content.is_null() {
Ok(String::new())
} else {
let slice = unsafe { std::slice::from_raw_parts(content as *const u8, length) };
match std::str::from_utf8(slice) {
Ok(s) => Ok(s.to_owned()),
Err(e) => Err(Error::Native(format!("non-utf8 model output: {e}"))),
}
}
} else {
let debug = if content.is_null() {
String::new()
} else {
unsafe { CStr::from_ptr(content) }.to_string_lossy().into_owned()
};
Err(Error::from_status(status, debug))
};
let _ = tx.send(result);
}
unsafe extern "C" fn structured_trampoline(
status: i32,
content: sys::FMGeneratedContentRef,
user_info: *mut c_void,
) {
let tx = unsafe { Box::from_raw(user_info as *mut oneshot::Sender<Result<GeneratedContent>>) };
let result = if status == crate::error::status::SUCCESS {
match ManagedRef::<GeneratedContentTag>::from_owned(content) {
Ok(handle) => Ok(GeneratedContent { handle }),
Err(e) => Err(e),
}
} else {
Err(Error::from_status(status, String::new()))
};
let _ = tx.send(result);
}