1use chrono::{DateTime, Utc};
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use uuid::Uuid;
19
20use crate::artifact::Artifact;
21use crate::message::Message;
22
23#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
25#[serde(rename_all = "camelCase")]
26pub struct Task {
27 pub id: String,
29
30 #[serde(skip_serializing_if = "Option::is_none")]
32 pub context_id: Option<String>,
33
34 pub state: TaskState,
36
37 #[serde(default, skip_serializing_if = "Vec::is_empty")]
39 pub messages: Vec<Message>,
40
41 #[serde(default, skip_serializing_if = "Vec::is_empty")]
43 pub artifacts: Vec<Artifact>,
44
45 #[serde(default, skip_serializing_if = "Option::is_none")]
47 pub metadata: Option<serde_json::Value>,
48
49 #[serde(skip_serializing_if = "Option::is_none")]
51 pub created_at: Option<DateTime<Utc>>,
52
53 #[serde(skip_serializing_if = "Option::is_none")]
55 pub updated_at: Option<DateTime<Utc>>,
56}
57
58impl Task {
59 pub fn new() -> Self {
61 Self {
62 id: Uuid::new_v4().to_string(),
63 context_id: None,
64 state: TaskState::Submitted,
65 messages: Vec::new(),
66 artifacts: Vec::new(),
67 metadata: None,
68 created_at: Some(Utc::now()),
69 updated_at: Some(Utc::now()),
70 }
71 }
72
73 pub fn with_context(context_id: impl Into<String>) -> Self {
75 Self {
76 context_id: Some(context_id.into()),
77 ..Self::new()
78 }
79 }
80
81 pub fn is_terminal(&self) -> bool {
83 matches!(
84 self.state,
85 TaskState::Completed | TaskState::Failed | TaskState::Canceled | TaskState::Rejected
86 )
87 }
88
89 pub fn is_interrupted(&self) -> bool {
91 matches!(
92 self.state,
93 TaskState::InputRequired | TaskState::AuthRequired
94 )
95 }
96
97 pub fn transition(&mut self, new_state: TaskState) -> Result<(), InvalidTransition> {
99 if self.is_terminal() {
100 return Err(InvalidTransition {
101 from: self.state.clone(),
102 to: new_state,
103 });
104 }
105 self.state = new_state;
106 self.updated_at = Some(Utc::now());
107 Ok(())
108 }
109
110 pub fn add_message(&mut self, message: Message) {
112 self.messages.push(message);
113 self.updated_at = Some(Utc::now());
114 }
115
116 pub fn add_artifact(&mut self, artifact: Artifact) {
118 self.artifacts.push(artifact);
119 self.updated_at = Some(Utc::now());
120 }
121}
122
123impl Default for Task {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
131#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
132pub enum TaskState {
133 Submitted,
135
136 Working,
138
139 Completed,
141
142 Failed,
144
145 Canceled,
147
148 Rejected,
150
151 InputRequired,
153
154 AuthRequired,
156}
157
158impl std::fmt::Display for TaskState {
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 match self {
161 TaskState::Submitted => write!(f, "SUBMITTED"),
162 TaskState::Working => write!(f, "WORKING"),
163 TaskState::Completed => write!(f, "COMPLETED"),
164 TaskState::Failed => write!(f, "FAILED"),
165 TaskState::Canceled => write!(f, "CANCELED"),
166 TaskState::Rejected => write!(f, "REJECTED"),
167 TaskState::InputRequired => write!(f, "INPUT_REQUIRED"),
168 TaskState::AuthRequired => write!(f, "AUTH_REQUIRED"),
169 }
170 }
171}
172
173#[derive(Debug)]
175pub struct InvalidTransition {
176 pub from: TaskState,
177 pub to: TaskState,
178}
179
180impl std::fmt::Display for InvalidTransition {
181 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182 write!(
183 f,
184 "invalid task transition from {} to {}",
185 self.from, self.to
186 )
187 }
188}
189
190impl std::error::Error for InvalidTransition {}
191
192#[derive(Debug, Default, Clone, Serialize, Deserialize)]
194#[serde(rename_all = "camelCase")]
195pub struct TaskQueryParams {
196 #[serde(skip_serializing_if = "Option::is_none")]
198 pub context_id: Option<String>,
199
200 #[serde(skip_serializing_if = "Option::is_none")]
202 pub state: Option<TaskState>,
203
204 #[serde(skip_serializing_if = "Option::is_none")]
206 pub limit: Option<u32>,
207
208 #[serde(skip_serializing_if = "Option::is_none")]
210 pub cursor: Option<String>,
211}
212
213#[derive(Debug, Clone, Serialize, Deserialize)]
215#[serde(rename_all = "camelCase", tag = "type")]
216pub enum TaskEvent {
217 StateChanged { task_id: String, state: TaskState },
219
220 MessageAdded { task_id: String, message: Message },
222
223 ArtifactAdded { task_id: String, artifact: Artifact },
225
226 ArtifactChunk {
228 task_id: String,
229 artifact_id: String,
230 chunk: String,
231 },
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn test_task_lifecycle() {
240 let mut task = Task::new();
241 assert_eq!(task.state, TaskState::Submitted);
242 assert!(!task.is_terminal());
243
244 task.transition(TaskState::Working).unwrap();
245 assert_eq!(task.state, TaskState::Working);
246
247 task.transition(TaskState::InputRequired).unwrap();
248 assert!(task.is_interrupted());
249
250 task.transition(TaskState::Working).unwrap();
251 task.transition(TaskState::Completed).unwrap();
252 assert!(task.is_terminal());
253
254 assert!(task.transition(TaskState::Working).is_err());
256 }
257
258 #[test]
259 fn test_task_serialization() {
260 let task = Task::new();
261 let json = serde_json::to_string(&task).unwrap();
262 assert!(json.contains("SUBMITTED"));
263
264 let parsed: Task = serde_json::from_str(&json).unwrap();
265 assert_eq!(parsed.state, TaskState::Submitted);
266 }
267}