use std::sync::Arc;
use async_trait::async_trait;
use cognis_core::error::Result;
use cognis_core::language_models::chat_model::BaseChatModel;
use super::types::{AgentMiddleware, AsyncModelHandler, ModelCallResult, ModelRequest};
pub struct ModelFallbackMiddleware {
pub fallback_models: Vec<Arc<dyn BaseChatModel>>,
}
impl ModelFallbackMiddleware {
pub fn new(fallback_models: Vec<Arc<dyn BaseChatModel>>) -> Self {
Self { fallback_models }
}
}
#[async_trait]
impl AgentMiddleware for ModelFallbackMiddleware {
fn name(&self) -> &str {
"ModelFallbackMiddleware"
}
async fn wrap_model_call(
&self,
request: &ModelRequest,
handler: &AsyncModelHandler,
) -> Result<ModelCallResult> {
match handler(request).await {
Ok(response) => return Ok(ModelCallResult::Response(response)),
Err(primary_error) => {
if self.fallback_models.is_empty() {
return Err(primary_error);
}
let mut last_error = primary_error;
for fallback in &self.fallback_models {
let fallback_request = ModelRequest {
model: Arc::clone(fallback),
messages: request.messages.clone(),
system_message: request.system_message.clone(),
tool_choice: request.tool_choice.clone(),
tools: request.tools.clone(),
response_format: request.response_format.clone(),
state: request.state.clone(),
model_settings: request.model_settings.clone(),
};
match handler(&fallback_request).await {
Ok(response) => return Ok(ModelCallResult::Response(response)),
Err(e) => {
last_error = e;
}
}
}
Err(last_error)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_fallback_new() {
let mw = ModelFallbackMiddleware::new(vec![]);
assert_eq!(mw.name(), "ModelFallbackMiddleware");
assert!(mw.fallback_models.is_empty());
}
}