use std::sync::Arc;
use std::time::Duration;
use tokio::time::timeout;
use crate::common::traits::AIProvider;
use crate::config::layered::{Config, ModelConfig};
use crate::errors::LocalModelError;
pub enum ExecutionStrategy {
LocalOnly {
model: Arc<dyn AIProvider>,
},
Hybrid {
local: Arc<dyn AIProvider>,
cloud: Arc<dyn AIProvider>,
threshold: Duration,
},
Dynamic {
fast_model: Arc<dyn AIProvider>,
quality_model: Arc<dyn AIProvider>,
cloud: Option<Arc<dyn AIProvider>>,
},
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum ExecutionMode {
Local,
Hybrid,
Dynamic,
CloudOnly,
}
impl ExecutionMode {
pub fn from_str(s: &str) -> Option<Self> {
match s {
"local" => Some(ExecutionMode::Local),
"hybrid" => Some(ExecutionMode::Hybrid),
"dynamic" => Some(ExecutionMode::Dynamic),
"cloud" => Some(ExecutionMode::CloudOnly),
_ => None,
}
}
}
impl ExecutionStrategy {
pub async fn create(config: &Config) -> Result<Self, LocalModelError> {
let create_provider = |model_name: &str,
model_config: &ModelConfig|
-> Result<Arc<dyn AIProvider>, LocalModelError> {
Ok(Arc::new(create_placeholder_model(model_name)?))
};
match &config.strategy.mode {
ExecutionMode::Local => {
let model = create_provider(&config.model.name, &config.model)?;
Ok(Self::LocalOnly { model })
}
ExecutionMode::Hybrid => {
let local = create_provider(&config.model.name, &config.model)?;
let cloud = create_placeholder_cloud_adapter();
Ok(Self::Hybrid {
local,
cloud: Arc::new(cloud),
threshold: Duration::from_millis(config.strategy.timeout_ms),
})
}
ExecutionMode::Dynamic => {
let fast_model_name = config
.strategy
.fast_model
.clone()
.unwrap_or_else(|| get_smaller_model(&config.model.name));
let quality_model_name = config.model.name.clone();
let fast_adapter = create_provider(&fast_model_name, &config.model)?;
let quality_adapter = create_provider(&quality_model_name, &config.model)?;
let cloud = if config.fallback.enabled {
Some(Arc::new(create_placeholder_cloud_adapter()) as Arc<dyn AIProvider>)
} else {
None
};
Ok(Self::Dynamic {
fast_model: fast_adapter,
quality_model: quality_adapter,
cloud,
})
}
ExecutionMode::CloudOnly => {
let cloud = create_placeholder_cloud_adapter();
Ok(Self::LocalOnly {
model: Arc::new(cloud),
})
}
}
}
pub async fn complete(&self, prompt: &str) -> Result<String, LocalModelError> {
match self {
ExecutionStrategy::LocalOnly { model } => model.complete(prompt).await,
ExecutionStrategy::Hybrid {
local,
cloud,
threshold,
} => match timeout(threshold.clone(), local.complete(prompt)).await {
Ok(Ok(response)) => Ok(response),
Ok(Err(e)) => {
tracing::warn!("Local model failed: {}, falling back to cloud", e);
cloud.complete(prompt).await
}
Err(_) => {
tracing::warn!("Local model timed out, falling back to cloud");
cloud.complete(prompt).await
}
},
ExecutionStrategy::Dynamic {
fast_model,
quality_model,
cloud,
} => {
let complexity_analysis =
crate::core::task_complexity::TaskComplexity::analyze(prompt);
match complexity_analysis {
crate::core::task_complexity::TaskComplexity::Simple => {
fast_model.complete(prompt).await
}
crate::core::task_complexity::TaskComplexity::Medium => {
quality_model.complete(prompt).await
}
crate::core::task_complexity::TaskComplexity::Complex => {
match quality_model.complete(prompt).await {
Ok(resp) => Ok(resp),
Err(e) if cloud.is_some() => {
tracing::warn!("Quality model failed: {}, using cloud", e);
cloud.as_ref().unwrap().complete(prompt).await
}
Err(e) => Err(e),
}
}
}
}
}
}
}
struct PlaceholderCloudAdapter;
#[async_trait::async_trait]
impl AIProvider for PlaceholderCloudAdapter {
async fn complete(&self, prompt: &str) -> Result<String, LocalModelError> {
Ok(format!("CLOUD_RESPONSE: {}", prompt)) }
async fn is_available(&self) -> bool {
true
}
async fn name(&self) -> String {
"PlaceholderCloud".to_string()
}
}
fn create_placeholder_cloud_adapter() -> PlaceholderCloudAdapter {
PlaceholderCloudAdapter
}
fn get_smaller_model(larger_model: &str) -> String {
match larger_model {
"qwen2.5-coder-14b-q4" => "qwen2.5-coder-3b-q4".to_string(),
"qwen2.5-coder-7b-q4" => "qwen2.5-coder-3b-q4".to_string(),
"llama3.1-70b-q4" => "qwen2.5-coder-7b-q4".to_string(),
_ => "qwen2.5-coder-1.5b-q4".to_string(), }
}
fn create_placeholder_model(_model_name: &str) -> Result<PlaceholderLocalAdapter, LocalModelError> {
Ok(PlaceholderLocalAdapter {})
}
struct PlaceholderLocalAdapter;
#[async_trait::async_trait]
impl AIProvider for PlaceholderLocalAdapter {
async fn complete(&self, prompt: &str) -> Result<String, LocalModelError> {
Ok(format!("LOCAL_RESPONSE: {}", prompt)) }
async fn is_available(&self) -> bool {
true }
async fn name(&self) -> String {
"PlaceholderLocal".to_string()
}
}