use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StepStatus {
Pending,
InProgress,
Completed,
Failed,
Skipped,
}
impl Default for StepStatus {
fn default() -> Self {
StepStatus::Pending
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlanStep {
pub id: String,
pub description: String,
pub status: StepStatus,
pub dependencies: Vec<String>,
pub output: Option<String>,
pub error: Option<String>,
pub estimated_duration: Option<u64>,
pub actual_duration: Option<u64>,
}
impl PlanStep {
pub fn new(description: impl Into<String>) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
description: description.into(),
status: StepStatus::Pending,
dependencies: Vec::new(),
output: None,
error: None,
estimated_duration: None,
actual_duration: None,
}
}
pub fn depends_on(mut self, step_id: impl Into<String>) -> Self {
self.dependencies.push(step_id.into());
self
}
pub fn estimated(mut self, seconds: u64) -> Self {
self.estimated_duration = Some(seconds);
self
}
pub fn start(&mut self) {
self.status = StepStatus::InProgress;
}
pub fn complete(&mut self, output: Option<String>) {
self.status = StepStatus::Completed;
self.output = output;
}
pub fn fail(&mut self, error: impl Into<String>) {
self.status = StepStatus::Failed;
self.error = Some(error.into());
}
pub fn skip(&mut self) {
self.status = StepStatus::Skipped;
}
pub fn is_ready(&self, completed_steps: &[String]) -> bool {
self.status == StepStatus::Pending
&& self.dependencies.iter().all(|d| completed_steps.contains(d))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Plan {
pub id: String,
pub name: String,
pub description: Option<String>,
pub steps: Vec<PlanStep>,
pub created_at: chrono::DateTime<chrono::Utc>,
pub updated_at: chrono::DateTime<chrono::Utc>,
pub metadata: HashMap<String, String>,
}
impl Plan {
pub fn new(name: impl Into<String>) -> Self {
let now = chrono::Utc::now();
Self {
id: uuid::Uuid::new_v4().to_string(),
name: name.into(),
description: None,
steps: Vec::new(),
created_at: now,
updated_at: now,
metadata: HashMap::new(),
}
}
pub fn description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
pub fn add_step(&mut self, step: PlanStep) {
self.steps.push(step);
self.updated_at = chrono::Utc::now();
}
pub fn get_step(&self, id: &str) -> Option<&PlanStep> {
self.steps.iter().find(|s| s.id == id)
}
pub fn get_step_mut(&mut self, id: &str) -> Option<&mut PlanStep> {
self.steps.iter_mut().find(|s| s.id == id)
}
pub fn completed_steps(&self) -> Vec<String> {
self.steps
.iter()
.filter(|s| s.status == StepStatus::Completed)
.map(|s| s.id.clone())
.collect()
}
pub fn next_step(&self) -> Option<&PlanStep> {
let completed = self.completed_steps();
self.steps.iter().find(|s| s.is_ready(&completed))
}
pub fn progress(&self) -> f64 {
if self.steps.is_empty() {
return 0.0;
}
let completed = self.steps.iter().filter(|s| s.status == StepStatus::Completed).count();
completed as f64 / self.steps.len() as f64 * 100.0
}
pub fn is_complete(&self) -> bool {
self.steps.iter().all(|s| {
s.status == StepStatus::Completed || s.status == StepStatus::Skipped
})
}
pub fn has_failed(&self) -> bool {
self.steps.iter().any(|s| s.status == StepStatus::Failed)
}
pub fn step_count(&self) -> usize {
self.steps.len()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TodoPriority {
Low,
Medium,
High,
Critical,
}
impl Default for TodoPriority {
fn default() -> Self {
TodoPriority::Medium
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TodoItem {
pub id: String,
pub content: String,
pub status: StepStatus,
pub priority: TodoPriority,
pub tags: Vec<String>,
pub due_date: Option<chrono::DateTime<chrono::Utc>>,
pub created_at: chrono::DateTime<chrono::Utc>,
}
impl TodoItem {
pub fn new(content: impl Into<String>) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
content: content.into(),
status: StepStatus::Pending,
priority: TodoPriority::Medium,
tags: Vec::new(),
due_date: None,
created_at: chrono::Utc::now(),
}
}
pub fn priority(mut self, priority: TodoPriority) -> Self {
self.priority = priority;
self
}
pub fn tag(mut self, tag: impl Into<String>) -> Self {
self.tags.push(tag.into());
self
}
pub fn due(mut self, date: chrono::DateTime<chrono::Utc>) -> Self {
self.due_date = Some(date);
self
}
pub fn complete(&mut self) {
self.status = StepStatus::Completed;
}
pub fn is_overdue(&self) -> bool {
if let Some(due) = self.due_date {
chrono::Utc::now() > due && self.status != StepStatus::Completed
} else {
false
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TodoList {
pub name: String,
pub items: Vec<TodoItem>,
}
impl TodoList {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
items: Vec::new(),
}
}
pub fn add(&mut self, item: TodoItem) {
self.items.push(item);
}
pub fn get(&self, id: &str) -> Option<&TodoItem> {
self.items.iter().find(|i| i.id == id)
}
pub fn get_mut(&mut self, id: &str) -> Option<&mut TodoItem> {
self.items.iter_mut().find(|i| i.id == id)
}
pub fn remove(&mut self, id: &str) -> Option<TodoItem> {
if let Some(pos) = self.items.iter().position(|i| i.id == id) {
Some(self.items.remove(pos))
} else {
None
}
}
pub fn pending(&self) -> Vec<&TodoItem> {
self.items.iter().filter(|i| i.status == StepStatus::Pending).collect()
}
pub fn completed(&self) -> Vec<&TodoItem> {
self.items.iter().filter(|i| i.status == StepStatus::Completed).collect()
}
pub fn overdue(&self) -> Vec<&TodoItem> {
self.items.iter().filter(|i| i.is_overdue()).collect()
}
pub fn by_tag(&self, tag: &str) -> Vec<&TodoItem> {
self.items.iter().filter(|i| i.tags.contains(&tag.to_string())).collect()
}
pub fn by_priority(&self, priority: TodoPriority) -> Vec<&TodoItem> {
self.items.iter().filter(|i| i.priority == priority).collect()
}
pub fn len(&self) -> usize {
self.items.len()
}
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
pub fn progress(&self) -> f64 {
if self.items.is_empty() {
return 0.0;
}
let completed = self.items.iter().filter(|i| i.status == StepStatus::Completed).count();
completed as f64 / self.items.len() as f64 * 100.0
}
}
#[derive(Debug, Default)]
pub struct PlanStorage {
plans: HashMap<String, Plan>,
path: Option<PathBuf>,
}
impl PlanStorage {
pub fn new() -> Self {
Self::default()
}
pub fn with_path(path: impl Into<PathBuf>) -> Self {
Self {
plans: HashMap::new(),
path: Some(path.into()),
}
}
pub fn save(&mut self, plan: Plan) {
self.plans.insert(plan.id.clone(), plan);
}
pub fn load(&self, id: &str) -> Option<&Plan> {
self.plans.get(id)
}
pub fn load_mut(&mut self, id: &str) -> Option<&mut Plan> {
self.plans.get_mut(id)
}
pub fn delete(&mut self, id: &str) -> Option<Plan> {
self.plans.remove(id)
}
pub fn list(&self) -> Vec<&Plan> {
self.plans.values().collect()
}
pub fn count(&self) -> usize {
self.plans.len()
}
pub fn persist(&self) -> std::io::Result<()> {
if let Some(ref path) = self.path {
let json = serde_json::to_string_pretty(&self.plans)?;
std::fs::write(path, json)?;
}
Ok(())
}
pub fn restore(&mut self) -> std::io::Result<()> {
if let Some(ref path) = self.path {
if path.exists() {
let json = std::fs::read_to_string(path)?;
self.plans = serde_json::from_str(&json)?;
}
}
Ok(())
}
}
pub const READ_ONLY_TOOLS: &[&str] = &[
"read_file",
"list_directory",
"search_codebase",
"search_files",
"grep_search",
"find_files",
"web_search",
"get_file_content",
"list_files",
"read_document",
"search_web",
"fetch_url",
"get_context",
];
pub const RESTRICTED_TOOLS: &[&str] = &[
"write_file",
"create_file",
"delete_file",
"execute_command",
"run_command",
"shell_command",
"modify_file",
"edit_file",
"remove_file",
"move_file",
"copy_file",
"mkdir",
"rmdir",
"git_commit",
"git_push",
"npm_install",
"pip_install",
];
pub const RESEARCH_TOOLS: &[&str] = &[
"web_search",
"search_web",
"duckduckgo_search",
"tavily_search",
"brave_search",
"google_search",
"read_url",
"fetch_url",
"read_file",
"list_directory",
"search_codebase",
"grep_search",
"find_files",
];
pub fn is_read_only_tool(name: &str) -> bool {
READ_ONLY_TOOLS.contains(&name)
}
pub fn is_restricted_tool(name: &str) -> bool {
RESTRICTED_TOOLS.contains(&name)
}
pub fn is_research_tool(name: &str) -> bool {
RESEARCH_TOOLS.contains(&name)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_plan_step_new() {
let step = PlanStep::new("Test step");
assert_eq!(step.description, "Test step");
assert_eq!(step.status, StepStatus::Pending);
}
#[test]
fn test_plan_step_lifecycle() {
let mut step = PlanStep::new("Test step");
step.start();
assert_eq!(step.status, StepStatus::InProgress);
step.complete(Some("Done".to_string()));
assert_eq!(step.status, StepStatus::Completed);
assert_eq!(step.output, Some("Done".to_string()));
}
#[test]
fn test_plan_step_fail() {
let mut step = PlanStep::new("Test step");
step.fail("Something went wrong");
assert_eq!(step.status, StepStatus::Failed);
assert!(step.error.is_some());
}
#[test]
fn test_plan_new() {
let plan = Plan::new("Test plan");
assert_eq!(plan.name, "Test plan");
assert!(plan.steps.is_empty());
}
#[test]
fn test_plan_add_steps() {
let mut plan = Plan::new("Test plan");
plan.add_step(PlanStep::new("Step 1"));
plan.add_step(PlanStep::new("Step 2"));
assert_eq!(plan.step_count(), 2);
}
#[test]
fn test_plan_progress() {
let mut plan = Plan::new("Test plan");
plan.add_step(PlanStep::new("Step 1"));
plan.add_step(PlanStep::new("Step 2"));
assert!((plan.progress() - 0.0).abs() < 0.001);
if let Some(step) = plan.steps.get_mut(0) {
step.complete(None);
}
assert!((plan.progress() - 50.0).abs() < 0.001);
}
#[test]
fn test_todo_item_new() {
let item = TodoItem::new("Test item");
assert_eq!(item.content, "Test item");
assert_eq!(item.priority, TodoPriority::Medium);
}
#[test]
fn test_todo_item_priority() {
let item = TodoItem::new("Test item").priority(TodoPriority::High);
assert_eq!(item.priority, TodoPriority::High);
}
#[test]
fn test_todo_list_new() {
let list = TodoList::new("My List");
assert_eq!(list.name, "My List");
assert!(list.is_empty());
}
#[test]
fn test_todo_list_add() {
let mut list = TodoList::new("My List");
list.add(TodoItem::new("Item 1"));
list.add(TodoItem::new("Item 2"));
assert_eq!(list.len(), 2);
}
#[test]
fn test_todo_list_progress() {
let mut list = TodoList::new("My List");
list.add(TodoItem::new("Item 1"));
list.add(TodoItem::new("Item 2"));
assert!((list.progress() - 0.0).abs() < 0.001);
if let Some(item) = list.items.get_mut(0) {
item.complete();
}
assert!((list.progress() - 50.0).abs() < 0.001);
}
#[test]
fn test_plan_storage() {
let mut storage = PlanStorage::new();
let plan = Plan::new("Test plan");
let id = plan.id.clone();
storage.save(plan);
assert_eq!(storage.count(), 1);
let loaded = storage.load(&id);
assert!(loaded.is_some());
assert_eq!(loaded.unwrap().name, "Test plan");
}
#[test]
fn test_is_read_only_tool() {
assert!(is_read_only_tool("read_file"));
assert!(!is_read_only_tool("write_file"));
}
#[test]
fn test_is_restricted_tool() {
assert!(is_restricted_tool("write_file"));
assert!(!is_restricted_tool("read_file"));
}
}