use async_trait::async_trait;
use serde_json::{json, Value};
use super::Middleware;
use crate::agent::DeepAgentError;
pub struct TodoListMiddleware {
state_key: String,
}
impl TodoListMiddleware {
pub fn new() -> Self {
Self {
state_key: "todos".to_string(),
}
}
pub fn with_key(key: impl Into<String>) -> Self {
Self {
state_key: key.into(),
}
}
pub fn format_todos(todos: &[Value]) -> String {
let mut lines = Vec::new();
for (i, todo) in todos.iter().enumerate() {
let content = todo
.get("content")
.and_then(|v| v.as_str())
.unwrap_or("(untitled)");
let status = todo
.get("status")
.and_then(|v| v.as_str())
.unwrap_or("pending");
let icon = match status {
"completed" => "[x]",
"in_progress" => "[~]",
_ => "[ ]",
};
lines.push(format!("{}. {} {}", i + 1, icon, content));
}
lines.join("\n")
}
}
impl Default for TodoListMiddleware {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Middleware for TodoListMiddleware {
fn name(&self) -> &str {
"todo_list"
}
async fn before_model(&self, state: &mut Value) -> std::result::Result<(), DeepAgentError> {
let todos = match state.get(&self.state_key).and_then(|v| v.as_array()) {
Some(arr) if !arr.is_empty() => arr.clone(),
_ => return Ok(()),
};
let formatted = Self::format_todos(&todos);
let system_content = format!(
"Current task list:\n{}\n\nUpdate task status as you work through them.",
formatted
);
if let Some(messages) = state.get_mut("messages").and_then(|v| v.as_array_mut()) {
let todo_msg = json!({
"type": "system",
"content": system_content
});
let is_todo_system = messages.first().is_some_and(|m| {
m.get("type").and_then(|t| t.as_str()) == Some("system")
&& m.get("content")
.and_then(|c| c.as_str())
.is_some_and(|c| c.starts_with("Current task list:"))
});
if is_todo_system {
messages[0] = todo_msg;
} else {
messages.insert(0, todo_msg);
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[tokio::test]
async fn test_todo_middleware_injects_system_prompt() {
let middleware = TodoListMiddleware::new();
let mut state = json!({
"messages": [
{"type": "human", "content": "Plan a project"}
],
"todos": [
{"content": "Design API", "status": "pending"},
{"content": "Write tests", "status": "completed"}
]
});
middleware.before_model(&mut state).await.unwrap();
let messages = state["messages"].as_array().unwrap();
assert_eq!(messages.len(), 2);
assert_eq!(messages[0]["type"], "system");
let content = messages[0]["content"].as_str().unwrap();
assert!(content.contains("Design API"));
assert!(content.contains("Write tests"));
assert!(content.contains("[ ]"));
assert!(content.contains("[x]"));
}
#[tokio::test]
async fn test_todo_middleware_no_todos_no_injection() {
let middleware = TodoListMiddleware::new();
let mut state = json!({
"messages": [
{"type": "human", "content": "Hello"}
]
});
middleware.before_model(&mut state).await.unwrap();
let messages = state["messages"].as_array().unwrap();
assert_eq!(messages.len(), 1);
}
#[test]
fn test_format_todos() {
let todos = vec![
json!({"content": "Task A", "status": "completed"}),
json!({"content": "Task B", "status": "in_progress"}),
json!({"content": "Task C", "status": "pending"}),
];
let formatted = TodoListMiddleware::format_todos(&todos);
assert!(formatted.contains("[x] Task A"));
assert!(formatted.contains("[~] Task B"));
assert!(formatted.contains("[ ] Task C"));
}
#[tokio::test]
async fn test_todo_middleware_custom_key() {
let middleware = TodoListMiddleware::with_key("tasks");
let mut state = json!({
"messages": [{"type": "human", "content": "Go"}],
"tasks": [{"content": "Do thing", "status": "pending"}]
});
middleware.before_model(&mut state).await.unwrap();
let messages = state["messages"].as_array().unwrap();
assert_eq!(messages.len(), 2);
assert!(messages[0]["content"]
.as_str()
.unwrap()
.contains("Do thing"));
}
#[tokio::test]
async fn test_todo_middleware_empty_todos_no_injection() {
let middleware = TodoListMiddleware::new();
let mut state = json!({
"messages": [{"type": "human", "content": "Hi"}],
"todos": []
});
middleware.before_model(&mut state).await.unwrap();
assert_eq!(state["messages"].as_array().unwrap().len(), 1);
}
}