use std::ffi::{c_void, CStr, CString};
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use doom_fish_utils::completion::{error_from_cstr, AsyncCompletion, AsyncCompletionFuture};
use serde::Deserialize;
use crate::content::{BridgeGeneratedContent, GeneratedContent};
use crate::error::FMError;
use crate::ffi;
use crate::generation::GenerationOptions;
use crate::model::Adapter;
use crate::prompt::{Prompt, ToPrompt};
use crate::schema::GenerationSchema;
use crate::session::{decode_bridge_text_response, SessionResponse};
use crate::transcript::Transcript;
#[derive(Debug, Deserialize)]
struct AsyncBridgeStructuredResponse {
content: BridgeGeneratedContent,
#[serde(rename = "rawContent")]
raw_content: BridgeGeneratedContent,
#[serde(rename = "transcriptJSON")]
transcript_json: String,
}
struct OpaquePtr(*mut c_void);
unsafe impl Send for OpaquePtr {}
unsafe extern "C" fn respond_async_cb(
ctx: *mut c_void,
response: *mut std::ffi::c_char,
error: *mut std::ffi::c_char,
status: i32,
) {
if status == ffi::status::OK && !response.is_null() {
let s = unsafe { CStr::from_ptr(response) }
.to_string_lossy()
.into_owned();
unsafe { ffi::fm_string_free(response) };
unsafe { AsyncCompletion::complete_ok(ctx, s) };
} else {
let fm_err = crate::error::from_swift(status, error);
unsafe { AsyncCompletion::<String>::complete_err(ctx, fm_err.to_string()) };
}
}
unsafe extern "C" fn adapter_init_async_cb(
result: *mut c_void,
error: *const std::ffi::c_char,
ctx: *mut c_void,
) {
if !error.is_null() {
let msg = unsafe { error_from_cstr(error) };
unsafe { AsyncCompletion::<OpaquePtr>::complete_err(ctx, msg) };
} else if !result.is_null() {
unsafe { AsyncCompletion::complete_ok(ctx, OpaquePtr(result)) };
} else {
unsafe { AsyncCompletion::<OpaquePtr>::complete_err(ctx, "null adapter pointer".into()) };
}
}
unsafe extern "C" fn adapter_compat_async_cb(
result: *mut c_void,
error: *const std::ffi::c_char,
ctx: *mut c_void,
) {
if !error.is_null() {
let msg = unsafe { error_from_cstr(error) };
unsafe { AsyncCompletion::<String>::complete_err(ctx, msg) };
} else if !result.is_null() {
let s = unsafe { CStr::from_ptr(result.cast::<std::ffi::c_char>()) }
.to_string_lossy()
.into_owned();
unsafe { ffi::fm_string_free(result.cast::<std::ffi::c_char>()) };
unsafe { AsyncCompletion::complete_ok(ctx, s) };
} else {
unsafe { AsyncCompletion::<String>::complete_err(ctx, "null compatibility result".into()) };
}
}
pub struct RespondFuture {
inner: AsyncCompletionFuture<String>,
}
impl std::fmt::Debug for RespondFuture {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RespondFuture").finish_non_exhaustive()
}
}
impl Future for RespondFuture {
type Output = Result<SessionResponse<String>, FMError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner).poll(cx).map(|r| {
r.map_err(|msg| FMError::Unknown {
code: ffi::status::UNKNOWN,
message: msg,
})
.and_then(|json| decode_bridge_text_response(&json))
})
}
}
pub struct RespondGeneratingFuture {
inner: AsyncCompletionFuture<String>,
}
impl std::fmt::Debug for RespondGeneratingFuture {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RespondGeneratingFuture")
.finish_non_exhaustive()
}
}
impl Future for RespondGeneratingFuture {
type Output = Result<SessionResponse<GeneratedContent>, FMError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner).poll(cx).map(|r| {
r.map_err(|msg| FMError::Unknown {
code: ffi::status::UNKNOWN,
message: msg,
})
.and_then(|json| {
let response: AsyncBridgeStructuredResponse = serde_json::from_str(&json)
.map_err(|e| FMError::DecodingFailure(e.to_string()))?;
Ok(SessionResponse {
content: GeneratedContent::from_bridge_payload(response.content, true)?,
raw_content: GeneratedContent::from_bridge_payload(response.raw_content, true)?,
transcript: Transcript::from_json_str(&response.transcript_json)?,
})
})
})
}
}
pub struct AdapterInitFuture {
inner: AsyncCompletionFuture<OpaquePtr>,
}
impl std::fmt::Debug for AdapterInitFuture {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AdapterInitFuture").finish_non_exhaustive()
}
}
impl Future for AdapterInitFuture {
type Output = Result<Adapter, FMError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner).poll(cx).map(|r| {
r.map_err(FMError::AdapterInvalidName)
.map(|OpaquePtr(ptr)| Adapter { ptr })
})
}
}
pub struct AdapterCompatibilityFuture {
inner: AsyncCompletionFuture<String>,
}
impl std::fmt::Debug for AdapterCompatibilityFuture {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AdapterCompatibilityFuture")
.finish_non_exhaustive()
}
}
impl Future for AdapterCompatibilityFuture {
type Output = Result<Vec<String>, FMError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner).poll(cx).map(|r| {
r.map_err(FMError::AdapterCompatibleNotFound)
.and_then(|json| {
serde_json::from_str::<Vec<String>>(&json)
.map_err(|e| FMError::DecodingFailure(e.to_string()))
})
})
}
}
pub struct AsyncSession<'s> {
session: &'s crate::session::LanguageModelSession,
}
impl<'s> AsyncSession<'s> {
#[must_use]
pub fn new(session: &'s crate::session::LanguageModelSession) -> Self {
Self { session }
}
pub fn respond(&self, prompt: impl ToPrompt) -> Result<RespondFuture, FMError> {
let prompt = prompt.to_prompt()?;
let payload = build_text_request_json(&prompt, GenerationOptions::new())?;
let session_ptr = self.session.as_ptr();
let (future, ctx) = AsyncCompletion::create();
unsafe {
ffi::fm_session_respond_request_json(
session_ptr,
payload.as_ptr(),
ctx,
respond_async_cb,
);
}
Ok(RespondFuture { inner: future })
}
pub fn respond_with_options(
&self,
prompt: impl ToPrompt,
options: GenerationOptions,
) -> Result<RespondFuture, FMError> {
let prompt = prompt.to_prompt()?;
let payload = build_text_request_json(&prompt, options)?;
let session_ptr = self.session.as_ptr();
let (future, ctx) = AsyncCompletion::create();
unsafe {
ffi::fm_session_respond_request_json(
session_ptr,
payload.as_ptr(),
ctx,
respond_async_cb,
);
}
Ok(RespondFuture { inner: future })
}
pub fn respond_generating(
&self,
prompt: impl ToPrompt,
schema: &GenerationSchema,
include_schema_in_prompt: bool,
options: GenerationOptions,
) -> Result<RespondGeneratingFuture, FMError> {
let prompt = prompt.to_prompt()?;
let payload =
build_structured_request_json(&prompt, options, schema, include_schema_in_prompt)?;
let session_ptr = self.session.as_ptr();
let (future, ctx) = AsyncCompletion::create();
unsafe {
ffi::fm_session_respond_request_json(
session_ptr,
payload.as_ptr(),
ctx,
respond_async_cb,
);
}
Ok(RespondGeneratingFuture { inner: future })
}
}
impl std::fmt::Debug for AsyncSession<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncSession").finish_non_exhaustive()
}
}
pub struct AsyncAdapter;
impl AsyncAdapter {
pub fn from_name(name: &str) -> Result<AdapterInitFuture, FMError> {
let cname = CString::new(name)
.map_err(|e| FMError::InvalidArgument(format!("NUL byte in adapter name: {e}")))?;
let (future, ctx) = AsyncCompletion::create();
unsafe {
ffi::fm_adapter_create_from_name_async(cname.as_ptr(), ctx, adapter_init_async_cb);
}
Ok(AdapterInitFuture { inner: future })
}
pub fn compatibility(name: &str) -> Result<AdapterCompatibilityFuture, FMError> {
let cname = CString::new(name)
.map_err(|e| FMError::InvalidArgument(format!("NUL byte in adapter name: {e}")))?;
let (future, ctx) = AsyncCompletion::create();
unsafe {
ffi::fm_adapter_compatibility_async(cname.as_ptr(), ctx, adapter_compat_async_cb);
}
Ok(AdapterCompatibilityFuture { inner: future })
}
}
fn build_text_request_json(
prompt: &Prompt,
options: GenerationOptions,
) -> Result<CString, FMError> {
build_request_json_inner(prompt, options, None, true)
}
fn build_structured_request_json(
prompt: &Prompt,
options: GenerationOptions,
schema: &GenerationSchema,
include_schema_in_prompt: bool,
) -> Result<CString, FMError> {
build_request_json_inner(prompt, options, Some(schema), include_schema_in_prompt)
}
fn build_request_json_inner(
prompt: &Prompt,
options: GenerationOptions,
schema: Option<&GenerationSchema>,
include_schema_in_prompt: bool,
) -> Result<CString, FMError> {
use crate::generation::SamplingMode;
use serde_json::json;
let sampling = match options.sampling() {
SamplingMode::Default => json!({ "mode": "default" }),
SamplingMode::Greedy => json!({ "mode": "greedy" }),
SamplingMode::TopK(k) => json!({
"mode": "top_k",
"topK": k,
"seed": options.sampling_seed(),
}),
SamplingMode::TopP(p) => json!({
"mode": "top_p",
"topP": p,
"seed": options.sampling_seed(),
}),
};
let include_schema_in_prompt = schema.map_or(include_schema_in_prompt, |schema| {
schema.effective_include_schema_in_prompt(include_schema_in_prompt)
});
let payload = serde_json::to_string(&json!({
"prompt": prompt.to_bridge_value(),
"options": {
"temperature": options.temperature(),
"maximumResponseTokens": options.maximum_response_tokens(),
"sampling": sampling,
},
"schemaJSON": schema.map(GenerationSchema::bridge_request_json),
"includeSchemaInPrompt": include_schema_in_prompt,
}))
.map_err(|e| FMError::InvalidArgument(format!("request not JSON-serializable: {e}")))?;
CString::new(payload)
.map_err(|e| FMError::InvalidArgument(format!("request JSON contains NUL: {e}")))
}