use std::sync::Arc;
use async_trait::async_trait;
use crate::{
agent::{
middleware::{Middleware, MiddlewareContext, MiddlewareError},
runtime::RuntimeRequest,
},
prompt::PromptArgs,
schemas::agent::AgentAction,
tools::ToolContext,
};
pub struct DynamicPromptMiddleware {
prompt_generator: Arc<dyn Fn(&dyn ToolContext) -> String + Send + Sync>,
}
impl DynamicPromptMiddleware {
pub fn new<F>(generator: F) -> Self
where
F: Fn(&dyn ToolContext) -> String + Send + Sync + 'static,
{
Self {
prompt_generator: Arc::new(generator),
}
}
pub fn with_template(template: String) -> Self {
Self::new(move |ctx: &dyn ToolContext| {
let mut prompt = template.clone();
if let Some(user_id) = ctx.user_id() {
prompt = prompt.replace("{user_id}", user_id);
}
if let Some(session_id) = ctx.session_id() {
prompt = prompt.replace("{session_id}", session_id);
}
if let Some(user_name) = ctx.get("user_name") {
prompt = prompt.replace("{user_name}", user_name);
}
prompt
})
}
}
#[async_trait]
impl Middleware for DynamicPromptMiddleware {
async fn before_agent_plan_with_runtime(
&self,
request: &RuntimeRequest,
_steps: &[(AgentAction, String)],
_context: &mut MiddlewareContext,
) -> Result<Option<PromptArgs>, MiddlewareError> {
let runtime = request.runtime.as_ref();
if let Some(runtime) = runtime {
let dynamic_prompt = (self.prompt_generator)(runtime.context());
let mut modified_input = request.input.clone();
if let Some(_chat_history) = modified_input.get_mut("chat_history") {
modified_input.insert(
"dynamic_system_prompt".to_string(),
serde_json::json!(dynamic_prompt),
);
} else {
modified_input.insert(
"system_prompt".to_string(),
serde_json::json!(dynamic_prompt),
);
}
Ok(Some(modified_input))
} else {
Ok(None)
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use crate::agent::Runtime;
use crate::tools::SimpleContext;
#[tokio::test]
async fn test_dynamic_prompt_middleware() {
let middleware = DynamicPromptMiddleware::new(|ctx: &dyn ToolContext| {
let user_id = ctx.user_id().unwrap_or("unknown");
format!("You are a helpful assistant for user: {}", user_id)
});
let context = Arc::new(SimpleContext::new().with_user_id("user123".to_string()));
let runtime = Arc::new(Runtime::new(
context.clone(),
Arc::new(crate::tools::InMemoryStore::new()),
));
let state = Arc::new(tokio::sync::Mutex::new(crate::agent::AgentState::new()));
let mut input = PromptArgs::new();
input.insert("input".to_string(), serde_json::json!("test"));
let request = RuntimeRequest::new(input, state).with_runtime(runtime);
let mut middleware_context = MiddlewareContext::new();
let result = middleware
.before_agent_plan_with_runtime(&request, &[], &mut middleware_context)
.await;
assert!(result.is_ok());
if let Ok(Some(modified)) = result {
assert!(
modified.contains_key("dynamic_system_prompt")
|| modified.contains_key("system_prompt")
);
}
}
#[test]
fn test_dynamic_prompt_with_template() {
let middleware = DynamicPromptMiddleware::with_template(
"Hello {user_id}, you are a valued user.".to_string(),
);
let context = Arc::new(SimpleContext::new().with_user_id("user123".to_string()));
let prompt = (middleware.prompt_generator)(context.as_ref());
assert!(prompt.contains("user123"));
}
#[tokio::test]
async fn test_dynamic_prompt_middleware_no_runtime() {
let middleware =
DynamicPromptMiddleware::new(|_ctx: &dyn ToolContext| "Default prompt".to_string());
let state = Arc::new(tokio::sync::Mutex::new(crate::agent::AgentState::new()));
let mut input = PromptArgs::new();
input.insert("input".to_string(), serde_json::json!("test"));
let request = RuntimeRequest::new(input, state);
let mut middleware_context = crate::agent::middleware::MiddlewareContext::new();
let result = middleware
.before_agent_plan_with_runtime(&request, &[], &mut middleware_context)
.await;
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
}