use chrono::{DateTime, Utc};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::artifact::Artifact;
use crate::message::Message;
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct Task {
pub id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub context_id: Option<String>,
pub state: TaskState,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub messages: Vec<Message>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub artifacts: Vec<Artifact>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub created_at: Option<DateTime<Utc>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub updated_at: Option<DateTime<Utc>>,
}
impl Task {
pub fn new() -> Self {
Self {
id: Uuid::new_v4().to_string(),
context_id: None,
state: TaskState::Submitted,
messages: Vec::new(),
artifacts: Vec::new(),
metadata: None,
created_at: Some(Utc::now()),
updated_at: Some(Utc::now()),
}
}
pub fn with_context(context_id: impl Into<String>) -> Self {
Self {
context_id: Some(context_id.into()),
..Self::new()
}
}
pub fn is_terminal(&self) -> bool {
matches!(
self.state,
TaskState::Completed | TaskState::Failed | TaskState::Canceled | TaskState::Rejected
)
}
pub fn is_interrupted(&self) -> bool {
matches!(
self.state,
TaskState::InputRequired | TaskState::AuthRequired
)
}
pub fn transition(&mut self, new_state: TaskState) -> Result<(), InvalidTransition> {
if self.is_terminal() {
return Err(InvalidTransition {
from: self.state.clone(),
to: new_state,
});
}
self.state = new_state;
self.updated_at = Some(Utc::now());
Ok(())
}
pub fn add_message(&mut self, message: Message) {
self.messages.push(message);
self.updated_at = Some(Utc::now());
}
pub fn add_artifact(&mut self, artifact: Artifact) {
self.artifacts.push(artifact);
self.updated_at = Some(Utc::now());
}
}
impl Default for Task {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum TaskState {
Submitted,
Working,
Completed,
Failed,
Canceled,
Rejected,
InputRequired,
AuthRequired,
}
impl std::fmt::Display for TaskState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TaskState::Submitted => write!(f, "SUBMITTED"),
TaskState::Working => write!(f, "WORKING"),
TaskState::Completed => write!(f, "COMPLETED"),
TaskState::Failed => write!(f, "FAILED"),
TaskState::Canceled => write!(f, "CANCELED"),
TaskState::Rejected => write!(f, "REJECTED"),
TaskState::InputRequired => write!(f, "INPUT_REQUIRED"),
TaskState::AuthRequired => write!(f, "AUTH_REQUIRED"),
}
}
}
#[derive(Debug)]
pub struct InvalidTransition {
pub from: TaskState,
pub to: TaskState,
}
impl std::fmt::Display for InvalidTransition {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"invalid task transition from {} to {}",
self.from, self.to
)
}
}
impl std::error::Error for InvalidTransition {}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TaskQueryParams {
#[serde(skip_serializing_if = "Option::is_none")]
pub context_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub state: Option<TaskState>,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cursor: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase", tag = "type")]
pub enum TaskEvent {
StateChanged { task_id: String, state: TaskState },
MessageAdded { task_id: String, message: Message },
ArtifactAdded { task_id: String, artifact: Artifact },
ArtifactChunk {
task_id: String,
artifact_id: String,
chunk: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_task_lifecycle() {
let mut task = Task::new();
assert_eq!(task.state, TaskState::Submitted);
assert!(!task.is_terminal());
task.transition(TaskState::Working).unwrap();
assert_eq!(task.state, TaskState::Working);
task.transition(TaskState::InputRequired).unwrap();
assert!(task.is_interrupted());
task.transition(TaskState::Working).unwrap();
task.transition(TaskState::Completed).unwrap();
assert!(task.is_terminal());
assert!(task.transition(TaskState::Working).is_err());
}
#[test]
fn test_task_serialization() {
let task = Task::new();
let json = serde_json::to_string(&task).unwrap();
assert!(json.contains("SUBMITTED"));
let parsed: Task = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.state, TaskState::Submitted);
}
}