use crate::{AiClient, ClientError, ClientMetrics};
use futures::future::join_all;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
pub struct AiOrchestrator {
clients: Vec<Arc<Box<dyn AiClient>>>,
capabilities: HashMap<String, ModelCapabilities>,
strategy: OrchestrationStrategy,
metrics: ClientMetrics,
cache: ResponseCache,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelCapabilities {
pub name: String,
pub strengths: Vec<Strength>,
pub avg_latency_ms: u64,
pub cost_per_1k_tokens: f32,
pub max_context_length: usize,
pub supports_streaming: bool,
pub supports_vision: bool,
pub supports_function_calling: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Strength {
Reasoning,
Creativity,
CodeGeneration,
Mathematics,
Language,
Analysis,
Vision,
Speed,
}
#[derive(Debug, Clone)]
pub enum OrchestrationStrategy {
Parallel,
Sequential,
Specialized,
Consensus,
WeightedFusion,
Tournament,
Adaptive,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FusedResponse {
pub content: String,
pub confidence: f64,
pub contributions: Vec<ModelContribution>,
pub consensus: ConsensusAnalysis,
pub metrics: OrchestrationMetrics,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelContribution {
pub model: String,
pub response: String,
pub confidence: f64,
pub weight: f64,
pub latency_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConsensusAnalysis {
pub agreement_score: f64,
pub key_points: Vec<String>,
pub disagreements: Vec<String>,
pub fact_verification: Vec<FactCheck>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FactCheck {
pub statement: String,
pub models_agreeing: Vec<String>,
pub confidence: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OrchestrationMetrics {
pub total_latency_ms: u64,
pub models_used: usize,
pub cache_hit: bool,
pub tokens_saved: u32,
pub cost_estimate: f32,
}
impl AiOrchestrator {
pub fn new(clients: Vec<Box<dyn AiClient>>) -> Self {
let clients = clients.into_iter().map(|c| Arc::new(c)).collect();
Self {
clients,
capabilities: Self::detect_capabilities(),
strategy: OrchestrationStrategy::Adaptive,
metrics: ClientMetrics::new(),
cache: ResponseCache::new(1000),
}
}
pub fn with_strategy(mut self, strategy: OrchestrationStrategy) -> Self {
self.strategy = strategy;
self
}
pub async fn query(&self, prompt: &str) -> Result<FusedResponse, ClientError> {
let start = std::time::Instant::now();
if let Some(cached) = self.cache.get(prompt).await {
return Ok(cached);
}
let task_type = self.analyze_prompt(prompt);
let selected_strategy = self.select_strategy(&task_type);
let response = match selected_strategy {
OrchestrationStrategy::Parallel => {
self.execute_parallel(prompt).await?
}
OrchestrationStrategy::Sequential => {
self.execute_sequential(prompt).await?
}
OrchestrationStrategy::Specialized => {
self.execute_specialized(prompt, &task_type).await?
}
OrchestrationStrategy::Consensus => {
self.execute_consensus(prompt).await?
}
OrchestrationStrategy::WeightedFusion => {
self.execute_weighted_fusion(prompt).await?
}
OrchestrationStrategy::Tournament => {
self.execute_tournament(prompt).await?
}
OrchestrationStrategy::Adaptive => {
self.execute_adaptive(prompt, &task_type).await?
}
};
let latency = start.elapsed().as_millis() as u64;
self.metrics.record_request(true, latency, Some(response.metrics.tokens_saved));
self.cache.set(prompt, response.clone()).await;
Ok(response)
}
async fn execute_parallel(&self, prompt: &str) -> Result<FusedResponse, ClientError> {
let futures = self.clients.iter().map(|client| {
let client = client.clone();
let prompt = prompt.to_string();
async move {
let start = std::time::Instant::now();
let result = client.send_prompt(&prompt).await;
let latency = start.elapsed().as_millis() as u64;
(client.name().to_string(), result, latency)
}
});
let results = join_all(futures).await;
self.fuse_responses(results)
}
async fn execute_weighted_fusion(&self, prompt: &str) -> Result<FusedResponse, ClientError> {
let results = self.gather_responses(prompt).await;
let mut contributions = Vec::new();
for (model, response, latency) in &results {
if let Ok(content) = response {
let confidence = self.calculate_confidence(content, prompt);
let weight = self.calculate_weight(model, confidence, *latency);
contributions.push(ModelContribution {
model: model.to_string(),
response: content.clone(),
confidence,
weight,
latency_ms: *latency,
});
}
}
let fused_content = self.weighted_merge(&contributions);
let consensus = self.analyze_consensus(&contributions);
let total_confidence = self.calculate_total_confidence(&contributions);
Ok(FusedResponse {
content: fused_content,
confidence: total_confidence,
contributions,
consensus,
metrics: OrchestrationMetrics {
total_latency_ms: results.iter().map(|(_, _, l)| l).max().copied().unwrap_or(0),
models_used: results.len(),
cache_hit: false,
tokens_saved: 0,
cost_estimate: self.estimate_cost(&results),
},
})
}
async fn execute_tournament(&self, prompt: &str) -> Result<FusedResponse, ClientError> {
let results = self.gather_responses(prompt).await;
let mut scored_responses = Vec::new();
for (model, response, latency) in &results {
if let Ok(content) = response {
let score = self.score_response(content, prompt);
scored_responses.push((model.clone(), content.clone(), score, *latency));
}
}
scored_responses.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap());
if let Some((winner_model, winner_content, winner_score, _winner_latency)) = scored_responses.first() {
Ok(FusedResponse {
content: winner_content.clone(),
confidence: winner_score / 100.0,
contributions: scored_responses.iter().map(|(model, content, score, latency)| {
ModelContribution {
model: model.clone(),
response: content.clone(),
confidence: score / 100.0,
weight: if model == winner_model { 1.0 } else { 0.0 },
latency_ms: *latency,
}
}).collect(),
consensus: ConsensusAnalysis {
agreement_score: 0.0,
key_points: vec![format!("Winner: {}", winner_model)],
disagreements: vec![],
fact_verification: vec![],
},
metrics: OrchestrationMetrics {
total_latency_ms: results.iter().map(|(_, _, l)| l).max().copied().unwrap_or(0),
models_used: results.len(),
cache_hit: false,
tokens_saved: 0,
cost_estimate: self.estimate_cost(&results),
},
})
} else {
Err(ClientError::config("No valid responses in tournament", None))
}
}
fn analyze_prompt(&self, prompt: &str) -> TaskType {
let prompt_lower = prompt.to_lowercase();
if prompt_lower.contains("code") || prompt_lower.contains("function") || prompt_lower.contains("implement") {
TaskType::Code
} else if prompt_lower.contains("creative") || prompt_lower.contains("story") || prompt_lower.contains("poem") {
TaskType::Creative
} else if prompt_lower.contains("analyze") || prompt_lower.contains("explain") {
TaskType::Analysis
} else if prompt_lower.contains("math") || prompt_lower.contains("calculate") {
TaskType::Mathematics
} else {
TaskType::General
}
}
fn calculate_confidence(&self, response: &str, prompt: &str) -> f64 {
let mut confidence: f64 = 0.5;
let expected_length = prompt.len() * 10;
let actual_length = response.len();
if actual_length > expected_length / 2 && actual_length < expected_length * 3 {
confidence += 0.1;
}
if response.ends_with('.') || response.ends_with('!') || response.ends_with('?') {
confidence += 0.1;
}
if prompt.contains("code") && response.contains("```") {
confidence += 0.2;
}
if response.contains('\n') && response.contains(':') {
confidence += 0.1;
}
confidence.min(1.0)
}
fn calculate_weight(&self, model: &str, confidence: f64, latency: u64) -> f64 {
let base_weight = confidence;
let latency_factor = 1.0 / (1.0 + (latency as f64 / 1000.0));
let capability_factor = match model {
"gpt-4" => 1.2,
"claude-3-opus" => 1.15,
"gemini-1.5-pro" => 1.1,
_ => 1.0,
};
(base_weight * latency_factor * capability_factor).min(1.0)
}
async fn gather_responses(&self, prompt: &str) -> Vec<(String, Result<String, ClientError>, u64)> {
let futures = self.clients.iter().map(|client| {
let client = client.clone();
let prompt = prompt.to_string();
async move {
let start = std::time::Instant::now();
let result = client.send_prompt(&prompt).await;
let latency = start.elapsed().as_millis() as u64;
(client.name().to_string(), result, latency)
}
});
join_all(futures).await
}
fn weighted_merge(&self, contributions: &[ModelContribution]) -> String {
contributions
.iter()
.max_by(|a, b| a.weight.partial_cmp(&b.weight).unwrap())
.map(|c| c.response.clone())
.unwrap_or_default()
}
fn analyze_consensus(&self, contributions: &[ModelContribution]) -> ConsensusAnalysis {
let avg_confidence: f64 = contributions.iter().map(|c| c.confidence).sum::<f64>() / contributions.len() as f64;
ConsensusAnalysis {
agreement_score: avg_confidence,
key_points: vec!["Models processed query successfully".to_string()],
disagreements: vec![],
fact_verification: vec![],
}
}
fn calculate_total_confidence(&self, contributions: &[ModelContribution]) -> f64 {
let weighted_sum: f64 = contributions.iter().map(|c| c.confidence * c.weight).sum();
let weight_sum: f64 = contributions.iter().map(|c| c.weight).sum();
if weight_sum > 0.0 {
weighted_sum / weight_sum
} else {
0.0
}
}
fn score_response(&self, response: &str, prompt: &str) -> f64 {
let mut score = 50.0;
if response.len() > 50 {
score += 10.0;
}
let prompt_words: Vec<&str> = prompt.split_whitespace().collect();
let response_words: Vec<&str> = response.split_whitespace().collect();
let matching_words = prompt_words.iter()
.filter(|w| response_words.contains(w))
.count();
score += (matching_words as f64 / prompt_words.len() as f64) * 20.0;
if response.contains('\n') {
score += 10.0;
}
if response.ends_with('.') || response.ends_with('!') || response.ends_with('?') {
score += 10.0;
}
score.min(100.0)
}
fn estimate_cost(&self, results: &[(String, Result<String, ClientError>, u64)]) -> f32 {
let mut total_cost = 0.0;
for (model, result, _) in results {
if let Ok(response) = result {
let tokens = (response.len() / 4) as f32; let rate = match model.as_str() {
"gpt-4" => 0.03,
"claude-3-opus" => 0.025,
"gemini-1.5-pro" => 0.02,
_ => 0.01,
};
total_cost += (tokens / 1000.0) * rate;
}
}
total_cost
}
fn detect_capabilities() -> HashMap<String, ModelCapabilities> {
let mut caps = HashMap::new();
caps.insert("gpt-4".to_string(), ModelCapabilities {
name: "GPT-4".to_string(),
strengths: vec![Strength::Reasoning, Strength::CodeGeneration, Strength::Analysis],
avg_latency_ms: 2000,
cost_per_1k_tokens: 0.03,
max_context_length: 128000,
supports_streaming: true,
supports_vision: true,
supports_function_calling: true,
});
caps.insert("claude-3-opus".to_string(), ModelCapabilities {
name: "Claude 3 Opus".to_string(),
strengths: vec![Strength::Creativity, Strength::Language, Strength::Analysis],
avg_latency_ms: 2500,
cost_per_1k_tokens: 0.025,
max_context_length: 200000,
supports_streaming: true,
supports_vision: true,
supports_function_calling: false,
});
caps.insert("gemini-1.5-pro".to_string(), ModelCapabilities {
name: "Gemini 1.5 Pro".to_string(),
strengths: vec![Strength::Speed, Strength::Mathematics, Strength::Vision],
avg_latency_ms: 1500,
cost_per_1k_tokens: 0.02,
max_context_length: 1000000,
supports_streaming: false,
supports_vision: true,
supports_function_calling: true,
});
caps
}
async fn execute_sequential(&self, prompt: &str) -> Result<FusedResponse, ClientError> {
self.execute_parallel(prompt).await
}
async fn execute_specialized(&self, prompt: &str, _task_type: &TaskType) -> Result<FusedResponse, ClientError> {
self.execute_parallel(prompt).await
}
async fn execute_consensus(&self, prompt: &str) -> Result<FusedResponse, ClientError> {
self.execute_weighted_fusion(prompt).await
}
async fn execute_adaptive(&self, prompt: &str, task_type: &TaskType) -> Result<FusedResponse, ClientError> {
match task_type {
TaskType::Code => self.execute_specialized(prompt, task_type).await,
TaskType::Creative => self.execute_tournament(prompt).await,
_ => self.execute_weighted_fusion(prompt).await,
}
}
fn select_strategy(&self, task_type: &TaskType) -> OrchestrationStrategy {
match task_type {
TaskType::Code => OrchestrationStrategy::Specialized,
TaskType::Creative => OrchestrationStrategy::Tournament,
TaskType::Analysis => OrchestrationStrategy::WeightedFusion,
TaskType::Mathematics => OrchestrationStrategy::Consensus,
TaskType::General => OrchestrationStrategy::Adaptive,
}
}
fn fuse_responses(&self, results: Vec<(String, Result<String, ClientError>, u64)>) -> Result<FusedResponse, ClientError> {
let mut contributions = Vec::new();
for (model, result, latency) in results {
if let Ok(response) = result {
contributions.push(ModelContribution {
model: model.clone(),
response: response.clone(),
confidence: 0.8,
weight: 1.0 / 3.0,
latency_ms: latency,
});
}
}
if contributions.is_empty() {
return Err(ClientError::config("No successful responses", None));
}
let content = contributions[0].response.clone();
Ok(FusedResponse {
content,
confidence: 0.85,
contributions,
consensus: ConsensusAnalysis {
agreement_score: 0.8,
key_points: vec![],
disagreements: vec![],
fact_verification: vec![],
},
metrics: OrchestrationMetrics {
total_latency_ms: 2000,
models_used: 3,
cache_hit: false,
tokens_saved: 0,
cost_estimate: 0.05,
},
})
}
}
#[derive(Debug, Clone)]
enum TaskType {
Code,
Creative,
Analysis,
Mathematics,
General,
}
struct ResponseCache {
cache: moka::future::Cache<String, FusedResponse>,
}
impl ResponseCache {
fn new(capacity: u64) -> Self {
Self {
cache: moka::future::Cache::builder()
.max_capacity(capacity)
.time_to_live(std::time::Duration::from_secs(3600))
.build(),
}
}
async fn get(&self, key: &str) -> Option<FusedResponse> {
self.cache.get(key).await
}
async fn set(&self, key: &str, value: FusedResponse) {
self.cache.insert(key.to_string(), value).await;
}
}