use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::sync::Arc;
use crate::error::Result;
use crate::typed_id::SessionId;
#[cfg(feature = "openapi")]
use utoipa::ToSchema;
pub type TaskProgress = crate::background::BackgroundProgress;
pub const TASK_KIND_SUBAGENT: &str = "subagent";
pub const TASK_KIND_EXTERNAL_AGENT: &str = "external_agent";
pub const TASK_KIND_BACKGROUND_TOOL: &str = "background_tool";
pub const TASK_KIND_MONITOR: &str = "monitor";
pub fn generate_task_id() -> String {
format!("task_{}", uuid::Uuid::now_v7().simple())
}
pub fn generate_task_message_id() -> String {
format!("tmsg_{}", uuid::Uuid::now_v7().simple())
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[cfg_attr(feature = "openapi", derive(ToSchema))]
#[serde(rename_all = "snake_case")]
pub enum SessionTaskState {
Queued,
Running,
AwaitingInput,
Succeeded,
Failed,
Canceled,
}
impl SessionTaskState {
pub fn is_terminal(&self) -> bool {
matches!(self, Self::Succeeded | Self::Failed | Self::Canceled)
}
pub fn parse(s: &str) -> Option<Self> {
match s {
"queued" => Some(Self::Queued),
"running" => Some(Self::Running),
"awaiting_input" => Some(Self::AwaitingInput),
"succeeded" => Some(Self::Succeeded),
"failed" => Some(Self::Failed),
"canceled" => Some(Self::Canceled),
_ => None,
}
}
}
impl std::fmt::Display for SessionTaskState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
Self::Queued => "queued",
Self::Running => "running",
Self::AwaitingInput => "awaiting_input",
Self::Succeeded => "succeeded",
Self::Failed => "failed",
Self::Canceled => "canceled",
};
write!(f, "{s}")
}
}
impl From<&str> for SessionTaskState {
fn from(s: &str) -> Self {
match s {
"running" => Self::Running,
"awaiting_input" => Self::AwaitingInput,
"succeeded" => Self::Succeeded,
"failed" => Self::Failed,
"canceled" => Self::Canceled,
_ => Self::Queued,
}
}
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
#[cfg_attr(feature = "openapi", derive(ToSchema))]
#[serde(rename_all = "snake_case")]
pub enum TaskWakePolicy {
#[default]
Silent,
OnTerminal,
OnActivity,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[cfg_attr(feature = "openapi", derive(ToSchema))]
pub struct TaskInputRequest {
pub id: String,
pub prompt: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[cfg_attr(feature = "openapi", schema(value_type = Object))]
pub expected: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[cfg_attr(feature = "openapi", derive(ToSchema))]
pub struct TaskError {
pub kind: String,
pub message: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[cfg_attr(feature = "openapi", derive(ToSchema))]
pub struct TaskArtifact {
pub name: String,
#[serde(rename = "type")]
pub artifact_type: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub path: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub url: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
#[cfg_attr(feature = "openapi", derive(ToSchema))]
pub struct TaskLinks {
#[serde(default, skip_serializing_if = "Option::is_none")]
#[cfg_attr(feature = "openapi", schema(value_type = Option<String>))]
pub child_session_id: Option<SessionId>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub remote_task_id: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub resource_ids: Vec<String>,
}
impl TaskLinks {
pub fn is_empty(&self) -> bool {
self.child_session_id.is_none()
&& self.remote_task_id.is_none()
&& self.resource_ids.is_empty()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "openapi", derive(ToSchema))]
pub struct SessionTask {
pub id: String,
#[cfg_attr(feature = "openapi", schema(value_type = String))]
pub session_id: SessionId,
pub kind: String,
pub display_name: String,
#[serde(default)]
#[cfg_attr(feature = "openapi", schema(value_type = Object))]
pub spec: Value,
pub state: SessionTaskState,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub state_detail: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub progress: Option<TaskProgress>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub input_request: Option<TaskInputRequest>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cancel_requested_at: Option<DateTime<Utc>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub summary: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub result_path: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub artifacts: Vec<TaskArtifact>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub error: Option<TaskError>,
#[serde(default = "default_attempt")]
pub attempt: i32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub worker_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub heartbeat_at: Option<DateTime<Utc>>,
#[serde(default, skip_serializing_if = "TaskLinks::is_empty")]
pub links: TaskLinks,
#[serde(default)]
pub wake_policy: TaskWakePolicy,
pub created_at: DateTime<Utc>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub started_at: Option<DateTime<Utc>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub finished_at: Option<DateTime<Utc>>,
pub updated_at: DateTime<Utc>,
}
fn default_attempt() -> i32 {
1
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateSessionTask {
pub session_id: SessionId,
#[serde(default)]
pub id: Option<String>,
pub kind: String,
pub display_name: String,
#[serde(default)]
pub spec: Value,
#[serde(default = "default_queued")]
pub state: SessionTaskState,
#[serde(default)]
pub links: TaskLinks,
#[serde(default)]
pub wake_policy: TaskWakePolicy,
}
fn default_queued() -> SessionTaskState {
SessionTaskState::Queued
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SessionTaskUpdate {
pub state: Option<SessionTaskState>,
pub state_detail: Option<String>,
pub progress: Option<TaskProgress>,
pub input_request: Option<TaskInputRequest>,
pub summary: Option<String>,
pub result_path: Option<String>,
pub artifacts: Option<Vec<TaskArtifact>>,
pub error: Option<TaskError>,
pub links: Option<TaskLinks>,
pub worker_id: Option<String>,
pub heartbeat_at: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone, Default)]
pub struct SessionTaskFilter {
pub kind: Option<String>,
pub state: Option<SessionTaskState>,
}
pub fn apply_task_update(task: &mut SessionTask, update: SessionTaskUpdate, now: DateTime<Utc>) {
let was_terminal = task.state.is_terminal();
if was_terminal
&& let Some(state) = update.state
&& state != task.state
{
return;
}
let mut next_state = update.state;
if update.input_request.is_some() && !was_terminal {
next_state = Some(SessionTaskState::AwaitingInput);
}
if let Some(input_request) = update.input_request
&& !was_terminal
{
task.input_request = Some(input_request);
}
if let Some(state) = next_state
&& !was_terminal
&& task.state != state
{
if task.state == SessionTaskState::Queued && state != SessionTaskState::Queued {
task.started_at.get_or_insert(now);
}
if state.is_terminal() {
task.finished_at.get_or_insert(now);
}
if state != SessionTaskState::AwaitingInput {
task.input_request = None;
}
task.state = state;
}
if let Some(detail) = update.state_detail {
task.state_detail = Some(detail);
}
if let Some(progress) = update.progress {
task.progress = Some(progress);
}
if let Some(summary) = update.summary {
task.summary = Some(summary);
}
if let Some(result_path) = update.result_path {
task.result_path = Some(result_path);
}
if let Some(artifacts) = update.artifacts {
task.artifacts = artifacts;
}
if let Some(error) = update.error {
task.error = Some(error);
}
if let Some(links) = update.links {
if links.child_session_id.is_some() {
task.links.child_session_id = links.child_session_id;
}
if links.remote_task_id.is_some() {
task.links.remote_task_id = links.remote_task_id;
}
for id in links.resource_ids {
if !task.links.resource_ids.contains(&id) {
task.links.resource_ids.push(id);
}
}
}
if let Some(worker_id) = update.worker_id {
task.worker_id = Some(worker_id);
}
if let Some(heartbeat_at) = update.heartbeat_at {
task.heartbeat_at = Some(heartbeat_at);
}
task.updated_at = now;
}
pub fn new_session_task(input: CreateSessionTask, now: DateTime<Utc>) -> SessionTask {
let state = input.state;
SessionTask {
id: input.id.unwrap_or_else(generate_task_id),
session_id: input.session_id,
kind: input.kind,
display_name: input.display_name,
spec: input.spec,
state,
state_detail: None,
progress: None,
input_request: None,
cancel_requested_at: None,
summary: None,
result_path: None,
artifacts: Vec::new(),
error: None,
attempt: 1,
worker_id: None,
heartbeat_at: None,
links: input.links,
wake_policy: input.wake_policy,
created_at: now,
started_at: if state == SessionTaskState::Queued {
None
} else {
Some(now)
},
finished_at: if state.is_terminal() { Some(now) } else { None },
updated_at: now,
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[cfg_attr(feature = "openapi", derive(ToSchema))]
#[serde(rename_all = "snake_case")]
pub enum TaskMessageDirection {
Inbound,
Outbound,
}
impl std::fmt::Display for TaskMessageDirection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Inbound => write!(f, "inbound"),
Self::Outbound => write!(f, "outbound"),
}
}
}
impl From<&str> for TaskMessageDirection {
fn from(s: &str) -> Self {
match s {
"outbound" => Self::Outbound,
_ => Self::Inbound,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[cfg_attr(feature = "openapi", derive(ToSchema))]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum TaskMessagePart {
Text {
text: String,
},
Data {
#[cfg_attr(feature = "openapi", schema(value_type = Object))]
data: Value,
},
}
impl TaskMessagePart {
pub fn text(text: impl Into<String>) -> Self {
Self::Text { text: text.into() }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "openapi", derive(ToSchema))]
pub struct TaskMessage {
pub id: String,
pub task_id: String,
pub direction: TaskMessageDirection,
pub content: Vec<TaskMessagePart>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub in_reply_to: Option<String>,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NewTaskMessage {
pub direction: TaskMessageDirection,
pub content: Vec<TaskMessagePart>,
#[serde(default)]
pub in_reply_to: Option<String>,
}
impl NewTaskMessage {
pub fn inbound_text(text: impl Into<String>) -> Self {
Self {
direction: TaskMessageDirection::Inbound,
content: vec![TaskMessagePart::text(text)],
in_reply_to: None,
}
}
pub fn outbound_text(text: impl Into<String>) -> Self {
Self {
direction: TaskMessageDirection::Outbound,
content: vec![TaskMessagePart::text(text)],
in_reply_to: None,
}
}
}
pub fn task_message_text(content: &[TaskMessagePart]) -> String {
content
.iter()
.filter_map(|part| match part {
TaskMessagePart::Text { text } => Some(text.as_str()),
TaskMessagePart::Data { .. } => None,
})
.collect::<Vec<_>>()
.join("\n")
}
#[async_trait]
pub trait SessionTaskRegistry: Send + Sync {
async fn create(&self, input: CreateSessionTask) -> Result<SessionTask>;
async fn update(
&self,
session_id: SessionId,
task_id: &str,
update: SessionTaskUpdate,
) -> Result<Option<SessionTask>>;
async fn get(&self, session_id: SessionId, task_id: &str) -> Result<Option<SessionTask>>;
async fn list(
&self,
session_id: SessionId,
filter: Option<&SessionTaskFilter>,
) -> Result<Vec<SessionTask>>;
async fn request_cancel(
&self,
session_id: SessionId,
task_id: &str,
) -> Result<Option<SessionTask>>;
async fn record_message(
&self,
session_id: SessionId,
task_id: &str,
message: NewTaskMessage,
) -> Result<TaskMessage>;
async fn list_messages(
&self,
session_id: SessionId,
task_id: &str,
limit: Option<u32>,
) -> Result<Vec<TaskMessage>>;
}
#[async_trait]
pub trait TaskExecutor: Send + Sync {
fn kind(&self) -> &str;
async fn start(&self, task: &SessionTask, context: &crate::traits::ToolContext) -> Result<()> {
let _ = (task, context);
Err(crate::error::AgentLoopError::tool(format!(
"task kind '{}' does not support start via the registry",
self.kind()
)))
}
async fn deliver(
&self,
task: &SessionTask,
message: &TaskMessage,
context: &crate::traits::ToolContext,
) -> Result<()> {
let _ = (task, message, context);
Err(crate::error::AgentLoopError::tool(format!(
"task kind '{}' does not accept inbound messages",
self.kind()
)))
}
async fn cancel(&self, task: &SessionTask, context: &crate::traits::ToolContext) -> Result<()>;
async fn reconcile(
&self,
task: &SessionTask,
context: &crate::traits::ToolContext,
) -> Result<()> {
let _ = (task, context);
Ok(())
}
}
pub struct TaskExecutorPlugin {
pub executor: fn() -> Arc<dyn TaskExecutor>,
}
inventory::collect!(TaskExecutorPlugin);
pub fn find_task_executor(kind: &str) -> Option<Arc<dyn TaskExecutor>> {
inventory::iter::<TaskExecutorPlugin>
.into_iter()
.map(|plugin| (plugin.executor)())
.find(|executor| executor.kind() == kind)
}
#[async_trait]
pub trait TaskSink: Send + Sync {
async fn state(&self, state: SessionTaskState, detail: Option<String>) -> Result<()>;
async fn progress(&self, progress: TaskProgress) -> Result<()>;
async fn output(&self, stream: &str, delta: &str) -> Result<()>;
async fn post(&self, message: NewTaskMessage) -> Result<()>;
async fn request_input(&self, request: TaskInputRequest) -> Result<()>;
async fn artifact(&self, artifact: TaskArtifact) -> Result<()>;
}
pub struct RegistryTaskSink {
registry: Arc<dyn SessionTaskRegistry>,
session_id: SessionId,
task_id: String,
}
impl RegistryTaskSink {
pub fn new(
registry: Arc<dyn SessionTaskRegistry>,
session_id: SessionId,
task_id: String,
) -> Self {
Self {
registry,
session_id,
task_id,
}
}
}
#[async_trait]
impl TaskSink for RegistryTaskSink {
async fn state(&self, state: SessionTaskState, detail: Option<String>) -> Result<()> {
self.registry
.update(
self.session_id,
&self.task_id,
SessionTaskUpdate {
state: Some(state),
state_detail: detail,
..Default::default()
},
)
.await?;
Ok(())
}
async fn progress(&self, progress: TaskProgress) -> Result<()> {
self.registry
.update(
self.session_id,
&self.task_id,
SessionTaskUpdate {
progress: Some(progress),
..Default::default()
},
)
.await?;
Ok(())
}
async fn output(&self, _stream: &str, _delta: &str) -> Result<()> {
Ok(())
}
async fn post(&self, message: NewTaskMessage) -> Result<()> {
self.registry
.record_message(self.session_id, &self.task_id, message)
.await?;
Ok(())
}
async fn request_input(&self, request: TaskInputRequest) -> Result<()> {
self.registry
.update(
self.session_id,
&self.task_id,
SessionTaskUpdate {
input_request: Some(request),
..Default::default()
},
)
.await?;
Ok(())
}
async fn artifact(&self, artifact: TaskArtifact) -> Result<()> {
let Some(task) = self.registry.get(self.session_id, &self.task_id).await? else {
return Ok(());
};
let mut artifacts = task.artifacts;
artifacts.push(artifact);
self.registry
.update(
self.session_id,
&self.task_id,
SessionTaskUpdate {
artifacts: Some(artifacts),
..Default::default()
},
)
.await?;
Ok(())
}
}
pub fn task_vfs_dir(task_id: &str) -> String {
format!("/.tasks/{task_id}")
}
pub fn task_result_path(task_id: &str) -> String {
format!("/.tasks/{task_id}/result.json")
}
#[cfg(test)]
mod tests {
use super::*;
fn task() -> SessionTask {
new_session_task(
CreateSessionTask {
session_id: SessionId::new(),
id: None,
kind: TASK_KIND_BACKGROUND_TOOL.to_string(),
display_name: "Test".to_string(),
spec: serde_json::json!({}),
state: SessionTaskState::Queued,
links: TaskLinks::default(),
wake_policy: TaskWakePolicy::Silent,
},
Utc::now(),
)
}
#[test]
fn create_generates_prefixed_id() {
let t = task();
assert!(t.id.starts_with("task_"));
assert_eq!(t.state, SessionTaskState::Queued);
assert!(t.started_at.is_none());
}
#[test]
fn first_transition_out_of_queued_stamps_started_at() {
let mut t = task();
let now = Utc::now();
apply_task_update(
&mut t,
SessionTaskUpdate {
state: Some(SessionTaskState::Running),
..Default::default()
},
now,
);
assert_eq!(t.state, SessionTaskState::Running);
assert_eq!(t.started_at, Some(now));
assert!(t.finished_at.is_none());
}
#[test]
fn terminal_transition_stamps_finished_at_and_is_final() {
let mut t = task();
let now = Utc::now();
apply_task_update(
&mut t,
SessionTaskUpdate {
state: Some(SessionTaskState::Succeeded),
summary: Some("done".to_string()),
..Default::default()
},
now,
);
assert_eq!(t.state, SessionTaskState::Succeeded);
assert_eq!(t.finished_at, Some(now));
apply_task_update(
&mut t,
SessionTaskUpdate {
state: Some(SessionTaskState::Failed),
error: Some(TaskError {
kind: "orphaned".to_string(),
message: "stale".to_string(),
}),
..Default::default()
},
Utc::now(),
);
assert_eq!(t.state, SessionTaskState::Succeeded);
assert!(t.error.is_none());
apply_task_update(
&mut t,
SessionTaskUpdate {
state: Some(SessionTaskState::Succeeded),
result_path: Some("/.tasks/x/result.json".to_string()),
..Default::default()
},
Utc::now(),
);
assert_eq!(t.result_path.as_deref(), Some("/.tasks/x/result.json"));
apply_task_update(
&mut t,
SessionTaskUpdate {
summary: Some("enriched".to_string()),
..Default::default()
},
Utc::now(),
);
assert_eq!(t.summary.as_deref(), Some("enriched"));
}
#[test]
fn input_request_forces_awaiting_input_and_clears_on_resume() {
let mut t = task();
apply_task_update(
&mut t,
SessionTaskUpdate {
input_request: Some(TaskInputRequest {
id: "req_1".to_string(),
prompt: "Approve?".to_string(),
expected: None,
}),
..Default::default()
},
Utc::now(),
);
assert_eq!(t.state, SessionTaskState::AwaitingInput);
assert!(t.input_request.is_some());
apply_task_update(
&mut t,
SessionTaskUpdate {
state: Some(SessionTaskState::Running),
..Default::default()
},
Utc::now(),
);
assert_eq!(t.state, SessionTaskState::Running);
assert!(t.input_request.is_none());
}
#[test]
fn links_merge_without_duplicates() {
let mut t = task();
let child = SessionId::new();
apply_task_update(
&mut t,
SessionTaskUpdate {
links: Some(TaskLinks {
child_session_id: Some(child),
remote_task_id: None,
resource_ids: vec!["res_1".to_string()],
}),
..Default::default()
},
Utc::now(),
);
apply_task_update(
&mut t,
SessionTaskUpdate {
links: Some(TaskLinks {
child_session_id: None,
remote_task_id: Some("rt_1".to_string()),
resource_ids: vec!["res_1".to_string(), "res_2".to_string()],
}),
..Default::default()
},
Utc::now(),
);
assert_eq!(t.links.child_session_id, Some(child));
assert_eq!(t.links.remote_task_id.as_deref(), Some("rt_1"));
assert_eq!(t.links.resource_ids, vec!["res_1", "res_2"]);
}
#[test]
fn message_text_rendering() {
let msg = NewTaskMessage::outbound_text("hello");
assert_eq!(task_message_text(&msg.content), "hello");
}
}