use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use crate::{
agent::{
context_engineering::ModelRequest,
middleware::{Middleware, MiddlewareContext, MiddlewareError},
},
schemas::StructuredOutputStrategy,
};
pub struct DynamicResponseFormatMiddleware {
format_selector:
Arc<dyn Fn(&ModelRequest) -> Option<Box<dyn StructuredOutputStrategy>> + Send + Sync>,
available_formats: HashMap<String, Box<dyn StructuredOutputStrategy>>,
}
impl DynamicResponseFormatMiddleware {
pub fn new<F>(selector: F, formats: HashMap<String, Box<dyn StructuredOutputStrategy>>) -> Self
where
F: Fn(&ModelRequest) -> Option<Box<dyn StructuredOutputStrategy>> + Send + Sync + 'static,
{
Self {
format_selector: Arc::new(selector),
available_formats: formats,
}
}
pub fn from_message_count(
formats: HashMap<String, Box<dyn StructuredOutputStrategy>>,
threshold: usize,
) -> Self {
let simple_name = "simple".to_string();
let detailed_name = "detailed".to_string();
Self::new(
move |request: &ModelRequest| {
let message_count = request.messages.len();
let _format_name = if message_count < threshold {
&simple_name
} else {
&detailed_name
};
None },
formats,
)
}
pub fn from_user_role(formats: HashMap<String, Box<dyn StructuredOutputStrategy>>) -> Self {
Self::new(
move |request: &ModelRequest| {
if let Some(runtime) = request.runtime() {
if let Some(_user_role) = runtime.context().get("user_role") {
return None;
}
}
None
},
formats,
)
}
}
#[async_trait]
impl Middleware for DynamicResponseFormatMiddleware {
async fn before_model_call(
&self,
request: &ModelRequest,
_context: &mut MiddlewareContext,
) -> Result<Option<ModelRequest>, MiddlewareError> {
let _selected_format = (self.format_selector)(request);
Ok(None) }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::AgentState;
use crate::schemas::Message;
use std::sync::Arc;
use tokio::sync::Mutex;
#[tokio::test]
async fn test_dynamic_response_format_from_message_count() {
let formats: HashMap<String, Box<dyn StructuredOutputStrategy>> = HashMap::new();
let middleware = DynamicResponseFormatMiddleware::from_message_count(formats, 5);
let state = Arc::new(Mutex::new(AgentState::new()));
let messages = vec![Message::new_human_message("Hello"); 3];
let request = ModelRequest::new(messages, vec![], state);
let mut middleware_context = MiddlewareContext::new();
let result = middleware
.before_model_call(&request, &mut middleware_context)
.await;
assert!(result.is_ok());
}
}