a2a_rs_server/
task_store.rs1use a2a_rs_core::{ListTasksRequest, Task, TaskListResponse};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10#[derive(Clone)]
12pub struct TaskStore {
13 tasks: Arc<RwLock<HashMap<String, Task>>>,
14}
15
16impl Default for TaskStore {
17 fn default() -> Self {
18 Self::new()
19 }
20}
21
22impl TaskStore {
23 pub fn new() -> Self {
25 Self {
26 tasks: Arc::new(RwLock::new(HashMap::new())),
27 }
28 }
29
30 pub async fn insert(&self, task: Task) {
32 let id = task.id.clone();
33 self.tasks.write().await.insert(id, task);
34 }
35
36 pub async fn get(&self, id: &str) -> Option<Task> {
38 self.tasks.read().await.get(id).cloned()
39 }
40
41 pub async fn get_flexible(&self, id: &str) -> Option<Task> {
45 let guard = self.tasks.read().await;
46
47 if let Some(task) = guard.get(id) {
49 return Some(task.clone());
50 }
51
52 let prefixed = format!("tasks/{}", id);
54 if let Some(task) = guard.get(&prefixed) {
55 return Some(task.clone());
56 }
57
58 if let Some(stripped) = id.strip_prefix("tasks/") {
60 if let Some(task) = guard.get(stripped) {
61 return Some(task.clone());
62 }
63 }
64
65 None
66 }
67
68 pub async fn update<F>(&self, id: &str, f: F) -> Option<Task>
70 where
71 F: FnOnce(&mut Task),
72 {
73 let mut guard = self.tasks.write().await;
74 if let Some(task) = guard.get_mut(id) {
75 f(task);
76 Some(task.clone())
77 } else {
78 None
79 }
80 }
81
82 pub async fn update_flexible<F, E>(&self, id: &str, f: F) -> Option<Result<Task, E>>
89 where
90 F: FnOnce(&mut Task) -> Result<(), E>,
91 {
92 let mut guard = self.tasks.write().await;
93
94 let key = if guard.contains_key(id) {
96 Some(id.to_string())
97 } else {
98 let prefixed = format!("tasks/{}", id);
99 if guard.contains_key(&prefixed) {
100 Some(prefixed)
101 } else if let Some(stripped) = id.strip_prefix("tasks/") {
102 if guard.contains_key(stripped) {
103 Some(stripped.to_string())
104 } else {
105 None
106 }
107 } else {
108 None
109 }
110 };
111
112 let key = key?;
113 let task = guard.get_mut(&key)?;
114
115 match f(task) {
116 Ok(()) => Some(Ok(task.clone())),
117 Err(e) => Some(Err(e)),
118 }
119 }
120
121 pub async fn remove(&self, id: &str) -> Option<Task> {
123 self.tasks.write().await.remove(id)
124 }
125
126 pub async fn list(&self) -> Vec<Task> {
128 self.tasks.read().await.values().cloned().collect()
129 }
130
131 pub async fn list_filtered(&self, params: &ListTasksRequest) -> TaskListResponse {
135 let guard = self.tasks.read().await;
136
137 let mut filtered: Vec<_> = guard
139 .values()
140 .filter(|task| {
141 if let Some(ref ctx) = params.context_id {
143 if task.context_id != *ctx {
144 return false;
145 }
146 }
147 if let Some(status) = params.status {
149 if task.status.state != status {
150 return false;
151 }
152 }
153 if let Some(after_ms) = params.status_timestamp_after {
155 if let Some(ref ts) = task.status.timestamp {
156 if let Ok(dt) = chrono::DateTime::parse_from_rfc3339(ts) {
158 if dt.timestamp_millis() <= after_ms {
159 return false;
160 }
161 }
162 } else {
163 return false;
165 }
166 }
167 true
168 })
169 .cloned()
170 .collect();
171
172 let total_size = filtered.len() as u32;
173 let page_size = params.page_size.unwrap_or(50).min(100);
174
175 filtered.sort_by(|a, b| a.id.cmp(&b.id));
177
178 let offset: usize = params
180 .page_token
181 .as_ref()
182 .and_then(|t| t.parse().ok())
183 .unwrap_or(0);
184
185 let paginated: Vec<_> = filtered
186 .into_iter()
187 .skip(offset)
188 .take(page_size as usize)
189 .map(|mut task| {
190 match params.history_length {
192 Some(0) => {
193 task.history = None;
194 }
195 Some(n) => {
196 if let Some(ref mut history) = task.history {
197 let keep = n as usize;
198 if history.len() > keep {
199 *history = history.split_off(history.len() - keep);
200 }
201 }
202 }
203 None => {}
204 }
205 if params.include_artifacts == Some(false) {
207 task.artifacts = None;
208 }
209 task
210 })
211 .collect();
212
213 let next_offset = offset + paginated.len();
214 let next_page_token = if next_offset < total_size as usize {
215 next_offset.to_string()
216 } else {
217 String::new()
218 };
219
220 TaskListResponse {
221 tasks: paginated,
222 next_page_token,
223 page_size,
224 total_size,
225 }
226 }
227
228 pub async fn len(&self) -> usize {
230 self.tasks.read().await.len()
231 }
232
233 pub async fn is_empty(&self) -> bool {
235 self.tasks.read().await.is_empty()
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use a2a_rs_core::{TaskState, TaskStatus};
243
244 fn make_task(id: &str) -> Task {
245 Task {
246 kind: "task".to_string(),
247 id: id.to_string(),
248 context_id: "ctx".to_string(),
249 status: TaskStatus {
250 state: TaskState::Working,
251 message: None,
252 timestamp: None,
253 },
254 history: None,
255 artifacts: None,
256 metadata: None,
257 }
258 }
259
260 #[tokio::test]
261 async fn test_insert_and_get() {
262 let store = TaskStore::new();
263 let task = make_task("task-1");
264
265 store.insert(task.clone()).await;
266
267 let retrieved = store.get("task-1").await;
268 assert!(retrieved.is_some());
269 assert_eq!(retrieved.unwrap().id, "task-1");
270 }
271
272 #[tokio::test]
273 async fn test_get_flexible() {
274 let store = TaskStore::new();
275 let task = make_task("tasks/abc-123");
276
277 store.insert(task).await;
278
279 assert!(store.get_flexible("tasks/abc-123").await.is_some());
281
282 assert!(store.get_flexible("abc-123").await.is_some());
284 }
285
286 #[tokio::test]
287 async fn test_update() {
288 let store = TaskStore::new();
289 let task = make_task("task-1");
290 store.insert(task).await;
291
292 let updated = store
293 .update("task-1", |t| {
294 t.status.state = TaskState::Completed;
295 })
296 .await;
297
298 assert!(updated.is_some());
299 assert_eq!(updated.unwrap().status.state, TaskState::Completed);
300 }
301
302 #[tokio::test]
303 async fn test_concurrent_inserts() {
304 let store = Arc::new(TaskStore::new());
305
306 let handles: Vec<_> = (0..100)
308 .map(|i| {
309 let store = store.clone();
310 tokio::spawn(async move {
311 store.insert(make_task(&format!("task-{}", i))).await;
312 })
313 })
314 .collect();
315
316 for h in handles {
317 h.await.unwrap();
318 }
319
320 assert_eq!(store.len().await, 100);
321 }
322
323 #[tokio::test]
324 async fn test_concurrent_reads_and_writes() {
325 let store = Arc::new(TaskStore::new());
326
327 for i in 0..10 {
329 store.insert(make_task(&format!("task-{}", i))).await;
330 }
331
332 let mut handles = Vec::new();
334
335 for i in 10..60 {
337 let store = store.clone();
338 handles.push(tokio::spawn(async move {
339 store.insert(make_task(&format!("task-{}", i))).await;
340 }));
341 }
342
343 for i in 0..10 {
345 let store = store.clone();
346 handles.push(tokio::spawn(async move {
347 let _ = store.get(&format!("task-{}", i)).await;
348 }));
349 }
350
351 for i in 0..10 {
353 let store = store.clone();
354 handles.push(tokio::spawn(async move {
355 store
356 .update(&format!("task-{}", i), |t| {
357 t.status.state = TaskState::Completed;
358 })
359 .await;
360 }));
361 }
362
363 for h in handles {
364 h.await.unwrap();
365 }
366
367 assert_eq!(store.len().await, 60);
369
370 for i in 0..10 {
372 let task = store.get(&format!("task-{}", i)).await.unwrap();
373 assert_eq!(task.status.state, TaskState::Completed);
374 }
375 }
376
377 #[tokio::test]
378 async fn test_concurrent_update_flexible() {
379 let store = Arc::new(TaskStore::new());
380 store.insert(make_task("tasks/shared-task")).await;
381
382 let handles: Vec<_> = (0..50)
384 .map(|_| {
385 let store = store.clone();
386 tokio::spawn(async move {
387 store
388 .update_flexible("shared-task", |t| -> Result<(), ()> {
389 t.context_id = "updated".to_string();
390 Ok(())
391 })
392 .await
393 })
394 })
395 .collect();
396
397 for h in handles {
398 let result = h.await.unwrap();
399 assert!(result.is_some());
400 assert!(result.unwrap().is_ok());
401 }
402
403 let task = store.get("tasks/shared-task").await.unwrap();
405 assert_eq!(task.context_id, "updated");
406 }
407}