a2a_rs_server/
task_store.rs1use a2a_rs_core::{Task, ListTasksRequest, TaskListResponse};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10#[derive(Clone)]
12pub struct TaskStore {
13 tasks: Arc<RwLock<HashMap<String, Task>>>,
14}
15
16impl Default for TaskStore {
17 fn default() -> Self {
18 Self::new()
19 }
20}
21
22impl TaskStore {
23 pub fn new() -> Self {
25 Self {
26 tasks: Arc::new(RwLock::new(HashMap::new())),
27 }
28 }
29
30 pub async fn insert(&self, task: Task) {
32 let id = task.id.clone();
33 self.tasks.write().await.insert(id, task);
34 }
35
36 pub async fn get(&self, id: &str) -> Option<Task> {
38 self.tasks.read().await.get(id).cloned()
39 }
40
41 pub async fn get_flexible(&self, id: &str) -> Option<Task> {
45 let guard = self.tasks.read().await;
46
47 if let Some(task) = guard.get(id) {
49 return Some(task.clone());
50 }
51
52 let prefixed = format!("tasks/{}", id);
54 if let Some(task) = guard.get(&prefixed) {
55 return Some(task.clone());
56 }
57
58 if let Some(stripped) = id.strip_prefix("tasks/") {
60 if let Some(task) = guard.get(stripped) {
61 return Some(task.clone());
62 }
63 }
64
65 None
66 }
67
68 pub async fn update<F>(&self, id: &str, f: F) -> Option<Task>
70 where
71 F: FnOnce(&mut Task),
72 {
73 let mut guard = self.tasks.write().await;
74 if let Some(task) = guard.get_mut(id) {
75 f(task);
76 Some(task.clone())
77 } else {
78 None
79 }
80 }
81
82 pub async fn update_flexible<F, E>(&self, id: &str, f: F) -> Option<Result<Task, E>>
89 where
90 F: FnOnce(&mut Task) -> Result<(), E>,
91 {
92 let mut guard = self.tasks.write().await;
93
94 let key = if guard.contains_key(id) {
96 Some(id.to_string())
97 } else {
98 let prefixed = format!("tasks/{}", id);
99 if guard.contains_key(&prefixed) {
100 Some(prefixed)
101 } else if let Some(stripped) = id.strip_prefix("tasks/") {
102 if guard.contains_key(stripped) {
103 Some(stripped.to_string())
104 } else {
105 None
106 }
107 } else {
108 None
109 }
110 };
111
112 let key = key?;
113 let task = guard.get_mut(&key)?;
114
115 match f(task) {
116 Ok(()) => Some(Ok(task.clone())),
117 Err(e) => Some(Err(e)),
118 }
119 }
120
121 pub async fn remove(&self, id: &str) -> Option<Task> {
123 self.tasks.write().await.remove(id)
124 }
125
126 pub async fn list(&self) -> Vec<Task> {
128 self.tasks.read().await.values().cloned().collect()
129 }
130
131 pub async fn list_filtered(&self, params: &ListTasksRequest) -> TaskListResponse {
135 let guard = self.tasks.read().await;
136
137 let mut filtered: Vec<_> = guard
139 .values()
140 .filter(|task| {
141 if let Some(ref ctx) = params.context_id {
143 if task.context_id != *ctx {
144 return false;
145 }
146 }
147 if let Some(status) = params.status {
149 if task.status.state != status {
150 return false;
151 }
152 }
153 if let Some(after_ms) = params.status_timestamp_after {
155 if let Some(ref ts) = task.status.timestamp {
156 if let Ok(dt) = chrono::DateTime::parse_from_rfc3339(ts) {
158 if dt.timestamp_millis() <= after_ms {
159 return false;
160 }
161 }
162 } else {
163 return false;
165 }
166 }
167 true
168 })
169 .cloned()
170 .collect();
171
172 let total_size = filtered.len() as u32;
173 let page_size = params.page_size.unwrap_or(50).min(100);
174
175 filtered.sort_by(|a, b| a.id.cmp(&b.id));
177
178 let offset: usize = params
180 .page_token
181 .as_ref()
182 .and_then(|t| t.parse().ok())
183 .unwrap_or(0);
184
185 let paginated: Vec<_> = filtered
186 .into_iter()
187 .skip(offset)
188 .take(page_size as usize)
189 .map(|mut task| {
190 if let Some(len) = params.history_length {
192 if let Some(ref mut history) = task.history {
193 let keep = len as usize;
194 if history.len() > keep {
195 *history = history.iter().rev().take(keep).cloned().collect();
196 history.reverse();
197 }
198 }
199 }
200 if params.include_artifacts == Some(false) {
202 task.artifacts = None;
203 }
204 task
205 })
206 .collect();
207
208 let next_offset = offset + paginated.len();
209 let next_page_token = if next_offset < total_size as usize {
210 next_offset.to_string()
211 } else {
212 String::new()
213 };
214
215 TaskListResponse {
216 tasks: paginated,
217 next_page_token,
218 page_size,
219 total_size,
220 }
221 }
222
223 pub async fn len(&self) -> usize {
225 self.tasks.read().await.len()
226 }
227
228 pub async fn is_empty(&self) -> bool {
230 self.tasks.read().await.is_empty()
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use a2a_rs_core::{TaskState, TaskStatus};
238
239 fn make_task(id: &str) -> Task {
240 Task {
241 id: id.to_string(),
242 context_id: "ctx".to_string(),
243 status: TaskStatus {
244 state: TaskState::Working,
245 message: None,
246 timestamp: None,
247 },
248 history: None,
249 artifacts: None,
250 metadata: None,
251 }
252 }
253
254 #[tokio::test]
255 async fn test_insert_and_get() {
256 let store = TaskStore::new();
257 let task = make_task("task-1");
258
259 store.insert(task.clone()).await;
260
261 let retrieved = store.get("task-1").await;
262 assert!(retrieved.is_some());
263 assert_eq!(retrieved.unwrap().id, "task-1");
264 }
265
266 #[tokio::test]
267 async fn test_get_flexible() {
268 let store = TaskStore::new();
269 let task = make_task("tasks/abc-123");
270
271 store.insert(task).await;
272
273 assert!(store.get_flexible("tasks/abc-123").await.is_some());
275
276 assert!(store.get_flexible("abc-123").await.is_some());
278 }
279
280 #[tokio::test]
281 async fn test_update() {
282 let store = TaskStore::new();
283 let task = make_task("task-1");
284 store.insert(task).await;
285
286 let updated = store
287 .update("task-1", |t| {
288 t.status.state = TaskState::Completed;
289 })
290 .await;
291
292 assert!(updated.is_some());
293 assert_eq!(updated.unwrap().status.state, TaskState::Completed);
294 }
295
296 #[tokio::test]
297 async fn test_concurrent_inserts() {
298 let store = Arc::new(TaskStore::new());
299
300 let handles: Vec<_> = (0..100)
302 .map(|i| {
303 let store = store.clone();
304 tokio::spawn(async move {
305 store.insert(make_task(&format!("task-{}", i))).await;
306 })
307 })
308 .collect();
309
310 for h in handles {
311 h.await.unwrap();
312 }
313
314 assert_eq!(store.len().await, 100);
315 }
316
317 #[tokio::test]
318 async fn test_concurrent_reads_and_writes() {
319 let store = Arc::new(TaskStore::new());
320
321 for i in 0..10 {
323 store.insert(make_task(&format!("task-{}", i))).await;
324 }
325
326 let mut handles = Vec::new();
328
329 for i in 10..60 {
331 let store = store.clone();
332 handles.push(tokio::spawn(async move {
333 store.insert(make_task(&format!("task-{}", i))).await;
334 }));
335 }
336
337 for i in 0..10 {
339 let store = store.clone();
340 handles.push(tokio::spawn(async move {
341 let _ = store.get(&format!("task-{}", i)).await;
342 }));
343 }
344
345 for i in 0..10 {
347 let store = store.clone();
348 handles.push(tokio::spawn(async move {
349 store
350 .update(&format!("task-{}", i), |t| {
351 t.status.state = TaskState::Completed;
352 })
353 .await;
354 }));
355 }
356
357 for h in handles {
358 h.await.unwrap();
359 }
360
361 assert_eq!(store.len().await, 60);
363
364 for i in 0..10 {
366 let task = store.get(&format!("task-{}", i)).await.unwrap();
367 assert_eq!(task.status.state, TaskState::Completed);
368 }
369 }
370
371 #[tokio::test]
372 async fn test_concurrent_update_flexible() {
373 let store = Arc::new(TaskStore::new());
374 store.insert(make_task("tasks/shared-task")).await;
375
376 let handles: Vec<_> = (0..50)
378 .map(|_| {
379 let store = store.clone();
380 tokio::spawn(async move {
381 store
382 .update_flexible("shared-task", |t| -> Result<(), ()> {
383 t.context_id = "updated".to_string();
384 Ok(())
385 })
386 .await
387 })
388 })
389 .collect();
390
391 for h in handles {
392 let result = h.await.unwrap();
393 assert!(result.is_some());
394 assert!(result.unwrap().is_ok());
395 }
396
397 let task = store.get("tasks/shared-task").await.unwrap();
399 assert_eq!(task.context_id, "updated");
400 }
401}