use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use trustformers_core::error::Result;
use trustformers_core::Tensor;
pub const COREML_VERSION: u32 = 5;
pub struct CoreMLModelConverter {
config: CoreMLConverterConfig,
optimization_passes: Vec<Box<dyn OptimizationPass>>,
validation_rules: Vec<Box<dyn ValidationRule>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CoreMLConverterConfig {
pub target_ios_version: String,
pub optimization_level: OptimizationLevel,
pub enable_compression: bool,
pub quantization: Option<CoreMLQuantizationConfig>,
pub pruning: Option<PruningConfig>,
pub output_format: CoreMLFormat,
pub hardware_target: HardwareTarget,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum OptimizationLevel {
None,
Basic,
Aggressive,
Maximum,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CoreMLFormat {
MLModel,
MLModelC,
MLPackage,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum HardwareTarget {
All,
NeuralEngine,
GPU,
CPU,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CoreMLQuantizationConfig {
pub weight_bits: QuantizationBits,
pub activation_bits: Option<QuantizationBits>,
pub method: QuantizationMethod,
pub calibration_size: usize,
pub per_channel: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum QuantizationBits {
Bit1,
Bit2,
Bit4,
Bit8,
Bit16,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum QuantizationMethod {
Linear,
LookupTable,
KMeans,
Custom,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PruningConfig {
pub target_sparsity: f32,
pub method: PruningMethod,
pub structured: bool,
pub exclude_layers: Vec<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum PruningMethod {
Magnitude,
Gradient,
Random,
Structured,
}
pub trait OptimizationPass: Send + Sync {
fn name(&self) -> &str;
fn apply(&self, model: &mut CoreMLModelGraph) -> Result<()>;
fn should_apply(&self, config: &CoreMLConverterConfig) -> bool;
}
pub trait ValidationRule: Send + Sync {
fn name(&self) -> &str;
fn validate(&self, model: &CoreMLModelGraph) -> Result<()>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CoreMLModelGraph {
pub name: String,
pub version: String,
pub inputs: Vec<TensorSpec>,
pub outputs: Vec<TensorSpec>,
pub layers: Vec<CoreMLLayer>,
pub weights: HashMap<String, WeightBlob>,
pub metadata: ModelMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorSpec {
pub name: String,
pub shape: Vec<i64>,
pub dtype: CoreMLDataType,
pub description: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CoreMLDataType {
Float32,
Float16,
Int32,
Int16,
Int8,
UInt8,
Bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CoreMLLayer {
pub name: String,
pub layer_type: LayerType,
pub inputs: Vec<String>,
pub outputs: Vec<String>,
pub params: LayerParams,
pub quantization: Option<LayerQuantization>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LayerType {
Convolution,
InnerProduct,
BatchNorm,
Activation,
Pooling,
Padding,
Concat,
Split,
Reshape,
Transpose,
Reduce,
Softmax,
Embedding,
LSTM,
GRU,
Attention,
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerParams {
pub params: HashMap<String, ParamValue>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ParamValue {
Int(i64),
Float(f32),
String(String),
IntArray(Vec<i64>),
FloatArray(Vec<f32>),
Bool(bool),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WeightBlob {
pub shape: Vec<usize>,
pub dtype: CoreMLDataType,
pub quantization: Option<WeightQuantization>,
pub data: Vec<u8>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerQuantization {
pub bits: u8,
pub scale: f32,
pub zero_point: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WeightQuantization {
pub qtype: QuantizationType,
pub lookup_table: Option<Vec<f32>>,
pub scales: Option<Vec<f32>>,
pub zero_points: Option<Vec<i32>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum QuantizationType {
Linear,
LookupTable,
PerChannel,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMetadata {
pub description: String,
pub author: String,
pub license: Option<String>,
pub user_defined: HashMap<String, String>,
pub performance_hints: PerformanceHints,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceHints {
pub compute_units: Vec<String>,
pub expected_latency_ms: Option<f32>,
pub memory_footprint_mb: Option<f32>,
pub power_efficiency: Option<String>,
}
impl CoreMLModelConverter {
pub fn new(config: CoreMLConverterConfig) -> Self {
let optimization_passes = Self::create_optimization_passes(&config);
let validation_rules = Self::create_validation_rules();
Self {
config,
optimization_passes,
validation_rules,
}
}
pub fn convert(&self, model_path: &Path, output_path: &Path) -> Result<ConversionResult> {
let trustformers_model = self.load_trustformers_model(model_path)?;
let mut coreml_graph = self.convert_to_coreml_graph(trustformers_model)?;
self.validate_model(&coreml_graph)?;
self.apply_optimizations(&mut coreml_graph)?;
if let Some(ref quant_config) = self.config.quantization {
self.apply_quantization(&mut coreml_graph, quant_config)?;
}
if let Some(ref pruning_config) = self.config.pruning {
self.apply_pruning(&mut coreml_graph, pruning_config)?;
}
self.validate_model(&coreml_graph)?;
let output_info = self.write_coreml_model(&coreml_graph, output_path)?;
Ok(ConversionResult {
output_path: output_info.path,
model_size_mb: output_info.size_mb,
compression_ratio: output_info.compression_ratio,
optimization_report: self.generate_optimization_report(&coreml_graph),
validation_report: self.generate_validation_report(&coreml_graph),
})
}
fn create_optimization_passes(
config: &CoreMLConverterConfig,
) -> Vec<Box<dyn OptimizationPass>> {
let mut passes: Vec<Box<dyn OptimizationPass>> = Vec::new();
match config.optimization_level {
OptimizationLevel::None => {},
OptimizationLevel::Basic => {
passes.push(Box::new(ConstantFoldingPass));
passes.push(Box::new(DeadCodeEliminationPass));
},
OptimizationLevel::Aggressive => {
passes.push(Box::new(ConstantFoldingPass));
passes.push(Box::new(DeadCodeEliminationPass));
passes.push(Box::new(OperatorFusionPass));
passes.push(Box::new(LayoutOptimizationPass));
},
OptimizationLevel::Maximum => {
passes.push(Box::new(ConstantFoldingPass));
passes.push(Box::new(DeadCodeEliminationPass));
passes.push(Box::new(OperatorFusionPass));
passes.push(Box::new(LayoutOptimizationPass));
passes.push(Box::new(AggressiveFusionPass));
passes.push(Box::new(PrecisionOptimizationPass));
},
}
passes
}
fn create_validation_rules() -> Vec<Box<dyn ValidationRule>> {
vec![
Box::new(SupportedOperationsRule),
Box::new(TensorShapeRule),
Box::new(DataTypeRule),
Box::new(MemoryLimitRule),
Box::new(HardwareCompatibilityRule),
]
}
fn load_trustformers_model(&self, path: &Path) -> Result<TrustformersModel> {
Ok(TrustformersModel {
weights: HashMap::new(),
graph: Vec::new(),
})
}
fn convert_to_coreml_graph(&self, model: TrustformersModel) -> Result<CoreMLModelGraph> {
let mut layers = Vec::new();
let mut weights = HashMap::new();
for op in model.graph {
let layer = self.convert_operation(op)?;
layers.push(layer);
}
for (name, tensor) in model.weights {
let weight_blob = self.convert_weight(name.clone(), tensor)?;
weights.insert(name, weight_blob);
}
Ok(CoreMLModelGraph {
name: "TrustformersModel".to_string(),
version: "1.0.0".to_string(),
inputs: self.create_input_specs(),
outputs: self.create_output_specs(),
layers,
weights,
metadata: self.create_metadata(),
})
}
fn convert_operation(&self, op: Operation) -> Result<CoreMLLayer> {
let layer_type = match op.op_type.as_str() {
"Conv2d" => LayerType::Convolution,
"Linear" => LayerType::InnerProduct,
"BatchNorm2d" => LayerType::BatchNorm,
"ReLU" => LayerType::Activation,
"MaxPool2d" => LayerType::Pooling,
_ => LayerType::Custom(op.op_type),
};
Ok(CoreMLLayer {
name: op.name,
layer_type,
inputs: op.inputs,
outputs: op.outputs,
params: self.convert_params(op.params),
quantization: None,
})
}
fn convert_params(&self, params: HashMap<String, String>) -> LayerParams {
let mut converted = HashMap::new();
for (key, value) in params {
if let Ok(int_val) = value.parse::<i64>() {
converted.insert(key, ParamValue::Int(int_val));
} else if let Ok(float_val) = value.parse::<f32>() {
converted.insert(key, ParamValue::Float(float_val));
} else if value == "true" || value == "false" {
converted.insert(key, ParamValue::Bool(value == "true"));
} else {
converted.insert(key, ParamValue::String(value));
}
}
LayerParams { params: converted }
}
fn convert_weight(&self, name: String, tensor: Tensor) -> Result<WeightBlob> {
let shape = tensor.shape().to_vec();
let dtype = CoreMLDataType::Float32;
let tensor_data = tensor.data()?;
let data = if self.config.enable_compression {
self.compress_weight_data(&tensor_data)?
} else {
tensor_data.iter().flat_map(|&f| f.to_ne_bytes()).collect()
};
Ok(WeightBlob {
shape,
dtype,
quantization: None,
data,
})
}
fn compress_weight_data(&self, data: &[f32]) -> Result<Vec<u8>> {
let bytes: Vec<u8> = data.iter().flat_map(|&f| f.to_ne_bytes()).collect();
Ok(bytes)
}
fn create_input_specs(&self) -> Vec<TensorSpec> {
vec![TensorSpec {
name: "input".to_string(),
shape: vec![1, 3, 224, 224],
dtype: CoreMLDataType::Float32,
description: Some("Model input".to_string()),
}]
}
fn create_output_specs(&self) -> Vec<TensorSpec> {
vec![TensorSpec {
name: "output".to_string(),
shape: vec![1, 1000],
dtype: CoreMLDataType::Float32,
description: Some("Model output".to_string()),
}]
}
fn create_metadata(&self) -> ModelMetadata {
ModelMetadata {
description: "Model converted from TrustformeRS".to_string(),
author: "TrustformeRS".to_string(),
license: Some("MIT".to_string()),
user_defined: HashMap::new(),
performance_hints: PerformanceHints {
compute_units: match self.config.hardware_target {
HardwareTarget::All => vec![
"cpu".to_string(),
"gpu".to_string(),
"neuralEngine".to_string(),
],
HardwareTarget::NeuralEngine => vec!["neuralEngine".to_string()],
HardwareTarget::GPU => vec!["gpu".to_string()],
HardwareTarget::CPU => vec!["cpu".to_string()],
},
expected_latency_ms: None,
memory_footprint_mb: None,
power_efficiency: None,
},
}
}
fn validate_model(&self, model: &CoreMLModelGraph) -> Result<()> {
for rule in &self.validation_rules {
rule.validate(model)?;
}
Ok(())
}
fn apply_optimizations(&self, model: &mut CoreMLModelGraph) -> Result<()> {
for pass in &self.optimization_passes {
if pass.should_apply(&self.config) {
pass.apply(model)?;
}
}
Ok(())
}
fn apply_quantization(
&self,
model: &mut CoreMLModelGraph,
config: &CoreMLQuantizationConfig,
) -> Result<()> {
for (name, weight) in &mut model.weights {
if self.should_quantize_weight(name) {
self.quantize_weight(weight, config)?;
}
}
if config.activation_bits.is_some() {
for layer in &mut model.layers {
self.quantize_layer_activations(layer, config)?;
}
}
Ok(())
}
fn should_quantize_weight(&self, name: &str) -> bool {
!name.contains("final") && !name.contains("output")
}
fn quantize_weight(
&self,
weight: &mut WeightBlob,
config: &CoreMLQuantizationConfig,
) -> Result<()> {
let bits = match config.weight_bits {
QuantizationBits::Bit1 => 1,
QuantizationBits::Bit2 => 2,
QuantizationBits::Bit4 => 4,
QuantizationBits::Bit8 => 8,
QuantizationBits::Bit16 => 16,
};
weight.quantization = Some(WeightQuantization {
qtype: if config.per_channel {
QuantizationType::PerChannel
} else {
QuantizationType::Linear
},
lookup_table: None,
scales: None,
zero_points: None,
});
Ok(())
}
fn quantize_layer_activations(
&self,
layer: &mut CoreMLLayer,
config: &CoreMLQuantizationConfig,
) -> Result<()> {
if let Some(bits) = config.activation_bits {
let num_bits = match bits {
QuantizationBits::Bit8 => 8,
QuantizationBits::Bit16 => 16,
_ => return Ok(()), };
layer.quantization = Some(LayerQuantization {
bits: num_bits,
scale: 1.0,
zero_point: 0,
});
}
Ok(())
}
fn apply_pruning(&self, model: &mut CoreMLModelGraph, config: &PruningConfig) -> Result<()> {
for (name, weight) in &mut model.weights {
if !config.exclude_layers.contains(name) {
self.prune_weight(weight, config)?;
}
}
Ok(())
}
fn prune_weight(&self, weight: &mut WeightBlob, config: &PruningConfig) -> Result<()> {
println!(
"Pruning weight to {}% sparsity",
config.target_sparsity * 100.0
);
Ok(())
}
fn write_coreml_model(
&self,
model: &CoreMLModelGraph,
output_path: &Path,
) -> Result<OutputInfo> {
let model_data = self.serialize_model(model)?;
if let Some(parent) = output_path.parent() {
std::fs::create_dir_all(parent)?;
}
let model_path = match self.config.output_format {
CoreMLFormat::MLModel => output_path.with_extension("mlmodel"),
CoreMLFormat::MLModelC => output_path.with_extension("mlmodelc"),
CoreMLFormat::MLPackage => {
let package_dir = output_path.with_extension("mlpackage");
std::fs::create_dir_all(&package_dir)?;
package_dir.join("Data").join("com.apple.CoreML").join("model.mlmodel")
},
};
std::fs::write(&model_path, &model_data)?;
let size_mb = model_data.len() as f32 / (1024.0 * 1024.0);
let original_size_mb = self.calculate_original_size(model);
let compression_ratio = original_size_mb / size_mb;
Ok(OutputInfo {
path: model_path,
size_mb,
compression_ratio,
})
}
fn serialize_model(&self, model: &CoreMLModelGraph) -> Result<Vec<u8>> {
Ok(serde_json::to_vec(model)?)
}
fn calculate_original_size(&self, model: &CoreMLModelGraph) -> f32 {
let weight_size: usize = model.weights.values()
.map(|w| w.shape.iter().product::<usize>() * 4) .sum();
weight_size as f32 / (1024.0 * 1024.0)
}
fn generate_optimization_report(&self, model: &CoreMLModelGraph) -> OptimizationReport {
OptimizationReport {
passes_applied: self
.optimization_passes
.iter()
.filter(|p| p.should_apply(&self.config))
.map(|p| p.name().to_string())
.collect(),
compression_achieved: self.config.enable_compression,
quantization_applied: self.config.quantization.is_some(),
pruning_applied: self.config.pruning.is_some(),
hardware_optimizations: match self.config.hardware_target {
HardwareTarget::NeuralEngine => vec!["Neural Engine optimizations".to_string()],
HardwareTarget::GPU => vec!["GPU optimizations".to_string()],
_ => vec![],
},
}
}
fn generate_validation_report(&self, model: &CoreMLModelGraph) -> ValidationReport {
ValidationReport {
ios_version: self.config.target_ios_version.clone(),
supported_devices: self.get_supported_devices(),
warnings: vec![],
info: vec![
format!("Model has {} layers", model.layers.len()),
format!("Model has {} weights", model.weights.len()),
],
}
}
fn get_supported_devices(&self) -> Vec<String> {
match self.config.hardware_target {
HardwareTarget::NeuralEngine => {
vec!["iPhone 11+".to_string(), "iPad Pro 2018+".to_string()]
},
_ => vec!["All iOS devices".to_string()],
}
}
}
struct TrustformersModel {
weights: HashMap<String, Tensor>,
graph: Vec<Operation>,
}
struct Operation {
name: String,
op_type: String,
inputs: Vec<String>,
outputs: Vec<String>,
params: HashMap<String, String>,
}
struct OutputInfo {
path: PathBuf,
size_mb: f32,
compression_ratio: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConversionResult {
pub output_path: PathBuf,
pub model_size_mb: f32,
pub compression_ratio: f32,
pub optimization_report: OptimizationReport,
pub validation_report: ValidationReport,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizationReport {
pub passes_applied: Vec<String>,
pub compression_achieved: bool,
pub quantization_applied: bool,
pub pruning_applied: bool,
pub hardware_optimizations: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValidationReport {
pub ios_version: String,
pub supported_devices: Vec<String>,
pub warnings: Vec<String>,
pub info: Vec<String>,
}
struct ConstantFoldingPass;
impl OptimizationPass for ConstantFoldingPass {
fn name(&self) -> &str {
"ConstantFolding"
}
fn apply(&self, model: &mut CoreMLModelGraph) -> Result<()> {
Ok(())
}
fn should_apply(&self, _config: &CoreMLConverterConfig) -> bool {
true
}
}
struct DeadCodeEliminationPass;
impl OptimizationPass for DeadCodeEliminationPass {
fn name(&self) -> &str {
"DeadCodeElimination"
}
fn apply(&self, model: &mut CoreMLModelGraph) -> Result<()> {
Ok(())
}
fn should_apply(&self, _config: &CoreMLConverterConfig) -> bool {
true
}
}
struct OperatorFusionPass;
impl OptimizationPass for OperatorFusionPass {
fn name(&self) -> &str {
"OperatorFusion"
}
fn apply(&self, model: &mut CoreMLModelGraph) -> Result<()> {
Ok(())
}
fn should_apply(&self, config: &CoreMLConverterConfig) -> bool {
matches!(
config.optimization_level,
OptimizationLevel::Aggressive | OptimizationLevel::Maximum
)
}
}
struct LayoutOptimizationPass;
impl OptimizationPass for LayoutOptimizationPass {
fn name(&self) -> &str {
"LayoutOptimization"
}
fn apply(&self, model: &mut CoreMLModelGraph) -> Result<()> {
Ok(())
}
fn should_apply(&self, config: &CoreMLConverterConfig) -> bool {
matches!(
config.optimization_level,
OptimizationLevel::Aggressive | OptimizationLevel::Maximum
)
}
}
struct AggressiveFusionPass;
impl OptimizationPass for AggressiveFusionPass {
fn name(&self) -> &str {
"AggressiveFusion"
}
fn apply(&self, model: &mut CoreMLModelGraph) -> Result<()> {
Ok(())
}
fn should_apply(&self, config: &CoreMLConverterConfig) -> bool {
matches!(config.optimization_level, OptimizationLevel::Maximum)
}
}
struct PrecisionOptimizationPass;
impl OptimizationPass for PrecisionOptimizationPass {
fn name(&self) -> &str {
"PrecisionOptimization"
}
fn apply(&self, model: &mut CoreMLModelGraph) -> Result<()> {
Ok(())
}
fn should_apply(&self, config: &CoreMLConverterConfig) -> bool {
matches!(config.optimization_level, OptimizationLevel::Maximum)
}
}
struct SupportedOperationsRule;
impl ValidationRule for SupportedOperationsRule {
fn name(&self) -> &str {
"SupportedOperations"
}
fn validate(&self, model: &CoreMLModelGraph) -> Result<()> {
Ok(())
}
}
struct TensorShapeRule;
impl ValidationRule for TensorShapeRule {
fn name(&self) -> &str {
"TensorShape"
}
fn validate(&self, model: &CoreMLModelGraph) -> Result<()> {
Ok(())
}
}
struct DataTypeRule;
impl ValidationRule for DataTypeRule {
fn name(&self) -> &str {
"DataType"
}
fn validate(&self, model: &CoreMLModelGraph) -> Result<()> {
Ok(())
}
}
struct MemoryLimitRule;
impl ValidationRule for MemoryLimitRule {
fn name(&self) -> &str {
"MemoryLimit"
}
fn validate(&self, model: &CoreMLModelGraph) -> Result<()> {
Ok(())
}
}
struct HardwareCompatibilityRule;
impl ValidationRule for HardwareCompatibilityRule {
fn name(&self) -> &str {
"HardwareCompatibility"
}
fn validate(&self, model: &CoreMLModelGraph) -> Result<()> {
Ok(())
}
}
impl Default for CoreMLConverterConfig {
fn default() -> Self {
Self {
target_ios_version: "14.0".to_string(),
optimization_level: OptimizationLevel::Basic,
enable_compression: true,
quantization: None,
pruning: None,
output_format: CoreMLFormat::MLModel,
hardware_target: HardwareTarget::All,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_converter_creation() {
let config = CoreMLConverterConfig::default();
let converter = CoreMLModelConverter::new(config);
assert!(!converter.optimization_passes.is_empty());
assert!(!converter.validation_rules.is_empty());
}
#[test]
fn test_quantization_config() {
let config = CoreMLQuantizationConfig {
weight_bits: QuantizationBits::Bit8,
activation_bits: Some(QuantizationBits::Bit8),
method: QuantizationMethod::Linear,
calibration_size: 1000,
per_channel: true,
};
assert_eq!(config.weight_bits, QuantizationBits::Bit8);
assert!(config.per_channel);
}
#[test]
fn test_pruning_config() {
let config = PruningConfig {
target_sparsity: 0.5,
method: PruningMethod::Magnitude,
structured: false,
exclude_layers: vec!["output".to_string()],
};
assert_eq!(config.target_sparsity, 0.5);
assert!(config.exclude_layers.contains(&"output".to_string()));
}
#[test]
fn test_default_converter_config() {
let config = CoreMLConverterConfig::default();
assert!(!config.target_ios_version.is_empty());
assert!(matches!(
config.optimization_level,
OptimizationLevel::Basic
));
assert!(matches!(config.output_format, CoreMLFormat::MLModel));
}
#[test]
fn test_optimization_level_variants() {
let levels = vec![
OptimizationLevel::None,
OptimizationLevel::Basic,
OptimizationLevel::Aggressive,
OptimizationLevel::Maximum,
];
assert_eq!(levels.len(), 4);
}
#[test]
fn test_coreml_format_variants() {
let formats = vec![
CoreMLFormat::MLModel,
CoreMLFormat::MLModelC,
CoreMLFormat::MLPackage,
];
assert_eq!(formats.len(), 3);
}
#[test]
fn test_hardware_target_variants() {
let targets = vec![
HardwareTarget::All,
HardwareTarget::NeuralEngine,
HardwareTarget::GPU,
HardwareTarget::CPU,
];
assert_eq!(targets.len(), 4);
}
#[test]
fn test_quantization_bits_variants() {
let bits = vec![
QuantizationBits::Bit1,
QuantizationBits::Bit2,
QuantizationBits::Bit4,
QuantizationBits::Bit8,
QuantizationBits::Bit16,
];
assert_eq!(bits.len(), 5);
}
#[test]
fn test_quantization_method_variants() {
let methods = vec![QuantizationMethod::Linear, QuantizationMethod::KMeans];
assert_eq!(methods.len(), 2);
}
#[test]
fn test_pruning_method_variants() {
let methods = vec![
PruningMethod::Magnitude,
PruningMethod::Gradient,
PruningMethod::Random,
PruningMethod::Structured,
];
assert_eq!(methods.len(), 4);
}
#[test]
fn test_converter_has_optimization_passes() {
let config = CoreMLConverterConfig::default();
let converter = CoreMLModelConverter::new(config);
assert!(!converter.optimization_passes.is_empty());
}
#[test]
fn test_converter_has_validation_rules() {
let config = CoreMLConverterConfig::default();
let converter = CoreMLModelConverter::new(config);
assert!(converter.validation_rules.len() >= 3);
}
#[test]
fn test_quantization_config_4bit() {
let config = CoreMLQuantizationConfig {
weight_bits: QuantizationBits::Bit4,
activation_bits: None,
method: QuantizationMethod::KMeans,
calibration_size: 500,
per_channel: false,
};
assert_eq!(config.weight_bits, QuantizationBits::Bit4);
assert!(config.activation_bits.is_none());
assert!(!config.per_channel);
}
#[test]
fn test_pruning_config_structured() {
let config = PruningConfig {
target_sparsity: 0.9,
method: PruningMethod::Structured,
structured: true,
exclude_layers: vec![],
};
assert_eq!(config.target_sparsity, 0.9);
assert!(config.structured);
assert!(config.exclude_layers.is_empty());
}
#[test]
fn test_pruning_config_sparsity_bounds() {
let config = PruningConfig {
target_sparsity: 0.5,
method: PruningMethod::Magnitude,
structured: false,
exclude_layers: vec!["output".to_string()],
};
assert!(config.target_sparsity >= 0.0);
assert!(config.target_sparsity <= 1.0);
}
#[test]
fn test_converter_config_with_compression() {
let mut config = CoreMLConverterConfig::default();
config.enable_compression = true;
assert!(config.enable_compression);
}
#[test]
fn test_converter_config_hardware_target_gpu() {
let mut config = CoreMLConverterConfig::default();
config.hardware_target = HardwareTarget::GPU;
assert_eq!(config.hardware_target, HardwareTarget::GPU);
}
#[test]
fn test_converter_config_aggressive_optimization() {
let mut config = CoreMLConverterConfig::default();
config.optimization_level = OptimizationLevel::Aggressive;
assert_eq!(config.optimization_level, OptimizationLevel::Aggressive);
}
#[test]
fn test_coreml_version_constant() {
assert_eq!(COREML_VERSION, 5);
}
#[test]
fn test_quantization_bits_equality() {
assert_eq!(QuantizationBits::Bit8, QuantizationBits::Bit8);
assert_ne!(QuantizationBits::Bit4, QuantizationBits::Bit8);
}
#[test]
fn test_optimization_level_equality() {
assert_eq!(OptimizationLevel::None, OptimizationLevel::None);
assert_ne!(OptimizationLevel::None, OptimizationLevel::Maximum);
}
#[test]
fn test_hardware_target_equality() {
assert_eq!(HardwareTarget::All, HardwareTarget::All);
assert_ne!(HardwareTarget::CPU, HardwareTarget::GPU);
}
#[test]
fn test_converter_config_with_quantization_and_pruning() {
let config = CoreMLConverterConfig {
target_ios_version: "17.0".to_string(),
optimization_level: OptimizationLevel::Maximum,
enable_compression: true,
quantization: Some(CoreMLQuantizationConfig {
weight_bits: QuantizationBits::Bit4,
activation_bits: Some(QuantizationBits::Bit8),
method: QuantizationMethod::Linear,
calibration_size: 2000,
per_channel: true,
}),
pruning: Some(PruningConfig {
target_sparsity: 0.7,
method: PruningMethod::Magnitude,
structured: false,
exclude_layers: vec!["embed".to_string()],
}),
output_format: CoreMLFormat::MLPackage,
hardware_target: HardwareTarget::NeuralEngine,
};
assert!(config.quantization.is_some());
assert!(config.pruning.is_some());
}
}