1use crate::schema::{Artifact, Message};
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use thiserror::Error;
9
10#[derive(Debug, Error)]
12pub enum TaskError {
13 #[error("Invalid state transition from {from} to {to}")]
15 InvalidStateTransition { from: TaskStatus, to: TaskStatus },
16}
17
18pub type TaskId = String;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct Task {
27 pub id: TaskId,
29 pub thread_id: String,
31 pub status: TaskStatus,
33 pub message: Message,
35 #[serde(default)]
37 pub messages: Vec<Message>,
38 #[serde(default)]
40 pub artifacts: Vec<Artifact>,
41 pub created_at: DateTime<Utc>,
43 pub updated_at: DateTime<Utc>,
45 #[serde(skip_serializing_if = "Option::is_none")]
47 pub error: Option<String>,
48 #[serde(skip_serializing_if = "Option::is_none")]
50 pub metadata: Option<serde_json::Value>,
51}
52
53impl Task {
54 pub fn new(thread_id: impl Into<String>, message: Message) -> Self {
58 let now = Utc::now();
59 Self {
60 id: uuid::Uuid::new_v4().to_string(),
61 thread_id: thread_id.into(),
62 status: TaskStatus::Pending,
63 message: message.clone(),
64 messages: vec![message], artifacts: Vec::new(),
66 created_at: now,
67 updated_at: now,
68 error: None,
69 metadata: None,
70 }
71 }
72
73 pub fn with_id(id: impl Into<String>, thread_id: impl Into<String>, message: Message) -> Self {
77 let now = Utc::now();
78 Self {
79 id: id.into(),
80 thread_id: thread_id.into(),
81 status: TaskStatus::Pending,
82 message: message.clone(),
83 messages: vec![message], artifacts: Vec::new(),
85 created_at: now,
86 updated_at: now,
87 error: None,
88 metadata: None,
89 }
90 }
91
92 pub fn is_pending(&self) -> bool {
94 self.status == TaskStatus::Pending
95 }
96
97 pub fn is_running(&self) -> bool {
99 self.status == TaskStatus::Running
100 }
101
102 pub fn is_completed(&self) -> bool {
104 matches!(
105 self.status,
106 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled
107 )
108 }
109
110 pub fn is_success(&self) -> bool {
112 self.status == TaskStatus::Completed
113 }
114
115 pub fn is_failed(&self) -> bool {
117 self.status == TaskStatus::Failed
118 }
119
120 pub fn start(&mut self) -> Result<(), TaskError> {
126 if self.status != TaskStatus::Pending {
127 return Err(TaskError::InvalidStateTransition {
128 from: self.status,
129 to: TaskStatus::Running,
130 });
131 }
132 self.status = TaskStatus::Running;
133 self.updated_at = Utc::now();
134 Ok(())
135 }
136
137 pub fn complete(&mut self) -> Result<(), TaskError> {
143 if self.status != TaskStatus::Running {
144 return Err(TaskError::InvalidStateTransition {
145 from: self.status,
146 to: TaskStatus::Completed,
147 });
148 }
149 self.status = TaskStatus::Completed;
150 self.updated_at = Utc::now();
151 Ok(())
152 }
153
154 pub fn fail(&mut self, error: impl Into<String>) -> Result<(), TaskError> {
160 if self.status != TaskStatus::Running {
161 return Err(TaskError::InvalidStateTransition {
162 from: self.status,
163 to: TaskStatus::Failed,
164 });
165 }
166 self.status = TaskStatus::Failed;
167 self.error = Some(error.into());
168 self.updated_at = Utc::now();
169 Ok(())
170 }
171
172 pub fn cancel(&mut self) -> Result<(), TaskError> {
180 if self.is_completed() {
181 return Err(TaskError::InvalidStateTransition {
182 from: self.status,
183 to: TaskStatus::Cancelled,
184 });
185 }
186 self.status = TaskStatus::Cancelled;
187 self.updated_at = Utc::now();
188 Ok(())
189 }
190
191 pub fn force_status(&mut self, status: TaskStatus) {
195 self.status = status;
196 self.updated_at = Utc::now();
197 }
198
199 pub fn add_message(&mut self, message: Message) {
201 self.messages.push(message);
202 self.updated_at = Utc::now();
203 }
204
205 pub fn add_artifact(&mut self, artifact: Artifact) {
207 self.artifacts.push(artifact);
208 self.updated_at = Utc::now();
209 }
210
211 pub fn set_metadata(&mut self, metadata: serde_json::Value) {
213 self.metadata = Some(metadata);
214 self.updated_at = Utc::now();
215 }
216}
217
218#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
220#[serde(rename_all = "lowercase")]
221pub enum TaskStatus {
222 #[default]
224 Pending,
225 Running,
227 Completed,
229 Failed,
231 Cancelled,
233}
234
235impl std::fmt::Display for TaskStatus {
236 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237 match self {
238 TaskStatus::Pending => write!(f, "pending"),
239 TaskStatus::Running => write!(f, "running"),
240 TaskStatus::Completed => write!(f, "completed"),
241 TaskStatus::Failed => write!(f, "failed"),
242 TaskStatus::Cancelled => write!(f, "cancelled"),
243 }
244 }
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
249pub struct TaskResult {
250 pub task_id: TaskId,
252 pub status: TaskStatus,
254 pub messages: Vec<Message>,
256 pub artifacts: Vec<Artifact>,
258 #[serde(skip_serializing_if = "Option::is_none")]
260 pub error: Option<String>,
261 #[serde(skip_serializing_if = "Option::is_none")]
263 pub duration_ms: Option<u64>,
264}
265
266impl TaskResult {
267 pub fn success(task_id: impl Into<String>, messages: Vec<Message>) -> Self {
269 Self {
270 task_id: task_id.into(),
271 status: TaskStatus::Completed,
272 messages,
273 artifacts: Vec::new(),
274 error: None,
275 duration_ms: None,
276 }
277 }
278
279 pub fn failure(task_id: impl Into<String>, error: impl Into<String>) -> Self {
281 Self {
282 task_id: task_id.into(),
283 status: TaskStatus::Failed,
284 messages: Vec::new(),
285 artifacts: Vec::new(),
286 error: Some(error.into()),
287 duration_ms: None,
288 }
289 }
290
291 pub fn with_artifacts(mut self, artifacts: Vec<Artifact>) -> Self {
293 self.artifacts = artifacts;
294 self
295 }
296
297 pub fn with_duration(mut self, duration_ms: u64) -> Self {
299 self.duration_ms = Some(duration_ms);
300 self
301 }
302
303 pub fn is_success(&self) -> bool {
305 self.status == TaskStatus::Completed
306 }
307
308 pub fn is_failure(&self) -> bool {
310 self.status == TaskStatus::Failed
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317 use crate::schema::Message;
318
319 #[test]
320 fn test_task_creation() {
321 let task = Task::new("thread-1", Message::user("Hello"));
322 assert!(task.is_pending());
323 assert!(!task.is_running());
324 assert!(!task.is_completed());
325 }
326
327 #[test]
328 fn test_task_includes_original_message() {
329 let task = Task::new("thread-1", Message::user("Hello"));
330
331 assert_eq!(task.messages.len(), 1);
333 assert_eq!(task.messages[0].text_content(), "Hello");
334 assert_eq!(task.message.text_content(), "Hello");
335 }
336
337 #[test]
338 fn test_task_with_id_includes_original_message() {
339 let task = Task::with_id("my-id", "thread-1", Message::user("Test"));
340
341 assert_eq!(task.id, "my-id");
342 assert_eq!(task.messages.len(), 1);
343 assert_eq!(task.messages[0].text_content(), "Test");
344 }
345
346 #[test]
347 fn test_task_lifecycle() {
348 let mut task = Task::new("thread-1", Message::user("Hello"));
349
350 task.start().unwrap();
351 assert!(task.is_running());
352
353 task.complete().unwrap();
354 assert!(task.is_completed());
355 assert!(task.is_success());
356 }
357
358 #[test]
359 fn test_task_failure() {
360 let mut task = Task::new("thread-1", Message::user("Hello"));
361
362 task.start().unwrap();
363 task.fail("Something went wrong").unwrap();
364
365 assert!(task.is_failed());
366 assert_eq!(task.error, Some("Something went wrong".to_string()));
367 }
368
369 #[test]
370 fn test_invalid_state_transition_start_from_running() {
371 let mut task = Task::new("thread-1", Message::user("Hello"));
372 task.start().unwrap();
373
374 let result = task.start();
375 assert!(result.is_err());
376 match result {
377 Err(TaskError::InvalidStateTransition { from, to }) => {
378 assert_eq!(from, TaskStatus::Running);
379 assert_eq!(to, TaskStatus::Running);
380 }
381 _ => panic!("Expected InvalidStateTransition error"),
382 }
383 }
384
385 #[test]
386 fn test_invalid_state_transition_complete_from_pending() {
387 let mut task = Task::new("thread-1", Message::user("Hello"));
388
389 let result = task.complete();
390 assert!(result.is_err());
391 }
392
393 #[test]
394 fn test_invalid_state_transition_complete_after_cancel() {
395 let mut task = Task::new("thread-1", Message::user("Hello"));
396 task.cancel().unwrap();
397
398 let result = task.complete();
400 assert!(result.is_err());
401 }
402
403 #[test]
404 fn test_cancel_from_pending() {
405 let mut task = Task::new("thread-1", Message::user("Hello"));
406 assert!(task.cancel().is_ok());
407 assert_eq!(task.status, TaskStatus::Cancelled);
408 }
409
410 #[test]
411 fn test_cancel_from_running() {
412 let mut task = Task::new("thread-1", Message::user("Hello"));
413 task.start().unwrap();
414 assert!(task.cancel().is_ok());
415 assert_eq!(task.status, TaskStatus::Cancelled);
416 }
417
418 #[test]
419 fn test_cannot_cancel_completed_task() {
420 let mut task = Task::new("thread-1", Message::user("Hello"));
421 task.start().unwrap();
422 task.complete().unwrap();
423
424 let result = task.cancel();
425 assert!(result.is_err());
426 }
427
428 #[test]
429 fn test_force_status() {
430 let mut task = Task::new("thread-1", Message::user("Hello"));
431 task.force_status(TaskStatus::Completed);
432 assert_eq!(task.status, TaskStatus::Completed);
433 }
434
435 #[test]
436 fn test_task_result_success() {
437 let result = TaskResult::success("task-1", vec![Message::agent("Done!")]);
438 assert!(result.is_success());
439 assert!(!result.is_failure());
440 }
441
442 #[test]
443 fn test_task_result_failure() {
444 let result = TaskResult::failure("task-1", "Error occurred");
445 assert!(result.is_failure());
446 assert_eq!(result.error, Some("Error occurred".to_string()));
447 }
448
449 #[test]
450 fn test_status_display() {
451 assert_eq!(TaskStatus::Pending.to_string(), "pending");
452 assert_eq!(TaskStatus::Running.to_string(), "running");
453 assert_eq!(TaskStatus::Completed.to_string(), "completed");
454 assert_eq!(TaskStatus::Failed.to_string(), "failed");
455 assert_eq!(TaskStatus::Cancelled.to_string(), "cancelled");
456 }
457}