use super::{
DType, DeviceType, GenerateParams, GeneratedToken, LlmBackend, ModelArchitecture, ModelConfig,
ModelInfo, Quantization, SpecialTokens, StreamEvent, TokenStream, Tokenizer,
};
use crate::error::{Result, RuvLLMError};
use std::path::{Path, PathBuf};
use std::sync::mpsc;
use std::time::Instant;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ComputeUnits {
CpuOnly,
CpuAndGpu,
CpuAndNeuralEngine,
#[default]
All,
}
impl ComputeUnits {
pub fn description(&self) -> &'static str {
match self {
Self::CpuOnly => "CPU only",
Self::CpuAndGpu => "CPU + GPU",
Self::CpuAndNeuralEngine => "CPU + Neural Engine (ANE)",
Self::All => "CPU + GPU + Neural Engine",
}
}
pub fn uses_ane(&self) -> bool {
matches!(self, Self::CpuAndNeuralEngine | Self::All)
}
pub fn uses_gpu(&self) -> bool {
matches!(self, Self::CpuAndGpu | Self::All)
}
}
#[derive(Debug, Clone)]
pub struct AneCapabilities {
pub available: bool,
pub tops: f32,
pub max_model_size_mb: usize,
pub supported_ops: Vec<String>,
}
impl Default for AneCapabilities {
fn default() -> Self {
Self::detect()
}
}
impl AneCapabilities {
pub fn detect() -> Self {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
Self {
available: true,
tops: 38.0, max_model_size_mb: 2048, supported_ops: vec![
"MatMul".to_string(),
"Conv2D".to_string(),
"GELU".to_string(),
"SiLU".to_string(),
"LayerNorm".to_string(),
"Softmax".to_string(),
"Add".to_string(),
"Mul".to_string(),
],
}
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
Self {
available: false,
tops: 0.0,
max_model_size_mb: 0,
supported_ops: vec![],
}
}
}
pub fn is_model_suitable(&self, model_size_mb: usize) -> bool {
self.available && model_size_mb <= self.max_model_size_mb
}
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", feature = "coreml"))]
pub mod coreml_native {
use super::*;
use objc2::rc::Retained;
use objc2::runtime::AnyObject;
use objc2::{msg_send_id, ClassType};
use objc2_core_ml::{
MLComputeUnits as MLComputeUnitsObjc, MLDictionaryFeatureProvider, MLFeatureProvider,
MLFeatureValue, MLModel, MLModelConfiguration, MLMultiArray, MLMultiArrayDataType,
MLPredictionOptions,
};
use objc2_foundation::{NSArray, NSDictionary, NSNumber, NSString, NSURL};
pub struct CoreMLModelHandle {
model: Retained<MLModel>,
model_path: PathBuf,
description: String,
input_names: Vec<String>,
output_names: Vec<String>,
vocab_size: Option<usize>,
hidden_size: Option<usize>,
}
unsafe impl Send for CoreMLModelHandle {}
unsafe impl Sync for CoreMLModelHandle {}
impl CoreMLModelHandle {
pub fn load(path: &Path, compute_units: ComputeUnits) -> Result<Self> {
if !path.exists() {
return Err(RuvLLMError::NotFound(format!(
"Core ML model not found: {}",
path.display()
)));
}
let url = NSURL::from_file_path(path).ok_or_else(|| {
RuvLLMError::CoreML(format!("Invalid model path: {}", path.display()))
})?;
let config = unsafe { MLModelConfiguration::new() };
let ml_compute_units = match compute_units {
ComputeUnits::CpuOnly => MLComputeUnitsObjc::CPUOnly,
ComputeUnits::CpuAndGpu => MLComputeUnitsObjc::CPUAndGPU,
ComputeUnits::CpuAndNeuralEngine => MLComputeUnitsObjc::CPUAndNeuralEngine,
ComputeUnits::All => MLComputeUnitsObjc::All,
};
unsafe {
config.setComputeUnits(ml_compute_units);
}
let model =
unsafe { MLModel::modelWithContentsOfURL_configuration_error(&url, &config) }
.map_err(|e| {
RuvLLMError::CoreML(format!(
"Failed to load Core ML model from {}: {}",
path.display(),
e.localizedDescription()
))
})?;
let (description, input_names, output_names, vocab_size, hidden_size) =
Self::extract_model_info(&model);
Ok(Self {
model,
model_path: path.to_path_buf(),
description,
input_names,
output_names,
vocab_size,
hidden_size,
})
}
fn extract_model_info(
model: &MLModel,
) -> (
String,
Vec<String>,
Vec<String>,
Option<usize>,
Option<usize>,
) {
unsafe {
let desc = model.modelDescription();
let input_desc = desc.inputDescriptionsByName();
let output_desc = desc.outputDescriptionsByName();
let input_count = input_desc.count();
let output_count = output_desc.count();
let input_names: Vec<String> = input_desc
.allKeys()
.iter()
.map(|key| key.to_string())
.collect();
let output_names: Vec<String> = output_desc
.allKeys()
.iter()
.map(|key| key.to_string())
.collect();
let description = format!("Inputs: {}, Outputs: {}", input_count, output_count);
let vocab_size = None; let hidden_size = None;
(
description,
input_names,
output_names,
vocab_size,
hidden_size,
)
}
}
pub fn create_input_array(&self, token_ids: &[i32]) -> Result<Retained<MLMultiArray>> {
let seq_len = token_ids.len();
unsafe {
let shape_vec: Vec<Retained<NSNumber>> = vec![
NSNumber::new_isize(1),
NSNumber::new_isize(seq_len as isize),
];
let shape = NSArray::from_retained_slice(&shape_vec);
use objc2::rc::Allocated;
let alloc: Allocated<MLMultiArray> = msg_send_id![MLMultiArray::class(), alloc];
let array = MLMultiArray::initWithShape_dataType_error(
alloc,
&shape,
MLMultiArrayDataType::Int32,
)
.map_err(|e| {
RuvLLMError::CoreML(format!(
"Failed to create input MLMultiArray: {}",
e.localizedDescription()
))
})?;
let ptr = array.dataPointer().as_ptr() as *mut i32;
for (i, &token_id) in token_ids.iter().enumerate() {
*ptr.add(i) = token_id;
}
Ok(array)
}
}
pub fn create_float_array(&self, shape: &[usize]) -> Result<Retained<MLMultiArray>> {
unsafe {
let shape_vec: Vec<Retained<NSNumber>> = shape
.iter()
.map(|&d| NSNumber::new_isize(d as isize))
.collect();
let ns_shape = NSArray::from_retained_slice(&shape_vec);
use objc2::rc::Allocated;
let alloc: Allocated<MLMultiArray> = msg_send_id![MLMultiArray::class(), alloc];
let array = MLMultiArray::initWithShape_dataType_error(
alloc,
&ns_shape,
MLMultiArrayDataType::Float32,
)
.map_err(|e| {
RuvLLMError::CoreML(format!(
"Failed to create float MLMultiArray: {}",
e.localizedDescription()
))
})?;
Ok(array)
}
}
pub fn predict(&self, input_name: &str, token_ids: &[i32]) -> Result<Vec<f32>> {
let input_array = self.create_input_array(token_ids)?;
unsafe {
let input_key = NSString::from_str(input_name);
let feature_value = MLFeatureValue::featureValueWithMultiArray(&input_array);
use objc2::runtime::ProtocolObject;
let dict: Retained<NSDictionary<NSString, AnyObject>> = msg_send_id![NSDictionary::<NSString, AnyObject>::class(), dictionaryWithObject: &*feature_value, forKey: &*input_key];
use objc2::rc::Allocated;
let alloc: Allocated<MLDictionaryFeatureProvider> =
msg_send_id![MLDictionaryFeatureProvider::class(), alloc];
let provider = MLDictionaryFeatureProvider::initWithDictionary_error(alloc, &*dict)
.map_err(|e| {
RuvLLMError::CoreML(format!(
"Failed to create feature provider: {}",
e.localizedDescription()
))
})?;
let options = MLPredictionOptions::new();
let provider_ref = ProtocolObject::from_ref(&*provider);
let output = self
.model
.predictionFromFeatures_options_error(provider_ref, &options)
.map_err(|e| {
RuvLLMError::CoreML(format!(
"Prediction failed: {}",
e.localizedDescription()
))
})?;
let output_name = self
.output_names
.first()
.ok_or_else(|| RuvLLMError::CoreML("No output features found".to_string()))?;
let output_key = NSString::from_str(output_name);
let output_value = MLFeatureProvider::featureValueForName(&*output, &output_key)
.ok_or_else(|| {
RuvLLMError::CoreML(format!("Output feature '{}' not found", output_name))
})?;
let output_array = output_value.multiArrayValue().ok_or_else(|| {
RuvLLMError::CoreML("Output is not a multi-array".to_string())
})?;
let count = output_array.count() as usize;
let ptr = output_array.dataPointer().as_ptr() as *const f32;
let logits: Vec<f32> = (0..count).map(|i| *ptr.add(i)).collect();
Ok(logits)
}
}
pub fn get_embeddings(
&self,
input_name: &str,
token_ids: &[i32],
embedding_output_name: Option<&str>,
) -> Result<Vec<f32>> {
let input_array = self.create_input_array(token_ids)?;
unsafe {
use objc2::rc::Allocated;
use objc2::runtime::ProtocolObject;
let input_key = NSString::from_str(input_name);
let feature_value = MLFeatureValue::featureValueWithMultiArray(&input_array);
let dict: Retained<NSDictionary<NSString, AnyObject>> = msg_send_id![NSDictionary::<NSString, AnyObject>::class(), dictionaryWithObject: &*feature_value, forKey: &*input_key];
let alloc: Allocated<MLDictionaryFeatureProvider> =
msg_send_id![MLDictionaryFeatureProvider::class(), alloc];
let provider = MLDictionaryFeatureProvider::initWithDictionary_error(alloc, &*dict)
.map_err(|e| {
RuvLLMError::CoreML(format!(
"Failed to create feature provider: {}",
e.localizedDescription()
))
})?;
let options = MLPredictionOptions::new();
let provider_ref = ProtocolObject::from_ref(&*provider);
let output = self
.model
.predictionFromFeatures_options_error(provider_ref, &options)
.map_err(|e| {
RuvLLMError::CoreML(format!(
"Prediction failed: {}",
e.localizedDescription()
))
})?;
let embedding_name = embedding_output_name.map(String::from).or_else(|| {
for name in &self.output_names {
let lower = name.to_lowercase();
if lower.contains("embed")
|| lower.contains("hidden")
|| lower.contains("pooled")
|| lower.contains("last_hidden")
{
return Some(name.clone());
}
}
self.output_names.first().cloned()
});
let output_name = embedding_name.ok_or_else(|| {
RuvLLMError::CoreML("No embedding output found in model".to_string())
})?;
let output_key = NSString::from_str(&output_name);
let output_value = MLFeatureProvider::featureValueForName(&*output, &output_key)
.ok_or_else(|| {
RuvLLMError::CoreML(format!("Embedding output '{}' not found", output_name))
})?;
let output_array = output_value.multiArrayValue().ok_or_else(|| {
RuvLLMError::CoreML("Embedding output is not a multi-array".to_string())
})?;
let count = output_array.count() as usize;
let ptr = output_array.dataPointer().as_ptr() as *const f32;
let shape_count = output_array.shape().count() as usize;
if shape_count >= 3 {
let shape_arr = output_array.shape();
let seq_len = shape_arr.objectAtIndex(1).intValue() as usize;
let hidden_dim = shape_arr.objectAtIndex(2).intValue() as usize;
let last_token_start = (seq_len - 1) * hidden_dim;
let embeddings: Vec<f32> = (0..hidden_dim)
.map(|i| *ptr.add(last_token_start + i))
.collect();
Ok(embeddings)
} else {
let embeddings: Vec<f32> = (0..count).map(|i| *ptr.add(i)).collect();
Ok(embeddings)
}
}
}
pub fn model(&self) -> &MLModel {
&self.model
}
pub fn path(&self) -> &Path {
&self.model_path
}
pub fn description(&self) -> &str {
&self.description
}
pub fn input_names(&self) -> &[String] {
&self.input_names
}
pub fn output_names(&self) -> &[String] {
&self.output_names
}
pub fn num_inputs(&self) -> usize {
self.input_names.len()
}
pub fn num_outputs(&self) -> usize {
self.output_names.len()
}
}
impl std::fmt::Debug for CoreMLModelHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CoreMLModelHandle")
.field("model_path", &self.model_path)
.field("description", &self.description)
.field("input_names", &self.input_names)
.field("output_names", &self.output_names)
.finish()
}
}
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", feature = "coreml"))]
pub use coreml_native::CoreMLModelHandle;
#[cfg(all(
target_os = "macos",
target_arch = "aarch64",
feature = "coreml",
feature = "candle"
))]
pub struct CoreMLStreamIterator<'a> {
model_handle: &'a CoreMLModelHandle,
tokenizer: &'a crate::tokenizer::RuvTokenizer,
input_ids: Vec<i32>,
max_tokens: usize,
temperature: f32,
top_p: f32,
input_feature_name: String,
eos_token_id: u32,
vocab_size: usize,
generated_count: usize,
finished: bool,
}
#[cfg(all(
target_os = "macos",
target_arch = "aarch64",
feature = "coreml",
feature = "candle"
))]
impl<'a> CoreMLStreamIterator<'a> {
pub fn new(
model_handle: &'a CoreMLModelHandle,
tokenizer: &'a crate::tokenizer::RuvTokenizer,
input_ids: Vec<i32>,
max_tokens: usize,
temperature: f32,
top_p: f32,
input_feature_name: String,
eos_token_id: u32,
vocab_size: usize,
) -> Self {
Self {
model_handle,
tokenizer,
input_ids,
max_tokens,
temperature,
top_p,
input_feature_name,
eos_token_id,
vocab_size,
generated_count: 0,
finished: false,
}
}
fn sample_token(&self, logits: &[f32]) -> Result<u32> {
use rand::Rng;
if logits.is_empty() {
return Err(RuvLLMError::Generation("Empty logits".to_string()));
}
let scaled_logits: Vec<f32> = if self.temperature > 0.0 && self.temperature != 1.0 {
logits.iter().map(|&x| x / self.temperature).collect()
} else {
logits.to_vec()
};
let max_logit = scaled_logits
.iter()
.cloned()
.fold(f32::NEG_INFINITY, f32::max);
let exp_logits: Vec<f32> = scaled_logits
.iter()
.map(|&x| (x - max_logit).exp())
.collect();
let sum_exp: f32 = exp_logits.iter().sum();
let probs: Vec<f32> = exp_logits.iter().map(|&x| x / sum_exp).collect();
if self.top_p < 1.0 {
let mut indexed_probs: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect();
indexed_probs
.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut cumsum = 0.0;
let mut cutoff_idx = indexed_probs.len();
for (i, (_, p)) in indexed_probs.iter().enumerate() {
cumsum += p;
if cumsum >= self.top_p {
cutoff_idx = i + 1;
break;
}
}
let filtered: Vec<(usize, f32)> = indexed_probs[..cutoff_idx].to_vec();
let filter_sum: f32 = filtered.iter().map(|(_, p)| p).sum();
let normalized: Vec<(usize, f32)> = filtered
.into_iter()
.map(|(i, p)| (i, p / filter_sum))
.collect();
let mut rng = rand::thread_rng();
let r: f32 = rng.gen();
let mut cumsum = 0.0;
for (idx, p) in &normalized {
cumsum += p;
if r < cumsum {
return Ok(*idx as u32);
}
}
return Ok(normalized.last().map(|(i, _)| *i as u32).unwrap_or(0));
}
let mut rng = rand::thread_rng();
let r: f32 = rng.gen();
let mut cumsum = 0.0;
for (idx, &p) in probs.iter().enumerate() {
cumsum += p;
if r < cumsum {
return Ok(idx as u32);
}
}
Ok(probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i as u32)
.unwrap_or(0))
}
}
#[cfg(all(
target_os = "macos",
target_arch = "aarch64",
feature = "coreml",
feature = "candle"
))]
impl<'a> Iterator for CoreMLStreamIterator<'a> {
type Item = Result<GeneratedToken>;
fn next(&mut self) -> Option<Self::Item> {
if self.finished || self.generated_count >= self.max_tokens {
return None;
}
let logits = match self
.model_handle
.predict(&self.input_feature_name, &self.input_ids)
{
Ok(l) => l,
Err(e) => {
self.finished = true;
return Some(Err(e));
}
};
let last_token_logits = if logits.len() >= self.vocab_size {
&logits[logits.len() - self.vocab_size..]
} else {
&logits
};
let next_token = match self.sample_token(last_token_logits) {
Ok(t) => t,
Err(e) => {
self.finished = true;
return Some(Err(e));
}
};
if next_token == self.eos_token_id {
self.finished = true;
return None;
}
let text = self.tokenizer.decode(&[next_token]).unwrap_or_default();
self.input_ids.push(next_token as i32);
self.generated_count += 1;
Some(Ok(GeneratedToken {
id: next_token,
text,
logprob: None,
is_special: false,
}))
}
}
#[cfg(all(
target_os = "macos",
target_arch = "aarch64",
feature = "coreml",
feature = "candle"
))]
unsafe impl<'a> Send for CoreMLStreamIterator<'a> {}
#[cfg(feature = "coreml")]
pub struct CoreMLBackend {
compute_units: ComputeUnits,
ane_caps: AneCapabilities,
cache_dir: PathBuf,
model_info: Option<ModelInfo>,
loaded: bool,
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
model_handle: Option<CoreMLModelHandle>,
#[cfg(feature = "candle")]
tokenizer: Option<crate::tokenizer::RuvTokenizer>,
input_feature_name: String,
eos_token_id: u32,
vocab_size: usize,
}
#[cfg(feature = "coreml")]
impl std::fmt::Debug for CoreMLBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CoreMLBackend")
.field("compute_units", &self.compute_units)
.field("ane_caps", &self.ane_caps)
.field("cache_dir", &self.cache_dir)
.field("model_info", &self.model_info)
.field("loaded", &self.loaded)
.field("input_feature_name", &self.input_feature_name)
.field("eos_token_id", &self.eos_token_id)
.field("vocab_size", &self.vocab_size)
.finish()
}
}
#[cfg(feature = "coreml")]
unsafe impl Send for CoreMLBackend {}
#[cfg(feature = "coreml")]
unsafe impl Sync for CoreMLBackend {}
#[cfg(feature = "coreml")]
impl Default for CoreMLBackend {
fn default() -> Self {
Self {
compute_units: ComputeUnits::All,
ane_caps: AneCapabilities::detect(),
cache_dir: dirs::cache_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("ruvllm")
.join("coreml"),
model_info: None,
loaded: false,
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
model_handle: None,
#[cfg(feature = "candle")]
tokenizer: None,
input_feature_name: "input_ids".to_string(),
eos_token_id: 2, vocab_size: 32000, }
}
}
#[cfg(feature = "coreml")]
impl CoreMLBackend {
pub fn new() -> Result<Self> {
let caps = AneCapabilities::detect();
if !caps.available {
return Err(RuvLLMError::Config(
"Apple Neural Engine not available on this device".to_string(),
));
}
let cache_dir = dirs::cache_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("ruvllm")
.join("coreml");
std::fs::create_dir_all(&cache_dir).map_err(|e| {
RuvLLMError::Storage(format!("Failed to create Core ML cache directory: {}", e))
})?;
Ok(Self {
compute_units: ComputeUnits::All,
ane_caps: caps,
cache_dir,
model_info: None,
loaded: false,
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
model_handle: None,
#[cfg(feature = "candle")]
tokenizer: None,
input_feature_name: "input_ids".to_string(),
eos_token_id: 2, vocab_size: 32000, })
}
#[cfg(feature = "candle")]
pub fn with_tokenizer(mut self, tokenizer: crate::tokenizer::RuvTokenizer) -> Self {
self.eos_token_id = tokenizer.eos_token_id();
self.vocab_size = tokenizer.vocab_size();
self.tokenizer = Some(tokenizer);
self
}
pub fn with_input_feature_name(mut self, name: impl Into<String>) -> Self {
self.input_feature_name = name.into();
self
}
pub fn with_eos_token_id(mut self, eos_token_id: u32) -> Self {
self.eos_token_id = eos_token_id;
self
}
pub fn with_vocab_size(mut self, vocab_size: usize) -> Self {
self.vocab_size = vocab_size;
self
}
#[cfg(feature = "candle")]
pub fn load_tokenizer(&mut self, model_id_or_path: &str) -> Result<()> {
let tokenizer = if std::path::Path::new(model_id_or_path).exists() {
crate::tokenizer::RuvTokenizer::from_file(std::path::Path::new(model_id_or_path))?
} else {
crate::tokenizer::RuvTokenizer::from_pretrained(model_id_or_path)?
};
self.eos_token_id = tokenizer.eos_token_id();
self.vocab_size = tokenizer.vocab_size();
self.tokenizer = Some(tokenizer);
Ok(())
}
pub fn with_compute_units(mut self, units: ComputeUnits) -> Self {
self.compute_units = units;
self
}
pub fn ane_capabilities(&self) -> &AneCapabilities {
&self.ane_caps
}
pub fn is_model_ane_suitable(&self, model_size_mb: usize) -> bool {
self.ane_caps.is_model_suitable(model_size_mb)
}
fn get_coreml_cache_path(&self, model_path: &Path) -> PathBuf {
let model_name = model_path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("model");
self.cache_dir.join(format!("{}.mlmodelc", model_name))
}
fn convert_to_coreml(&self, _gguf_path: &Path, _output_path: &Path) -> Result<()> {
Err(RuvLLMError::NotImplemented(
"GGUF to Core ML conversion not yet implemented. \
Use `coremltools` Python package to convert models, or \
use pre-converted Core ML models."
.to_string(),
))
}
fn validate_coreml_path(path: &Path) -> Result<()> {
if !path.exists() {
return Err(RuvLLMError::NotFound(format!(
"Model path does not exist: {}",
path.display()
)));
}
let extension = path.extension().and_then(|e| e.to_str());
match extension {
Some("mlmodelc") => {
if !path.is_dir() {
return Err(RuvLLMError::CoreML(
".mlmodelc should be a directory (compiled Core ML model)".to_string(),
));
}
let model_mil = path.join("model.mil");
let coreml_data = path.join("coremldata.bin");
let weights = path.join("weights");
if !model_mil.exists() && !coreml_data.exists() && !weights.exists() {
return Err(RuvLLMError::CoreML(format!(
"Invalid .mlmodelc directory: missing expected files at {}",
path.display()
)));
}
}
Some("mlmodel") => {
if !path.is_file() {
return Err(RuvLLMError::CoreML(".mlmodel should be a file".to_string()));
}
}
Some("mlpackage") => {
if !path.is_dir() {
return Err(RuvLLMError::CoreML(
".mlpackage should be a directory".to_string(),
));
}
}
_ => {
return Err(RuvLLMError::CoreML(format!(
"Unsupported Core ML model format. Expected .mlmodel, .mlmodelc, or .mlpackage: {}",
path.display()
)));
}
}
Ok(())
}
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
pub fn model_handle(&self) -> Option<&CoreMLModelHandle> {
self.model_handle.as_ref()
}
pub fn compute_units(&self) -> ComputeUnits {
self.compute_units
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", feature = "candle"))]
fn sample_token(&self, logits: &[f32], temperature: f32, top_p: f32) -> Result<u32> {
use rand::Rng;
if logits.is_empty() {
return Err(RuvLLMError::Generation("Empty logits".to_string()));
}
let scaled_logits: Vec<f32> = if temperature > 0.0 && temperature != 1.0 {
logits.iter().map(|&x| x / temperature).collect()
} else {
logits.to_vec()
};
let max_logit = scaled_logits
.iter()
.cloned()
.fold(f32::NEG_INFINITY, f32::max);
let exp_logits: Vec<f32> = scaled_logits
.iter()
.map(|&x| (x - max_logit).exp())
.collect();
let sum_exp: f32 = exp_logits.iter().sum();
let probs: Vec<f32> = exp_logits.iter().map(|&x| x / sum_exp).collect();
if top_p < 1.0 {
let mut indexed_probs: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect();
indexed_probs
.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut cumsum = 0.0;
let mut cutoff_idx = indexed_probs.len();
for (i, (_, p)) in indexed_probs.iter().enumerate() {
cumsum += p;
if cumsum >= top_p {
cutoff_idx = i + 1;
break;
}
}
let filtered: Vec<(usize, f32)> = indexed_probs[..cutoff_idx].to_vec();
let filter_sum: f32 = filtered.iter().map(|(_, p)| p).sum();
let normalized: Vec<(usize, f32)> = filtered
.into_iter()
.map(|(i, p)| (i, p / filter_sum))
.collect();
let mut rng = rand::thread_rng();
let r: f32 = rng.gen();
let mut cumsum = 0.0;
for (idx, p) in &normalized {
cumsum += p;
if r < cumsum {
return Ok(*idx as u32);
}
}
return Ok(normalized.last().map(|(i, _)| *i as u32).unwrap_or(0));
}
let mut rng = rand::thread_rng();
let r: f32 = rng.gen();
let mut cumsum = 0.0;
for (idx, &p) in probs.iter().enumerate() {
cumsum += p;
if r < cumsum {
return Ok(idx as u32);
}
}
Ok(probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i as u32)
.unwrap_or(0))
}
}
#[cfg(feature = "coreml")]
impl LlmBackend for CoreMLBackend {
fn load_model(&mut self, model_id: &str, config: ModelConfig) -> Result<()> {
let path = Path::new(model_id);
let extension = path.extension().and_then(|e| e.to_str());
if matches!(extension, Some("mlmodelc" | "mlmodel" | "mlpackage")) {
Self::validate_coreml_path(path)?;
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
let handle = CoreMLModelHandle::load(path, self.compute_units)?;
let input_names = handle.input_names();
let output_names = handle.output_names();
tracing::info!(
"Loaded Core ML model: {} (inputs: {:?}, outputs: {:?})",
path.display(),
input_names,
output_names
);
let memory_usage = if path.is_dir() {
walkdir_size(path).unwrap_or(0)
} else {
std::fs::metadata(path)
.map(|m| m.len() as usize)
.unwrap_or(0)
};
self.model_info = Some(ModelInfo {
name: path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unknown")
.to_string(),
architecture: config.architecture,
num_parameters: 0, vocab_size: config.vocab_size.unwrap_or(32000),
hidden_size: config.hidden_size.unwrap_or(4096),
num_layers: config.num_layers.unwrap_or(32),
max_context_length: config.max_sequence_length,
quantization: config.quantization,
memory_usage,
});
self.model_handle = Some(handle);
self.loaded = true;
return Ok(());
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
return Err(RuvLLMError::Config(
"Core ML model loading is only supported on macOS aarch64 (Apple Silicon)"
.to_string(),
));
}
}
if matches!(extension, Some("gguf")) {
let coreml_path = self.get_coreml_cache_path(path);
if !coreml_path.exists() {
self.convert_to_coreml(path, &coreml_path)?;
}
return self.load_model(coreml_path.to_str().unwrap(), config);
}
Err(RuvLLMError::NotFound(format!(
"Unsupported model format. Expected .mlmodel, .mlmodelc, .mlpackage, or .gguf: {}",
model_id
)))
}
fn generate(&self, prompt: &str, params: GenerateParams) -> Result<String> {
if !self.loaded {
return Err(RuvLLMError::InvalidOperation("No model loaded".to_string()));
}
#[cfg(all(target_os = "macos", target_arch = "aarch64", feature = "candle"))]
{
let model_handle = self.model_handle.as_ref().ok_or_else(|| {
RuvLLMError::InvalidOperation("Model handle not initialized".to_string())
})?;
let tokenizer = self.tokenizer.as_ref().ok_or_else(|| {
RuvLLMError::Config(
"Tokenizer not loaded. Call load_tokenizer() or use with_tokenizer() first."
.to_string(),
)
})?;
let mut input_ids: Vec<i32> = tokenizer
.encode(prompt)?
.into_iter()
.map(|t| t as i32)
.collect();
let max_tokens = params.max_tokens;
let temperature = params.temperature;
let top_p = params.top_p;
let _start_time = Instant::now();
let mut generated_tokens: Vec<u32> = Vec::with_capacity(max_tokens);
for _ in 0..max_tokens {
let logits = model_handle.predict(&self.input_feature_name, &input_ids)?;
let vocab_size = self.vocab_size;
let last_token_logits = if logits.len() >= vocab_size {
&logits[logits.len() - vocab_size..]
} else {
&logits
};
let next_token = self.sample_token(last_token_logits, temperature, top_p)?;
if next_token == self.eos_token_id {
break;
}
generated_tokens.push(next_token);
input_ids.push(next_token as i32);
}
let output = tokenizer.decode(&generated_tokens)?;
return Ok(output);
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", feature = "candle")))]
{
let _ = (prompt, params);
Err(RuvLLMError::Config(
"Core ML inference requires macOS aarch64 with candle feature enabled".to_string(),
))
}
}
fn generate_stream(
&self,
prompt: &str,
params: GenerateParams,
) -> Result<Box<dyn Iterator<Item = Result<GeneratedToken>> + Send + '_>> {
#[cfg(all(target_os = "macos", target_arch = "aarch64", feature = "candle"))]
{
if !self.loaded {
return Err(RuvLLMError::InvalidOperation("No model loaded".to_string()));
}
let model_handle = self.model_handle.as_ref().ok_or_else(|| {
RuvLLMError::InvalidOperation("Model handle not initialized".to_string())
})?;
let tokenizer = self.tokenizer.as_ref().ok_or_else(|| {
RuvLLMError::Config(
"Tokenizer not loaded. Call load_tokenizer() or use with_tokenizer() first."
.to_string(),
)
})?;
let input_ids: Vec<i32> = tokenizer
.encode(prompt)?
.into_iter()
.map(|t| t as i32)
.collect();
let max_tokens = params.max_tokens;
let temperature = params.temperature;
let top_p = params.top_p;
let input_feature_name = self.input_feature_name.clone();
let eos_token_id = self.eos_token_id;
let vocab_size = self.vocab_size;
let iter = CoreMLStreamIterator::new(
model_handle,
tokenizer,
input_ids,
max_tokens,
temperature,
top_p,
input_feature_name,
eos_token_id,
vocab_size,
);
return Ok(Box::new(iter));
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", feature = "candle")))]
{
let _ = (prompt, params);
Err(RuvLLMError::Config(
"Core ML streaming requires macOS aarch64 with candle feature enabled".to_string(),
))
}
}
fn generate_stream_v2(&self, prompt: &str, params: GenerateParams) -> Result<TokenStream> {
#[cfg(all(target_os = "macos", target_arch = "aarch64", feature = "candle"))]
{
if !self.loaded {
return Err(RuvLLMError::InvalidOperation("No model loaded".to_string()));
}
let model_handle = self.model_handle.as_ref().ok_or_else(|| {
RuvLLMError::InvalidOperation("Model handle not initialized".to_string())
})?;
let tokenizer = self.tokenizer.as_ref().ok_or_else(|| {
RuvLLMError::Config(
"Tokenizer not loaded. Call load_tokenizer() or use with_tokenizer() first."
.to_string(),
)
})?;
let mut input_ids: Vec<i32> = tokenizer
.encode(prompt)?
.into_iter()
.map(|t| t as i32)
.collect();
let max_tokens = params.max_tokens;
let temperature = params.temperature;
let top_p = params.top_p;
let start_time = Instant::now();
let (tx, rx) = mpsc::channel::<StreamEvent>();
let mut generated_count = 0;
for _step in 0..max_tokens {
let logits = match model_handle.predict(&self.input_feature_name, &input_ids) {
Ok(l) => l,
Err(e) => {
let _ = tx.send(StreamEvent::Error(e.to_string()));
break;
}
};
let last_token_logits = if logits.len() >= self.vocab_size {
&logits[logits.len() - self.vocab_size..]
} else {
&logits
};
let next_token = match self.sample_token(last_token_logits, temperature, top_p) {
Ok(t) => t,
Err(e) => {
let _ = tx.send(StreamEvent::Error(e.to_string()));
break;
}
};
if next_token == self.eos_token_id {
break;
}
let text = tokenizer.decode(&[next_token]).unwrap_or_default();
let _ = tx.send(StreamEvent::Token(GeneratedToken {
id: next_token,
text,
logprob: None,
is_special: next_token == self.eos_token_id,
}));
input_ids.push(next_token as i32);
generated_count += 1;
}
let elapsed = start_time.elapsed();
let tokens_per_sec = generated_count as f64 / elapsed.as_secs_f64();
let _ = tx.send(StreamEvent::Done {
total_tokens: input_ids.len(),
duration_ms: elapsed.as_millis() as u64,
tokens_per_second: tokens_per_sec,
});
return Ok(TokenStream::new(rx));
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", feature = "candle")))]
{
let _ = (prompt, params);
Err(RuvLLMError::Config(
"Core ML streaming requires macOS aarch64 with candle feature enabled".to_string(),
))
}
}
fn get_embeddings(&self, text: &str) -> Result<Vec<f32>> {
#[cfg(all(target_os = "macos", target_arch = "aarch64", feature = "candle"))]
{
if !self.loaded {
return Err(RuvLLMError::InvalidOperation("No model loaded".to_string()));
}
let model_handle = self.model_handle.as_ref().ok_or_else(|| {
RuvLLMError::InvalidOperation("Model handle not initialized".to_string())
})?;
let tokenizer = self.tokenizer.as_ref().ok_or_else(|| {
RuvLLMError::Config(
"Tokenizer not loaded. Call load_tokenizer() or use with_tokenizer() first."
.to_string(),
)
})?;
let token_ids: Vec<i32> = tokenizer
.encode(text)?
.into_iter()
.map(|t| t as i32)
.collect();
let embeddings = model_handle.get_embeddings(
&self.input_feature_name,
&token_ids,
None, )?;
return Ok(embeddings);
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64", feature = "candle")))]
{
let _ = text;
Err(RuvLLMError::Config(
"Core ML embeddings require macOS aarch64 with candle feature enabled".to_string(),
))
}
}
fn tokenizer(&self) -> Option<&dyn Tokenizer> {
#[cfg(feature = "candle")]
{
self.tokenizer.as_ref().map(|t| t as &dyn Tokenizer)
}
#[cfg(not(feature = "candle"))]
{
None
}
}
fn is_model_loaded(&self) -> bool {
self.loaded
}
fn model_info(&self) -> Option<ModelInfo> {
self.model_info.clone()
}
fn unload_model(&mut self) {
self.loaded = false;
self.model_info = None;
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
self.model_handle = None;
}
}
}
#[cfg(feature = "coreml")]
fn walkdir_size(path: &Path) -> std::io::Result<usize> {
let mut total = 0;
if path.is_dir() {
for entry in std::fs::read_dir(path)? {
let entry = entry?;
let path = entry.path();
if path.is_dir() {
total += walkdir_size(&path)?;
} else {
total += std::fs::metadata(&path)?.len() as usize;
}
}
} else {
total = std::fs::metadata(path)?.len() as usize;
}
Ok(total)
}
#[cfg(not(feature = "coreml"))]
#[derive(Debug)]
pub struct CoreMLBackend;
#[cfg(not(feature = "coreml"))]
impl CoreMLBackend {
pub fn new() -> Result<Self> {
Err(RuvLLMError::Config(
"Core ML feature not enabled. Enable with `coreml` feature flag.".to_string(),
))
}
}
#[cfg(not(feature = "coreml"))]
impl LlmBackend for CoreMLBackend {
fn load_model(&mut self, _model_id: &str, _config: ModelConfig) -> Result<()> {
Err(RuvLLMError::Config(
"Core ML feature not enabled".to_string(),
))
}
fn generate(&self, _prompt: &str, _params: GenerateParams) -> Result<String> {
Err(RuvLLMError::Config(
"Core ML feature not enabled".to_string(),
))
}
fn generate_stream(
&self,
_prompt: &str,
_params: GenerateParams,
) -> Result<Box<dyn Iterator<Item = Result<GeneratedToken>> + Send + '_>> {
Err(RuvLLMError::Config(
"Core ML feature not enabled".to_string(),
))
}
fn generate_stream_v2(&self, _prompt: &str, _params: GenerateParams) -> Result<TokenStream> {
Err(RuvLLMError::Config(
"Core ML feature not enabled".to_string(),
))
}
fn get_embeddings(&self, _text: &str) -> Result<Vec<f32>> {
Err(RuvLLMError::Config(
"Core ML feature not enabled".to_string(),
))
}
fn tokenizer(&self) -> Option<&dyn Tokenizer> {
None
}
fn is_model_loaded(&self) -> bool {
false
}
fn model_info(&self) -> Option<ModelInfo> {
None
}
fn unload_model(&mut self) {
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compute_units_default() {
let units = ComputeUnits::default();
assert_eq!(units, ComputeUnits::All);
}
#[test]
fn test_compute_units_uses_ane() {
assert!(ComputeUnits::CpuAndNeuralEngine.uses_ane());
assert!(ComputeUnits::All.uses_ane());
assert!(!ComputeUnits::CpuOnly.uses_ane());
assert!(!ComputeUnits::CpuAndGpu.uses_ane());
}
#[test]
fn test_compute_units_uses_gpu() {
assert!(ComputeUnits::CpuAndGpu.uses_gpu());
assert!(ComputeUnits::All.uses_gpu());
assert!(!ComputeUnits::CpuOnly.uses_gpu());
assert!(!ComputeUnits::CpuAndNeuralEngine.uses_gpu());
}
#[test]
fn test_compute_units_description() {
assert_eq!(ComputeUnits::CpuOnly.description(), "CPU only");
assert_eq!(ComputeUnits::CpuAndGpu.description(), "CPU + GPU");
assert_eq!(
ComputeUnits::CpuAndNeuralEngine.description(),
"CPU + Neural Engine (ANE)"
);
assert_eq!(ComputeUnits::All.description(), "CPU + GPU + Neural Engine");
}
#[test]
fn test_compute_units_clone() {
let units = ComputeUnits::CpuAndNeuralEngine;
let cloned = units.clone();
assert_eq!(units, cloned);
}
#[test]
fn test_compute_units_copy() {
let units = ComputeUnits::All;
let copied: ComputeUnits = units; assert_eq!(units, copied);
}
#[test]
fn test_compute_units_debug() {
let debug_str = format!("{:?}", ComputeUnits::CpuAndNeuralEngine);
assert!(debug_str.contains("CpuAndNeuralEngine"));
}
#[test]
fn test_compute_units_eq() {
assert_eq!(ComputeUnits::CpuOnly, ComputeUnits::CpuOnly);
assert_ne!(ComputeUnits::CpuOnly, ComputeUnits::CpuAndGpu);
assert_ne!(ComputeUnits::All, ComputeUnits::CpuAndNeuralEngine);
}
#[test]
fn test_ane_capabilities_detect() {
let caps = AneCapabilities::detect();
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
assert!(caps.available);
assert!(caps.tops > 0.0);
assert!(!caps.supported_ops.is_empty());
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
assert!(!caps.available);
}
}
#[test]
fn test_ane_capabilities_default() {
let caps = AneCapabilities::default();
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
assert!(caps.available);
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
assert!(!caps.available);
}
}
#[test]
fn test_ane_capabilities_model_suitability() {
let caps = AneCapabilities {
available: true,
tops: 38.0,
max_model_size_mb: 2048,
supported_ops: vec!["MatMul".to_string()],
};
assert!(caps.is_model_suitable(1000)); assert!(caps.is_model_suitable(2048)); assert!(!caps.is_model_suitable(4096)); assert!(caps.is_model_suitable(0)); assert!(caps.is_model_suitable(1)); }
#[test]
fn test_ane_capabilities_unavailable_device() {
let caps = AneCapabilities {
available: false,
tops: 0.0,
max_model_size_mb: 0,
supported_ops: vec![],
};
assert!(!caps.is_model_suitable(100));
assert!(!caps.is_model_suitable(0));
}
#[test]
fn test_ane_capabilities_clone() {
let caps = AneCapabilities {
available: true,
tops: 38.0,
max_model_size_mb: 2048,
supported_ops: vec!["MatMul".to_string(), "GELU".to_string()],
};
let cloned = caps.clone();
assert_eq!(caps.available, cloned.available);
assert_eq!(caps.tops, cloned.tops);
assert_eq!(caps.max_model_size_mb, cloned.max_model_size_mb);
assert_eq!(caps.supported_ops, cloned.supported_ops);
}
#[test]
fn test_ane_capabilities_debug() {
let caps = AneCapabilities::detect();
let debug_str = format!("{:?}", caps);
assert!(debug_str.contains("AneCapabilities"));
assert!(debug_str.contains("available"));
assert!(debug_str.contains("tops"));
}
#[test]
fn test_ane_capabilities_supported_ops() {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
let caps = AneCapabilities::detect();
assert!(caps.supported_ops.contains(&"MatMul".to_string()));
assert!(caps.supported_ops.contains(&"GELU".to_string()));
assert!(caps.supported_ops.contains(&"SiLU".to_string()));
assert!(caps.supported_ops.contains(&"LayerNorm".to_string()));
assert!(caps.supported_ops.contains(&"Softmax".to_string()));
}
}
#[test]
fn test_ane_capabilities_tops_reasonable() {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
let caps = AneCapabilities::detect();
assert!(caps.tops >= 10.0);
assert!(caps.tops <= 50.0);
}
}
#[cfg(feature = "coreml")]
mod coreml_backend_tests {
use super::*;
#[test]
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn test_coreml_backend_new_on_apple_silicon() {
let backend = CoreMLBackend::new();
assert!(backend.is_ok());
let backend = backend.unwrap();
assert!(!backend.is_model_loaded());
assert!(backend.model_info().is_none());
}
#[test]
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
fn test_coreml_backend_new_on_non_apple_silicon() {
let backend = CoreMLBackend::new();
assert!(backend.is_err());
let err = backend.unwrap_err();
assert!(err
.to_string()
.contains("Apple Neural Engine not available"));
}
#[test]
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn test_coreml_backend_with_compute_units() {
let backend = CoreMLBackend::new()
.unwrap()
.with_compute_units(ComputeUnits::CpuAndNeuralEngine);
assert_eq!(backend.compute_units(), ComputeUnits::CpuAndNeuralEngine);
}
#[test]
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn test_coreml_backend_ane_capabilities() {
let backend = CoreMLBackend::new().unwrap();
let caps = backend.ane_capabilities();
assert!(caps.available);
assert!(caps.tops > 0.0);
}
#[test]
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn test_coreml_backend_is_model_ane_suitable() {
let backend = CoreMLBackend::new().unwrap();
assert!(backend.is_model_ane_suitable(1000)); assert!(backend.is_model_ane_suitable(2048)); assert!(!backend.is_model_ane_suitable(5000)); }
#[test]
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn test_coreml_backend_unsupported_format() {
let mut backend = CoreMLBackend::new().unwrap();
let result = backend.load_model("model.safetensors", ModelConfig::default());
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("Unsupported model format"));
}
#[test]
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn test_coreml_backend_gguf_conversion_not_implemented() {
let mut backend = CoreMLBackend::new().unwrap();
let result = backend.load_model("/nonexistent/model.gguf", ModelConfig::default());
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("not yet implemented")
|| err.to_string().contains("conversion")
);
}
#[test]
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn test_coreml_backend_generate_requires_loaded_model() {
let backend = CoreMLBackend::new().unwrap();
let result = backend.generate("Hello", GenerateParams::default());
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("No model loaded"));
}
#[test]
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn test_coreml_backend_unload_model() {
let mut backend = CoreMLBackend::new().unwrap();
backend.unload_model();
assert!(!backend.is_model_loaded());
}
#[test]
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn test_coreml_backend_tokenizer_not_available() {
let backend = CoreMLBackend::new().unwrap();
assert!(backend.tokenizer().is_none());
}
#[test]
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn test_coreml_backend_generate_stream_requires_model() {
let backend = CoreMLBackend::new().unwrap();
let result = backend.generate_stream("Hello", GenerateParams::default());
assert!(result.is_err());
match result {
Err(err) => {
let msg = err.to_string();
assert!(
msg.contains("No model loaded")
|| msg.contains("Tokenizer")
|| msg.contains("requires macOS aarch64"),
"Unexpected error: {}",
msg
);
}
Ok(_) => panic!("Expected error when no model loaded, got Ok"),
}
}
#[test]
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn test_coreml_backend_get_embeddings_requires_model() {
let backend = CoreMLBackend::new().unwrap();
let result = backend.get_embeddings("Test text");
assert!(result.is_err());
let err = result.unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("No model loaded")
|| msg.contains("Tokenizer")
|| msg.contains("requires macOS aarch64"),
"Unexpected error: {}",
msg
);
}
#[test]
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn test_coreml_backend_cache_directory() {
let backend = CoreMLBackend::new().unwrap();
assert!(backend.cache_dir.to_str().unwrap().contains("coreml"));
}
#[test]
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn test_coreml_backend_validate_path_nonexistent() {
let result =
CoreMLBackend::validate_coreml_path(Path::new("/nonexistent/model.mlmodel"));
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("does not exist"));
}
#[test]
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn test_coreml_backend_validate_path_wrong_extension() {
let temp_dir = std::env::temp_dir();
let temp_file = temp_dir.join("test_model.txt");
std::fs::write(&temp_file, "test").unwrap();
let result = CoreMLBackend::validate_coreml_path(&temp_file);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Unsupported Core ML model format"));
std::fs::remove_file(temp_file).ok();
}
}
#[cfg(not(feature = "coreml"))]
mod stub_backend_tests {
use super::*;
#[test]
fn test_stub_backend_new_returns_error() {
let result = CoreMLBackend::new();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("feature not enabled"));
}
}
#[test]
fn test_backend_trait_bounds() {
fn assert_send_sync<T: Send + Sync>() {}
#[cfg(feature = "coreml")]
assert_send_sync::<CoreMLBackend>();
}
#[test]
fn test_model_suitability_boundary_values() {
let caps = AneCapabilities {
available: true,
tops: 38.0,
max_model_size_mb: 2048,
supported_ops: vec!["MatMul".to_string()],
};
assert!(caps.is_model_suitable(2048));
assert!(!caps.is_model_suitable(2049));
assert!(caps.is_model_suitable(2047));
}
#[test]
fn test_compute_units_all_variants() {
let variants = [
ComputeUnits::CpuOnly,
ComputeUnits::CpuAndGpu,
ComputeUnits::CpuAndNeuralEngine,
ComputeUnits::All,
];
for variant in &variants {
let _ = variant.description();
let _ = variant.uses_ane();
let _ = variant.uses_gpu();
let _ = format!("{:?}", variant);
}
}
#[test]
fn test_ane_capabilities_empty_ops() {
let caps = AneCapabilities {
available: true,
tops: 38.0,
max_model_size_mb: 2048,
supported_ops: vec![], };
assert!(caps.is_model_suitable(1000));
}
#[test]
fn test_ane_capabilities_max_tops_value() {
let caps = AneCapabilities {
available: true,
tops: f32::MAX,
max_model_size_mb: usize::MAX,
supported_ops: vec!["MatMul".to_string()],
};
assert!(caps.available);
assert!(caps.is_model_suitable(usize::MAX - 1));
}
#[test]
fn test_ane_capabilities_zero_values() {
let caps = AneCapabilities {
available: true, tops: 0.0,
max_model_size_mb: 0,
supported_ops: vec![],
};
assert!(caps.is_model_suitable(0));
assert!(!caps.is_model_suitable(1));
}
}