#![warn(
missing_docs,
rustdoc::bare_urls,
rustdoc::broken_intra_doc_links,
rustdoc::invalid_codeblock_attributes
)]
#![cfg_attr(
not(test),
deny(
clippy::expect_used,
clippy::panic,
clippy::todo,
clippy::unimplemented,
clippy::unwrap_used
)
)]
use std::convert::Infallible;
use std::ffi::{CStr, CString, NulError};
use std::future::Future;
use std::pin::Pin;
use std::task::{Context as StdContext, Poll};
use futures_channel::mpsc;
use futures_core::Stream;
#[cfg(aimx_bridge)]
use std::ffi::{c_char, c_void};
#[cfg(aimx_bridge)]
use std::ptr::null;
#[cfg(aimx_bridge)]
use std::ptr::NonNull;
#[cfg(aimx_bridge)]
use std::sync::Arc;
#[cfg(aimx_bridge)]
use futures_channel::oneshot;
#[cfg(aimx_bridge)]
unsafe extern "C" {
fn fm_availability_reason() -> i32;
fn fm_session_create(instructions: *const c_char) -> *mut c_void;
fn fm_session_create_with_tools(
instructions: *const c_char,
tools_json: *const c_char,
tool_ctx: *mut c_void,
tool_dispatch: extern "C" fn(
*mut c_void,
*const c_char,
*const c_char,
*mut c_void,
extern "C" fn(*mut c_void, *const c_char, *const c_char),
),
) -> *mut c_void;
fn fm_session_destroy(handle: *mut c_void);
fn fm_session_respond(
handle: *mut c_void,
prompt: *const c_char,
temperature: f64,
max_tokens: i64,
ctx: *mut c_void,
callback: extern "C" fn(*mut c_void, *const c_char, *const c_char),
);
fn fm_session_respond_structured(
handle: *mut c_void,
prompt: *const c_char,
schema_json: *const c_char,
temperature: f64,
max_tokens: i64,
ctx: *mut c_void,
callback: extern "C" fn(*mut c_void, *const c_char, *const c_char),
);
fn fm_session_stream(
handle: *mut c_void,
prompt: *const c_char,
temperature: f64,
max_tokens: i64,
ctx: *mut c_void,
on_token: extern "C" fn(*mut c_void, *const c_char),
on_done: extern "C" fn(*mut c_void, *const c_char),
);
}
#[cfg(not(target_family = "wasm"))]
pub trait WasmCompatSend: Send {}
#[cfg(target_family = "wasm")]
pub trait WasmCompatSend {}
#[cfg(not(target_family = "wasm"))]
impl<T> WasmCompatSend for T where T: Send {}
#[cfg(target_family = "wasm")]
impl<T> WasmCompatSend for T {}
#[cfg(not(target_family = "wasm"))]
pub trait WasmCompatSync: Sync {}
#[cfg(target_family = "wasm")]
pub trait WasmCompatSync {}
#[cfg(not(target_family = "wasm"))]
impl<T> WasmCompatSync for T where T: Sync {}
#[cfg(target_family = "wasm")]
impl<T> WasmCompatSync for T {}
macro_rules! string_newtype {
($(#[$meta:meta])* $name:ident) => {
$(#[$meta])*
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
#[serde(transparent)]
pub struct $name(String);
impl $name {
pub fn new(value: impl Into<String>) -> Self {
Self(value.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
pub fn into_string(self) -> String {
self.0
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
impl From<String> for $name {
fn from(value: String) -> Self {
Self(value)
}
}
impl From<&str> for $name {
fn from(value: &str) -> Self {
Self(value.to_owned())
}
}
impl AsRef<str> for $name {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl std::fmt::Display for $name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
impl PartialEq<&str> for $name {
fn eq(&self, other: &&str) -> bool {
self.as_str() == *other
}
}
impl PartialEq<$name> for &str {
fn eq(&self, other: &$name) -> bool {
*self == other.as_str()
}
}
};
}
string_newtype!(
InstructionsText
);
string_newtype!(
PromptText
);
string_newtype!(
ResponseText
);
pub type GeneratedText = ResponseText;
string_newtype!(
GenerationSchemaName
);
pub type ResponseSchemaName = GenerationSchemaName;
pub type SchemaName = GenerationSchemaName;
string_newtype!(
GenerationSchemaPropertyName
);
pub type ResponseFieldName = GenerationSchemaPropertyName;
pub type SchemaPropertyName = GenerationSchemaPropertyName;
string_newtype!(
SchemaDescription
);
string_newtype!(
ToolName
);
string_newtype!(
ToolDescription
);
string_newtype!(
ToolOutput
);
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Prompt {
text: String,
c_text: CString,
}
impl Prompt {
pub fn new(value: impl Into<String>) -> Result<Self, Error> {
let text = value.into();
let c_text = CString::new(text.clone())?;
Ok(Self { text, c_text })
}
pub fn as_str(&self) -> &str {
&self.text
}
#[cfg(aimx_bridge)]
fn as_ptr(&self) -> *const c_char {
self.c_text.as_ptr()
}
}
impl TryFrom<&str> for Prompt {
type Error = Error;
fn try_from(value: &str) -> Result<Self, Self::Error> {
Self::new(value)
}
}
impl TryFrom<String> for Prompt {
type Error = Error;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::new(value)
}
}
impl TryFrom<PromptText> for Prompt {
type Error = Error;
fn try_from(value: PromptText) -> Result<Self, Self::Error> {
Self::new(value.into_string())
}
}
impl AsRef<str> for Prompt {
fn as_ref(&self) -> &str {
self.as_str()
}
}
pub type PromptInput = Prompt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SystemInstructions {
text: String,
c_text: CString,
}
impl SystemInstructions {
pub fn new(value: impl Into<String>) -> Result<Self, Error> {
let text = value.into();
let c_text = CString::new(text.clone())?;
Ok(Self { text, c_text })
}
pub fn empty() -> Self {
Self {
text: String::new(),
c_text: CString::default(),
}
}
pub fn as_str(&self) -> &str {
&self.text
}
#[cfg(aimx_bridge)]
fn as_ptr(&self) -> *const c_char {
self.c_text.as_ptr()
}
}
impl Default for SystemInstructions {
fn default() -> Self {
Self::empty()
}
}
impl TryFrom<&str> for SystemInstructions {
type Error = Error;
fn try_from(value: &str) -> Result<Self, Self::Error> {
Self::new(value)
}
}
impl TryFrom<String> for SystemInstructions {
type Error = Error;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::new(value)
}
}
impl TryFrom<InstructionsText> for SystemInstructions {
type Error = Error;
fn try_from(value: InstructionsText) -> Result<Self, Self::Error> {
Self::new(value.into_string())
}
}
pub type Instructions = SystemInstructions;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Temperature(f64);
impl Temperature {
pub const MIN: f64 = 0.0;
pub const MAX: f64 = 2.0;
pub fn new(value: f64) -> Result<Self, Error> {
if (Self::MIN..=Self::MAX).contains(&value) {
Ok(Self(value))
} else {
Err(Error::InvalidTemperature(value))
}
}
pub fn as_f64(self) -> f64 {
self.0
}
}
impl TryFrom<f64> for Temperature {
type Error = Error;
fn try_from(value: f64) -> Result<Self, Self::Error> {
Self::new(value)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct MaxTokens(usize);
impl MaxTokens {
pub const MAX: usize = i64::MAX as usize;
pub fn new(value: usize) -> Result<Self, Error> {
if value <= Self::MAX {
Ok(Self(value))
} else {
Err(Error::InvalidMaxTokens(value))
}
}
pub fn get(self) -> usize {
self.0
}
}
impl TryFrom<usize> for MaxTokens {
type Error = Error;
fn try_from(value: usize) -> Result<Self, Self::Error> {
Self::new(value)
}
}
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
#[error("{message}")]
pub struct GenerationError {
message: String,
}
impl GenerationError {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
pub fn as_str(&self) -> &str {
&self.message
}
}
impl From<String> for GenerationError {
fn from(message: String) -> Self {
Self::new(message)
}
}
impl From<&str> for GenerationError {
fn from(message: &str) -> Self {
Self::new(message)
}
}
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
#[error("{message}")]
pub struct ToolCallError {
message: String,
}
impl ToolCallError {
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
pub fn as_str(&self) -> &str {
&self.message
}
}
impl From<String> for ToolCallError {
fn from(message: String) -> Self {
Self::new(message)
}
}
impl From<&str> for ToolCallError {
fn from(message: &str) -> Self {
Self::new(message)
}
}
pub type ToolResult = Result<ToolOutput, ToolCallError>;
type ModelTextResult = Result<ResponseText, GenerationError>;
type StreamSender = mpsc::UnboundedSender<ModelTextResult>;
type StreamReceiver = mpsc::UnboundedReceiver<ModelTextResult>;
type ToolHandlerBox = Box<dyn ToolHandler>;
#[cfg(aimx_bridge)]
type ResponseSender = oneshot::Sender<ModelTextResult>;
#[cfg(aimx_bridge)]
type ResponseReceiver = oneshot::Receiver<ModelTextResult>;
#[cfg(aimx_bridge)]
type ToolResultCallback = extern "C" fn(*mut c_void, *const c_char, *const c_char);
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
pub enum AvailabilityError {
#[error("device is not eligible (requires Apple Silicon M1 or later)")]
DeviceNotEligible,
#[error("Apple Intelligence is not enabled in System Settings")]
NotEnabled,
#[error("the on-device model is not ready yet")]
ModelNotReady,
#[error("unknown availability state")]
Unknown,
}
pub type UnavailabilityReason = AvailabilityError;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Apple Intelligence unavailable: {0}")]
Unavailable(#[source] AvailabilityError),
#[error("generation error: {0}")]
Generation(#[from] GenerationError),
#[error("argument contains a null byte: {0}")]
NullByte(#[from] NulError),
#[error("temperature {0} is out of range; expected 0.0 – 2.0")]
InvalidTemperature(f64),
#[error("max_tokens {0} is out of range; expected no more than i64::MAX")]
InvalidMaxTokens(usize),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("tool '{name}' failed: {error}")]
ToolError {
name: ToolName,
#[source]
error: ToolCallError,
},
}
impl From<Infallible> for Error {
fn from(error: Infallible) -> Self {
match error {}
}
}
#[derive(Debug, Default, Clone)]
pub struct GenerationOptions {
temperature: Option<Temperature>,
max_tokens: Option<MaxTokens>,
}
impl GenerationOptions {
pub fn new() -> Self {
Self::default()
}
pub fn temperature(mut self, temperature: Temperature) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_temperature(mut self, temperature: Temperature) -> Self {
self = self.temperature(temperature);
self
}
pub fn try_temperature(self, temperature: f64) -> Result<Self, Error> {
Ok(self.temperature(Temperature::new(temperature)?))
}
pub fn max_tokens(mut self, max_tokens: MaxTokens) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_max_tokens(mut self, max_tokens: MaxTokens) -> Self {
self = self.max_tokens(max_tokens);
self
}
pub fn try_max_tokens(self, max_tokens: usize) -> Result<Self, Error> {
Ok(self.max_tokens(MaxTokens::new(max_tokens)?))
}
pub fn temperature_value(&self) -> Option<Temperature> {
self.temperature
}
pub fn max_tokens_value(&self) -> Option<MaxTokens> {
self.max_tokens
}
pub fn validate(&self) -> Result<(), Error> {
GenerationConfig::try_from(self).map(|_| ())
}
fn validated(&self) -> Result<GenerationConfig, Error> {
GenerationConfig::try_from(self)
}
}
#[derive(Debug, Clone, Copy, Default)]
struct GenerationConfig {
temperature: Option<Temperature>,
max_tokens: Option<MaxTokens>,
}
impl GenerationConfig {
fn ffi_temperature(self) -> f64 {
self.temperature.map(Temperature::as_f64).unwrap_or(-1.0)
}
fn ffi_max_tokens(self) -> i64 {
self.max_tokens
.map(|max_tokens| max_tokens.get() as i64)
.unwrap_or(-1)
}
}
impl TryFrom<&GenerationOptions> for GenerationConfig {
type Error = Error;
fn try_from(options: &GenerationOptions) -> Result<Self, Self::Error> {
Ok(Self {
temperature: options.temperature,
max_tokens: options.max_tokens,
})
}
}
#[derive(Debug, Clone, serde::Serialize)]
#[serde(rename_all = "lowercase")]
pub enum GenerationSchemaPropertyType {
String,
Integer,
Double,
Bool,
}
pub type ResponseFieldType = GenerationSchemaPropertyType;
pub type SchemaPropertyType = GenerationSchemaPropertyType;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum GenerationSchemaPropertyRequirement {
#[default]
Required,
Optional,
}
impl GenerationSchemaPropertyRequirement {
pub fn is_optional(self) -> bool {
matches!(self, Self::Optional)
}
pub fn is_required(self) -> bool {
matches!(self, Self::Required)
}
}
#[derive(Debug, Clone)]
pub struct GenerationSchemaProperty {
pub name: GenerationSchemaPropertyName,
pub description: Option<SchemaDescription>,
pub property_type: GenerationSchemaPropertyType,
pub requirement: GenerationSchemaPropertyRequirement,
}
pub type ResponseField = GenerationSchemaProperty;
pub type SchemaProperty = GenerationSchemaProperty;
impl GenerationSchemaProperty {
pub fn new(
name: impl Into<GenerationSchemaPropertyName>,
property_type: GenerationSchemaPropertyType,
) -> Self {
Self {
name: name.into(),
description: None,
property_type,
requirement: GenerationSchemaPropertyRequirement::Required,
}
}
pub fn description(mut self, description: impl Into<SchemaDescription>) -> Self {
self.description = Some(description.into());
self
}
pub fn optional(mut self) -> Self {
self.requirement = GenerationSchemaPropertyRequirement::Optional;
self
}
pub fn required(mut self) -> Self {
self.requirement = GenerationSchemaPropertyRequirement::Required;
self
}
}
impl serde::Serialize for GenerationSchemaProperty {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let field_count = if self.description.is_some() { 4 } else { 3 };
let mut state = serializer.serialize_struct("GenerationSchemaProperty", field_count)?;
state.serialize_field("name", &self.name)?;
if let Some(description) = &self.description {
state.serialize_field("description", description)?;
}
state.serialize_field("type", &self.property_type)?;
state.serialize_field("optional", &self.requirement.is_optional())?;
state.end()
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct GenerationSchema {
pub name: GenerationSchemaName,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<SchemaDescription>,
pub properties: Vec<GenerationSchemaProperty>,
}
pub type ResponseSchema = GenerationSchema;
pub type Schema = GenerationSchema;
impl GenerationSchema {
pub fn new(name: impl Into<GenerationSchemaName>) -> Self {
Self {
name: name.into(),
description: None,
properties: Vec::new(),
}
}
pub fn description(mut self, description: impl Into<SchemaDescription>) -> Self {
self.description = Some(description.into());
self
}
pub fn property(mut self, property: GenerationSchemaProperty) -> Self {
self.properties.push(property);
self
}
}
pub struct ToolDefinition {
pub name: ToolName,
pub description: ToolDescription,
pub parameters: GenerationSchema,
handler: ToolHandlerBox,
}
impl ToolDefinition {
pub fn new(
name: impl Into<ToolName>,
description: impl Into<ToolDescription>,
parameters: GenerationSchema,
handler: impl Fn(serde_json::Value) -> ToolResult + WasmCompatSend + WasmCompatSync + 'static,
) -> Self {
Self::builder(name, description, parameters).handler(handler)
}
pub fn builder(
name: impl Into<ToolName>,
description: impl Into<ToolDescription>,
parameters: GenerationSchema,
) -> ToolDefinitionBuilder {
ToolDefinitionBuilder {
name: name.into(),
description: description.into(),
parameters,
}
}
pub fn from_handler(
name: impl Into<ToolName>,
description: impl Into<ToolDescription>,
parameters: GenerationSchema,
handler: impl Fn(serde_json::Value) -> ToolResult + WasmCompatSend + WasmCompatSync + 'static,
) -> Self {
Self::new(name, description, parameters, handler)
}
#[cfg(aimx_bridge)]
fn bridge_description(&self) -> serde_json::Value {
serde_json::json!({
"name": self.name.as_str(),
"description": self.description.as_str(),
"properties": &self.parameters.properties,
})
}
}
#[derive(Debug, Clone)]
pub struct ToolDefinitionBuilder {
name: ToolName,
description: ToolDescription,
parameters: GenerationSchema,
}
impl ToolDefinitionBuilder {
pub fn handler(
self,
handler: impl Fn(serde_json::Value) -> ToolResult + WasmCompatSend + WasmCompatSync + 'static,
) -> ToolDefinition {
ToolDefinition {
name: self.name,
description: self.description,
parameters: self.parameters,
handler: Box::new(handler),
}
}
}
pub trait Tool: std::fmt::Debug + WasmCompatSend + WasmCompatSync {
fn name(&self) -> &ToolName;
fn description(&self) -> &ToolDescription;
fn parameters(&self) -> &GenerationSchema;
fn call(&self, args: serde_json::Value) -> ToolResult;
}
impl Tool for ToolDefinition {
fn name(&self) -> &ToolName {
&self.name
}
fn description(&self) -> &ToolDescription {
&self.description
}
fn parameters(&self) -> &GenerationSchema {
&self.parameters
}
fn call(&self, args: serde_json::Value) -> ToolResult {
call_tool_handler(self.handler.as_ref(), args)
}
}
trait ToolHandler: WasmCompatSend + WasmCompatSync {
fn call(&self, args: serde_json::Value) -> ToolResult;
}
impl<F> ToolHandler for F
where
F: Fn(serde_json::Value) -> ToolResult + WasmCompatSend + WasmCompatSync,
{
fn call(&self, args: serde_json::Value) -> ToolResult {
self(args)
}
}
fn call_tool_handler(handler: &dyn ToolHandler, args: serde_json::Value) -> ToolResult {
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| handler.call(args))) {
Ok(result) => result,
Err(payload) => Err(ToolCallError::new(format!(
"tool handler panicked: {}",
panic_payload_message(payload.as_ref())
))),
}
}
fn panic_payload_message(payload: &(dyn std::any::Any + Send)) -> String {
if let Some(message) = payload.downcast_ref::<&'static str>() {
return (*message).to_owned();
}
if let Some(message) = payload.downcast_ref::<String>() {
return message.clone();
}
"non-string panic payload".to_owned()
}
impl std::fmt::Debug for ToolDefinition {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolDefinition")
.field("name", &self.name)
.field("description", &self.description)
.finish_non_exhaustive()
}
}
#[cfg(aimx_bridge)]
struct ToolsContext {
tools: Vec<(ToolName, ToolHandlerBox)>,
}
#[cfg(aimx_bridge)]
impl ToolsContext {
fn from_definitions(tools: Vec<ToolDefinition>) -> Arc<Self> {
Arc::new(Self {
tools: tools
.into_iter()
.map(|tool| (tool.name, tool.handler))
.collect(),
})
}
fn call(&self, name: &str, args: serde_json::Value) -> ToolResult {
let handler = self
.tools
.iter()
.find_map(|(tool_name, handler)| (tool_name.as_str() == name).then_some(handler));
match handler {
Some(handler) => call_tool_handler(handler.as_ref(), args),
None => Err(ToolCallError::new(format!("unknown tool: {name}"))),
}
}
}
const FM_AVAILABLE: i32 = 0;
const FM_DEVICE_NOT_ELIGIBLE: i32 = 1;
const FM_NOT_ENABLED: i32 = 2;
const FM_MODEL_NOT_READY: i32 = 3;
pub fn is_available() -> bool {
availability().is_ok()
}
pub fn availability() -> Result<(), AvailabilityError> {
#[cfg(aimx_bridge)]
{
let code = unsafe { fm_availability_reason() };
match code {
FM_AVAILABLE => Ok(()),
FM_DEVICE_NOT_ELIGIBLE => Err(AvailabilityError::DeviceNotEligible),
FM_NOT_ENABLED => Err(AvailabilityError::NotEnabled),
FM_MODEL_NOT_READY => Err(AvailabilityError::ModelNotReady),
_ => Err(AvailabilityError::Unknown),
}
}
#[cfg(not(aimx_bridge))]
Err(AvailabilityError::DeviceNotEligible)
}
#[derive(Debug, Default, Clone, Copy)]
pub struct AppleIntelligenceModels {
_private: (),
}
impl AppleIntelligenceModels {
pub fn new() -> Self {
Self::default()
}
pub fn availability(&self) -> Result<(), AvailabilityError> {
availability()
}
pub fn is_available(&self) -> bool {
self.availability().is_ok()
}
pub fn session(&self) -> LanguageModelSessionBuilder {
LanguageModelSessionBuilder::new()
}
pub fn agent(&self) -> LanguageModelSessionBuilder {
self.session()
}
pub async fn respond<P>(&self, prompt: P) -> Result<String, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
Ok(self.generate_text(prompt).await?.into_string())
}
pub async fn generate<P>(&self, prompt: P) -> Result<GeneratedText, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.generate_text(prompt).await
}
pub async fn generate_with_options<P>(
&self,
prompt: P,
options: &GenerationOptions,
) -> Result<GeneratedText, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.generate_text_with_options(prompt, options).await
}
pub async fn complete<P>(&self, prompt: P) -> Result<ResponseText, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.generate_text(prompt).await
}
pub async fn generate_text<P>(&self, prompt: P) -> Result<ResponseText, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
let options = GenerationOptions::default();
self.generate_text_with_options(prompt, &options).await
}
pub async fn generate_text_with_options<P>(
&self,
prompt: P,
options: &GenerationOptions,
) -> Result<ResponseText, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
LanguageModel::generate_text_with_options(self, prompt, options.clone()).await
}
pub fn stream_text<P>(&self, prompt: P) -> Result<ResponseStream, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
let options = GenerationOptions::default();
self.stream_text_with_options(prompt, &options)
}
pub fn stream_text_with_options<P>(
&self,
prompt: P,
options: &GenerationOptions,
) -> Result<ResponseStream, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
LanguageModel::stream_text_with_options(self, prompt, options.clone())
}
pub fn stream_generate<P>(&self, prompt: P) -> Result<ResponseStream, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.stream_text(prompt)
}
pub fn stream_generate_with_options<P>(
&self,
prompt: P,
options: &GenerationOptions,
) -> Result<ResponseStream, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.stream_text_with_options(prompt, options)
}
}
pub type SystemLanguageModel = AppleIntelligenceModels;
pub type FoundationModels = AppleIntelligenceModels;
pub type Client = AppleIntelligenceModels;
impl LanguageModel for AppleIntelligenceModels {
fn generate_text_with_options<P>(
&self,
prompt: P,
options: GenerationOptions,
) -> impl Future<Output = Result<ResponseText, Error>> + '_
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
let prompt = prompt.try_into().map_err(Into::into);
let builder = self.session().options(options.clone());
async move {
let prompt = prompt?;
let session = builder.build()?;
session.generate_prompt_with_options(prompt, &options).await
}
}
fn stream_text_with_options<P>(
&self,
prompt: P,
options: GenerationOptions,
) -> Result<ResponseStream, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
let prompt = prompt.try_into().map_err(Into::into)?;
let session = self.session().options(options.clone()).build()?;
session.stream_text_with_options(prompt, &options)
}
}
pub trait LanguageModel {
fn generate_text_with_options<P>(
&self,
prompt: P,
options: GenerationOptions,
) -> impl Future<Output = Result<ResponseText, Error>> + '_
where
P: TryInto<Prompt>,
P::Error: Into<Error>;
fn stream_text_with_options<P>(
&self,
prompt: P,
options: GenerationOptions,
) -> Result<ResponseStream, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>;
}
pub trait CompletionModel: LanguageModel {
fn completion<P>(
&self,
prompt: P,
options: GenerationOptions,
) -> impl Future<Output = Result<ResponseText, Error>> + '_
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.generate_text_with_options(prompt, options)
}
fn stream_completion<P>(
&self,
prompt: P,
options: GenerationOptions,
) -> Result<ResponseStream, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.stream_text_with_options(prompt, options)
}
}
impl<T> CompletionModel for T where T: LanguageModel {}
pub trait GenerateText: LanguageModel {
fn prompt<P>(&self, prompt: P) -> impl Future<Output = Result<ResponseText, Error>> + '_
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
let prompt = prompt.try_into().map_err(Into::into);
async move {
let prompt = prompt?;
self.generate_text_with_options(prompt, GenerationOptions::default())
.await
}
}
}
impl<T> GenerateText for T where T: LanguageModel {}
#[derive(Debug)]
pub struct LanguageModelSessionBuilder {
instructions: InstructionsText,
tools: Vec<ToolDefinition>,
default_options: GenerationOptions,
}
impl LanguageModelSessionBuilder {
pub fn new() -> Self {
Self {
instructions: InstructionsText::new(""),
tools: Vec::new(),
default_options: GenerationOptions::default(),
}
}
pub fn instructions(mut self, instructions: impl Into<InstructionsText>) -> Self {
self.instructions = instructions.into();
self
}
pub fn preamble(self, instructions: impl Into<InstructionsText>) -> Self {
self.instructions(instructions)
}
pub fn tool(mut self, tool: ToolDefinition) -> Self {
self.tools.push(tool);
self
}
pub fn tools(mut self, tools: impl IntoIterator<Item = ToolDefinition>) -> Self {
self.tools.extend(tools);
self
}
pub fn temperature(mut self, temperature: Temperature) -> Self {
self.default_options = self.default_options.temperature(temperature);
self
}
pub fn with_temperature(mut self, temperature: Temperature) -> Self {
self = self.temperature(temperature);
self
}
pub fn try_temperature(mut self, temperature: f64) -> Result<Self, Error> {
self.default_options = self.default_options.try_temperature(temperature)?;
Ok(self)
}
pub fn max_tokens(mut self, max_tokens: MaxTokens) -> Self {
self.default_options = self.default_options.max_tokens(max_tokens);
self
}
pub fn with_max_tokens(mut self, max_tokens: MaxTokens) -> Self {
self = self.max_tokens(max_tokens);
self
}
pub fn try_max_tokens(mut self, max_tokens: usize) -> Result<Self, Error> {
self.default_options = self.default_options.try_max_tokens(max_tokens)?;
Ok(self)
}
pub fn options(mut self, options: GenerationOptions) -> Self {
self.default_options = options;
self
}
pub fn build(self) -> Result<LanguageModelSession, Error> {
let instructions = SystemInstructions::try_from(self.instructions)?;
LanguageModelSession::create(instructions, self.tools, self.default_options)
}
}
impl Default for LanguageModelSessionBuilder {
fn default() -> Self {
Self::new()
}
}
pub type SessionBuilder = LanguageModelSessionBuilder;
pub async fn respond<P>(prompt: P) -> Result<String, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
AppleIntelligenceModels::default().respond(prompt).await
}
pub async fn respond_with_options<P>(
prompt: P,
options: &GenerationOptions,
) -> Result<String, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
Ok(AppleIntelligenceModels::default()
.generate_with_options(prompt, options)
.await?
.into_string())
}
pub async fn generate<P>(prompt: P) -> Result<String, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
respond(prompt).await
}
pub async fn generate_with_options<P>(
prompt: P,
options: &GenerationOptions,
) -> Result<String, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
respond_with_options(prompt, options).await
}
pub fn stream_generate<P>(prompt: P) -> Result<ResponseStream, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
AppleIntelligenceModels::default().stream_generate(prompt)
}
pub fn stream_generate_with_options<P>(
prompt: P,
options: &GenerationOptions,
) -> Result<ResponseStream, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
AppleIntelligenceModels::default().stream_generate_with_options(prompt, options)
}
#[cfg(aimx_bridge)]
#[derive(Debug)]
struct SessionHandle(NonNull<c_void>);
#[cfg(aimx_bridge)]
impl SessionHandle {
fn from_raw(handle: *mut c_void) -> Result<Self, Error> {
NonNull::new(handle)
.map(Self)
.ok_or(Error::Unavailable(AvailabilityError::Unknown))
}
fn as_ptr(&self) -> *mut c_void {
self.0.as_ptr()
}
}
#[cfg(aimx_bridge)]
impl Drop for SessionHandle {
fn drop(&mut self) {
unsafe {
fm_session_destroy(self.as_ptr());
}
}
}
#[cfg(aimx_bridge)]
unsafe impl Send for SessionHandle {}
#[cfg(aimx_bridge)]
unsafe impl Sync for SessionHandle {}
pub struct LanguageModelSession {
default_options: GenerationOptions,
#[cfg(aimx_bridge)]
handle: Arc<SessionHandle>,
#[cfg(aimx_bridge)]
_tools: Option<Arc<ToolsContext>>,
}
impl std::fmt::Debug for LanguageModelSession {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LanguageModelSession")
.field("default_options", &self.default_options)
.finish_non_exhaustive()
}
}
impl LanguageModelSession {
pub fn builder() -> LanguageModelSessionBuilder {
LanguageModelSessionBuilder::new()
}
pub fn new() -> Result<Self, Error> {
Self::builder().build()
}
pub fn with_instructions<I>(instructions: I) -> Result<Self, Error>
where
I: TryInto<SystemInstructions>,
I::Error: Into<Error>,
{
let instructions = instructions.try_into().map_err(Into::into)?;
Self::create(instructions, Vec::new(), GenerationOptions::default())
}
pub fn with_tools<I>(instructions: I, tools: Vec<ToolDefinition>) -> Result<Self, Error>
where
I: TryInto<SystemInstructions>,
I::Error: Into<Error>,
{
let instructions = instructions.try_into().map_err(Into::into)?;
Self::create(instructions, tools, GenerationOptions::default())
}
fn create(
instructions: SystemInstructions,
tools: Vec<ToolDefinition>,
default_options: GenerationOptions,
) -> Result<Self, Error> {
default_options.validate()?;
availability().map_err(Error::Unavailable)?;
#[cfg(aimx_bridge)]
{
Self::create_bridge_session(instructions, tools, default_options)
}
#[cfg(not(aimx_bridge))]
{
let _ = (instructions, tools, default_options);
Err(Error::Unavailable(AvailabilityError::DeviceNotEligible))
}
}
#[cfg(aimx_bridge)]
fn create_bridge_session(
instructions: SystemInstructions,
tools: Vec<ToolDefinition>,
default_options: GenerationOptions,
) -> Result<Self, Error> {
if tools.is_empty() {
return Self::create_plain_bridge_session(instructions, default_options);
}
Self::create_tool_bridge_session(instructions, tools, default_options)
}
#[cfg(aimx_bridge)]
fn create_plain_bridge_session(
instructions: SystemInstructions,
default_options: GenerationOptions,
) -> Result<Self, Error> {
let handle = unsafe { fm_session_create(instructions.as_ptr()) };
Ok(Self {
default_options,
handle: Arc::new(SessionHandle::from_raw(handle)?),
_tools: None,
})
}
#[cfg(aimx_bridge)]
fn create_tool_bridge_session(
instructions: SystemInstructions,
tools: Vec<ToolDefinition>,
default_options: GenerationOptions,
) -> Result<Self, Error> {
let tool_descriptions = tools
.iter()
.map(ToolDefinition::bridge_description)
.collect::<Vec<_>>();
let c_tools_json = CString::new(serde_json::to_vec(&tool_descriptions)?)?;
let tools_ctx = ToolsContext::from_definitions(tools);
let tool_ctx_ptr = Arc::as_ptr(&tools_ctx) as *mut c_void;
let handle = unsafe {
fm_session_create_with_tools(
instructions.as_ptr(),
c_tools_json.as_ptr(),
tool_ctx_ptr,
tool_dispatch,
)
};
Ok(Self {
default_options,
handle: Arc::new(SessionHandle::from_raw(handle)?),
_tools: Some(tools_ctx),
})
}
pub async fn respond<P>(&self, prompt: P) -> Result<String, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
Ok(self.respond_to(prompt).await?.into_string())
}
pub async fn respond_to<P>(&self, prompt: P) -> Result<ResponseText, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.respond_to_with_options(prompt, &self.default_options)
.await
}
pub async fn complete<P>(&self, prompt: P) -> Result<ResponseText, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.respond_to(prompt).await
}
pub async fn generate<P>(&self, prompt: P) -> Result<GeneratedText, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.respond_to(prompt).await
}
pub async fn generate_text<P>(&self, prompt: P) -> Result<ResponseText, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.respond_to(prompt).await
}
pub async fn respond_with_options<P>(
&self,
prompt: P,
options: &GenerationOptions,
) -> Result<String, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
Ok(self
.respond_to_with_options(prompt, options)
.await?
.into_string())
}
pub async fn respond_to_with_options<P>(
&self,
prompt: P,
options: &GenerationOptions,
) -> Result<ResponseText, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
let prompt = prompt.try_into().map_err(Into::into)?;
self.generate_prompt_with_options(prompt, options).await
}
pub async fn complete_with_options<P>(
&self,
prompt: P,
options: &GenerationOptions,
) -> Result<ResponseText, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.respond_to_with_options(prompt, options).await
}
pub async fn generate_with_options<P>(
&self,
prompt: P,
options: &GenerationOptions,
) -> Result<GeneratedText, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.respond_to_with_options(prompt, options).await
}
pub async fn generate_text_with_options<P>(
&self,
prompt: P,
options: &GenerationOptions,
) -> Result<ResponseText, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.respond_to_with_options(prompt, options).await
}
async fn generate_prompt_with_options(
&self,
prompt: Prompt,
options: &GenerationOptions,
) -> Result<ResponseText, Error> {
let config = options.validated()?;
#[cfg(aimx_bridge)]
{
let handle = Arc::clone(&self.handle);
let (tx, rx) = oneshot::channel::<ModelTextResult>();
let ctx = Box::into_raw(Box::new(ResponseContext {
tx,
_handle: handle,
})) as *mut c_void;
unsafe {
fm_session_respond(
self.handle.as_ptr(),
prompt.as_ptr(),
config.ffi_temperature(),
config.ffi_max_tokens(),
ctx,
respond_callback,
);
}
receive_response(rx).await
}
#[cfg(not(aimx_bridge))]
{
let _ = (prompt, config);
Err(Error::Unavailable(AvailabilityError::DeviceNotEligible))
}
}
pub async fn respond_as<T, P>(&self, prompt: P, schema: &GenerationSchema) -> Result<T, Error>
where
T: serde::de::DeserializeOwned,
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.respond_generating(prompt, schema).await
}
pub async fn respond_generating<T, P>(
&self,
prompt: P,
schema: &GenerationSchema,
) -> Result<T, Error>
where
T: serde::de::DeserializeOwned,
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.respond_generating_with_options(prompt, schema, &self.default_options)
.await
}
pub async fn generate_object<T, P>(
&self,
prompt: P,
schema: &GenerationSchema,
) -> Result<T, Error>
where
T: serde::de::DeserializeOwned,
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.respond_generating(prompt, schema).await
}
pub async fn respond_as_with_options<T, P>(
&self,
prompt: P,
schema: &GenerationSchema,
options: &GenerationOptions,
) -> Result<T, Error>
where
T: serde::de::DeserializeOwned,
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.respond_generating_with_options(prompt, schema, options)
.await
}
pub async fn respond_generating_with_options<T, P>(
&self,
prompt: P,
schema: &GenerationSchema,
options: &GenerationOptions,
) -> Result<T, Error>
where
T: serde::de::DeserializeOwned,
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
let prompt = prompt.try_into().map_err(Into::into)?;
let config = options.validated()?;
self.respond_generating_prompt_with_config(prompt, schema, config)
.await
}
pub async fn generate_object_with_options<T, P>(
&self,
prompt: P,
schema: &GenerationSchema,
options: &GenerationOptions,
) -> Result<T, Error>
where
T: serde::de::DeserializeOwned,
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.respond_generating_with_options(prompt, schema, options)
.await
}
async fn respond_generating_prompt_with_config<T>(
&self,
prompt: Prompt,
schema: &GenerationSchema,
config: GenerationConfig,
) -> Result<T, Error>
where
T: serde::de::DeserializeOwned,
{
#[cfg(aimx_bridge)]
{
let handle = Arc::clone(&self.handle);
let (tx, rx) = oneshot::channel::<ModelTextResult>();
let ctx = Box::into_raw(Box::new(ResponseContext {
tx,
_handle: handle,
})) as *mut c_void;
let c_schema_json = CString::new(serde_json::to_vec(schema)?)?;
unsafe {
fm_session_respond_structured(
self.handle.as_ptr(),
prompt.as_ptr(),
c_schema_json.as_ptr(),
config.ffi_temperature(),
config.ffi_max_tokens(),
ctx,
respond_callback,
);
}
let json = receive_response(rx).await?.into_string();
serde_json::from_str(&json).map_err(Error::from)
}
#[cfg(not(aimx_bridge))]
{
let _ = (prompt, schema, config);
Err(Error::Unavailable(AvailabilityError::DeviceNotEligible))
}
}
pub fn stream<P>(&self, prompt: P) -> Result<ResponseStream, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.stream_response(prompt)
}
pub fn stream_response<P>(&self, prompt: P) -> Result<ResponseStream, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.stream_response_with_options(prompt, &self.default_options)
}
pub fn stream_generate<P>(&self, prompt: P) -> Result<ResponseStream, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.stream_response(prompt)
}
pub fn stream_text<P>(&self, prompt: P) -> Result<ResponseStream, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.stream_response(prompt)
}
pub fn stream_with_options<P>(
&self,
prompt: P,
options: &GenerationOptions,
) -> Result<ResponseStream, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.stream_response_with_options(prompt, options)
}
pub fn stream_response_with_options<P>(
&self,
prompt: P,
options: &GenerationOptions,
) -> Result<ResponseStream, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
let prompt = prompt.try_into().map_err(Into::into)?;
let config = options.validated()?;
self.stream_prompt_with_config(prompt, config)
}
pub fn stream_generate_with_options<P>(
&self,
prompt: P,
options: &GenerationOptions,
) -> Result<ResponseStream, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.stream_response_with_options(prompt, options)
}
pub fn stream_text_with_options<P>(
&self,
prompt: P,
options: &GenerationOptions,
) -> Result<ResponseStream, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
self.stream_response_with_options(prompt, options)
}
fn stream_prompt_with_config(
&self,
prompt: Prompt,
config: GenerationConfig,
) -> Result<ResponseStream, Error> {
#[cfg(aimx_bridge)]
{
let handle = Arc::clone(&self.handle);
let (tx, rx) = mpsc::unbounded::<ModelTextResult>();
let ctx = Box::into_raw(Box::new(StreamContext {
tx,
_handle: handle,
})) as *mut c_void;
unsafe {
fm_session_stream(
self.handle.as_ptr(),
prompt.as_ptr(),
config.ffi_temperature(),
config.ffi_max_tokens(),
ctx,
stream_token_callback,
stream_done_callback,
);
}
Ok(ResponseStream { rx })
}
#[cfg(not(aimx_bridge))]
{
let _ = (prompt, config);
Err(Error::Unavailable(AvailabilityError::DeviceNotEligible))
}
}
}
pub type Session = LanguageModelSession;
impl LanguageModel for LanguageModelSession {
fn generate_text_with_options<P>(
&self,
prompt: P,
options: GenerationOptions,
) -> impl Future<Output = Result<ResponseText, Error>> + '_
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
let prompt = prompt.try_into().map_err(Into::into);
async move {
let prompt = prompt?;
self.generate_prompt_with_options(prompt, &options).await
}
}
fn stream_text_with_options<P>(
&self,
prompt: P,
options: GenerationOptions,
) -> Result<ResponseStream, Error>
where
P: TryInto<Prompt>,
P::Error: Into<Error>,
{
LanguageModelSession::stream_text_with_options(self, prompt, &options)
}
}
pub struct ResponseStream {
rx: StreamReceiver,
}
impl Stream for ResponseStream {
type Item = Result<ResponseText, Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut StdContext<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.rx)
.poll_next(cx)
.map(|opt| opt.map(|r| r.map_err(Error::from)))
}
}
#[cfg(aimx_bridge)]
async fn receive_response(receiver: ResponseReceiver) -> Result<ResponseText, Error> {
receiver
.await
.map_err(|_| GenerationError::new("session was dropped before responding"))?
.map_err(Error::from)
}
#[cfg(aimx_bridge)]
struct ResponseContext {
tx: ResponseSender,
_handle: Arc<SessionHandle>,
}
#[cfg(aimx_bridge)]
extern "C" fn respond_callback(ctx: *mut c_void, result: *const c_char, error: *const c_char) {
let context = unsafe { Box::from_raw(ctx as *mut ResponseContext) };
if let Some(msg) = callback_owned_text(error) {
context.tx.send(Err(GenerationError::from(msg))).ok();
} else if let Some(text) = callback_owned_text(result) {
context.tx.send(Ok(ResponseText::from(text))).ok();
}
}
#[cfg(aimx_bridge)]
fn callback_owned_text(ptr: *const c_char) -> Option<String> {
if ptr.is_null() {
return None;
}
Some(unsafe { CStr::from_ptr(ptr).to_string_lossy().into_owned() })
}
#[cfg(aimx_bridge)]
struct StreamContext {
tx: StreamSender,
_handle: Arc<SessionHandle>,
}
#[cfg(aimx_bridge)]
extern "C" fn stream_token_callback(ctx: *mut c_void, token: *const c_char) {
let stream_ctx = unsafe { &*(ctx as *const StreamContext) };
let Some(text) = callback_owned_text(token) else {
return;
};
stream_ctx
.tx
.unbounded_send(Ok(ResponseText::from(text)))
.ok();
}
#[cfg(aimx_bridge)]
extern "C" fn stream_done_callback(ctx: *mut c_void, error: *const c_char) {
let stream_ctx = unsafe { Box::from_raw(ctx as *mut StreamContext) };
if let Some(msg) = callback_owned_text(error) {
stream_ctx
.tx
.unbounded_send(Err(GenerationError::from(msg)))
.ok();
}
}
#[cfg(aimx_bridge)]
extern "C" fn tool_dispatch(
ctx: *mut c_void,
name_ptr: *const c_char,
args_ptr: *const c_char,
result_ctx: *mut c_void,
result_cb: ToolResultCallback,
) {
let result = dispatch_tool_call(ctx, name_ptr, args_ptr);
send_tool_result(result_ctx, result_cb, result);
}
#[cfg(aimx_bridge)]
fn dispatch_tool_call(
ctx: *mut c_void,
name_ptr: *const c_char,
args_ptr: *const c_char,
) -> ToolResult {
if ctx.is_null() {
return Err(ToolCallError::new("missing tool context"));
}
let tools = unsafe { &*(ctx as *const ToolsContext) };
with_callback_text(name_ptr, "tool name", |name| {
let args = parse_tool_args(args_ptr)?;
tools.call(name, args)
})?
}
#[cfg(aimx_bridge)]
fn parse_tool_args(args_ptr: *const c_char) -> Result<serde_json::Value, ToolCallError> {
if args_ptr.is_null() {
return Err(ToolCallError::new("missing tool arguments"));
}
let args = unsafe { CStr::from_ptr(args_ptr) };
serde_json::from_slice(args.to_bytes())
.map_err(|error| ToolCallError::new(format!("invalid tool args JSON: {error}")))
}
#[cfg(aimx_bridge)]
fn with_callback_text<R>(
ptr: *const c_char,
label: &str,
f: impl FnOnce(&str) -> R,
) -> Result<R, ToolCallError> {
if ptr.is_null() {
return Err(ToolCallError::new(format!("missing {label}")));
}
let text = unsafe { CStr::from_ptr(ptr).to_string_lossy() };
Ok(f(text.as_ref()))
}
#[cfg(aimx_bridge)]
fn send_tool_result(result_ctx: *mut c_void, result_cb: ToolResultCallback, result: ToolResult) {
match result {
Ok(output) => send_tool_output(result_ctx, result_cb, output),
Err(error) => send_tool_error(result_ctx, result_cb, error.as_str()),
}
}
#[cfg(aimx_bridge)]
fn send_tool_output(result_ctx: *mut c_void, result_cb: ToolResultCallback, output: ToolOutput) {
match CString::new(output.into_string()) {
Ok(c_output) => result_cb(result_ctx, c_output.as_ptr(), null()),
Err(error) => send_tool_error(
result_ctx,
result_cb,
&format!("tool result contains a null byte: {error}"),
),
}
}
#[cfg(aimx_bridge)]
fn send_tool_error(result_ctx: *mut c_void, result_cb: ToolResultCallback, message: &str) {
match CString::new(message) {
Ok(c_error) => result_cb(result_ctx, null(), c_error.as_ptr()),
Err(_) => result_cb(
result_ctx,
null(),
TOOL_ERROR_ENCODING_FAILURE.as_ptr().cast::<c_char>(),
),
}
}
#[cfg(aimx_bridge)]
const TOOL_ERROR_ENCODING_FAILURE: &[u8] = b"tool error contains a null byte\0";
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
#[test]
fn test_is_available_returns_without_panic() {
let _ = is_available();
}
#[test]
fn test_availability_result_is_consistent() {
let avail = availability();
assert_eq!(is_available(), avail.is_ok());
}
#[test]
fn test_options_default_is_valid() -> Result<(), Error> {
let opts = GenerationOptions::default();
assert!(opts.validate().is_ok());
let config = opts.validated()?;
assert_eq!(config.ffi_temperature(), -1.0);
assert_eq!(config.ffi_max_tokens(), -1);
Ok(())
}
#[test]
fn test_options_valid_temperature() -> Result<(), Error> {
for (temp, expected_ffi) in [(0.0_f64, 0.0), (1.0, 1.0), (2.0, 2.0)] {
let opts = GenerationOptions::new().try_temperature(temp)?;
assert!(
opts.validate().is_ok(),
"temperature {temp} should be valid"
);
let config = opts.validated()?;
assert_eq!(config.ffi_temperature(), expected_ffi);
}
Ok(())
}
#[test]
fn test_options_invalid_temperature() {
for temp in [-f64::INFINITY, -0.1_f64, 2.001, f64::INFINITY, f64::NAN] {
assert!(
GenerationOptions::new().try_temperature(temp).is_err(),
"temperature {temp} should be invalid"
);
}
}
#[test]
fn test_options_invalid_max_tokens() {
if usize::BITS < i64::BITS {
return;
}
let invalid = MaxTokens::MAX + 1;
assert!(matches!(
GenerationOptions::new().try_max_tokens(invalid),
Err(Error::InvalidMaxTokens(value)) if value == invalid
));
assert!(matches!(
MaxTokens::new(invalid),
Err(Error::InvalidMaxTokens(value)) if value == invalid
));
}
#[test]
fn test_session_creation_fails_gracefully_when_unavailable() {
if is_available() {
return; }
assert!(matches!(
LanguageModelSession::new(),
Err(Error::Unavailable(_))
));
}
#[test]
fn test_null_byte_in_prompt_returns_error() {
let result = futures_executor::block_on(respond("hello\0world"));
assert!(matches!(result, Err(Error::NullByte(_))));
}
#[test]
fn test_prompt_and_instruction_inputs_reject_null_bytes_before_availability() {
let prompt = Prompt::try_from("hello\0world");
let instructions = SystemInstructions::try_from("system\0prompt");
assert!(matches!(prompt, Err(Error::NullByte(_))));
assert!(matches!(instructions, Err(Error::NullByte(_))));
}
#[test]
fn test_session_builder_validates_options_before_availability() {
let result = AppleIntelligenceModels::default()
.session()
.instructions("Valid system prompt")
.try_temperature(2.5)
.and_then(LanguageModelSessionBuilder::build);
assert!(matches!(result, Err(Error::InvalidTemperature(value)) if value == 2.5));
}
#[test]
fn test_options_expose_typed_values() -> Result<(), Error> {
let temperature = Temperature::new(0.4)?;
let max_tokens = MaxTokens::new(128)?;
let opts = GenerationOptions::new()
.temperature(temperature)
.max_tokens(max_tokens);
assert_eq!(opts.temperature_value(), Some(temperature));
assert_eq!(opts.max_tokens_value(), Some(max_tokens));
Ok(())
}
#[test]
fn test_schema_property_requirement_serializes_as_optional_flag() -> Result<(), Error> {
let schema = GenerationSchema::new("Answer")
.property(GenerationSchemaProperty::new(
"required",
GenerationSchemaPropertyType::String,
))
.property(
GenerationSchemaProperty::new("maybe", GenerationSchemaPropertyType::String)
.optional(),
);
let json = serde_json::to_value(schema)?;
assert_eq!(json["properties"][0]["optional"], false);
assert_eq!(json["properties"][1]["optional"], true);
assert!(GenerationSchemaPropertyRequirement::Required.is_required());
assert!(GenerationSchemaPropertyRequirement::Optional.is_optional());
Ok(())
}
#[test]
fn test_session_builder_validates_instructions_before_availability() {
let result = AppleIntelligenceModels::default()
.session()
.instructions("bad\0instructions")
.build();
assert!(matches!(result, Err(Error::NullByte(_))));
}
#[test]
fn test_string_newtypes_round_trip_through_display_and_inner_value() {
let cases = [
PromptText::new("prompt").into_string(),
ResponseText::new("response").to_string(),
GenerationSchemaName::new("GenerationSchema").to_string(),
GenerationSchemaPropertyName::new("field").to_string(),
ToolName::new("tool").to_string(),
ToolOutput::new("output").to_string(),
];
assert_eq!(
cases,
[
"prompt",
"response",
"GenerationSchema",
"field",
"tool",
"output"
]
);
}
#[test]
fn test_schema_builder() -> Result<(), Error> {
let schema = GenerationSchema::new("Point")
.description("A 2D point")
.property(
GenerationSchemaProperty::new("x", GenerationSchemaPropertyType::Double)
.description("X axis"),
)
.property(GenerationSchemaProperty::new(
"y",
GenerationSchemaPropertyType::Double,
));
assert_eq!(schema.name, "Point");
assert_eq!(schema.properties.len(), 2);
let json = serde_json::to_string(&schema)?;
assert!(json.contains("\"x\""));
assert!(json.contains("\"double\""));
Ok(())
}
#[test]
fn test_tool_definition_builder() -> Result<(), ToolCallError> {
let tool = ToolDefinition::builder(
"add",
"Add two numbers",
GenerationSchema::new("AddArgs")
.property(GenerationSchemaProperty::new(
"a",
GenerationSchemaPropertyType::Double,
))
.property(GenerationSchemaProperty::new(
"b",
GenerationSchemaPropertyType::Double,
)),
)
.handler(|args| {
let a = args["a"].as_f64().unwrap_or(0.0);
let b = args["b"].as_f64().unwrap_or(0.0);
Ok(ToolOutput::from(format!("{}", a + b)))
});
assert_eq!(tool.name, "add");
let result = tool.call(serde_json::json!({"a": 3.0, "b": 4.0}));
assert_eq!(result?, "7");
Ok(())
}
#[test]
fn test_tool_definition_new_and_trait_boundary() -> Result<(), ToolCallError> {
let tool = ToolDefinition::new(
"echo",
"Echo an input string",
GenerationSchema::new("EchoArgs").property(GenerationSchemaProperty::new(
"value",
GenerationSchemaPropertyType::String,
)),
|args| {
args["value"]
.as_str()
.map(ToolOutput::from)
.ok_or_else(|| ToolCallError::new("missing value"))
},
);
assert_eq!(tool.name().as_str(), "echo");
assert_eq!(tool.description().as_str(), "Echo an input string");
assert_eq!(tool.parameters().name, "EchoArgs");
assert_eq!(tool.call(serde_json::json!({"value": "hello"}))?, "hello");
assert!(tool.call(serde_json::json!({})).is_err());
Ok(())
}
#[test]
fn test_tool_handler_panic_returns_tool_error() {
let tool = ToolDefinition::new(
"panic_tool",
"Tool that fails inside user code",
GenerationSchema::new("PanicArgs"),
|_| -> ToolResult {
std::panic::resume_unwind(Box::new("boom"));
},
);
let error = tool.call(serde_json::json!({})).err();
assert!(
error
.as_ref()
.is_some_and(|error| error.as_str().contains("tool handler panicked: boom")),
"expected panic to be converted into ToolCallError"
);
}
proptest! {
#[test]
fn proptest_prompt_input_matches_c_string_null_boundary(input in ".*") {
let result = Prompt::try_from(input.as_str());
if input.contains('\0') {
prop_assert!(matches!(result, Err(Error::NullByte(_))));
} else {
match result {
Ok(prompt) => prop_assert_eq!(prompt.as_str(), input.as_str()),
Err(error) => prop_assert!(false, "unexpected prompt error: {error}"),
}
}
}
#[test]
fn proptest_instructions_match_c_string_null_boundary(input in ".*") {
let result = SystemInstructions::try_from(input.as_str());
if input.contains('\0') {
prop_assert!(matches!(result, Err(Error::NullByte(_))));
} else {
match result {
Ok(instructions) => prop_assert_eq!(instructions.as_str(), input.as_str()),
Err(error) => prop_assert!(false, "unexpected instructions error: {error}"),
}
}
}
#[test]
fn proptest_temperature_validation_matches_closed_interval(temp in any::<f64>()) {
let result = Temperature::new(temp);
if (Temperature::MIN..=Temperature::MAX).contains(&temp) {
match result {
Ok(temperature) => prop_assert_eq!(temperature.as_f64(), temp),
Err(error) => prop_assert!(false, "unexpected temperature error: {error}"),
}
} else {
prop_assert!(matches!(result, Err(Error::InvalidTemperature(value)) if value.to_bits() == temp.to_bits()));
}
}
#[test]
fn proptest_generation_options_preserve_max_tokens(max_tokens in any::<usize>()) {
if max_tokens <= MaxTokens::MAX {
match GenerationOptions::new().try_max_tokens(max_tokens) {
Ok(opts) => match opts.validated() {
Ok(config) => prop_assert_eq!(config.ffi_max_tokens(), max_tokens as i64),
Err(error) => prop_assert!(false, "unexpected options error: {error}"),
},
Err(error) => prop_assert!(false, "unexpected max token error: {error}"),
}
} else {
prop_assert!(matches!(
GenerationOptions::new().try_max_tokens(max_tokens),
Err(Error::InvalidMaxTokens(value)) if value == max_tokens
));
}
}
}
#[test]
#[ignore = "requires Apple Intelligence (macOS 26+, Apple Silicon, AI enabled)"]
fn test_simple_respond() -> Result<(), Error> {
let response =
futures_executor::block_on(respond("Reply with only the number: what is 2 + 2?"))?;
assert!(
response.as_str().contains('4'),
"expected '4' in: {response:?}"
);
Ok(())
}
#[test]
#[ignore = "requires Apple Intelligence"]
fn test_respond_with_low_temperature() -> Result<(), Error> {
let opts = GenerationOptions::new().temperature(Temperature::new(0.0)?);
let r = futures_executor::block_on(respond_with_options(
"Reply with only the word: capital of France?",
&opts,
))?;
assert!(
r.as_str().to_lowercase().contains("paris"),
"expected Paris in: {r:?}"
);
Ok(())
}
#[test]
#[ignore = "requires Apple Intelligence"]
fn test_multi_turn_session() -> Result<(), Error> {
let session = LanguageModelSession::with_instructions(
"Reply to every message with exactly one word.",
)?;
let r1 = futures_executor::block_on(session.respond_to("Say hello."))?;
let r2 = futures_executor::block_on(session.respond_to("Say goodbye."))?;
assert!(!r1.is_empty(), "first response was empty");
assert!(!r2.is_empty(), "second response was empty");
Ok(())
}
#[test]
#[ignore = "requires Apple Intelligence"]
fn test_streaming_yields_chunks() -> Result<(), Error> {
let session = LanguageModelSession::new()?;
let stream = session.stream_response("Count: one two three")?;
let chunks: Vec<ResponseText> =
futures_executor::block_on_stream(stream).collect::<Result<_, _>>()?;
assert!(!chunks.is_empty(), "stream produced no chunks");
let full = chunks
.into_iter()
.map(ResponseText::into_string)
.collect::<Vec<_>>()
.join("");
assert!(!full.is_empty(), "concatenated response was empty");
Ok(())
}
#[test]
#[ignore = "requires Apple Intelligence"]
fn test_structured_generation() -> Result<(), Error> {
use serde::Deserialize;
#[derive(Debug, Deserialize)]
struct MathAnswer {
value: f64,
explanation: String,
}
let session = LanguageModelSession::new()?;
let schema = GenerationSchema::new("MathAnswer")
.description("A numeric answer with a brief explanation")
.property(
GenerationSchemaProperty::new("value", GenerationSchemaPropertyType::Double)
.description("The numeric result"),
)
.property(
GenerationSchemaProperty::new("explanation", GenerationSchemaPropertyType::String)
.description("One-sentence explanation"),
);
let answer: MathAnswer =
futures_executor::block_on(session.respond_generating("What is 6 × 7?", &schema))?;
assert!(
(answer.value - 42.0).abs() < 0.5,
"expected 42, got {}",
answer.value
);
assert!(!answer.explanation.is_empty(), "explanation was empty");
Ok(())
}
#[test]
#[ignore = "requires Apple Intelligence"]
fn test_tool_calling() -> Result<(), Error> {
let tool = ToolDefinition::builder(
"add_numbers",
"Add two numbers together and return the sum",
GenerationSchema::new("AddArgs")
.property(
GenerationSchemaProperty::new("a", GenerationSchemaPropertyType::Double)
.description("First number"),
)
.property(
GenerationSchemaProperty::new("b", GenerationSchemaPropertyType::Double)
.description("Second number"),
),
)
.handler(|args| {
let a = args["a"].as_f64().unwrap_or(0.0);
let b = args["b"].as_f64().unwrap_or(0.0);
Ok(ToolOutput::from(format!("{}", a + b)))
});
let session = LanguageModelSession::with_tools(
"You are a calculator. Use the add_numbers tool when asked to add.",
vec![tool],
)?;
let response = futures_executor::block_on(session.respond_to("What is 15 + 27?"))?;
assert!(
response.as_str().contains("42"),
"expected 42 in response: {response:?}"
);
Ok(())
}
}