use std::collections::HashMap;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use cognis_core::error::Result;
use cognis_core::messages::Message;
use super::types::{AgentMiddleware, AgentState, AsyncModelHandler, ModelCallResult, ModelRequest};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum TodoStatus {
#[default]
Pending,
InProgress,
Completed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Todo {
pub content: String,
pub status: TodoStatus,
}
impl Todo {
pub fn new(content: impl Into<String>) -> Self {
Self {
content: content.into(),
status: TodoStatus::Pending,
}
}
pub fn with_status(mut self, status: TodoStatus) -> Self {
self.status = status;
self
}
pub fn start(&mut self) {
self.status = TodoStatus::InProgress;
}
pub fn complete(&mut self) {
self.status = TodoStatus::Completed;
}
pub fn is_completed(&self) -> bool {
self.status == TodoStatus::Completed
}
}
pub struct TodoListMiddleware {
pub system_prompt: String,
pub state_key: String,
}
impl TodoListMiddleware {
pub fn new() -> Self {
Self::default()
}
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = prompt.into();
self
}
pub fn with_state_key(mut self, key: impl Into<String>) -> Self {
self.state_key = key.into();
self
}
fn get_todos(&self, state: &AgentState) -> Vec<Todo> {
state
.extra
.get(&self.state_key)
.and_then(|v| serde_json::from_value::<Vec<Todo>>(v.clone()).ok())
.unwrap_or_default()
}
fn format_todos(&self, todos: &[Todo]) -> String {
if todos.is_empty() {
return "No tasks in the todo list.".into();
}
let mut lines = Vec::new();
for (i, todo) in todos.iter().enumerate() {
let status_icon = match todo.status {
TodoStatus::Pending => "[ ]",
TodoStatus::InProgress => "[~]",
TodoStatus::Completed => "[x]",
};
lines.push(format!("{}. {} {}", i + 1, status_icon, todo.content));
}
lines.join("\n")
}
fn build_system_prompt(&self, todos: &[Todo]) -> String {
let formatted = self.format_todos(todos);
self.system_prompt.replace("{todos}", &formatted)
}
fn check_parallel_calls(&self, state: &AgentState) -> Option<usize> {
if let Some(Message::Ai(ai_msg)) = state.messages.last() {
let call_count = ai_msg.tool_calls.len();
if call_count > 1 {
return Some(call_count);
}
}
None
}
}
impl Default for TodoListMiddleware {
fn default() -> Self {
Self {
system_prompt: "You are working through a task list. Here is the current status of your tasks:\n\n\
{todos}\n\n\
Instructions for task management:\n\
- Work through tasks in order, marking them as in_progress when you start and completed when done.\n\
- If you discover sub-tasks, add them to the list.\n\
- If a task is blocked, note the reason and move on to the next available task.\n\
- Provide a brief status update when completing each task.\n\
- Use parallel tool calls when tasks are independent and can be done simultaneously."
.into(),
state_key: "todo_list".into(),
}
}
}
#[async_trait]
impl AgentMiddleware for TodoListMiddleware {
fn name(&self) -> &str {
"TodoListMiddleware"
}
async fn wrap_model_call(
&self,
request: &ModelRequest,
handler: &AsyncModelHandler,
) -> Result<ModelCallResult> {
let todos = self.get_todos(&request.state);
let todo_prompt = self.build_system_prompt(&todos);
let new_system_message = if let Some(existing) = &request.system_message {
let existing_text = existing.content().text();
Message::system(format!("{}\n\n{}", existing_text, todo_prompt))
} else {
Message::system(todo_prompt)
};
let new_request = ModelRequest {
model: request.model.clone(),
messages: request.messages.clone(),
system_message: Some(new_system_message),
tool_choice: request.tool_choice.clone(),
tools: request.tools.clone(),
response_format: request.response_format.clone(),
state: request.state.clone(),
model_settings: request.model_settings.clone(),
};
let response = handler(&new_request).await?;
Ok(ModelCallResult::Response(response))
}
async fn after_model(&self, state: &AgentState) -> Result<Option<HashMap<String, Value>>> {
let mut updates = HashMap::new();
if let Some(call_count) = self.check_parallel_calls(state) {
updates.insert("parallel_tool_calls".into(), serde_json::json!(call_count));
}
let todos = self.get_todos(state);
if !todos.is_empty() {
let completed = todos.iter().filter(|t| t.is_completed()).count();
let total = todos.len();
updates.insert(
"todo_progress".into(),
serde_json::json!({
"completed": completed,
"total": total,
"all_done": completed == total
}),
);
}
if updates.is_empty() {
Ok(None)
} else {
Ok(Some(updates))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_todo_new() {
let todo = Todo::new("Write tests");
assert_eq!(todo.content, "Write tests");
assert_eq!(todo.status, TodoStatus::Pending);
assert!(!todo.is_completed());
}
#[test]
fn test_todo_lifecycle() {
let mut todo = Todo::new("Implement feature");
assert_eq!(todo.status, TodoStatus::Pending);
todo.start();
assert_eq!(todo.status, TodoStatus::InProgress);
todo.complete();
assert_eq!(todo.status, TodoStatus::Completed);
assert!(todo.is_completed());
}
#[test]
fn test_todo_status_serde() {
assert_eq!(
serde_json::to_string(&TodoStatus::Pending).unwrap(),
"\"pending\""
);
assert_eq!(
serde_json::to_string(&TodoStatus::InProgress).unwrap(),
"\"in_progress\""
);
assert_eq!(
serde_json::to_string(&TodoStatus::Completed).unwrap(),
"\"completed\""
);
let s: TodoStatus = serde_json::from_str("\"completed\"").unwrap();
assert_eq!(s, TodoStatus::Completed);
}
#[test]
fn test_todo_serde() {
let todo = Todo::new("Test task").with_status(TodoStatus::InProgress);
let json = serde_json::to_string(&todo).unwrap();
let parsed: Todo = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.content, "Test task");
assert_eq!(parsed.status, TodoStatus::InProgress);
}
#[test]
fn test_format_todos_empty() {
let mw = TodoListMiddleware::new();
let formatted = mw.format_todos(&[]);
assert_eq!(formatted, "No tasks in the todo list.");
}
#[test]
fn test_format_todos_with_items() {
let mw = TodoListMiddleware::new();
let todos = vec![
Todo::new("Task 1"),
Todo::new("Task 2").with_status(TodoStatus::InProgress),
Todo::new("Task 3").with_status(TodoStatus::Completed),
];
let formatted = mw.format_todos(&todos);
assert!(formatted.contains("[ ] Task 1"));
assert!(formatted.contains("[~] Task 2"));
assert!(formatted.contains("[x] Task 3"));
}
#[test]
fn test_build_system_prompt() {
let mw = TodoListMiddleware::new();
let todos = vec![Todo::new("Do something")];
let prompt = mw.build_system_prompt(&todos);
assert!(prompt.contains("[ ] Do something"));
assert!(prompt.contains("working through a task list"));
}
#[test]
fn test_get_todos_from_state() {
let mw = TodoListMiddleware::new();
let mut state = AgentState::default();
let todos = vec![Todo::new("A"), Todo::new("B")];
state.set_extra("todo_list", serde_json::to_value(&todos).unwrap());
let retrieved = mw.get_todos(&state);
assert_eq!(retrieved.len(), 2);
assert_eq!(retrieved[0].content, "A");
}
#[test]
fn test_get_todos_empty_state() {
let mw = TodoListMiddleware::new();
let state = AgentState::default();
let todos = mw.get_todos(&state);
assert!(todos.is_empty());
}
#[test]
fn test_middleware_name() {
let mw = TodoListMiddleware::new();
assert_eq!(mw.name(), "TodoListMiddleware");
}
#[test]
fn test_middleware_builder() {
let mw = TodoListMiddleware::new()
.with_system_prompt("Custom: {todos}")
.with_state_key("my_todos");
assert_eq!(mw.system_prompt, "Custom: {todos}");
assert_eq!(mw.state_key, "my_todos");
}
}