1use super::{
4 cursor::{PaginatedResult, PaginationCursor},
5 StorageError, TaskFilter, TaskStorage, TaskUpdate,
6};
7use crate::task::{A2ATask, TaskStatus, TaskType};
8use async_trait::async_trait;
9use octocrab::{models::issues::Issue, Octocrab};
10use tracing::{debug, info};
11
12pub struct GitHubTaskStorage {
14 client: Octocrab,
15 repo_owner: String,
16 repo_name: String,
17}
18
19impl GitHubTaskStorage {
20 pub fn new(token: String, repo_owner: String, repo_name: String) -> Result<Self, StorageError> {
27 let client = Octocrab::builder()
28 .personal_token(token)
29 .build()
30 .map_err(|e| StorageError::Auth(e.to_string()))?;
31
32 Ok(Self {
33 client,
34 repo_owner,
35 repo_name,
36 })
37 }
38
39 fn issue_to_task(&self, issue: Issue) -> Result<A2ATask, StorageError> {
53 let status = issue
55 .labels
56 .iter()
57 .find_map(|label| {
58 let name = label.name.as_str();
59 name.strip_prefix("a2a:").and_then(|status| {
60 match status {
61 "submitted" => Some(TaskStatus::Submitted),
62 "working" => Some(TaskStatus::Working),
63 "completed" => Some(TaskStatus::Completed),
64 "failed" => Some(TaskStatus::Failed),
65 "cancelled" => Some(TaskStatus::Cancelled),
66 "pending" => Some(TaskStatus::Submitted),
68 "in-progress" => Some(TaskStatus::Working),
69 _ => None,
70 }
71 })
72 })
73 .unwrap_or(TaskStatus::Submitted);
74
75 let task_type = issue
77 .labels
78 .iter()
79 .find_map(|label| {
80 let name = label.name.as_str();
81 name.strip_prefix("a2a:").and_then(|task_type| match task_type {
82 "codegen" => Some(TaskType::CodeGeneration),
83 "review" => Some(TaskType::CodeReview),
84 "testing" => Some(TaskType::Testing),
85 "deployment" => Some(TaskType::Deployment),
86 "documentation" => Some(TaskType::Documentation),
87 "analysis" => Some(TaskType::Analysis),
88 _ => None,
89 })
90 })
91 .unwrap_or(TaskType::Analysis);
92
93 let agent = issue.assignee.as_ref().map(|a| a.login.clone());
95
96 let retry_count = issue
98 .labels
99 .iter()
100 .find_map(|label| {
101 label
102 .name
103 .strip_prefix("retry:")
104 .and_then(|count_str| count_str.parse::<u32>().ok())
105 })
106 .unwrap_or(0);
107
108 Ok(A2ATask {
109 id: issue.number,
110 title: issue.title,
111 description: issue.body.unwrap_or_default(),
112 status,
113 task_type,
114 agent,
115 context_id: None, priority: 3, retry_count,
118 created_at: issue.created_at,
119 updated_at: issue.updated_at,
120 issue_url: issue.html_url.to_string(),
121 })
122 }
123}
124
125#[async_trait]
126impl TaskStorage for GitHubTaskStorage {
127 async fn save_task(&self, task: A2ATask) -> Result<u64, StorageError> {
140 info!("Creating GitHub Issue for task: {}", task.title);
141
142 let issue = self
143 .client
144 .issues(&self.repo_owner, &self.repo_name)
145 .create(&task.title)
146 .body(&task.description)
147 .labels(vec![task.status.to_label(), task.task_type.to_label()])
148 .send()
149 .await?;
150
151 debug!("Created GitHub Issue #{}: {}", issue.number, issue.html_url);
152
153 Ok(issue.number)
154 }
155
156 async fn get_task(&self, id: u64) -> Result<Option<A2ATask>, StorageError> {
166 debug!("Fetching task #{} from GitHub", id);
167
168 match self.client.issues(&self.repo_owner, &self.repo_name).get(id).await {
169 Ok(issue) => {
170 let task = self.issue_to_task(issue)?;
171 Ok(Some(task))
172 },
173 Err(octocrab::Error::GitHub { source, .. }) if source.message.contains("Not Found") => {
174 debug!("Task #{} not found", id);
175 Ok(None)
176 },
177 Err(e) => Err(StorageError::from(e)),
178 }
179 }
180
181 async fn list_tasks(&self, filter: TaskFilter) -> Result<Vec<A2ATask>, StorageError> {
196 debug!("Listing tasks with filter: {:?}", filter);
197
198 let issues = self.client.issues(&self.repo_owner, &self.repo_name);
200 let per_page = filter.limit.unwrap_or(30) as u8;
201
202 let page = if let Some(status) = filter.status {
204 let label = status.to_label();
205 let labels = vec![label.clone()];
206 debug!("Applying API-level status filter: {}", label);
207
208 issues.list().labels(&labels).per_page(per_page).send().await?
209 } else {
210 issues.list().per_page(per_page).send().await?
211 };
212
213 let all_tasks: Result<Vec<A2ATask>, StorageError> =
215 page.items.into_iter().map(|issue| self.issue_to_task(issue)).collect();
216
217 let mut tasks = all_tasks?;
218
219 if let Some(ref context_id) = filter.context_id {
223 tasks.retain(|t| t.context_id.as_ref() == Some(context_id));
224 }
225
226 if let Some(ref agent) = filter.agent {
227 tasks.retain(|t| t.agent.as_ref() == Some(agent));
228 }
229
230 if let Some(after) = filter.last_updated_after {
231 tasks.retain(|t| t.updated_at > after);
232 }
233
234 Ok(tasks)
235 }
236
237 async fn list_tasks_paginated(
274 &self,
275 filter: TaskFilter,
276 ) -> Result<PaginatedResult<A2ATask>, StorageError> {
277 debug!("Listing tasks with pagination: {:?}", filter);
278
279 let limit = filter.limit.unwrap_or(50).min(100);
281
282 let cursor = filter
284 .cursor
285 .as_ref()
286 .map(|c| PaginationCursor::decode(c))
287 .transpose()
288 .map_err(|e| StorageError::Other(format!("Invalid cursor: {}", e)))?;
289
290 let issues = self.client.issues(&self.repo_owner, &self.repo_name);
292
293 let fetch_count = (limit + 1) as u8;
295
296 let page = if let Some(status) = filter.status {
298 let label = status.to_label();
299 let labels = vec![label.clone()];
300 debug!("Applying API-level status filter: {}", label);
301
302 issues.list().labels(&labels).per_page(fetch_count).send().await?
303 } else {
304 issues.list().per_page(fetch_count).send().await?
305 };
306
307 let all_tasks: Result<Vec<A2ATask>, StorageError> =
309 page.items.into_iter().map(|issue| self.issue_to_task(issue)).collect();
310
311 let mut tasks = all_tasks?;
312
313 if let Some(ref c) = cursor {
315 match c.direction {
316 super::cursor::Direction::Forward => {
317 tasks.retain(|t| {
319 t.updated_at > c.last_updated
320 || (t.updated_at == c.last_updated && t.id > c.last_id)
321 });
322 },
323 super::cursor::Direction::Backward => {
324 tasks.retain(|t| {
326 t.updated_at < c.last_updated
327 || (t.updated_at == c.last_updated && t.id < c.last_id)
328 });
329 tasks.reverse();
331 },
332 }
333 }
334
335 if let Some(ref context_id) = filter.context_id {
337 tasks.retain(|t| t.context_id.as_ref() == Some(context_id));
338 }
339
340 if let Some(ref agent) = filter.agent {
341 tasks.retain(|t| t.agent.as_ref() == Some(agent));
342 }
343
344 if let Some(after) = filter.last_updated_after {
345 tasks.retain(|t| t.updated_at > after);
346 }
347
348 tasks.sort_by(|a, b| b.updated_at.cmp(&a.updated_at).then_with(|| b.id.cmp(&a.id)));
350
351 let has_more = tasks.len() > limit;
353 if has_more {
354 tasks.truncate(limit);
355 }
356
357 let next_cursor =
359 if has_more && !tasks.is_empty() {
360 let last = tasks.last().unwrap();
361 let cursor = PaginationCursor::forward(last.id, last.updated_at);
362 Some(cursor.encode().map_err(|e| {
363 StorageError::Other(format!("Failed to encode next cursor: {}", e))
364 })?)
365 } else {
366 None
367 };
368
369 let previous_cursor = if !tasks.is_empty() && cursor.is_some() {
370 let first = tasks.first().unwrap();
371 let cursor = PaginationCursor::backward(first.id, first.updated_at);
372 Some(cursor.encode().map_err(|e| {
373 StorageError::Other(format!("Failed to encode previous cursor: {}", e))
374 })?)
375 } else {
376 None
377 };
378
379 Ok(PaginatedResult::new(tasks, next_cursor, previous_cursor, has_more))
380 }
381
382 async fn update_task(&self, id: u64, update: TaskUpdate) -> Result<(), StorageError> {
401 info!("Updating task #{}", id);
402
403 if let Some(description) = update.description {
405 self.client
406 .issues(&self.repo_owner, &self.repo_name)
407 .update(id)
408 .body(&description)
409 .send()
410 .await?;
411
412 debug!("Updated task #{} description", id);
413 }
414
415 if update.status.is_some() || update.retry_count.is_some() {
417 let issue = self.client.issues(&self.repo_owner, &self.repo_name).get(id).await?;
419
420 let mut new_labels: Vec<String> = issue
422 .labels
423 .iter()
424 .filter(|label| {
425 !label.name.starts_with("a2a:pending")
426 && !label.name.starts_with("a2a:in-progress")
427 && !label.name.starts_with("a2a:completed")
428 && !label.name.starts_with("a2a:failed")
429 && !label.name.starts_with("a2a:blocked")
430 && !label.name.starts_with("retry:")
431 })
432 .map(|label| label.name.clone())
433 .collect();
434
435 if let Some(new_status) = update.status {
437 new_labels.push(new_status.to_label());
438 debug!("Updated task #{} status to {:?}", id, new_status);
439 }
440
441 if let Some(retry_count) = update.retry_count {
443 new_labels.push(format!("retry:{}", retry_count));
444 debug!("Updated task #{} retry_count to {}", id, retry_count);
445 }
446
447 self.client
449 .issues(&self.repo_owner, &self.repo_name)
450 .update(id)
451 .labels(&new_labels)
452 .send()
453 .await?;
454 }
455
456 Ok(())
457 }
458
459 async fn delete_task(&self, id: u64) -> Result<(), StorageError> {
471 info!("Closing task #{} (GitHub Issues don't support deletion)", id);
472
473 use octocrab::models::IssueState;
475
476 self.client
477 .issues(&self.repo_owner, &self.repo_name)
478 .update(id)
479 .state(IssueState::Closed)
480 .send()
481 .await?;
482
483 debug!("Closed task #{}", id);
484
485 Ok(())
486 }
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492 use crate::storage::cursor::{Direction, PaginationCursor};
493 use chrono::Utc;
494
495 #[tokio::test]
496 async fn test_github_storage_construction() {
497 let result = GitHubTaskStorage::new(
500 "fake_token".to_string(),
501 "owner".to_string(),
502 "repo".to_string(),
503 );
504 assert!(result.is_ok());
505 }
506
507 #[test]
508 fn test_cursor_encode_decode_roundtrip() {
509 let timestamp = Utc::now();
510 let cursor = PaginationCursor::forward(123, timestamp);
511
512 let encoded = cursor.encode().unwrap();
513 let decoded = PaginationCursor::decode(&encoded).unwrap();
514
515 assert_eq!(cursor.last_id, decoded.last_id);
516 assert_eq!(cursor.last_updated, decoded.last_updated);
517 assert_eq!(cursor.direction, decoded.direction);
518 }
519
520 #[test]
521 fn test_forward_cursor_creation() {
522 let timestamp = Utc::now();
523 let cursor = PaginationCursor::forward(456, timestamp);
524
525 assert_eq!(cursor.last_id, 456);
526 assert_eq!(cursor.last_updated, timestamp);
527 assert_eq!(cursor.direction, Direction::Forward);
528 }
529
530 #[test]
531 fn test_backward_cursor_creation() {
532 let timestamp = Utc::now();
533 let cursor = PaginationCursor::backward(789, timestamp);
534
535 assert_eq!(cursor.last_id, 789);
536 assert_eq!(cursor.last_updated, timestamp);
537 assert_eq!(cursor.direction, Direction::Backward);
538 }
539
540 #[test]
541 fn test_paginated_result_structure() {
542 let result = PaginatedResult::new(
543 vec![1, 2, 3],
544 Some("next_cursor".to_string()),
545 Some("prev_cursor".to_string()),
546 true,
547 );
548
549 assert_eq!(result.items.len(), 3);
550 assert!(result.next_cursor.is_some());
551 assert!(result.previous_cursor.is_some());
552 assert!(result.has_more);
553 }
554
555 #[test]
556 fn test_paginated_result_last_page() {
557 let result: PaginatedResult<i32> = PaginatedResult::new(vec![1, 2], None, None, false);
558
559 assert_eq!(result.items.len(), 2);
560 assert!(result.next_cursor.is_none());
561 assert!(result.previous_cursor.is_none());
562 assert!(!result.has_more);
563 }
564}