use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tensorlogic_ir::EinsumGraph;
use thiserror::Error;
#[derive(Error, Debug, Clone, PartialEq)]
pub enum MultiModelError {
#[error("Model not found: {0}")]
ModelNotFound(String),
#[error("Incompatible model outputs")]
IncompatibleOutputs,
#[error("Invalid ensemble configuration: {0}")]
InvalidEnsemble(String),
#[error("Model routing failed: {0}")]
RoutingFailed(String),
#[error("Resource limit exceeded: {0}")]
ResourceLimitExceeded(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum EnsembleStrategy {
Average,
WeightedAverage,
MajorityVote,
SoftVote,
Stacking,
Boosting,
MaxConfidence,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMetadata {
pub id: String,
pub name: String,
pub version: String,
pub input_shapes: Vec<Vec<usize>>,
pub output_shapes: Vec<Vec<usize>>,
pub weight: f64,
pub priority: u32,
pub resource_requirements: ResourceRequirements,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResourceRequirements {
pub memory_bytes: usize,
pub gpu_memory_bytes: Option<usize>,
pub estimated_flops: f64,
pub estimated_latency_ms: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnsembleConfig {
pub strategy: EnsembleStrategy,
pub model_weights: HashMap<String, f64>,
pub min_models: usize,
pub parallel_execution: bool,
pub model_timeout_ms: Option<u64>,
}
impl Default for EnsembleConfig {
fn default() -> Self {
Self {
strategy: EnsembleStrategy::Average,
model_weights: HashMap::new(),
min_models: 1,
parallel_execution: true,
model_timeout_ms: None,
}
}
}
impl EnsembleConfig {
pub fn voting() -> Self {
Self {
strategy: EnsembleStrategy::MajorityVote,
min_models: 3,
..Default::default()
}
}
pub fn weighted_average(weights: HashMap<String, f64>) -> Self {
Self {
strategy: EnsembleStrategy::WeightedAverage,
model_weights: weights,
..Default::default()
}
}
pub fn stacking() -> Self {
Self {
strategy: EnsembleStrategy::Stacking,
parallel_execution: true,
..Default::default()
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RoutingStrategy {
Priority,
LowestLatency,
BestAccuracy,
RoundRobin,
Random,
Cascade,
ContentBased,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CascadeConfig {
pub model_sequence: Vec<String>,
pub confidence_thresholds: Vec<f64>,
pub enable_early_exit: bool,
pub max_models: usize,
}
impl CascadeConfig {
pub fn two_tier(fast_model: String, accurate_model: String, threshold: f64) -> Self {
Self {
model_sequence: vec![fast_model, accurate_model],
confidence_thresholds: vec![threshold],
enable_early_exit: true,
max_models: 2,
}
}
pub fn three_tier(
fast: String,
medium: String,
accurate: String,
thresholds: (f64, f64),
) -> Self {
Self {
model_sequence: vec![fast, medium, accurate],
confidence_thresholds: vec![thresholds.0, thresholds.1],
enable_early_exit: true,
max_models: 3,
}
}
}
pub struct MultiModelCoordinator {
models: HashMap<String, EinsumGraph>,
metadata: HashMap<String, ModelMetadata>,
ensemble_config: Option<EnsembleConfig>,
routing_strategy: RoutingStrategy,
stats: CoordinationStats,
}
impl MultiModelCoordinator {
pub fn new() -> Self {
Self {
models: HashMap::new(),
metadata: HashMap::new(),
ensemble_config: None,
routing_strategy: RoutingStrategy::Priority,
stats: CoordinationStats::default(),
}
}
pub fn register_model(
&mut self,
graph: EinsumGraph,
metadata: ModelMetadata,
) -> Result<(), MultiModelError> {
let id = metadata.id.clone();
self.models.insert(id.clone(), graph);
self.metadata.insert(id, metadata);
self.stats.total_models += 1;
Ok(())
}
pub fn unregister_model(&mut self, model_id: &str) -> Result<(), MultiModelError> {
self.models
.remove(model_id)
.ok_or_else(|| MultiModelError::ModelNotFound(model_id.to_string()))?;
self.metadata.remove(model_id);
self.stats.total_models = self.models.len();
Ok(())
}
pub fn set_ensemble_config(&mut self, config: EnsembleConfig) {
self.ensemble_config = Some(config);
}
pub fn set_routing_strategy(&mut self, strategy: RoutingStrategy) {
self.routing_strategy = strategy;
}
pub fn select_model(
&mut self,
_input_features: Option<&[f64]>,
) -> Result<String, MultiModelError> {
if self.models.is_empty() {
return Err(MultiModelError::RoutingFailed(
"No models registered".to_string(),
));
}
let selected = match self.routing_strategy {
RoutingStrategy::Priority => self.select_by_priority(),
RoutingStrategy::LowestLatency => self.select_by_latency(),
RoutingStrategy::BestAccuracy => self.select_by_accuracy(),
RoutingStrategy::RoundRobin => self.select_round_robin(),
RoutingStrategy::Random => self.select_random(),
RoutingStrategy::Cascade => self.select_cascade(),
RoutingStrategy::ContentBased => self.select_content_based(_input_features),
};
if let Ok(ref id) = selected {
self.stats.total_routings += 1;
*self.stats.model_usage.entry(id.clone()).or_insert(0) += 1;
}
selected
}
fn select_by_priority(&self) -> Result<String, MultiModelError> {
self.metadata
.iter()
.max_by_key(|(_, meta)| meta.priority)
.map(|(id, _)| id.clone())
.ok_or_else(|| MultiModelError::RoutingFailed("No models available".to_string()))
}
fn select_by_latency(&self) -> Result<String, MultiModelError> {
self.metadata
.iter()
.min_by(|(_, a), (_, b)| {
a.resource_requirements
.estimated_latency_ms
.partial_cmp(&b.resource_requirements.estimated_latency_ms)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(id, _)| id.clone())
.ok_or_else(|| MultiModelError::RoutingFailed("No models available".to_string()))
}
fn select_by_accuracy(&self) -> Result<String, MultiModelError> {
self.select_by_priority()
}
fn select_round_robin(&mut self) -> Result<String, MultiModelError> {
let model_ids: Vec<_> = self.models.keys().cloned().collect();
if model_ids.is_empty() {
return Err(MultiModelError::RoutingFailed(
"No models available".to_string(),
));
}
let idx = self.stats.total_routings % model_ids.len();
Ok(model_ids[idx].clone())
}
fn select_random(&self) -> Result<String, MultiModelError> {
let model_ids: Vec<_> = self.models.keys().cloned().collect();
if model_ids.is_empty() {
return Err(MultiModelError::RoutingFailed(
"No models available".to_string(),
));
}
Ok(model_ids[0].clone())
}
fn select_cascade(&self) -> Result<String, MultiModelError> {
self.select_by_latency()
}
fn select_content_based(&self, _features: Option<&[f64]>) -> Result<String, MultiModelError> {
self.select_by_priority()
}
pub fn get_model(&self, model_id: &str) -> Option<&EinsumGraph> {
self.models.get(model_id)
}
pub fn get_metadata(&self, model_id: &str) -> Option<&ModelMetadata> {
self.metadata.get(model_id)
}
pub fn model_ids(&self) -> Vec<String> {
self.models.keys().cloned().collect()
}
pub fn stats(&self) -> &CoordinationStats {
&self.stats
}
pub fn has_ensemble(&self) -> bool {
self.ensemble_config.is_some()
}
pub fn ensemble_config(&self) -> Option<&EnsembleConfig> {
self.ensemble_config.as_ref()
}
pub fn total_resource_requirements(&self) -> ResourceRequirements {
let mut total = ResourceRequirements {
memory_bytes: 0,
gpu_memory_bytes: Some(0),
estimated_flops: 0.0,
estimated_latency_ms: 0.0,
};
for metadata in self.metadata.values() {
let req = &metadata.resource_requirements;
total.memory_bytes += req.memory_bytes;
if let (Some(total_gpu), Some(req_gpu)) = (total.gpu_memory_bytes, req.gpu_memory_bytes)
{
total.gpu_memory_bytes = Some(total_gpu + req_gpu);
}
total.estimated_flops += req.estimated_flops;
total.estimated_latency_ms = total.estimated_latency_ms.max(req.estimated_latency_ms);
}
total
}
}
impl Default for MultiModelCoordinator {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CoordinationStats {
pub total_models: usize,
pub total_routings: usize,
pub total_ensemble_executions: usize,
pub model_usage: HashMap<String, usize>,
}
impl CoordinationStats {
pub fn most_used_model(&self) -> Option<(String, usize)> {
self.model_usage
.iter()
.max_by_key(|(_, &count)| count)
.map(|(id, &count)| (id.clone(), count))
}
pub fn usage_distribution(&self) -> HashMap<String, f64> {
let total = self.model_usage.values().sum::<usize>() as f64;
if total == 0.0 {
return HashMap::new();
}
self.model_usage
.iter()
.map(|(id, &count)| (id.clone(), count as f64 / total))
.collect()
}
}
pub trait TlEnsembleExecutor {
type Output;
type Error;
fn execute_ensemble(
&self,
models: &[&EinsumGraph],
inputs: &[Self::Output],
strategy: EnsembleStrategy,
) -> Result<Self::Output, Self::Error>;
fn aggregate_predictions(
&self,
predictions: &[Self::Output],
strategy: EnsembleStrategy,
) -> Result<Self::Output, Self::Error>;
}
pub trait TlModelRouter {
fn route_to_model(&self, input: &[f64]) -> Result<String, MultiModelError>;
fn routing_confidence(&self, input: &[f64], model_id: &str) -> f64;
}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::{EinsumNode, OpType};
fn create_test_graph(_id: &str) -> EinsumGraph {
let mut graph = EinsumGraph::new();
graph.nodes.push(EinsumNode {
op: OpType::Einsum {
spec: "ij->ij".to_string(),
},
inputs: vec![],
outputs: vec![0],
metadata: Default::default(),
});
graph
}
fn create_test_metadata(id: &str, priority: u32) -> ModelMetadata {
ModelMetadata {
id: id.to_string(),
name: format!("Model {}", id),
version: "1.0".to_string(),
input_shapes: vec![vec![10, 10]],
output_shapes: vec![vec![10, 10]],
weight: 1.0,
priority,
resource_requirements: ResourceRequirements {
memory_bytes: 1024 * 1024,
gpu_memory_bytes: Some(512 * 1024),
estimated_flops: 1e9,
estimated_latency_ms: 10.0,
},
}
}
#[test]
fn test_ensemble_strategy() {
let config = EnsembleConfig::voting();
assert_eq!(config.strategy, EnsembleStrategy::MajorityVote);
let mut weights = HashMap::new();
weights.insert("model1".to_string(), 0.6);
weights.insert("model2".to_string(), 0.4);
let config = EnsembleConfig::weighted_average(weights);
assert_eq!(config.strategy, EnsembleStrategy::WeightedAverage);
}
#[test]
fn test_cascade_config() {
let config = CascadeConfig::two_tier("fast".to_string(), "accurate".to_string(), 0.9);
assert_eq!(config.model_sequence.len(), 2);
assert_eq!(config.confidence_thresholds[0], 0.9);
let config = CascadeConfig::three_tier(
"fast".to_string(),
"medium".to_string(),
"accurate".to_string(),
(0.8, 0.95),
);
assert_eq!(config.model_sequence.len(), 3);
}
#[test]
fn test_coordinator_creation() {
let coordinator = MultiModelCoordinator::new();
assert_eq!(coordinator.models.len(), 0);
assert_eq!(coordinator.stats.total_models, 0);
}
#[test]
fn test_model_registration() {
let mut coordinator = MultiModelCoordinator::new();
let graph = create_test_graph("model1");
let metadata = create_test_metadata("model1", 1);
assert!(coordinator.register_model(graph, metadata).is_ok());
assert_eq!(coordinator.stats.total_models, 1);
assert!(coordinator.get_model("model1").is_some());
}
#[test]
fn test_model_unregistration() {
let mut coordinator = MultiModelCoordinator::new();
let graph = create_test_graph("model1");
let metadata = create_test_metadata("model1", 1);
coordinator.register_model(graph, metadata).expect("unwrap");
assert!(coordinator.unregister_model("model1").is_ok());
assert_eq!(coordinator.stats.total_models, 0);
assert!(coordinator.get_model("model1").is_none());
}
#[test]
fn test_routing_by_priority() {
let mut coordinator = MultiModelCoordinator::new();
coordinator
.register_model(
create_test_graph("model1"),
create_test_metadata("model1", 1),
)
.expect("unwrap");
coordinator
.register_model(
create_test_graph("model2"),
create_test_metadata("model2", 5),
)
.expect("unwrap");
coordinator.set_routing_strategy(RoutingStrategy::Priority);
let selected = coordinator.select_model(None).expect("unwrap");
assert_eq!(selected, "model2"); }
#[test]
fn test_routing_by_latency() {
let mut coordinator = MultiModelCoordinator::new();
let mut meta1 = create_test_metadata("model1", 1);
meta1.resource_requirements.estimated_latency_ms = 20.0;
let mut meta2 = create_test_metadata("model2", 1);
meta2.resource_requirements.estimated_latency_ms = 5.0;
coordinator
.register_model(create_test_graph("model1"), meta1)
.expect("unwrap");
coordinator
.register_model(create_test_graph("model2"), meta2)
.expect("unwrap");
coordinator.set_routing_strategy(RoutingStrategy::LowestLatency);
let selected = coordinator.select_model(None).expect("unwrap");
assert_eq!(selected, "model2"); }
#[test]
fn test_ensemble_configuration() {
let mut coordinator = MultiModelCoordinator::new();
assert!(!coordinator.has_ensemble());
coordinator.set_ensemble_config(EnsembleConfig::voting());
assert!(coordinator.has_ensemble());
assert_eq!(
coordinator.ensemble_config().expect("unwrap").strategy,
EnsembleStrategy::MajorityVote
);
}
#[test]
fn test_total_resource_requirements() {
let mut coordinator = MultiModelCoordinator::new();
coordinator
.register_model(
create_test_graph("model1"),
create_test_metadata("model1", 1),
)
.expect("unwrap");
coordinator
.register_model(
create_test_graph("model2"),
create_test_metadata("model2", 1),
)
.expect("unwrap");
let total = coordinator.total_resource_requirements();
assert_eq!(total.memory_bytes, 2 * 1024 * 1024);
assert_eq!(total.gpu_memory_bytes, Some(2 * 512 * 1024));
}
#[test]
fn test_coordination_stats() {
let mut stats = CoordinationStats::default();
stats.model_usage.insert("model1".to_string(), 10);
stats.model_usage.insert("model2".to_string(), 5);
let (id, count) = stats.most_used_model().expect("unwrap");
assert_eq!(id, "model1");
assert_eq!(count, 10);
let dist = stats.usage_distribution();
assert_eq!(dist.get("model1").expect("unwrap"), &(10.0 / 15.0));
}
#[test]
fn test_round_robin_routing() {
let mut coordinator = MultiModelCoordinator::new();
coordinator
.register_model(
create_test_graph("model1"),
create_test_metadata("model1", 1),
)
.expect("unwrap");
coordinator
.register_model(
create_test_graph("model2"),
create_test_metadata("model2", 1),
)
.expect("unwrap");
coordinator.set_routing_strategy(RoutingStrategy::RoundRobin);
let id1 = coordinator.select_model(None).expect("unwrap");
let id2 = coordinator.select_model(None).expect("unwrap");
assert_ne!(id1, id2);
}
#[test]
fn test_model_ids() {
let mut coordinator = MultiModelCoordinator::new();
coordinator
.register_model(
create_test_graph("model1"),
create_test_metadata("model1", 1),
)
.expect("unwrap");
coordinator
.register_model(
create_test_graph("model2"),
create_test_metadata("model2", 1),
)
.expect("unwrap");
let ids = coordinator.model_ids();
assert_eq!(ids.len(), 2);
assert!(ids.contains(&"model1".to_string()));
assert!(ids.contains(&"model2".to_string()));
}
}