use core::ffi::{c_char, c_void};
use std::ffi::CString;
use std::path::Path;
use std::ptr;
use std::sync::mpsc;
use serde_json::Value;
#[cfg(feature = "async")]
use doom_fish_utils::completion::{error_from_cstr, AsyncCompletion};
use crate::error::{from_swift, FMError, Unavailability};
use crate::ffi;
fn availability_from_code(code: i32) -> Availability {
match code {
0 => Availability::Available,
1 => Availability::Unavailable(Unavailability::DeviceNotEligible),
2 => Availability::Unavailable(Unavailability::AppleIntelligenceNotEnabled),
3 => Availability::Unavailable(Unavailability::ModelNotReady),
-1 => Availability::Unavailable(Unavailability::OsTooOld),
_ => Availability::Unavailable(Unavailability::Unknown),
}
}
fn owned_string(ptr: *mut c_char) -> String {
if ptr.is_null() {
return String::new();
}
let string = unsafe { core::ffi::CStr::from_ptr(ptr) }
.to_string_lossy()
.into_owned();
unsafe { ffi::fm_string_free(ptr) };
string
}
fn json_string(ptr: *mut c_char) -> String {
if ptr.is_null() {
return String::from("[]");
}
owned_string(ptr)
}
#[cfg(feature = "async")]
async fn token_count_inner(model_ptr: usize, prompt: &str) -> Result<usize, FMError> {
let prompt = CString::new(prompt).map_err(|error| {
FMError::InvalidArgument(format!("prompt contains an interior NUL byte: {error}"))
})?;
let (future, ctx) = AsyncCompletion::<String>::create();
unsafe {
ffi::fm_system_model_token_count_prompt_async(
model_ptr as *mut c_void,
prompt.as_ptr(),
ctx,
token_count_async_cb,
);
}
let value = future.await.map_err(|message| FMError::Unknown {
code: ffi::status::UNKNOWN,
message,
})?;
value.parse::<usize>().map_err(|error| {
FMError::DecodingFailure(format!(
"token count bridge returned invalid integer: {error}"
))
})
}
#[cfg(feature = "async")]
unsafe extern "C" fn token_count_async_cb(
result: *mut c_void,
error: *const c_char,
ctx: *mut c_void,
) {
if !error.is_null() {
let message = unsafe { error_from_cstr(error) };
unsafe { AsyncCompletion::<String>::complete_err(ctx, message) };
} else if !result.is_null() {
let value = unsafe { core::ffi::CStr::from_ptr(result.cast::<c_char>()) }
.to_string_lossy()
.into_owned();
unsafe { ffi::fm_string_free(result.cast::<c_char>()) };
unsafe { AsyncCompletion::complete_ok(ctx, value) };
} else {
unsafe { AsyncCompletion::<String>::complete_err(ctx, "null token count result".into()) };
}
}
#[derive(Debug, Clone, Copy)]
pub struct SystemLanguageModel;
impl SystemLanguageModel {
#[must_use]
pub fn is_available() -> bool {
unsafe { ffi::fm_system_model_is_available() }
}
#[must_use]
pub fn availability() -> Availability {
let code = unsafe { ffi::fm_system_model_availability_code() };
availability_from_code(code)
}
#[must_use]
pub fn default_model() -> Option<ConfiguredSystemLanguageModel> {
let ptr = unsafe { ffi::fm_system_model_create_default() };
(!ptr.is_null()).then_some(ConfiguredSystemLanguageModel { ptr })
}
pub fn with_use_case(
use_case: UseCase,
guardrails: Guardrails,
) -> Result<ConfiguredSystemLanguageModel, FMError> {
let mut error: *mut c_char = ptr::null_mut();
let ptr = unsafe {
ffi::fm_system_model_create(use_case.as_ffi(), guardrails.as_ffi(), &mut error)
};
if ptr.is_null() {
return Err(from_swift(ffi::status::MODEL_UNAVAILABLE, error));
}
Ok(ConfiguredSystemLanguageModel { ptr })
}
pub fn with_adapter(
adapter: &Adapter,
guardrails: Guardrails,
) -> Result<ConfiguredSystemLanguageModel, FMError> {
let mut error: *mut c_char = ptr::null_mut();
let ptr = unsafe {
ffi::fm_system_model_create_with_adapter(adapter.ptr, guardrails.as_ffi(), &mut error)
};
if ptr.is_null() {
return Err(from_swift(ffi::status::MODEL_UNAVAILABLE, error));
}
Ok(ConfiguredSystemLanguageModel { ptr })
}
#[must_use]
pub fn supported_languages() -> Vec<String> {
let json = unsafe { ffi::fm_system_model_supported_languages_json(ptr::null_mut()) };
serde_json::from_str(&json_string(json)).unwrap_or_default()
}
#[must_use]
pub fn supports_locale(locale_identifier: &str) -> bool {
CString::new(locale_identifier).map_or(false, |locale| unsafe {
ffi::fm_system_model_supports_locale(ptr::null_mut(), locale.as_ptr())
})
}
#[cfg(feature = "async")]
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
pub async fn token_count(prompt: &str) -> Result<usize, FMError> {
token_count_inner(ptr::null_mut::<c_void>() as usize, prompt).await
}
}
pub struct ConfiguredSystemLanguageModel {
pub(crate) ptr: *mut c_void,
}
impl ConfiguredSystemLanguageModel {
#[must_use]
pub fn availability(&self) -> Availability {
availability_from_code(unsafe { ffi::fm_system_model_availability_code_for(self.ptr) })
}
#[must_use]
pub fn is_available(&self) -> bool {
matches!(self.availability(), Availability::Available)
}
#[must_use]
pub fn supported_languages(&self) -> Vec<String> {
let json = unsafe { ffi::fm_system_model_supported_languages_json(self.ptr) };
serde_json::from_str(&json_string(json)).unwrap_or_default()
}
#[must_use]
pub fn supports_locale(&self, locale_identifier: &str) -> bool {
CString::new(locale_identifier).map_or(false, |locale| unsafe {
ffi::fm_system_model_supports_locale(self.ptr, locale.as_ptr())
})
}
#[cfg(feature = "async")]
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
#[allow(clippy::future_not_send)]
pub async fn token_count(&self, prompt: &str) -> Result<usize, FMError> {
let model_ptr = self.ptr as usize;
token_count_inner(model_ptr, prompt).await
}
}
impl Drop for ConfiguredSystemLanguageModel {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe { ffi::fm_object_release(self.ptr) };
}
}
}
impl core::fmt::Debug for ConfiguredSystemLanguageModel {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("ConfiguredSystemLanguageModel")
.field("availability", &self.availability())
.finish()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UseCase {
General,
ContentTagging,
}
impl UseCase {
const fn as_ffi(self) -> i32 {
match self {
Self::General => 0,
Self::ContentTagging => 1,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Guardrails {
Default,
PermissiveContentTransformations,
}
impl Guardrails {
const fn as_ffi(self) -> i32 {
match self {
Self::Default => 0,
Self::PermissiveContentTransformations => 1,
}
}
}
pub struct Adapter {
pub(crate) ptr: *mut c_void,
}
impl Adapter {
pub fn from_file(path: impl AsRef<Path>) -> Result<Self, FMError> {
let path = CString::new(path.as_ref().to_string_lossy().into_owned()).map_err(|error| {
FMError::InvalidArgument(format!(
"adapter path contains an interior NUL byte: {error}"
))
})?;
let mut error: *mut c_char = ptr::null_mut();
let ptr = unsafe { ffi::fm_adapter_create_from_file(path.as_ptr(), &mut error) };
if ptr.is_null() {
return Err(from_swift(ffi::status::ADAPTER_INVALID_ASSET, error));
}
Ok(Self { ptr })
}
pub fn from_name(name: &str) -> Result<Self, FMError> {
let name = CString::new(name).map_err(|error| {
FMError::InvalidArgument(format!("adapter name contains NUL byte: {error}"))
})?;
let mut error: *mut c_char = ptr::null_mut();
let ptr = unsafe { ffi::fm_adapter_create_from_name(name.as_ptr(), &mut error) };
if ptr.is_null() {
return Err(from_swift(ffi::status::ADAPTER_INVALID_NAME, error));
}
Ok(Self { ptr })
}
pub fn compile(&self) -> Result<(), FMError> {
let (tx, rx) = mpsc::channel();
let tx_box: Box<mpsc::Sender<Result<(), FMError>>> = Box::new(tx);
let context = Box::into_raw(tx_box).cast::<c_void>();
unsafe { ffi::fm_adapter_compile(self.ptr, context, adapter_compile_trampoline) };
rx.recv().map_err(|_| FMError::Unknown {
code: ffi::status::UNKNOWN,
message: "Swift bridge dropped the adapter compile callback".into(),
})?
}
#[must_use]
pub fn creator_defined_metadata_json(&self) -> String {
let ptr = unsafe { ffi::fm_adapter_metadata_json(self.ptr) };
owned_string(ptr)
}
pub fn creator_defined_metadata(&self) -> Result<Value, FMError> {
serde_json::from_str(&self.creator_defined_metadata_json())
.map_err(|error| FMError::DecodingFailure(error.to_string()))
}
#[must_use]
pub fn compatible_adapter_identifiers(name: &str) -> Vec<String> {
let Ok(name) = CString::new(name) else {
return Vec::new();
};
let ptr = unsafe { ffi::fm_adapter_compatible_identifiers_json(name.as_ptr()) };
serde_json::from_str(&json_string(ptr)).unwrap_or_default()
}
pub fn remove_obsolete_adapters() -> Result<(), FMError> {
let mut error: *mut c_char = ptr::null_mut();
let status = unsafe { ffi::fm_adapter_remove_obsolete(&mut error) };
if status != ffi::status::OK {
return Err(from_swift(status, error));
}
Ok(())
}
}
impl Drop for Adapter {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe { ffi::fm_object_release(self.ptr) };
}
}
}
impl core::fmt::Debug for Adapter {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Adapter").finish_non_exhaustive()
}
}
unsafe extern "C" fn adapter_compile_trampoline(
context: *mut c_void,
response: *mut c_char,
error: *mut c_char,
status: i32,
) {
let tx = Box::from_raw(context.cast::<mpsc::Sender<Result<(), FMError>>>());
if !response.is_null() {
unsafe { ffi::fm_string_free(response) };
}
let result = if status == ffi::status::OK {
Ok(())
} else {
Err(from_swift(status, error))
};
let _ = tx.send(result);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum Availability {
Available,
Unavailable(Unavailability),
}