1#![warn(
208 missing_docs,
209 rustdoc::bare_urls,
210 rustdoc::broken_intra_doc_links,
211 rustdoc::invalid_codeblock_attributes
212)]
213#![cfg_attr(
214 not(test),
215 deny(
216 clippy::expect_used,
217 clippy::panic,
218 clippy::todo,
219 clippy::unimplemented,
220 clippy::unwrap_used
221 )
222)]
223
224use std::convert::Infallible;
225use std::ffi::{CStr, CString, NulError};
226use std::future::Future;
227use std::pin::Pin;
228use std::task::{Context as StdContext, Poll};
229
230use futures_channel::mpsc;
231use futures_core::Stream;
232
233#[cfg(aimx_bridge)]
234use std::ffi::{c_char, c_void};
235
236#[cfg(aimx_bridge)]
237use std::ptr::null;
238
239#[cfg(aimx_bridge)]
240use std::ptr::NonNull;
241
242#[cfg(aimx_bridge)]
243use std::sync::Arc;
244
245#[cfg(aimx_bridge)]
246use futures_channel::oneshot;
247
248#[cfg(aimx_bridge)]
251unsafe extern "C" {
252 fn fm_availability_reason() -> i32;
253 fn fm_session_create(instructions: *const c_char) -> *mut c_void;
254 fn fm_session_create_with_tools(
255 instructions: *const c_char,
256 tools_json: *const c_char,
257 tool_ctx: *mut c_void,
258 tool_dispatch: extern "C" fn(
259 *mut c_void,
260 *const c_char,
261 *const c_char,
262 *mut c_void,
263 extern "C" fn(*mut c_void, *const c_char, *const c_char),
264 ),
265 ) -> *mut c_void;
266 fn fm_session_destroy(handle: *mut c_void);
267 fn fm_session_respond(
268 handle: *mut c_void,
269 prompt: *const c_char,
270 temperature: f64,
271 max_tokens: i64,
272 ctx: *mut c_void,
273 callback: extern "C" fn(*mut c_void, *const c_char, *const c_char),
274 );
275 fn fm_session_respond_structured(
276 handle: *mut c_void,
277 prompt: *const c_char,
278 schema_json: *const c_char,
279 temperature: f64,
280 max_tokens: i64,
281 ctx: *mut c_void,
282 callback: extern "C" fn(*mut c_void, *const c_char, *const c_char),
283 );
284 fn fm_session_stream(
285 handle: *mut c_void,
286 prompt: *const c_char,
287 temperature: f64,
288 max_tokens: i64,
289 ctx: *mut c_void,
290 on_token: extern "C" fn(*mut c_void, *const c_char),
291 on_done: extern "C" fn(*mut c_void, *const c_char),
292 );
293}
294
295#[cfg(not(target_family = "wasm"))]
299pub trait WasmCompatSend: Send {}
300
301#[cfg(target_family = "wasm")]
303pub trait WasmCompatSend {}
304
305#[cfg(not(target_family = "wasm"))]
306impl<T> WasmCompatSend for T where T: Send {}
307
308#[cfg(target_family = "wasm")]
309impl<T> WasmCompatSend for T {}
310
311#[cfg(not(target_family = "wasm"))]
313pub trait WasmCompatSync: Sync {}
314
315#[cfg(target_family = "wasm")]
317pub trait WasmCompatSync {}
318
319#[cfg(not(target_family = "wasm"))]
320impl<T> WasmCompatSync for T where T: Sync {}
321
322#[cfg(target_family = "wasm")]
323impl<T> WasmCompatSync for T {}
324
325macro_rules! string_newtype {
326 ($(#[$meta:meta])* $name:ident) => {
327 $(#[$meta])*
328 #[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
329 #[serde(transparent)]
330 pub struct $name(String);
331
332 impl $name {
333 pub fn new(value: impl Into<String>) -> Self {
335 Self(value.into())
336 }
337
338 pub fn as_str(&self) -> &str {
340 &self.0
341 }
342
343 pub fn into_string(self) -> String {
345 self.0
346 }
347
348 pub fn is_empty(&self) -> bool {
350 self.0.is_empty()
351 }
352 }
353
354 impl From<String> for $name {
355 fn from(value: String) -> Self {
356 Self(value)
357 }
358 }
359
360 impl From<&str> for $name {
361 fn from(value: &str) -> Self {
362 Self(value.to_owned())
363 }
364 }
365
366 impl AsRef<str> for $name {
367 fn as_ref(&self) -> &str {
368 self.as_str()
369 }
370 }
371
372 impl std::fmt::Display for $name {
373 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374 f.write_str(self.as_str())
375 }
376 }
377
378 impl PartialEq<&str> for $name {
379 fn eq(&self, other: &&str) -> bool {
380 self.as_str() == *other
381 }
382 }
383
384 impl PartialEq<$name> for &str {
385 fn eq(&self, other: &$name) -> bool {
386 *self == other.as_str()
387 }
388 }
389 };
390}
391
392string_newtype!(
393 InstructionsText
395);
396string_newtype!(
397 PromptText
399);
400string_newtype!(
401 ResponseText
403);
404pub type GeneratedText = ResponseText;
406
407string_newtype!(
408 GenerationSchemaName
410);
411pub type ResponseSchemaName = GenerationSchemaName;
413pub type SchemaName = GenerationSchemaName;
415
416string_newtype!(
417 GenerationSchemaPropertyName
419);
420pub type ResponseFieldName = GenerationSchemaPropertyName;
422pub type SchemaPropertyName = GenerationSchemaPropertyName;
424
425string_newtype!(
426 SchemaDescription
428);
429string_newtype!(
430 ToolName
432);
433string_newtype!(
434 ToolDescription
436);
437string_newtype!(
438 ToolOutput
440);
441
442#[derive(Debug, Clone, PartialEq, Eq)]
444pub struct Prompt {
445 text: String,
446 c_text: CString,
447}
448
449impl Prompt {
450 pub fn new(value: impl Into<String>) -> Result<Self, Error> {
469 let text = value.into();
470 let c_text = CString::new(text.clone())?;
471
472 Ok(Self { text, c_text })
473 }
474
475 pub fn as_str(&self) -> &str {
477 &self.text
478 }
479
480 #[cfg(aimx_bridge)]
481 fn as_ptr(&self) -> *const c_char {
482 self.c_text.as_ptr()
483 }
484}
485
486impl TryFrom<&str> for Prompt {
487 type Error = Error;
488
489 fn try_from(value: &str) -> Result<Self, Self::Error> {
490 Self::new(value)
491 }
492}
493
494impl TryFrom<String> for Prompt {
495 type Error = Error;
496
497 fn try_from(value: String) -> Result<Self, Self::Error> {
498 Self::new(value)
499 }
500}
501
502impl TryFrom<PromptText> for Prompt {
503 type Error = Error;
504
505 fn try_from(value: PromptText) -> Result<Self, Self::Error> {
506 Self::new(value.into_string())
507 }
508}
509
510impl AsRef<str> for Prompt {
511 fn as_ref(&self) -> &str {
512 self.as_str()
513 }
514}
515
516pub type PromptInput = Prompt;
518
519#[derive(Debug, Clone, PartialEq, Eq)]
521pub struct SystemInstructions {
522 text: String,
523 c_text: CString,
524}
525
526impl SystemInstructions {
527 pub fn new(value: impl Into<String>) -> Result<Self, Error> {
546 let text = value.into();
547 let c_text = CString::new(text.clone())?;
548
549 Ok(Self { text, c_text })
550 }
551
552 pub fn empty() -> Self {
557 Self {
558 text: String::new(),
559 c_text: CString::default(),
560 }
561 }
562
563 pub fn as_str(&self) -> &str {
565 &self.text
566 }
567
568 #[cfg(aimx_bridge)]
569 fn as_ptr(&self) -> *const c_char {
570 self.c_text.as_ptr()
571 }
572}
573
574impl Default for SystemInstructions {
575 fn default() -> Self {
576 Self::empty()
577 }
578}
579
580impl TryFrom<&str> for SystemInstructions {
581 type Error = Error;
582
583 fn try_from(value: &str) -> Result<Self, Self::Error> {
584 Self::new(value)
585 }
586}
587
588impl TryFrom<String> for SystemInstructions {
589 type Error = Error;
590
591 fn try_from(value: String) -> Result<Self, Self::Error> {
592 Self::new(value)
593 }
594}
595
596impl TryFrom<InstructionsText> for SystemInstructions {
597 type Error = Error;
598
599 fn try_from(value: InstructionsText) -> Result<Self, Self::Error> {
600 Self::new(value.into_string())
601 }
602}
603
604pub type Instructions = SystemInstructions;
606
607#[derive(Debug, Clone, Copy, PartialEq)]
609pub struct Temperature(f64);
610
611impl Temperature {
612 pub const MIN: f64 = 0.0;
614 pub const MAX: f64 = 2.0;
616
617 pub fn new(value: f64) -> Result<Self, Error> {
638 if (Self::MIN..=Self::MAX).contains(&value) {
639 Ok(Self(value))
640 } else {
641 Err(Error::InvalidTemperature(value))
642 }
643 }
644
645 pub fn as_f64(self) -> f64 {
647 self.0
648 }
649}
650
651impl TryFrom<f64> for Temperature {
652 type Error = Error;
653
654 fn try_from(value: f64) -> Result<Self, Self::Error> {
655 Self::new(value)
656 }
657}
658
659#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
661pub struct MaxTokens(usize);
662
663impl MaxTokens {
664 pub const MAX: usize = i64::MAX as usize;
666
667 pub fn new(value: usize) -> Result<Self, Error> {
688 if value <= Self::MAX {
689 Ok(Self(value))
690 } else {
691 Err(Error::InvalidMaxTokens(value))
692 }
693 }
694
695 pub fn get(self) -> usize {
697 self.0
698 }
699}
700
701impl TryFrom<usize> for MaxTokens {
702 type Error = Error;
703
704 fn try_from(value: usize) -> Result<Self, Self::Error> {
705 Self::new(value)
706 }
707}
708
709#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
711#[error("{message}")]
712pub struct GenerationError {
713 message: String,
714}
715
716impl GenerationError {
717 pub fn new(message: impl Into<String>) -> Self {
719 Self {
720 message: message.into(),
721 }
722 }
723
724 pub fn as_str(&self) -> &str {
726 &self.message
727 }
728}
729
730impl From<String> for GenerationError {
731 fn from(message: String) -> Self {
732 Self::new(message)
733 }
734}
735
736impl From<&str> for GenerationError {
737 fn from(message: &str) -> Self {
738 Self::new(message)
739 }
740}
741
742#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
744#[error("{message}")]
745pub struct ToolCallError {
746 message: String,
747}
748
749impl ToolCallError {
750 pub fn new(message: impl Into<String>) -> Self {
752 Self {
753 message: message.into(),
754 }
755 }
756
757 pub fn as_str(&self) -> &str {
759 &self.message
760 }
761}
762
763impl From<String> for ToolCallError {
764 fn from(message: String) -> Self {
765 Self::new(message)
766 }
767}
768
769impl From<&str> for ToolCallError {
770 fn from(message: &str) -> Self {
771 Self::new(message)
772 }
773}
774
775pub type ToolResult = Result<ToolOutput, ToolCallError>;
777
778type ModelTextResult = Result<ResponseText, GenerationError>;
779type StreamSender = mpsc::UnboundedSender<ModelTextResult>;
780type StreamReceiver = mpsc::UnboundedReceiver<ModelTextResult>;
781type ToolHandlerBox = Box<dyn ToolHandler>;
782
783#[cfg(aimx_bridge)]
784type ResponseSender = oneshot::Sender<ModelTextResult>;
785#[cfg(aimx_bridge)]
786type ResponseReceiver = oneshot::Receiver<ModelTextResult>;
787#[cfg(aimx_bridge)]
788type ToolResultCallback = extern "C" fn(*mut c_void, *const c_char, *const c_char);
789
790#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
794pub enum AvailabilityError {
795 #[error("device is not eligible (requires Apple Silicon M1 or later)")]
797 DeviceNotEligible,
798 #[error("Apple Intelligence is not enabled in System Settings")]
800 NotEnabled,
801 #[error("the on-device model is not ready yet")]
803 ModelNotReady,
804 #[error("unknown availability state")]
806 Unknown,
807}
808
809pub type UnavailabilityReason = AvailabilityError;
811
812#[derive(Debug, thiserror::Error)]
814pub enum Error {
815 #[error("Apple Intelligence unavailable: {0}")]
817 Unavailable(#[source] AvailabilityError),
818
819 #[error("generation error: {0}")]
821 Generation(#[from] GenerationError),
822
823 #[error("argument contains a null byte: {0}")]
825 NullByte(#[from] NulError),
826
827 #[error("temperature {0} is out of range; expected 0.0 – 2.0")]
829 InvalidTemperature(f64),
830
831 #[error("max_tokens {0} is out of range; expected no more than i64::MAX")]
833 InvalidMaxTokens(usize),
834
835 #[error("JSON error: {0}")]
837 Json(#[from] serde_json::Error),
838
839 #[error("tool '{name}' failed: {error}")]
841 ToolError {
842 name: ToolName,
844 #[source]
846 error: ToolCallError,
847 },
848}
849
850impl From<Infallible> for Error {
851 fn from(error: Infallible) -> Self {
852 match error {}
853 }
854}
855
856#[derive(Debug, Default, Clone)]
864pub struct GenerationOptions {
865 temperature: Option<Temperature>,
866 max_tokens: Option<MaxTokens>,
867}
868
869impl GenerationOptions {
870 pub fn new() -> Self {
872 Self::default()
873 }
874
875 pub fn temperature(mut self, temperature: Temperature) -> Self {
879 self.temperature = Some(temperature);
880 self
881 }
882
883 pub fn with_temperature(mut self, temperature: Temperature) -> Self {
885 self = self.temperature(temperature);
886 self
887 }
888
889 pub fn try_temperature(self, temperature: f64) -> Result<Self, Error> {
899 Ok(self.temperature(Temperature::new(temperature)?))
900 }
901
902 pub fn max_tokens(mut self, max_tokens: MaxTokens) -> Self {
908 self.max_tokens = Some(max_tokens);
909 self
910 }
911
912 pub fn with_max_tokens(mut self, max_tokens: MaxTokens) -> Self {
914 self = self.max_tokens(max_tokens);
915 self
916 }
917
918 pub fn try_max_tokens(self, max_tokens: usize) -> Result<Self, Error> {
928 Ok(self.max_tokens(MaxTokens::new(max_tokens)?))
929 }
930
931 pub fn temperature_value(&self) -> Option<Temperature> {
933 self.temperature
934 }
935
936 pub fn max_tokens_value(&self) -> Option<MaxTokens> {
938 self.max_tokens
939 }
940
941 pub fn validate(&self) -> Result<(), Error> {
963 GenerationConfig::try_from(self).map(|_| ())
964 }
965
966 fn validated(&self) -> Result<GenerationConfig, Error> {
967 GenerationConfig::try_from(self)
968 }
969}
970
971#[derive(Debug, Clone, Copy, Default)]
972struct GenerationConfig {
973 temperature: Option<Temperature>,
974 max_tokens: Option<MaxTokens>,
975}
976
977impl GenerationConfig {
978 fn ffi_temperature(self) -> f64 {
979 self.temperature.map(Temperature::as_f64).unwrap_or(-1.0)
980 }
981
982 fn ffi_max_tokens(self) -> i64 {
983 self.max_tokens
984 .map(|max_tokens| max_tokens.get() as i64)
985 .unwrap_or(-1)
986 }
987}
988
989impl TryFrom<&GenerationOptions> for GenerationConfig {
990 type Error = Error;
991
992 fn try_from(options: &GenerationOptions) -> Result<Self, Self::Error> {
993 Ok(Self {
994 temperature: options.temperature,
995 max_tokens: options.max_tokens,
996 })
997 }
998}
999
1000#[derive(Debug, Clone, serde::Serialize)]
1004#[serde(rename_all = "lowercase")]
1005pub enum GenerationSchemaPropertyType {
1006 String,
1008 Integer,
1010 Double,
1012 Bool,
1014}
1015
1016pub type ResponseFieldType = GenerationSchemaPropertyType;
1018pub type SchemaPropertyType = GenerationSchemaPropertyType;
1020
1021#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
1023pub enum GenerationSchemaPropertyRequirement {
1024 #[default]
1026 Required,
1027 Optional,
1029}
1030
1031impl GenerationSchemaPropertyRequirement {
1032 pub fn is_optional(self) -> bool {
1034 matches!(self, Self::Optional)
1035 }
1036
1037 pub fn is_required(self) -> bool {
1039 matches!(self, Self::Required)
1040 }
1041}
1042
1043#[derive(Debug, Clone)]
1045pub struct GenerationSchemaProperty {
1046 pub name: GenerationSchemaPropertyName,
1048 pub description: Option<SchemaDescription>,
1050 pub property_type: GenerationSchemaPropertyType,
1052 pub requirement: GenerationSchemaPropertyRequirement,
1054}
1055
1056pub type ResponseField = GenerationSchemaProperty;
1058pub type SchemaProperty = GenerationSchemaProperty;
1060
1061impl GenerationSchemaProperty {
1062 pub fn new(
1064 name: impl Into<GenerationSchemaPropertyName>,
1065 property_type: GenerationSchemaPropertyType,
1066 ) -> Self {
1067 Self {
1068 name: name.into(),
1069 description: None,
1070 property_type,
1071 requirement: GenerationSchemaPropertyRequirement::Required,
1072 }
1073 }
1074
1075 pub fn description(mut self, description: impl Into<SchemaDescription>) -> Self {
1077 self.description = Some(description.into());
1078 self
1079 }
1080
1081 pub fn optional(mut self) -> Self {
1083 self.requirement = GenerationSchemaPropertyRequirement::Optional;
1084 self
1085 }
1086
1087 pub fn required(mut self) -> Self {
1089 self.requirement = GenerationSchemaPropertyRequirement::Required;
1090 self
1091 }
1092}
1093
1094impl serde::Serialize for GenerationSchemaProperty {
1095 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1096 where
1097 S: serde::Serializer,
1098 {
1099 use serde::ser::SerializeStruct;
1100
1101 let field_count = if self.description.is_some() { 4 } else { 3 };
1102 let mut state = serializer.serialize_struct("GenerationSchemaProperty", field_count)?;
1103 state.serialize_field("name", &self.name)?;
1104 if let Some(description) = &self.description {
1105 state.serialize_field("description", description)?;
1106 }
1107 state.serialize_field("type", &self.property_type)?;
1108 state.serialize_field("optional", &self.requirement.is_optional())?;
1109 state.end()
1110 }
1111}
1112
1113#[derive(Debug, Clone, serde::Serialize)]
1125pub struct GenerationSchema {
1126 pub name: GenerationSchemaName,
1128 #[serde(skip_serializing_if = "Option::is_none")]
1130 pub description: Option<SchemaDescription>,
1131 pub properties: Vec<GenerationSchemaProperty>,
1133}
1134
1135pub type ResponseSchema = GenerationSchema;
1137pub type Schema = GenerationSchema;
1139
1140impl GenerationSchema {
1141 pub fn new(name: impl Into<GenerationSchemaName>) -> Self {
1143 Self {
1144 name: name.into(),
1145 description: None,
1146 properties: Vec::new(),
1147 }
1148 }
1149
1150 pub fn description(mut self, description: impl Into<SchemaDescription>) -> Self {
1152 self.description = Some(description.into());
1153 self
1154 }
1155
1156 pub fn property(mut self, property: GenerationSchemaProperty) -> Self {
1158 self.properties.push(property);
1159 self
1160 }
1161}
1162
1163pub struct ToolDefinition {
1172 pub name: ToolName,
1174 pub description: ToolDescription,
1176 pub parameters: GenerationSchema,
1178 handler: ToolHandlerBox,
1179}
1180
1181impl ToolDefinition {
1182 pub fn new(
1184 name: impl Into<ToolName>,
1185 description: impl Into<ToolDescription>,
1186 parameters: GenerationSchema,
1187 handler: impl Fn(serde_json::Value) -> ToolResult + WasmCompatSend + WasmCompatSync + 'static,
1188 ) -> Self {
1189 Self::builder(name, description, parameters).handler(handler)
1190 }
1191
1192 pub fn builder(
1194 name: impl Into<ToolName>,
1195 description: impl Into<ToolDescription>,
1196 parameters: GenerationSchema,
1197 ) -> ToolDefinitionBuilder {
1198 ToolDefinitionBuilder {
1199 name: name.into(),
1200 description: description.into(),
1201 parameters,
1202 }
1203 }
1204
1205 pub fn from_handler(
1207 name: impl Into<ToolName>,
1208 description: impl Into<ToolDescription>,
1209 parameters: GenerationSchema,
1210 handler: impl Fn(serde_json::Value) -> ToolResult + WasmCompatSend + WasmCompatSync + 'static,
1211 ) -> Self {
1212 Self::new(name, description, parameters, handler)
1213 }
1214
1215 #[cfg(aimx_bridge)]
1216 fn bridge_description(&self) -> serde_json::Value {
1217 serde_json::json!({
1218 "name": self.name.as_str(),
1219 "description": self.description.as_str(),
1220 "properties": &self.parameters.properties,
1221 })
1222 }
1223}
1224
1225#[derive(Debug, Clone)]
1227pub struct ToolDefinitionBuilder {
1228 name: ToolName,
1229 description: ToolDescription,
1230 parameters: GenerationSchema,
1231}
1232
1233impl ToolDefinitionBuilder {
1234 pub fn handler(
1236 self,
1237 handler: impl Fn(serde_json::Value) -> ToolResult + WasmCompatSend + WasmCompatSync + 'static,
1238 ) -> ToolDefinition {
1239 ToolDefinition {
1240 name: self.name,
1241 description: self.description,
1242 parameters: self.parameters,
1243 handler: Box::new(handler),
1244 }
1245 }
1246}
1247
1248pub trait Tool: std::fmt::Debug + WasmCompatSend + WasmCompatSync {
1250 fn name(&self) -> &ToolName;
1252
1253 fn description(&self) -> &ToolDescription;
1255
1256 fn parameters(&self) -> &GenerationSchema;
1258
1259 fn call(&self, args: serde_json::Value) -> ToolResult;
1267}
1268
1269impl Tool for ToolDefinition {
1270 fn name(&self) -> &ToolName {
1271 &self.name
1272 }
1273
1274 fn description(&self) -> &ToolDescription {
1275 &self.description
1276 }
1277
1278 fn parameters(&self) -> &GenerationSchema {
1279 &self.parameters
1280 }
1281
1282 fn call(&self, args: serde_json::Value) -> ToolResult {
1283 call_tool_handler(self.handler.as_ref(), args)
1284 }
1285}
1286
1287trait ToolHandler: WasmCompatSend + WasmCompatSync {
1288 fn call(&self, args: serde_json::Value) -> ToolResult;
1289}
1290
1291impl<F> ToolHandler for F
1292where
1293 F: Fn(serde_json::Value) -> ToolResult + WasmCompatSend + WasmCompatSync,
1294{
1295 fn call(&self, args: serde_json::Value) -> ToolResult {
1296 self(args)
1297 }
1298}
1299
1300fn call_tool_handler(handler: &dyn ToolHandler, args: serde_json::Value) -> ToolResult {
1301 match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| handler.call(args))) {
1302 Ok(result) => result,
1303 Err(payload) => Err(ToolCallError::new(format!(
1304 "tool handler panicked: {}",
1305 panic_payload_message(payload.as_ref())
1306 ))),
1307 }
1308}
1309
1310fn panic_payload_message(payload: &(dyn std::any::Any + Send)) -> String {
1311 if let Some(message) = payload.downcast_ref::<&'static str>() {
1312 return (*message).to_owned();
1313 }
1314
1315 if let Some(message) = payload.downcast_ref::<String>() {
1316 return message.clone();
1317 }
1318
1319 "non-string panic payload".to_owned()
1320}
1321
1322impl std::fmt::Debug for ToolDefinition {
1323 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1324 f.debug_struct("ToolDefinition")
1325 .field("name", &self.name)
1326 .field("description", &self.description)
1327 .finish_non_exhaustive()
1328 }
1329}
1330
1331#[cfg(aimx_bridge)]
1334struct ToolsContext {
1335 tools: Vec<(ToolName, ToolHandlerBox)>,
1336}
1337
1338#[cfg(aimx_bridge)]
1339impl ToolsContext {
1340 fn from_definitions(tools: Vec<ToolDefinition>) -> Arc<Self> {
1341 Arc::new(Self {
1342 tools: tools
1343 .into_iter()
1344 .map(|tool| (tool.name, tool.handler))
1345 .collect(),
1346 })
1347 }
1348
1349 fn call(&self, name: &str, args: serde_json::Value) -> ToolResult {
1350 let handler = self
1351 .tools
1352 .iter()
1353 .find_map(|(tool_name, handler)| (tool_name.as_str() == name).then_some(handler));
1354
1355 match handler {
1356 Some(handler) => call_tool_handler(handler.as_ref(), args),
1357 None => Err(ToolCallError::new(format!("unknown tool: {name}"))),
1358 }
1359 }
1360}
1361
1362const FM_AVAILABLE: i32 = 0;
1365const FM_DEVICE_NOT_ELIGIBLE: i32 = 1;
1366const FM_NOT_ENABLED: i32 = 2;
1367const FM_MODEL_NOT_READY: i32 = 3;
1368
1369pub fn is_available() -> bool {
1374 availability().is_ok()
1375}
1376
1377pub fn availability() -> Result<(), AvailabilityError> {
1384 #[cfg(aimx_bridge)]
1385 {
1386 let code = unsafe { fm_availability_reason() };
1387 match code {
1388 FM_AVAILABLE => Ok(()),
1389 FM_DEVICE_NOT_ELIGIBLE => Err(AvailabilityError::DeviceNotEligible),
1390 FM_NOT_ENABLED => Err(AvailabilityError::NotEnabled),
1391 FM_MODEL_NOT_READY => Err(AvailabilityError::ModelNotReady),
1392 _ => Err(AvailabilityError::Unknown),
1393 }
1394 }
1395 #[cfg(not(aimx_bridge))]
1396 Err(AvailabilityError::DeviceNotEligible)
1397}
1398
1399#[derive(Debug, Default, Clone, Copy)]
1403pub struct AppleIntelligenceModels {
1404 _private: (),
1405}
1406
1407impl AppleIntelligenceModels {
1408 pub fn new() -> Self {
1410 Self::default()
1411 }
1412
1413 pub fn availability(&self) -> Result<(), AvailabilityError> {
1423 availability()
1424 }
1425
1426 pub fn is_available(&self) -> bool {
1428 self.availability().is_ok()
1429 }
1430
1431 pub fn session(&self) -> LanguageModelSessionBuilder {
1433 LanguageModelSessionBuilder::new()
1434 }
1435
1436 pub fn agent(&self) -> LanguageModelSessionBuilder {
1438 self.session()
1439 }
1440
1441 pub async fn respond<P>(&self, prompt: P) -> Result<String, Error>
1449 where
1450 P: TryInto<Prompt>,
1451 P::Error: Into<Error>,
1452 {
1453 Ok(self.generate_text(prompt).await?.into_string())
1454 }
1455
1456 pub async fn generate<P>(&self, prompt: P) -> Result<GeneratedText, Error>
1462 where
1463 P: TryInto<Prompt>,
1464 P::Error: Into<Error>,
1465 {
1466 self.generate_text(prompt).await
1467 }
1468
1469 pub async fn generate_with_options<P>(
1475 &self,
1476 prompt: P,
1477 options: &GenerationOptions,
1478 ) -> Result<GeneratedText, Error>
1479 where
1480 P: TryInto<Prompt>,
1481 P::Error: Into<Error>,
1482 {
1483 self.generate_text_with_options(prompt, options).await
1484 }
1485
1486 pub async fn complete<P>(&self, prompt: P) -> Result<ResponseText, Error>
1492 where
1493 P: TryInto<Prompt>,
1494 P::Error: Into<Error>,
1495 {
1496 self.generate_text(prompt).await
1497 }
1498
1499 pub async fn generate_text<P>(&self, prompt: P) -> Result<ResponseText, Error>
1505 where
1506 P: TryInto<Prompt>,
1507 P::Error: Into<Error>,
1508 {
1509 let options = GenerationOptions::default();
1510 self.generate_text_with_options(prompt, &options).await
1511 }
1512
1513 pub async fn generate_text_with_options<P>(
1522 &self,
1523 prompt: P,
1524 options: &GenerationOptions,
1525 ) -> Result<ResponseText, Error>
1526 where
1527 P: TryInto<Prompt>,
1528 P::Error: Into<Error>,
1529 {
1530 LanguageModel::generate_text_with_options(self, prompt, options.clone()).await
1531 }
1532
1533 pub fn stream_text<P>(&self, prompt: P) -> Result<ResponseStream, Error>
1539 where
1540 P: TryInto<Prompt>,
1541 P::Error: Into<Error>,
1542 {
1543 let options = GenerationOptions::default();
1544 self.stream_text_with_options(prompt, &options)
1545 }
1546
1547 pub fn stream_text_with_options<P>(
1553 &self,
1554 prompt: P,
1555 options: &GenerationOptions,
1556 ) -> Result<ResponseStream, Error>
1557 where
1558 P: TryInto<Prompt>,
1559 P::Error: Into<Error>,
1560 {
1561 LanguageModel::stream_text_with_options(self, prompt, options.clone())
1562 }
1563
1564 pub fn stream_generate<P>(&self, prompt: P) -> Result<ResponseStream, Error>
1570 where
1571 P: TryInto<Prompt>,
1572 P::Error: Into<Error>,
1573 {
1574 self.stream_text(prompt)
1575 }
1576
1577 pub fn stream_generate_with_options<P>(
1583 &self,
1584 prompt: P,
1585 options: &GenerationOptions,
1586 ) -> Result<ResponseStream, Error>
1587 where
1588 P: TryInto<Prompt>,
1589 P::Error: Into<Error>,
1590 {
1591 self.stream_text_with_options(prompt, options)
1592 }
1593}
1594
1595pub type SystemLanguageModel = AppleIntelligenceModels;
1597pub type FoundationModels = AppleIntelligenceModels;
1599pub type Client = AppleIntelligenceModels;
1601
1602impl LanguageModel for AppleIntelligenceModels {
1603 fn generate_text_with_options<P>(
1604 &self,
1605 prompt: P,
1606 options: GenerationOptions,
1607 ) -> impl Future<Output = Result<ResponseText, Error>> + '_
1608 where
1609 P: TryInto<Prompt>,
1610 P::Error: Into<Error>,
1611 {
1612 let prompt = prompt.try_into().map_err(Into::into);
1613 let builder = self.session().options(options.clone());
1614
1615 async move {
1616 let prompt = prompt?;
1617 let session = builder.build()?;
1618 session.generate_prompt_with_options(prompt, &options).await
1619 }
1620 }
1621
1622 fn stream_text_with_options<P>(
1623 &self,
1624 prompt: P,
1625 options: GenerationOptions,
1626 ) -> Result<ResponseStream, Error>
1627 where
1628 P: TryInto<Prompt>,
1629 P::Error: Into<Error>,
1630 {
1631 let prompt = prompt.try_into().map_err(Into::into)?;
1632 let session = self.session().options(options.clone()).build()?;
1633 session.stream_text_with_options(prompt, &options)
1634 }
1635}
1636
1637pub trait LanguageModel {
1639 fn generate_text_with_options<P>(
1646 &self,
1647 prompt: P,
1648 options: GenerationOptions,
1649 ) -> impl Future<Output = Result<ResponseText, Error>> + '_
1650 where
1651 P: TryInto<Prompt>,
1652 P::Error: Into<Error>;
1653
1654 fn stream_text_with_options<P>(
1661 &self,
1662 prompt: P,
1663 options: GenerationOptions,
1664 ) -> Result<ResponseStream, Error>
1665 where
1666 P: TryInto<Prompt>,
1667 P::Error: Into<Error>;
1668}
1669
1670pub trait CompletionModel: LanguageModel {
1672 fn completion<P>(
1679 &self,
1680 prompt: P,
1681 options: GenerationOptions,
1682 ) -> impl Future<Output = Result<ResponseText, Error>> + '_
1683 where
1684 P: TryInto<Prompt>,
1685 P::Error: Into<Error>,
1686 {
1687 self.generate_text_with_options(prompt, options)
1688 }
1689
1690 fn stream_completion<P>(
1697 &self,
1698 prompt: P,
1699 options: GenerationOptions,
1700 ) -> Result<ResponseStream, Error>
1701 where
1702 P: TryInto<Prompt>,
1703 P::Error: Into<Error>,
1704 {
1705 self.stream_text_with_options(prompt, options)
1706 }
1707}
1708
1709impl<T> CompletionModel for T where T: LanguageModel {}
1710
1711pub trait GenerateText: LanguageModel {
1713 fn prompt<P>(&self, prompt: P) -> impl Future<Output = Result<ResponseText, Error>> + '_
1720 where
1721 P: TryInto<Prompt>,
1722 P::Error: Into<Error>,
1723 {
1724 let prompt = prompt.try_into().map_err(Into::into);
1725
1726 async move {
1727 let prompt = prompt?;
1728 self.generate_text_with_options(prompt, GenerationOptions::default())
1729 .await
1730 }
1731 }
1732}
1733
1734impl<T> GenerateText for T where T: LanguageModel {}
1735
1736#[derive(Debug)]
1738pub struct LanguageModelSessionBuilder {
1739 instructions: InstructionsText,
1740 tools: Vec<ToolDefinition>,
1741 default_options: GenerationOptions,
1742}
1743
1744impl LanguageModelSessionBuilder {
1745 pub fn new() -> Self {
1747 Self {
1748 instructions: InstructionsText::new(""),
1749 tools: Vec::new(),
1750 default_options: GenerationOptions::default(),
1751 }
1752 }
1753
1754 pub fn instructions(mut self, instructions: impl Into<InstructionsText>) -> Self {
1756 self.instructions = instructions.into();
1757 self
1758 }
1759
1760 pub fn preamble(self, instructions: impl Into<InstructionsText>) -> Self {
1762 self.instructions(instructions)
1763 }
1764
1765 pub fn tool(mut self, tool: ToolDefinition) -> Self {
1767 self.tools.push(tool);
1768 self
1769 }
1770
1771 pub fn tools(mut self, tools: impl IntoIterator<Item = ToolDefinition>) -> Self {
1773 self.tools.extend(tools);
1774 self
1775 }
1776
1777 pub fn temperature(mut self, temperature: Temperature) -> Self {
1779 self.default_options = self.default_options.temperature(temperature);
1780 self
1781 }
1782
1783 pub fn with_temperature(mut self, temperature: Temperature) -> Self {
1785 self = self.temperature(temperature);
1786 self
1787 }
1788
1789 pub fn try_temperature(mut self, temperature: f64) -> Result<Self, Error> {
1796 self.default_options = self.default_options.try_temperature(temperature)?;
1797 Ok(self)
1798 }
1799
1800 pub fn max_tokens(mut self, max_tokens: MaxTokens) -> Self {
1802 self.default_options = self.default_options.max_tokens(max_tokens);
1803 self
1804 }
1805
1806 pub fn with_max_tokens(mut self, max_tokens: MaxTokens) -> Self {
1808 self = self.max_tokens(max_tokens);
1809 self
1810 }
1811
1812 pub fn try_max_tokens(mut self, max_tokens: usize) -> Result<Self, Error> {
1819 self.default_options = self.default_options.try_max_tokens(max_tokens)?;
1820 Ok(self)
1821 }
1822
1823 pub fn options(mut self, options: GenerationOptions) -> Self {
1825 self.default_options = options;
1826 self
1827 }
1828
1829 pub fn build(self) -> Result<LanguageModelSession, Error> {
1853 let instructions = SystemInstructions::try_from(self.instructions)?;
1854 LanguageModelSession::create(instructions, self.tools, self.default_options)
1855 }
1856}
1857
1858impl Default for LanguageModelSessionBuilder {
1859 fn default() -> Self {
1860 Self::new()
1861 }
1862}
1863
1864pub type SessionBuilder = LanguageModelSessionBuilder;
1866
1867pub async fn respond<P>(prompt: P) -> Result<String, Error>
1880where
1881 P: TryInto<Prompt>,
1882 P::Error: Into<Error>,
1883{
1884 AppleIntelligenceModels::default().respond(prompt).await
1885}
1886
1887pub async fn respond_with_options<P>(
1895 prompt: P,
1896 options: &GenerationOptions,
1897) -> Result<String, Error>
1898where
1899 P: TryInto<Prompt>,
1900 P::Error: Into<Error>,
1901{
1902 Ok(AppleIntelligenceModels::default()
1903 .generate_with_options(prompt, options)
1904 .await?
1905 .into_string())
1906}
1907
1908pub async fn generate<P>(prompt: P) -> Result<String, Error>
1917where
1918 P: TryInto<Prompt>,
1919 P::Error: Into<Error>,
1920{
1921 respond(prompt).await
1922}
1923
1924pub async fn generate_with_options<P>(
1930 prompt: P,
1931 options: &GenerationOptions,
1932) -> Result<String, Error>
1933where
1934 P: TryInto<Prompt>,
1935 P::Error: Into<Error>,
1936{
1937 respond_with_options(prompt, options).await
1938}
1939
1940pub fn stream_generate<P>(prompt: P) -> Result<ResponseStream, Error>
1948where
1949 P: TryInto<Prompt>,
1950 P::Error: Into<Error>,
1951{
1952 AppleIntelligenceModels::default().stream_generate(prompt)
1953}
1954
1955pub fn stream_generate_with_options<P>(
1963 prompt: P,
1964 options: &GenerationOptions,
1965) -> Result<ResponseStream, Error>
1966where
1967 P: TryInto<Prompt>,
1968 P::Error: Into<Error>,
1969{
1970 AppleIntelligenceModels::default().stream_generate_with_options(prompt, options)
1971}
1972
1973#[cfg(aimx_bridge)]
1977#[derive(Debug)]
1978struct SessionHandle(NonNull<c_void>);
1979
1980#[cfg(aimx_bridge)]
1981impl SessionHandle {
1982 fn from_raw(handle: *mut c_void) -> Result<Self, Error> {
1983 NonNull::new(handle)
1984 .map(Self)
1985 .ok_or(Error::Unavailable(AvailabilityError::Unknown))
1986 }
1987
1988 fn as_ptr(&self) -> *mut c_void {
1989 self.0.as_ptr()
1990 }
1991}
1992
1993#[cfg(aimx_bridge)]
1994impl Drop for SessionHandle {
1995 fn drop(&mut self) {
1996 unsafe {
1997 fm_session_destroy(self.as_ptr());
1998 }
1999 }
2000}
2001
2002#[cfg(aimx_bridge)]
2003unsafe impl Send for SessionHandle {}
2004
2005#[cfg(aimx_bridge)]
2006unsafe impl Sync for SessionHandle {}
2007
2008pub struct LanguageModelSession {
2028 default_options: GenerationOptions,
2029 #[cfg(aimx_bridge)]
2030 handle: Arc<SessionHandle>,
2031 #[cfg(aimx_bridge)]
2034 _tools: Option<Arc<ToolsContext>>,
2035}
2036
2037impl std::fmt::Debug for LanguageModelSession {
2038 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2039 f.debug_struct("LanguageModelSession")
2040 .field("default_options", &self.default_options)
2041 .finish_non_exhaustive()
2042 }
2043}
2044
2045impl LanguageModelSession {
2046 pub fn builder() -> LanguageModelSessionBuilder {
2048 LanguageModelSessionBuilder::new()
2049 }
2050
2051 pub fn new() -> Result<Self, Error> {
2057 Self::builder().build()
2058 }
2059
2060 pub fn with_instructions<I>(instructions: I) -> Result<Self, Error>
2072 where
2073 I: TryInto<SystemInstructions>,
2074 I::Error: Into<Error>,
2075 {
2076 let instructions = instructions.try_into().map_err(Into::into)?;
2077 Self::create(instructions, Vec::new(), GenerationOptions::default())
2078 }
2079
2080 pub fn with_tools<I>(instructions: I, tools: Vec<ToolDefinition>) -> Result<Self, Error>
2092 where
2093 I: TryInto<SystemInstructions>,
2094 I::Error: Into<Error>,
2095 {
2096 let instructions = instructions.try_into().map_err(Into::into)?;
2097 Self::create(instructions, tools, GenerationOptions::default())
2098 }
2099
2100 fn create(
2101 instructions: SystemInstructions,
2102 tools: Vec<ToolDefinition>,
2103 default_options: GenerationOptions,
2104 ) -> Result<Self, Error> {
2105 default_options.validate()?;
2106 availability().map_err(Error::Unavailable)?;
2107
2108 #[cfg(aimx_bridge)]
2109 {
2110 Self::create_bridge_session(instructions, tools, default_options)
2111 }
2112 #[cfg(not(aimx_bridge))]
2113 {
2114 let _ = (instructions, tools, default_options);
2115 Err(Error::Unavailable(AvailabilityError::DeviceNotEligible))
2116 }
2117 }
2118
2119 #[cfg(aimx_bridge)]
2120 fn create_bridge_session(
2121 instructions: SystemInstructions,
2122 tools: Vec<ToolDefinition>,
2123 default_options: GenerationOptions,
2124 ) -> Result<Self, Error> {
2125 if tools.is_empty() {
2126 return Self::create_plain_bridge_session(instructions, default_options);
2127 }
2128
2129 Self::create_tool_bridge_session(instructions, tools, default_options)
2130 }
2131
2132 #[cfg(aimx_bridge)]
2133 fn create_plain_bridge_session(
2134 instructions: SystemInstructions,
2135 default_options: GenerationOptions,
2136 ) -> Result<Self, Error> {
2137 let handle = unsafe { fm_session_create(instructions.as_ptr()) };
2138
2139 Ok(Self {
2140 default_options,
2141 handle: Arc::new(SessionHandle::from_raw(handle)?),
2142 _tools: None,
2143 })
2144 }
2145
2146 #[cfg(aimx_bridge)]
2147 fn create_tool_bridge_session(
2148 instructions: SystemInstructions,
2149 tools: Vec<ToolDefinition>,
2150 default_options: GenerationOptions,
2151 ) -> Result<Self, Error> {
2152 let tool_descriptions = tools
2153 .iter()
2154 .map(ToolDefinition::bridge_description)
2155 .collect::<Vec<_>>();
2156 let c_tools_json = CString::new(serde_json::to_vec(&tool_descriptions)?)?;
2157 let tools_ctx = ToolsContext::from_definitions(tools);
2158 let tool_ctx_ptr = Arc::as_ptr(&tools_ctx) as *mut c_void;
2159
2160 let handle = unsafe {
2161 fm_session_create_with_tools(
2162 instructions.as_ptr(),
2163 c_tools_json.as_ptr(),
2164 tool_ctx_ptr,
2165 tool_dispatch,
2166 )
2167 };
2168
2169 Ok(Self {
2170 default_options,
2171 handle: Arc::new(SessionHandle::from_raw(handle)?),
2172 _tools: Some(tools_ctx),
2173 })
2174 }
2175
2176 pub async fn respond<P>(&self, prompt: P) -> Result<String, Error>
2186 where
2187 P: TryInto<Prompt>,
2188 P::Error: Into<Error>,
2189 {
2190 Ok(self.respond_to(prompt).await?.into_string())
2191 }
2192
2193 pub async fn respond_to<P>(&self, prompt: P) -> Result<ResponseText, Error>
2199 where
2200 P: TryInto<Prompt>,
2201 P::Error: Into<Error>,
2202 {
2203 self.respond_to_with_options(prompt, &self.default_options)
2204 .await
2205 }
2206
2207 pub async fn complete<P>(&self, prompt: P) -> Result<ResponseText, Error>
2213 where
2214 P: TryInto<Prompt>,
2215 P::Error: Into<Error>,
2216 {
2217 self.respond_to(prompt).await
2218 }
2219
2220 pub async fn generate<P>(&self, prompt: P) -> Result<GeneratedText, Error>
2226 where
2227 P: TryInto<Prompt>,
2228 P::Error: Into<Error>,
2229 {
2230 self.respond_to(prompt).await
2231 }
2232
2233 pub async fn generate_text<P>(&self, prompt: P) -> Result<ResponseText, Error>
2239 where
2240 P: TryInto<Prompt>,
2241 P::Error: Into<Error>,
2242 {
2243 self.respond_to(prompt).await
2244 }
2245
2246 pub async fn respond_with_options<P>(
2254 &self,
2255 prompt: P,
2256 options: &GenerationOptions,
2257 ) -> Result<String, Error>
2258 where
2259 P: TryInto<Prompt>,
2260 P::Error: Into<Error>,
2261 {
2262 Ok(self
2263 .respond_to_with_options(prompt, options)
2264 .await?
2265 .into_string())
2266 }
2267
2268 pub async fn respond_to_with_options<P>(
2276 &self,
2277 prompt: P,
2278 options: &GenerationOptions,
2279 ) -> Result<ResponseText, Error>
2280 where
2281 P: TryInto<Prompt>,
2282 P::Error: Into<Error>,
2283 {
2284 let prompt = prompt.try_into().map_err(Into::into)?;
2285 self.generate_prompt_with_options(prompt, options).await
2286 }
2287
2288 pub async fn complete_with_options<P>(
2294 &self,
2295 prompt: P,
2296 options: &GenerationOptions,
2297 ) -> Result<ResponseText, Error>
2298 where
2299 P: TryInto<Prompt>,
2300 P::Error: Into<Error>,
2301 {
2302 self.respond_to_with_options(prompt, options).await
2303 }
2304
2305 pub async fn generate_with_options<P>(
2311 &self,
2312 prompt: P,
2313 options: &GenerationOptions,
2314 ) -> Result<GeneratedText, Error>
2315 where
2316 P: TryInto<Prompt>,
2317 P::Error: Into<Error>,
2318 {
2319 self.respond_to_with_options(prompt, options).await
2320 }
2321
2322 pub async fn generate_text_with_options<P>(
2328 &self,
2329 prompt: P,
2330 options: &GenerationOptions,
2331 ) -> Result<ResponseText, Error>
2332 where
2333 P: TryInto<Prompt>,
2334 P::Error: Into<Error>,
2335 {
2336 self.respond_to_with_options(prompt, options).await
2337 }
2338
2339 async fn generate_prompt_with_options(
2340 &self,
2341 prompt: Prompt,
2342 options: &GenerationOptions,
2343 ) -> Result<ResponseText, Error> {
2344 let config = options.validated()?;
2345 #[cfg(aimx_bridge)]
2346 {
2347 let handle = Arc::clone(&self.handle);
2348 let (tx, rx) = oneshot::channel::<ModelTextResult>();
2349 let ctx = Box::into_raw(Box::new(ResponseContext {
2350 tx,
2351 _handle: handle,
2352 })) as *mut c_void;
2353
2354 unsafe {
2355 fm_session_respond(
2356 self.handle.as_ptr(),
2357 prompt.as_ptr(),
2358 config.ffi_temperature(),
2359 config.ffi_max_tokens(),
2360 ctx,
2361 respond_callback,
2362 );
2363 }
2364
2365 receive_response(rx).await
2366 }
2367 #[cfg(not(aimx_bridge))]
2368 {
2369 let _ = (prompt, config);
2370 Err(Error::Unavailable(AvailabilityError::DeviceNotEligible))
2371 }
2372 }
2373
2374 pub async fn respond_as<T, P>(&self, prompt: P, schema: &GenerationSchema) -> Result<T, Error>
2386 where
2387 T: serde::de::DeserializeOwned,
2388 P: TryInto<Prompt>,
2389 P::Error: Into<Error>,
2390 {
2391 self.respond_generating(prompt, schema).await
2392 }
2393
2394 pub async fn respond_generating<T, P>(
2403 &self,
2404 prompt: P,
2405 schema: &GenerationSchema,
2406 ) -> Result<T, Error>
2407 where
2408 T: serde::de::DeserializeOwned,
2409 P: TryInto<Prompt>,
2410 P::Error: Into<Error>,
2411 {
2412 self.respond_generating_with_options(prompt, schema, &self.default_options)
2413 .await
2414 }
2415
2416 pub async fn generate_object<T, P>(
2422 &self,
2423 prompt: P,
2424 schema: &GenerationSchema,
2425 ) -> Result<T, Error>
2426 where
2427 T: serde::de::DeserializeOwned,
2428 P: TryInto<Prompt>,
2429 P::Error: Into<Error>,
2430 {
2431 self.respond_generating(prompt, schema).await
2432 }
2433
2434 pub async fn respond_as_with_options<T, P>(
2442 &self,
2443 prompt: P,
2444 schema: &GenerationSchema,
2445 options: &GenerationOptions,
2446 ) -> Result<T, Error>
2447 where
2448 T: serde::de::DeserializeOwned,
2449 P: TryInto<Prompt>,
2450 P::Error: Into<Error>,
2451 {
2452 self.respond_generating_with_options(prompt, schema, options)
2453 .await
2454 }
2455
2456 pub async fn respond_generating_with_options<T, P>(
2462 &self,
2463 prompt: P,
2464 schema: &GenerationSchema,
2465 options: &GenerationOptions,
2466 ) -> Result<T, Error>
2467 where
2468 T: serde::de::DeserializeOwned,
2469 P: TryInto<Prompt>,
2470 P::Error: Into<Error>,
2471 {
2472 let prompt = prompt.try_into().map_err(Into::into)?;
2473 let config = options.validated()?;
2474
2475 self.respond_generating_prompt_with_config(prompt, schema, config)
2476 .await
2477 }
2478
2479 pub async fn generate_object_with_options<T, P>(
2485 &self,
2486 prompt: P,
2487 schema: &GenerationSchema,
2488 options: &GenerationOptions,
2489 ) -> Result<T, Error>
2490 where
2491 T: serde::de::DeserializeOwned,
2492 P: TryInto<Prompt>,
2493 P::Error: Into<Error>,
2494 {
2495 self.respond_generating_with_options(prompt, schema, options)
2496 .await
2497 }
2498
2499 async fn respond_generating_prompt_with_config<T>(
2500 &self,
2501 prompt: Prompt,
2502 schema: &GenerationSchema,
2503 config: GenerationConfig,
2504 ) -> Result<T, Error>
2505 where
2506 T: serde::de::DeserializeOwned,
2507 {
2508 #[cfg(aimx_bridge)]
2509 {
2510 let handle = Arc::clone(&self.handle);
2511 let (tx, rx) = oneshot::channel::<ModelTextResult>();
2512 let ctx = Box::into_raw(Box::new(ResponseContext {
2513 tx,
2514 _handle: handle,
2515 })) as *mut c_void;
2516 let c_schema_json = CString::new(serde_json::to_vec(schema)?)?;
2517
2518 unsafe {
2519 fm_session_respond_structured(
2520 self.handle.as_ptr(),
2521 prompt.as_ptr(),
2522 c_schema_json.as_ptr(),
2523 config.ffi_temperature(),
2524 config.ffi_max_tokens(),
2525 ctx,
2526 respond_callback,
2527 );
2528 }
2529
2530 let json = receive_response(rx).await?.into_string();
2531 serde_json::from_str(&json).map_err(Error::from)
2532 }
2533 #[cfg(not(aimx_bridge))]
2534 {
2535 let _ = (prompt, schema, config);
2536 Err(Error::Unavailable(AvailabilityError::DeviceNotEligible))
2537 }
2538 }
2539
2540 pub fn stream<P>(&self, prompt: P) -> Result<ResponseStream, Error>
2560 where
2561 P: TryInto<Prompt>,
2562 P::Error: Into<Error>,
2563 {
2564 self.stream_response(prompt)
2565 }
2566
2567 pub fn stream_response<P>(&self, prompt: P) -> Result<ResponseStream, Error>
2573 where
2574 P: TryInto<Prompt>,
2575 P::Error: Into<Error>,
2576 {
2577 self.stream_response_with_options(prompt, &self.default_options)
2578 }
2579
2580 pub fn stream_generate<P>(&self, prompt: P) -> Result<ResponseStream, Error>
2586 where
2587 P: TryInto<Prompt>,
2588 P::Error: Into<Error>,
2589 {
2590 self.stream_response(prompt)
2591 }
2592
2593 pub fn stream_text<P>(&self, prompt: P) -> Result<ResponseStream, Error>
2599 where
2600 P: TryInto<Prompt>,
2601 P::Error: Into<Error>,
2602 {
2603 self.stream_response(prompt)
2604 }
2605
2606 pub fn stream_with_options<P>(
2614 &self,
2615 prompt: P,
2616 options: &GenerationOptions,
2617 ) -> Result<ResponseStream, Error>
2618 where
2619 P: TryInto<Prompt>,
2620 P::Error: Into<Error>,
2621 {
2622 self.stream_response_with_options(prompt, options)
2623 }
2624
2625 pub fn stream_response_with_options<P>(
2633 &self,
2634 prompt: P,
2635 options: &GenerationOptions,
2636 ) -> Result<ResponseStream, Error>
2637 where
2638 P: TryInto<Prompt>,
2639 P::Error: Into<Error>,
2640 {
2641 let prompt = prompt.try_into().map_err(Into::into)?;
2642 let config = options.validated()?;
2643
2644 self.stream_prompt_with_config(prompt, config)
2645 }
2646
2647 pub fn stream_generate_with_options<P>(
2653 &self,
2654 prompt: P,
2655 options: &GenerationOptions,
2656 ) -> Result<ResponseStream, Error>
2657 where
2658 P: TryInto<Prompt>,
2659 P::Error: Into<Error>,
2660 {
2661 self.stream_response_with_options(prompt, options)
2662 }
2663
2664 pub fn stream_text_with_options<P>(
2670 &self,
2671 prompt: P,
2672 options: &GenerationOptions,
2673 ) -> Result<ResponseStream, Error>
2674 where
2675 P: TryInto<Prompt>,
2676 P::Error: Into<Error>,
2677 {
2678 self.stream_response_with_options(prompt, options)
2679 }
2680
2681 fn stream_prompt_with_config(
2682 &self,
2683 prompt: Prompt,
2684 config: GenerationConfig,
2685 ) -> Result<ResponseStream, Error> {
2686 #[cfg(aimx_bridge)]
2687 {
2688 let handle = Arc::clone(&self.handle);
2689 let (tx, rx) = mpsc::unbounded::<ModelTextResult>();
2690 let ctx = Box::into_raw(Box::new(StreamContext {
2691 tx,
2692 _handle: handle,
2693 })) as *mut c_void;
2694
2695 unsafe {
2696 fm_session_stream(
2697 self.handle.as_ptr(),
2698 prompt.as_ptr(),
2699 config.ffi_temperature(),
2700 config.ffi_max_tokens(),
2701 ctx,
2702 stream_token_callback,
2703 stream_done_callback,
2704 );
2705 }
2706
2707 Ok(ResponseStream { rx })
2708 }
2709 #[cfg(not(aimx_bridge))]
2710 {
2711 let _ = (prompt, config);
2712 Err(Error::Unavailable(AvailabilityError::DeviceNotEligible))
2713 }
2714 }
2715}
2716
2717pub type Session = LanguageModelSession;
2719
2720impl LanguageModel for LanguageModelSession {
2721 fn generate_text_with_options<P>(
2722 &self,
2723 prompt: P,
2724 options: GenerationOptions,
2725 ) -> impl Future<Output = Result<ResponseText, Error>> + '_
2726 where
2727 P: TryInto<Prompt>,
2728 P::Error: Into<Error>,
2729 {
2730 let prompt = prompt.try_into().map_err(Into::into);
2731
2732 async move {
2733 let prompt = prompt?;
2734 self.generate_prompt_with_options(prompt, &options).await
2735 }
2736 }
2737
2738 fn stream_text_with_options<P>(
2739 &self,
2740 prompt: P,
2741 options: GenerationOptions,
2742 ) -> Result<ResponseStream, Error>
2743 where
2744 P: TryInto<Prompt>,
2745 P::Error: Into<Error>,
2746 {
2747 LanguageModelSession::stream_text_with_options(self, prompt, &options)
2748 }
2749}
2750
2751pub struct ResponseStream {
2761 rx: StreamReceiver,
2762}
2763
2764impl Stream for ResponseStream {
2765 type Item = Result<ResponseText, Error>;
2766
2767 fn poll_next(mut self: Pin<&mut Self>, cx: &mut StdContext<'_>) -> Poll<Option<Self::Item>> {
2768 Pin::new(&mut self.rx)
2769 .poll_next(cx)
2770 .map(|opt| opt.map(|r| r.map_err(Error::from)))
2771 }
2772}
2773
2774#[cfg(aimx_bridge)]
2775async fn receive_response(receiver: ResponseReceiver) -> Result<ResponseText, Error> {
2776 receiver
2777 .await
2778 .map_err(|_| GenerationError::new("session was dropped before responding"))?
2779 .map_err(Error::from)
2780}
2781
2782#[cfg(aimx_bridge)]
2786struct ResponseContext {
2787 tx: ResponseSender,
2788 _handle: Arc<SessionHandle>,
2789}
2790
2791#[cfg(aimx_bridge)]
2793extern "C" fn respond_callback(ctx: *mut c_void, result: *const c_char, error: *const c_char) {
2794 let context = unsafe { Box::from_raw(ctx as *mut ResponseContext) };
2797
2798 if let Some(msg) = callback_owned_text(error) {
2799 context.tx.send(Err(GenerationError::from(msg))).ok();
2800 } else if let Some(text) = callback_owned_text(result) {
2801 context.tx.send(Ok(ResponseText::from(text))).ok();
2802 }
2803}
2804
2805#[cfg(aimx_bridge)]
2806fn callback_owned_text(ptr: *const c_char) -> Option<String> {
2807 if ptr.is_null() {
2808 return None;
2809 }
2810
2811 Some(unsafe { CStr::from_ptr(ptr).to_string_lossy().into_owned() })
2812}
2813
2814#[cfg(aimx_bridge)]
2816struct StreamContext {
2817 tx: StreamSender,
2818 _handle: Arc<SessionHandle>,
2819}
2820
2821#[cfg(aimx_bridge)]
2823extern "C" fn stream_token_callback(ctx: *mut c_void, token: *const c_char) {
2824 let stream_ctx = unsafe { &*(ctx as *const StreamContext) };
2827 let Some(text) = callback_owned_text(token) else {
2828 return;
2829 };
2830 stream_ctx
2832 .tx
2833 .unbounded_send(Ok(ResponseText::from(text)))
2834 .ok();
2835}
2836
2837#[cfg(aimx_bridge)]
2839extern "C" fn stream_done_callback(ctx: *mut c_void, error: *const c_char) {
2840 let stream_ctx = unsafe { Box::from_raw(ctx as *mut StreamContext) };
2842 if let Some(msg) = callback_owned_text(error) {
2843 stream_ctx
2844 .tx
2845 .unbounded_send(Err(GenerationError::from(msg)))
2846 .ok();
2847 }
2848 }
2850
2851#[cfg(aimx_bridge)]
2854extern "C" fn tool_dispatch(
2855 ctx: *mut c_void,
2856 name_ptr: *const c_char,
2857 args_ptr: *const c_char,
2858 result_ctx: *mut c_void,
2859 result_cb: ToolResultCallback,
2860) {
2861 let result = dispatch_tool_call(ctx, name_ptr, args_ptr);
2862 send_tool_result(result_ctx, result_cb, result);
2863}
2864
2865#[cfg(aimx_bridge)]
2866fn dispatch_tool_call(
2867 ctx: *mut c_void,
2868 name_ptr: *const c_char,
2869 args_ptr: *const c_char,
2870) -> ToolResult {
2871 if ctx.is_null() {
2872 return Err(ToolCallError::new("missing tool context"));
2873 }
2874
2875 let tools = unsafe { &*(ctx as *const ToolsContext) };
2877 with_callback_text(name_ptr, "tool name", |name| {
2878 let args = parse_tool_args(args_ptr)?;
2879 tools.call(name, args)
2880 })?
2881}
2882
2883#[cfg(aimx_bridge)]
2884fn parse_tool_args(args_ptr: *const c_char) -> Result<serde_json::Value, ToolCallError> {
2885 if args_ptr.is_null() {
2886 return Err(ToolCallError::new("missing tool arguments"));
2887 }
2888
2889 let args = unsafe { CStr::from_ptr(args_ptr) };
2890 serde_json::from_slice(args.to_bytes())
2891 .map_err(|error| ToolCallError::new(format!("invalid tool args JSON: {error}")))
2892}
2893
2894#[cfg(aimx_bridge)]
2895fn with_callback_text<R>(
2896 ptr: *const c_char,
2897 label: &str,
2898 f: impl FnOnce(&str) -> R,
2899) -> Result<R, ToolCallError> {
2900 if ptr.is_null() {
2901 return Err(ToolCallError::new(format!("missing {label}")));
2902 }
2903
2904 let text = unsafe { CStr::from_ptr(ptr).to_string_lossy() };
2905 Ok(f(text.as_ref()))
2906}
2907
2908#[cfg(aimx_bridge)]
2909fn send_tool_result(result_ctx: *mut c_void, result_cb: ToolResultCallback, result: ToolResult) {
2910 match result {
2911 Ok(output) => send_tool_output(result_ctx, result_cb, output),
2912 Err(error) => send_tool_error(result_ctx, result_cb, error.as_str()),
2913 }
2914}
2915
2916#[cfg(aimx_bridge)]
2917fn send_tool_output(result_ctx: *mut c_void, result_cb: ToolResultCallback, output: ToolOutput) {
2918 match CString::new(output.into_string()) {
2919 Ok(c_output) => result_cb(result_ctx, c_output.as_ptr(), null()),
2920 Err(error) => send_tool_error(
2921 result_ctx,
2922 result_cb,
2923 &format!("tool result contains a null byte: {error}"),
2924 ),
2925 }
2926}
2927
2928#[cfg(aimx_bridge)]
2929fn send_tool_error(result_ctx: *mut c_void, result_cb: ToolResultCallback, message: &str) {
2930 match CString::new(message) {
2931 Ok(c_error) => result_cb(result_ctx, null(), c_error.as_ptr()),
2932 Err(_) => result_cb(
2933 result_ctx,
2934 null(),
2935 TOOL_ERROR_ENCODING_FAILURE.as_ptr().cast::<c_char>(),
2936 ),
2937 }
2938}
2939
2940#[cfg(aimx_bridge)]
2941const TOOL_ERROR_ENCODING_FAILURE: &[u8] = b"tool error contains a null byte\0";
2942
2943#[cfg(test)]
2946mod tests {
2947 use super::*;
2948 use proptest::prelude::*;
2949
2950 #[test]
2953 fn test_is_available_returns_without_panic() {
2954 let _ = is_available();
2955 }
2956
2957 #[test]
2958 fn test_availability_result_is_consistent() {
2959 let avail = availability();
2960 assert_eq!(is_available(), avail.is_ok());
2961 }
2962
2963 #[test]
2964 fn test_options_default_is_valid() -> Result<(), Error> {
2965 let opts = GenerationOptions::default();
2966 assert!(opts.validate().is_ok());
2967 let config = opts.validated()?;
2968 assert_eq!(config.ffi_temperature(), -1.0);
2969 assert_eq!(config.ffi_max_tokens(), -1);
2970 Ok(())
2971 }
2972
2973 #[test]
2974 fn test_options_valid_temperature() -> Result<(), Error> {
2975 for (temp, expected_ffi) in [(0.0_f64, 0.0), (1.0, 1.0), (2.0, 2.0)] {
2976 let opts = GenerationOptions::new().try_temperature(temp)?;
2977 assert!(
2978 opts.validate().is_ok(),
2979 "temperature {temp} should be valid"
2980 );
2981 let config = opts.validated()?;
2982 assert_eq!(config.ffi_temperature(), expected_ffi);
2983 }
2984 Ok(())
2985 }
2986
2987 #[test]
2988 fn test_options_invalid_temperature() {
2989 for temp in [-f64::INFINITY, -0.1_f64, 2.001, f64::INFINITY, f64::NAN] {
2990 assert!(
2991 GenerationOptions::new().try_temperature(temp).is_err(),
2992 "temperature {temp} should be invalid"
2993 );
2994 }
2995 }
2996
2997 #[test]
2998 fn test_options_invalid_max_tokens() {
2999 if usize::BITS < i64::BITS {
3000 return;
3001 }
3002
3003 let invalid = MaxTokens::MAX + 1;
3004
3005 assert!(matches!(
3006 GenerationOptions::new().try_max_tokens(invalid),
3007 Err(Error::InvalidMaxTokens(value)) if value == invalid
3008 ));
3009 assert!(matches!(
3010 MaxTokens::new(invalid),
3011 Err(Error::InvalidMaxTokens(value)) if value == invalid
3012 ));
3013 }
3014
3015 #[test]
3016 fn test_session_creation_fails_gracefully_when_unavailable() {
3017 if is_available() {
3018 return; }
3020 assert!(matches!(
3021 LanguageModelSession::new(),
3022 Err(Error::Unavailable(_))
3023 ));
3024 }
3025
3026 #[test]
3027 fn test_null_byte_in_prompt_returns_error() {
3028 let result = futures_executor::block_on(respond("hello\0world"));
3029 assert!(matches!(result, Err(Error::NullByte(_))));
3030 }
3031
3032 #[test]
3033 fn test_prompt_and_instruction_inputs_reject_null_bytes_before_availability() {
3034 let prompt = Prompt::try_from("hello\0world");
3035 let instructions = SystemInstructions::try_from("system\0prompt");
3036
3037 assert!(matches!(prompt, Err(Error::NullByte(_))));
3038 assert!(matches!(instructions, Err(Error::NullByte(_))));
3039 }
3040
3041 #[test]
3042 fn test_session_builder_validates_options_before_availability() {
3043 let result = AppleIntelligenceModels::default()
3044 .session()
3045 .instructions("Valid system prompt")
3046 .try_temperature(2.5)
3047 .and_then(LanguageModelSessionBuilder::build);
3048
3049 assert!(matches!(result, Err(Error::InvalidTemperature(value)) if value == 2.5));
3050 }
3051
3052 #[test]
3053 fn test_options_expose_typed_values() -> Result<(), Error> {
3054 let temperature = Temperature::new(0.4)?;
3055 let max_tokens = MaxTokens::new(128)?;
3056 let opts = GenerationOptions::new()
3057 .temperature(temperature)
3058 .max_tokens(max_tokens);
3059
3060 assert_eq!(opts.temperature_value(), Some(temperature));
3061 assert_eq!(opts.max_tokens_value(), Some(max_tokens));
3062 Ok(())
3063 }
3064
3065 #[test]
3066 fn test_schema_property_requirement_serializes_as_optional_flag() -> Result<(), Error> {
3067 let schema = GenerationSchema::new("Answer")
3068 .property(GenerationSchemaProperty::new(
3069 "required",
3070 GenerationSchemaPropertyType::String,
3071 ))
3072 .property(
3073 GenerationSchemaProperty::new("maybe", GenerationSchemaPropertyType::String)
3074 .optional(),
3075 );
3076
3077 let json = serde_json::to_value(schema)?;
3078
3079 assert_eq!(json["properties"][0]["optional"], false);
3080 assert_eq!(json["properties"][1]["optional"], true);
3081 assert!(GenerationSchemaPropertyRequirement::Required.is_required());
3082 assert!(GenerationSchemaPropertyRequirement::Optional.is_optional());
3083 Ok(())
3084 }
3085
3086 #[test]
3087 fn test_session_builder_validates_instructions_before_availability() {
3088 let result = AppleIntelligenceModels::default()
3089 .session()
3090 .instructions("bad\0instructions")
3091 .build();
3092
3093 assert!(matches!(result, Err(Error::NullByte(_))));
3094 }
3095
3096 #[test]
3097 fn test_string_newtypes_round_trip_through_display_and_inner_value() {
3098 let cases = [
3099 PromptText::new("prompt").into_string(),
3100 ResponseText::new("response").to_string(),
3101 GenerationSchemaName::new("GenerationSchema").to_string(),
3102 GenerationSchemaPropertyName::new("field").to_string(),
3103 ToolName::new("tool").to_string(),
3104 ToolOutput::new("output").to_string(),
3105 ];
3106
3107 assert_eq!(
3108 cases,
3109 [
3110 "prompt",
3111 "response",
3112 "GenerationSchema",
3113 "field",
3114 "tool",
3115 "output"
3116 ]
3117 );
3118 }
3119
3120 #[test]
3121 fn test_schema_builder() -> Result<(), Error> {
3122 let schema = GenerationSchema::new("Point")
3123 .description("A 2D point")
3124 .property(
3125 GenerationSchemaProperty::new("x", GenerationSchemaPropertyType::Double)
3126 .description("X axis"),
3127 )
3128 .property(GenerationSchemaProperty::new(
3129 "y",
3130 GenerationSchemaPropertyType::Double,
3131 ));
3132 assert_eq!(schema.name, "Point");
3133 assert_eq!(schema.properties.len(), 2);
3134 let json = serde_json::to_string(&schema)?;
3135 assert!(json.contains("\"x\""));
3136 assert!(json.contains("\"double\""));
3137 Ok(())
3138 }
3139
3140 #[test]
3141 fn test_tool_definition_builder() -> Result<(), ToolCallError> {
3142 let tool = ToolDefinition::builder(
3143 "add",
3144 "Add two numbers",
3145 GenerationSchema::new("AddArgs")
3146 .property(GenerationSchemaProperty::new(
3147 "a",
3148 GenerationSchemaPropertyType::Double,
3149 ))
3150 .property(GenerationSchemaProperty::new(
3151 "b",
3152 GenerationSchemaPropertyType::Double,
3153 )),
3154 )
3155 .handler(|args| {
3156 let a = args["a"].as_f64().unwrap_or(0.0);
3157 let b = args["b"].as_f64().unwrap_or(0.0);
3158 Ok(ToolOutput::from(format!("{}", a + b)))
3159 });
3160 assert_eq!(tool.name, "add");
3161 let result = tool.call(serde_json::json!({"a": 3.0, "b": 4.0}));
3162 assert_eq!(result?, "7");
3163 Ok(())
3164 }
3165
3166 #[test]
3167 fn test_tool_definition_new_and_trait_boundary() -> Result<(), ToolCallError> {
3168 let tool = ToolDefinition::new(
3169 "echo",
3170 "Echo an input string",
3171 GenerationSchema::new("EchoArgs").property(GenerationSchemaProperty::new(
3172 "value",
3173 GenerationSchemaPropertyType::String,
3174 )),
3175 |args| {
3176 args["value"]
3177 .as_str()
3178 .map(ToolOutput::from)
3179 .ok_or_else(|| ToolCallError::new("missing value"))
3180 },
3181 );
3182
3183 assert_eq!(tool.name().as_str(), "echo");
3184 assert_eq!(tool.description().as_str(), "Echo an input string");
3185 assert_eq!(tool.parameters().name, "EchoArgs");
3186 assert_eq!(tool.call(serde_json::json!({"value": "hello"}))?, "hello");
3187 assert!(tool.call(serde_json::json!({})).is_err());
3188 Ok(())
3189 }
3190
3191 #[test]
3192 fn test_tool_handler_panic_returns_tool_error() {
3193 let tool = ToolDefinition::new(
3194 "panic_tool",
3195 "Tool that fails inside user code",
3196 GenerationSchema::new("PanicArgs"),
3197 |_| -> ToolResult {
3198 std::panic::resume_unwind(Box::new("boom"));
3199 },
3200 );
3201
3202 let error = tool.call(serde_json::json!({})).err();
3203
3204 assert!(
3205 error
3206 .as_ref()
3207 .is_some_and(|error| error.as_str().contains("tool handler panicked: boom")),
3208 "expected panic to be converted into ToolCallError"
3209 );
3210 }
3211
3212 proptest! {
3213 #[test]
3214 fn proptest_prompt_input_matches_c_string_null_boundary(input in ".*") {
3215 let result = Prompt::try_from(input.as_str());
3216
3217 if input.contains('\0') {
3218 prop_assert!(matches!(result, Err(Error::NullByte(_))));
3219 } else {
3220 match result {
3221 Ok(prompt) => prop_assert_eq!(prompt.as_str(), input.as_str()),
3222 Err(error) => prop_assert!(false, "unexpected prompt error: {error}"),
3223 }
3224 }
3225 }
3226
3227 #[test]
3228 fn proptest_instructions_match_c_string_null_boundary(input in ".*") {
3229 let result = SystemInstructions::try_from(input.as_str());
3230
3231 if input.contains('\0') {
3232 prop_assert!(matches!(result, Err(Error::NullByte(_))));
3233 } else {
3234 match result {
3235 Ok(instructions) => prop_assert_eq!(instructions.as_str(), input.as_str()),
3236 Err(error) => prop_assert!(false, "unexpected instructions error: {error}"),
3237 }
3238 }
3239 }
3240
3241 #[test]
3242 fn proptest_temperature_validation_matches_closed_interval(temp in any::<f64>()) {
3243 let result = Temperature::new(temp);
3244
3245 if (Temperature::MIN..=Temperature::MAX).contains(&temp) {
3246 match result {
3247 Ok(temperature) => prop_assert_eq!(temperature.as_f64(), temp),
3248 Err(error) => prop_assert!(false, "unexpected temperature error: {error}"),
3249 }
3250 } else {
3251 prop_assert!(matches!(result, Err(Error::InvalidTemperature(value)) if value.to_bits() == temp.to_bits()));
3252 }
3253 }
3254
3255 #[test]
3256 fn proptest_generation_options_preserve_max_tokens(max_tokens in any::<usize>()) {
3257 if max_tokens <= MaxTokens::MAX {
3258 match GenerationOptions::new().try_max_tokens(max_tokens) {
3259 Ok(opts) => match opts.validated() {
3260 Ok(config) => prop_assert_eq!(config.ffi_max_tokens(), max_tokens as i64),
3261 Err(error) => prop_assert!(false, "unexpected options error: {error}"),
3262 },
3263 Err(error) => prop_assert!(false, "unexpected max token error: {error}"),
3264 }
3265 } else {
3266 prop_assert!(matches!(
3267 GenerationOptions::new().try_max_tokens(max_tokens),
3268 Err(Error::InvalidMaxTokens(value)) if value == max_tokens
3269 ));
3270 }
3271 }
3272 }
3273
3274 #[test]
3279 #[ignore = "requires Apple Intelligence (macOS 26+, Apple Silicon, AI enabled)"]
3280 fn test_simple_respond() -> Result<(), Error> {
3281 let response =
3282 futures_executor::block_on(respond("Reply with only the number: what is 2 + 2?"))?;
3283 assert!(
3284 response.as_str().contains('4'),
3285 "expected '4' in: {response:?}"
3286 );
3287 Ok(())
3288 }
3289
3290 #[test]
3291 #[ignore = "requires Apple Intelligence"]
3292 fn test_respond_with_low_temperature() -> Result<(), Error> {
3293 let opts = GenerationOptions::new().temperature(Temperature::new(0.0)?);
3294 let r = futures_executor::block_on(respond_with_options(
3295 "Reply with only the word: capital of France?",
3296 &opts,
3297 ))?;
3298 assert!(
3299 r.as_str().to_lowercase().contains("paris"),
3300 "expected Paris in: {r:?}"
3301 );
3302 Ok(())
3303 }
3304
3305 #[test]
3306 #[ignore = "requires Apple Intelligence"]
3307 fn test_multi_turn_session() -> Result<(), Error> {
3308 let session = LanguageModelSession::with_instructions(
3309 "Reply to every message with exactly one word.",
3310 )?;
3311 let r1 = futures_executor::block_on(session.respond_to("Say hello."))?;
3312 let r2 = futures_executor::block_on(session.respond_to("Say goodbye."))?;
3313 assert!(!r1.is_empty(), "first response was empty");
3314 assert!(!r2.is_empty(), "second response was empty");
3315 Ok(())
3316 }
3317
3318 #[test]
3319 #[ignore = "requires Apple Intelligence"]
3320 fn test_streaming_yields_chunks() -> Result<(), Error> {
3321 let session = LanguageModelSession::new()?;
3322 let stream = session.stream_response("Count: one two three")?;
3323
3324 let chunks: Vec<ResponseText> =
3325 futures_executor::block_on_stream(stream).collect::<Result<_, _>>()?;
3326
3327 assert!(!chunks.is_empty(), "stream produced no chunks");
3328 let full = chunks
3329 .into_iter()
3330 .map(ResponseText::into_string)
3331 .collect::<Vec<_>>()
3332 .join("");
3333 assert!(!full.is_empty(), "concatenated response was empty");
3334 Ok(())
3335 }
3336
3337 #[test]
3338 #[ignore = "requires Apple Intelligence"]
3339 fn test_structured_generation() -> Result<(), Error> {
3340 use serde::Deserialize;
3341
3342 #[derive(Debug, Deserialize)]
3343 struct MathAnswer {
3344 value: f64,
3345 explanation: String,
3346 }
3347
3348 let session = LanguageModelSession::new()?;
3349 let schema = GenerationSchema::new("MathAnswer")
3350 .description("A numeric answer with a brief explanation")
3351 .property(
3352 GenerationSchemaProperty::new("value", GenerationSchemaPropertyType::Double)
3353 .description("The numeric result"),
3354 )
3355 .property(
3356 GenerationSchemaProperty::new("explanation", GenerationSchemaPropertyType::String)
3357 .description("One-sentence explanation"),
3358 );
3359
3360 let answer: MathAnswer =
3361 futures_executor::block_on(session.respond_generating("What is 6 × 7?", &schema))?;
3362
3363 assert!(
3364 (answer.value - 42.0).abs() < 0.5,
3365 "expected 42, got {}",
3366 answer.value
3367 );
3368 assert!(!answer.explanation.is_empty(), "explanation was empty");
3369 Ok(())
3370 }
3371
3372 #[test]
3373 #[ignore = "requires Apple Intelligence"]
3374 fn test_tool_calling() -> Result<(), Error> {
3375 let tool = ToolDefinition::builder(
3376 "add_numbers",
3377 "Add two numbers together and return the sum",
3378 GenerationSchema::new("AddArgs")
3379 .property(
3380 GenerationSchemaProperty::new("a", GenerationSchemaPropertyType::Double)
3381 .description("First number"),
3382 )
3383 .property(
3384 GenerationSchemaProperty::new("b", GenerationSchemaPropertyType::Double)
3385 .description("Second number"),
3386 ),
3387 )
3388 .handler(|args| {
3389 let a = args["a"].as_f64().unwrap_or(0.0);
3390 let b = args["b"].as_f64().unwrap_or(0.0);
3391 Ok(ToolOutput::from(format!("{}", a + b)))
3392 });
3393
3394 let session = LanguageModelSession::with_tools(
3395 "You are a calculator. Use the add_numbers tool when asked to add.",
3396 vec![tool],
3397 )?;
3398
3399 let response = futures_executor::block_on(session.respond_to("What is 15 + 27?"))?;
3400
3401 assert!(
3402 response.as_str().contains("42"),
3403 "expected 42 in response: {response:?}"
3404 );
3405 Ok(())
3406 }
3407}