use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use crate::{
agent::{
context_engineering::ModelRequest,
middleware::{Middleware, MiddlewareContext, MiddlewareError},
},
language_models::llm::LLM,
};
pub struct DynamicModelMiddleware {
model_selector: Arc<dyn Fn(&ModelRequest) -> Option<Arc<dyn LLM>> + Send + Sync>,
available_models: HashMap<String, Arc<dyn LLM>>,
}
impl DynamicModelMiddleware {
pub fn new<F>(selector: F, models: HashMap<String, Arc<dyn LLM>>) -> Self
where
F: Fn(&ModelRequest) -> Option<Arc<dyn LLM>> + Send + Sync + 'static,
{
Self {
model_selector: Arc::new(selector),
available_models: models,
}
}
pub fn from_message_count(
models: HashMap<String, Arc<dyn LLM>>,
thresholds: Vec<(usize, String)>,
) -> Self {
let models_clone = models.clone();
Self::new(
move |request: &ModelRequest| {
let message_count = request.messages.len();
for (threshold, model_name) in &thresholds {
if message_count >= *threshold {
if let Some(model) = models_clone.get(model_name) {
return Some(Arc::clone(model));
}
}
}
models_clone.values().next().map(|m| Arc::clone(m))
},
models,
)
}
pub fn from_user_preference(models: HashMap<String, Arc<dyn LLM>>) -> Self {
let models_clone = models.clone();
Self::new(
move |request: &ModelRequest| {
if let Some(runtime) = request.runtime() {
if let Some(_user_id) = runtime.context().user_id() {
models_clone.values().next().map(|m| Arc::clone(m))
} else {
None
}
} else {
None
}
},
models,
)
}
pub fn from_cost_tier(models: HashMap<String, Arc<dyn LLM>>) -> Self {
let models_clone = models.clone();
Self::new(
move |request: &ModelRequest| {
if let Some(runtime) = request.runtime() {
if let Some(cost_tier) = runtime.context().get("cost_tier") {
let model_name = match cost_tier {
"premium" => "premium_model",
"budget" => "budget_model",
_ => "standard_model",
};
return models_clone.get(model_name).map(|m| Arc::clone(m));
}
}
models_clone
.get("standard_model")
.or_else(|| models_clone.values().next())
.map(|m| Arc::clone(m))
},
models,
)
}
}
#[async_trait]
impl Middleware for DynamicModelMiddleware {
async fn before_model_call(
&self,
request: &ModelRequest,
_context: &mut MiddlewareContext,
) -> Result<Option<ModelRequest>, MiddlewareError> {
let _selected_model = (self.model_selector)(request);
Ok(None) }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::AgentState;
use crate::schemas::Message;
use std::sync::Arc;
use tokio::sync::Mutex;
#[tokio::test]
async fn test_dynamic_model_from_message_count() {
let models: HashMap<String, Arc<dyn LLM>> = HashMap::new();
let thresholds = vec![
(10, "standard_model".to_string()),
(20, "large_model".to_string()),
];
let middleware = DynamicModelMiddleware::from_message_count(models, thresholds);
let state = Arc::new(Mutex::new(AgentState::new()));
let messages = vec![Message::new_human_message("Hello"); 15];
let request = ModelRequest::new(messages, vec![], state);
let mut middleware_context = MiddlewareContext::new();
let result = middleware
.before_model_call(&request, &mut middleware_context)
.await;
assert!(result.is_ok());
}
}