a2a_rs_server/
task_store.rs1use a2a_rs_core::{ListTasksRequest, Task, TaskListResponse};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10#[derive(Clone)]
12pub struct TaskStore {
13 tasks: Arc<RwLock<HashMap<String, Task>>>,
14}
15
16impl Default for TaskStore {
17 fn default() -> Self {
18 Self::new()
19 }
20}
21
22impl TaskStore {
23 pub fn new() -> Self {
25 Self {
26 tasks: Arc::new(RwLock::new(HashMap::new())),
27 }
28 }
29
30 pub async fn insert(&self, task: Task) {
32 let id = task.id.clone();
33 self.tasks.write().await.insert(id, task);
34 }
35
36 pub async fn get(&self, id: &str) -> Option<Task> {
38 self.tasks.read().await.get(id).cloned()
39 }
40
41 pub async fn get_flexible(&self, id: &str) -> Option<Task> {
45 let guard = self.tasks.read().await;
46
47 if let Some(task) = guard.get(id) {
49 return Some(task.clone());
50 }
51
52 let prefixed = format!("tasks/{}", id);
54 if let Some(task) = guard.get(&prefixed) {
55 return Some(task.clone());
56 }
57
58 if let Some(stripped) = id.strip_prefix("tasks/") {
60 if let Some(task) = guard.get(stripped) {
61 return Some(task.clone());
62 }
63 }
64
65 None
66 }
67
68 pub async fn update<F>(&self, id: &str, f: F) -> Option<Task>
70 where
71 F: FnOnce(&mut Task),
72 {
73 let mut guard = self.tasks.write().await;
74 if let Some(task) = guard.get_mut(id) {
75 f(task);
76 Some(task.clone())
77 } else {
78 None
79 }
80 }
81
82 pub async fn update_flexible<F, E>(&self, id: &str, f: F) -> Option<Result<Task, E>>
89 where
90 F: FnOnce(&mut Task) -> Result<(), E>,
91 {
92 let mut guard = self.tasks.write().await;
93
94 let key = if guard.contains_key(id) {
96 Some(id.to_string())
97 } else {
98 let prefixed = format!("tasks/{}", id);
99 if guard.contains_key(&prefixed) {
100 Some(prefixed)
101 } else if let Some(stripped) = id.strip_prefix("tasks/") {
102 if guard.contains_key(stripped) {
103 Some(stripped.to_string())
104 } else {
105 None
106 }
107 } else {
108 None
109 }
110 };
111
112 let key = key?;
113 let task = guard.get_mut(&key)?;
114
115 match f(task) {
116 Ok(()) => Some(Ok(task.clone())),
117 Err(e) => Some(Err(e)),
118 }
119 }
120
121 pub async fn remove(&self, id: &str) -> Option<Task> {
123 self.tasks.write().await.remove(id)
124 }
125
126 pub async fn list(&self) -> Vec<Task> {
128 self.tasks.read().await.values().cloned().collect()
129 }
130
131 pub async fn list_filtered(&self, params: &ListTasksRequest) -> TaskListResponse {
135 let guard = self.tasks.read().await;
136
137 let mut filtered: Vec<_> = guard
139 .values()
140 .filter(|task| {
141 if let Some(ref ctx) = params.context_id {
143 if task.context_id != *ctx {
144 return false;
145 }
146 }
147 if let Some(status) = params.status {
149 if task.status.state != status {
150 return false;
151 }
152 }
153 if let Some(after_ms) = params.status_timestamp_after {
155 if let Some(ref ts) = task.status.timestamp {
156 if let Ok(dt) = chrono::DateTime::parse_from_rfc3339(ts) {
158 if dt.timestamp_millis() <= after_ms {
159 return false;
160 }
161 }
162 } else {
163 return false;
165 }
166 }
167 true
168 })
169 .cloned()
170 .collect();
171
172 let total_size = filtered.len() as u32;
173 let page_size = params.page_size.unwrap_or(50).min(100);
174
175 filtered.sort_by(|a, b| a.id.cmp(&b.id));
177
178 let offset: usize = params
180 .page_token
181 .as_ref()
182 .and_then(|t| t.parse().ok())
183 .unwrap_or(0);
184
185 let paginated: Vec<_> = filtered
186 .into_iter()
187 .skip(offset)
188 .take(page_size as usize)
189 .map(|mut task| {
190 if let Some(len) = params.history_length {
192 if let Some(ref mut history) = task.history {
193 let keep = len as usize;
194 if history.len() > keep {
195 *history = history.iter().rev().take(keep).cloned().collect();
196 history.reverse();
197 }
198 }
199 }
200 if params.include_artifacts == Some(false) {
202 task.artifacts = None;
203 }
204 task
205 })
206 .collect();
207
208 let next_offset = offset + paginated.len();
209 let next_page_token = if next_offset < total_size as usize {
210 next_offset.to_string()
211 } else {
212 String::new()
213 };
214
215 TaskListResponse {
216 tasks: paginated,
217 next_page_token,
218 page_size,
219 total_size,
220 }
221 }
222
223 pub async fn len(&self) -> usize {
225 self.tasks.read().await.len()
226 }
227
228 pub async fn is_empty(&self) -> bool {
230 self.tasks.read().await.is_empty()
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use a2a_rs_core::{TaskState, TaskStatus};
238
239 fn make_task(id: &str) -> Task {
240 Task {
241 kind: "task".to_string(),
242 id: id.to_string(),
243 context_id: "ctx".to_string(),
244 status: TaskStatus {
245 state: TaskState::Working,
246 message: None,
247 timestamp: None,
248 },
249 history: None,
250 artifacts: None,
251 metadata: None,
252 }
253 }
254
255 #[tokio::test]
256 async fn test_insert_and_get() {
257 let store = TaskStore::new();
258 let task = make_task("task-1");
259
260 store.insert(task.clone()).await;
261
262 let retrieved = store.get("task-1").await;
263 assert!(retrieved.is_some());
264 assert_eq!(retrieved.unwrap().id, "task-1");
265 }
266
267 #[tokio::test]
268 async fn test_get_flexible() {
269 let store = TaskStore::new();
270 let task = make_task("tasks/abc-123");
271
272 store.insert(task).await;
273
274 assert!(store.get_flexible("tasks/abc-123").await.is_some());
276
277 assert!(store.get_flexible("abc-123").await.is_some());
279 }
280
281 #[tokio::test]
282 async fn test_update() {
283 let store = TaskStore::new();
284 let task = make_task("task-1");
285 store.insert(task).await;
286
287 let updated = store
288 .update("task-1", |t| {
289 t.status.state = TaskState::Completed;
290 })
291 .await;
292
293 assert!(updated.is_some());
294 assert_eq!(updated.unwrap().status.state, TaskState::Completed);
295 }
296
297 #[tokio::test]
298 async fn test_concurrent_inserts() {
299 let store = Arc::new(TaskStore::new());
300
301 let handles: Vec<_> = (0..100)
303 .map(|i| {
304 let store = store.clone();
305 tokio::spawn(async move {
306 store.insert(make_task(&format!("task-{}", i))).await;
307 })
308 })
309 .collect();
310
311 for h in handles {
312 h.await.unwrap();
313 }
314
315 assert_eq!(store.len().await, 100);
316 }
317
318 #[tokio::test]
319 async fn test_concurrent_reads_and_writes() {
320 let store = Arc::new(TaskStore::new());
321
322 for i in 0..10 {
324 store.insert(make_task(&format!("task-{}", i))).await;
325 }
326
327 let mut handles = Vec::new();
329
330 for i in 10..60 {
332 let store = store.clone();
333 handles.push(tokio::spawn(async move {
334 store.insert(make_task(&format!("task-{}", i))).await;
335 }));
336 }
337
338 for i in 0..10 {
340 let store = store.clone();
341 handles.push(tokio::spawn(async move {
342 let _ = store.get(&format!("task-{}", i)).await;
343 }));
344 }
345
346 for i in 0..10 {
348 let store = store.clone();
349 handles.push(tokio::spawn(async move {
350 store
351 .update(&format!("task-{}", i), |t| {
352 t.status.state = TaskState::Completed;
353 })
354 .await;
355 }));
356 }
357
358 for h in handles {
359 h.await.unwrap();
360 }
361
362 assert_eq!(store.len().await, 60);
364
365 for i in 0..10 {
367 let task = store.get(&format!("task-{}", i)).await.unwrap();
368 assert_eq!(task.status.state, TaskState::Completed);
369 }
370 }
371
372 #[tokio::test]
373 async fn test_concurrent_update_flexible() {
374 let store = Arc::new(TaskStore::new());
375 store.insert(make_task("tasks/shared-task")).await;
376
377 let handles: Vec<_> = (0..50)
379 .map(|_| {
380 let store = store.clone();
381 tokio::spawn(async move {
382 store
383 .update_flexible("shared-task", |t| -> Result<(), ()> {
384 t.context_id = "updated".to_string();
385 Ok(())
386 })
387 .await
388 })
389 })
390 .collect();
391
392 for h in handles {
393 let result = h.await.unwrap();
394 assert!(result.is_some());
395 assert!(result.unwrap().is_ok());
396 }
397
398 let task = store.get("tasks/shared-task").await.unwrap();
400 assert_eq!(task.context_id, "updated");
401 }
402}