use std::{
fs,
path::PathBuf,
sync::atomic::{AtomicU64, Ordering},
time::{SystemTime, UNIX_EPOCH},
};
use crate::{
BuiltinProvider, ContentBlock, Message, Role,
agent::{AgentConfig, CompactionConfig, TaskConfig},
provider::{ContentBlockDelta, ContentBlockStart, ProviderError, ProviderEvent},
runtime::{
Runtime, SqliteRuntimeStore, TaskItem, TaskStatus, TaskStore, task::TASK_REMINDER_TEXT,
},
};
use super::support::{ScriptedProvider, erroring_stream, model_info, ok_stream};
#[tokio::test]
async fn task_updates_snapshot_and_persists_for_new_agents() {
let model = model_info("model", BuiltinProvider::Anthropic);
let tasks_dir = temp_tasks_dir("persist");
let store = temp_store("persist");
let provider = ScriptedProvider::new(
BuiltinProvider::Anthropic,
vec![model.clone()],
vec![
task_tool_stream(
"tool-1",
"task_create",
r#"{"subject":"Plan work","owner":"agent-a"}"#,
),
text_stream("created"),
],
);
let runtime = Runtime::builder()
.with_store(store.clone())
.with_provider_instance(provider)
.build()
.expect("build runtime");
let config = task_config(tasks_dir.clone());
let mut agent = runtime
.spawn_with_config("agent", model.clone(), config.clone())
.expect("spawn agent");
agent
.send(vec![ContentBlock::Text {
text: "start".to_string(),
}])
.await
.expect("send");
assert_eq!(
agent.watch_snapshot().borrow().tasks,
vec![TaskItem {
id: 1,
subject: "Plan work".to_string(),
description: String::new(),
status: TaskStatus::Pending,
blocked_by: Vec::new(),
blocks: Vec::new(),
owner: "agent-a".to_string(),
working_directory: None,
}]
);
assert_eq!(
store
.load_tasks(tasks_dir.as_path())
.expect("load persisted tasks")
.len(),
1
);
let other_provider = ScriptedProvider::new(
BuiltinProvider::Anthropic,
vec![model.clone()],
vec![text_stream("ok")],
);
let other_runtime = Runtime::builder()
.with_store(store)
.with_provider_instance(other_provider)
.build()
.expect("build runtime");
let other_agent = other_runtime
.spawn_with_config("other", model, config)
.expect("spawn other agent");
assert_eq!(other_agent.watch_snapshot().borrow().tasks.len(), 1);
}
#[tokio::test]
async fn task_reminder_is_injected_after_three_rounds_without_task_tools() {
let model = model_info("model", BuiltinProvider::Anthropic);
let tasks_dir = temp_tasks_dir("reminder");
let provider = ScriptedProvider::new(
BuiltinProvider::Anthropic,
vec![model.clone()],
vec![
task_tool_stream("tool-1", "task_create", r#"{"subject":"Plan work"}"#),
text_stream("created"),
text_stream("round 1"),
text_stream("round 2"),
text_stream("round 3"),
text_stream("round 4"),
],
);
let provider_handle = provider.clone();
let runtime = Runtime::builder()
.with_provider_instance(provider)
.build()
.expect("build runtime");
let mut agent = runtime
.spawn_with_config(
"agent",
model,
AgentConfig {
system: Some("Base system prompt".to_string()),
task: TaskConfig {
tasks_dir,
reminder_threshold: 3,
},
..Default::default()
},
)
.expect("spawn agent");
agent
.send(vec![ContentBlock::Text {
text: "set task".to_string(),
}])
.await
.expect("create task");
for round in 1..=4 {
agent
.send(vec![ContentBlock::Text {
text: format!("round {round}"),
}])
.await
.expect("send round");
}
let requests = provider_handle.recorded_requests().await;
assert_eq!(requests.len(), 6);
assert_eq!(requests[0].system.as_deref(), Some("Base system prompt"));
assert_eq!(requests[3].system.as_deref(), Some("Base system prompt"));
let expected_system = format!("{TASK_REMINDER_TEXT}\n\nBase system prompt");
assert_eq!(
requests[4].system.as_deref(),
Some(expected_system.as_str())
);
assert_eq!(
requests[5].system.as_deref(),
Some(expected_system.as_str())
);
}
#[tokio::test]
async fn task_state_rolls_back_when_run_fails() {
let model = model_info("model", BuiltinProvider::Anthropic);
let tasks_dir = temp_tasks_dir("rollback");
let store = temp_store("rollback");
let provider = ScriptedProvider::new(
BuiltinProvider::Anthropic,
vec![model.clone()],
vec![
task_tool_stream("tool-1", "task_create", r#"{"subject":"Plan work"}"#),
erroring_stream(
vec![ProviderEvent::MessageStarted {
id: "msg-fail".to_string(),
model: model.id.clone(),
role: Role::Assistant,
}],
ProviderError::MalformedStream("boom".to_string()),
),
],
);
let runtime = Runtime::builder()
.with_store(store.clone())
.with_provider_instance(provider)
.build()
.expect("build runtime");
let mut agent = runtime
.spawn_with_config("agent", model, task_config(tasks_dir.clone()))
.expect("spawn agent");
let result = agent
.send(vec![ContentBlock::Text {
text: "create task".to_string(),
}])
.await;
assert!(result.is_err());
assert!(agent.history().is_empty());
assert!(agent.watch_snapshot().borrow().tasks.is_empty());
assert!(
store
.load_tasks(tasks_dir.as_path())
.expect("load rolled-back tasks")
.is_empty()
);
}
#[tokio::test]
async fn task_survives_auto_compaction() {
let model = model_info("model", BuiltinProvider::Anthropic);
let tasks_dir = temp_tasks_dir("compact");
let provider = ScriptedProvider::new(
BuiltinProvider::Anthropic,
vec![model.clone()],
vec![
task_tool_stream("tool-1", "task_create", r#"{"subject":"Plan work"}"#),
text_stream("created"),
text_stream("summary"),
text_stream("after compact"),
],
);
let runtime = Runtime::builder()
.with_provider_instance(provider)
.build()
.expect("build runtime");
let mut agent = runtime
.spawn_with_config(
"agent",
model,
AgentConfig {
task: TaskConfig {
tasks_dir,
reminder_threshold: 3,
},
compaction: CompactionConfig {
auto_compact_threshold_tokens: Some(500),
..CompactionConfig::default()
},
..Default::default()
},
)
.expect("spawn agent");
agent
.send(vec![ContentBlock::Text {
text: "create task".to_string(),
}])
.await
.expect("create task");
agent
.send(vec![ContentBlock::Text {
text: "trigger compact ".repeat(100),
}])
.await
.expect("trigger compact");
assert_eq!(agent.watch_snapshot().borrow().tasks.len(), 1);
assert!(agent.history().iter().any(|message| {
matches!(
message,
Message {
role: Role::User,
content,
} if matches!(content.first(), Some(ContentBlock::Text { text }) if text.contains("[Compaction summary]"))
)
}));
}
fn task_config(tasks_dir: PathBuf) -> AgentConfig {
AgentConfig {
task: TaskConfig {
tasks_dir,
reminder_threshold: 3,
},
..Default::default()
}
}
fn task_tool_stream(
tool_id: &str,
tool_name: &str,
input_json: &str,
) -> super::support::StreamScript {
ok_stream(vec![
ProviderEvent::MessageStarted {
id: format!("msg-{tool_id}"),
model: "model".to_string(),
role: Role::Assistant,
},
ProviderEvent::ContentBlockStarted {
index: 0,
kind: ContentBlockStart::ToolUse {
id: tool_id.to_string(),
name: tool_name.to_string(),
},
},
ProviderEvent::ContentBlockDelta {
index: 0,
delta: ContentBlockDelta::ToolUseInputJson(input_json.to_string()),
},
ProviderEvent::ContentBlockStopped { index: 0 },
ProviderEvent::MessageStopped,
])
}
fn text_stream(text: &str) -> super::support::StreamScript {
ok_stream(vec![
ProviderEvent::MessageStarted {
id: format!("msg-{text}"),
model: "model".to_string(),
role: Role::Assistant,
},
ProviderEvent::ContentBlockStarted {
index: 0,
kind: ContentBlockStart::Text,
},
ProviderEvent::ContentBlockDelta {
index: 0,
delta: ContentBlockDelta::Text(text.to_string()),
},
ProviderEvent::ContentBlockStopped { index: 0 },
ProviderEvent::MessageStopped,
])
}
static NEXT_TEMP_ID: AtomicU64 = AtomicU64::new(1);
fn temp_tasks_dir(label: &str) -> PathBuf {
let unique = NEXT_TEMP_ID.fetch_add(1, Ordering::Relaxed);
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time")
.as_nanos();
let path =
std::env::temp_dir().join(format!("mentra-task-runtime-{label}-{timestamp}-{unique}"));
fs::create_dir_all(&path).expect("create temp dir");
path
}
fn temp_store(label: &str) -> SqliteRuntimeStore {
let unique = NEXT_TEMP_ID.fetch_add(1, Ordering::Relaxed);
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time")
.as_nanos();
SqliteRuntimeStore::new(std::env::temp_dir().join(format!(
"mentra-task-runtime-store-{label}-{timestamp}-{unique}.sqlite"
)))
}