1use std::collections::HashMap;
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use tokio::sync::Mutex; use crate::adapter::business::push_notification::{
12 PushNotificationRegistry, PushNotificationSender,
13};
14
15#[cfg(feature = "http-client")]
16use crate::adapter::business::push_notification::HttpPushNotificationSender;
17#[cfg(not(feature = "http-client"))]
18use crate::adapter::business::push_notification::NoopPushNotificationSender;
19use crate::domain::{
20 A2AError, ContextId, Message, Task, TaskId, TaskPushNotificationConfig, TaskState,
21 VersionedTask,
22};
23use crate::port::{
24 AsyncNotificationManager, AsyncPushNotifier, AsyncTaskLifecycle, AsyncTaskQuery,
25 AsyncTaskVersioning,
26};
27
28pub struct InMemoryTaskStorage {
37 pub(crate) tasks: Arc<Mutex<HashMap<String, Task>>>,
39 pub(crate) versions: Arc<Mutex<HashMap<String, u64>>>,
45 pub(crate) push_notification_registry: Arc<PushNotificationRegistry>,
47}
48
49impl InMemoryTaskStorage {
50 pub fn new() -> Self {
52 #[cfg(feature = "http-client")]
54 let push_sender = HttpPushNotificationSender::new();
55 #[cfg(not(feature = "http-client"))]
56 let push_sender = NoopPushNotificationSender;
57
58 let push_registry = PushNotificationRegistry::new(push_sender);
59
60 Self {
61 tasks: Arc::new(Mutex::new(HashMap::new())),
62 versions: Arc::new(Mutex::new(HashMap::new())),
63 push_notification_registry: Arc::new(push_registry),
64 }
65 }
66
67 pub fn with_push_sender(push_sender: impl PushNotificationSender + 'static) -> Self {
69 let push_registry = PushNotificationRegistry::new(push_sender);
70
71 Self {
72 tasks: Arc::new(Mutex::new(HashMap::new())),
73 versions: Arc::new(Mutex::new(HashMap::new())),
74 push_notification_registry: Arc::new(push_registry),
75 }
76 }
77
78 async fn bump_version(&self, task_id: &str) -> u64 {
82 let mut versions = self.versions.lock().await;
83 let v = versions.entry(task_id.to_string()).or_insert(0);
84 *v += 1;
85 *v
86 }
87
88 pub fn push_notifier(&self) -> Arc<dyn AsyncPushNotifier> {
95 self.push_notification_registry.clone()
96 }
97}
98
99impl Default for InMemoryTaskStorage {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105#[async_trait]
106impl AsyncTaskLifecycle for InMemoryTaskStorage {
107 async fn create(&self, id: &TaskId, context_id: &ContextId) -> Result<Task, A2AError> {
108 let task_id = id.as_str();
109 let context_id = context_id.as_str();
110 let mut tasks_guard = self.tasks.lock().await;
111
112 if tasks_guard.contains_key(task_id) {
113 return Err(A2AError::TaskNotFound(format!(
114 "Task {} already exists",
115 task_id
116 )));
117 }
118
119 let task = Task::new(task_id.to_string(), context_id.to_string());
120 tasks_guard.insert(task_id.to_string(), task.clone());
121 self.bump_version(task_id).await; Ok(task)
124 }
125
126 async fn update_status(
127 &self,
128 id: &TaskId,
129 state: TaskState,
130 message: Option<Message>,
131 ) -> Result<Task, A2AError> {
132 let task_id = id.as_str();
133 let mut tasks_guard = self.tasks.lock().await;
134
135 let task = tasks_guard
136 .get_mut(task_id)
137 .ok_or_else(|| A2AError::TaskNotFound(task_id.to_string()))?;
138
139 task.update_status(state, message);
141 let updated = task.clone();
142 self.bump_version(task_id).await;
143
144 Ok(updated)
148 }
149
150 async fn exists(&self, id: &TaskId) -> Result<bool, A2AError> {
151 let task_id = id.as_str();
152 let tasks_guard = self.tasks.lock().await;
153 Ok(tasks_guard.contains_key(task_id))
154 }
155
156 async fn get(&self, id: &TaskId, history_length: Option<u32>) -> Result<Task, A2AError> {
157 let task_id = id.as_str();
158 let task = {
160 let tasks_guard = self.tasks.lock().await;
161
162 let Some(task) = tasks_guard.get(task_id) else {
163 return Err(A2AError::TaskNotFound(task_id.to_string()));
164 };
165
166 task.with_limited_history(history_length)
168 }; Ok(task)
171 }
172
173 async fn cancel(&self, id: &TaskId) -> Result<Task, A2AError> {
174 let task_id = id.as_str();
175 let mut tasks_guard = self.tasks.lock().await;
176
177 let Some(task) = tasks_guard.get(task_id) else {
178 return Err(A2AError::TaskNotFound(task_id.to_string()));
179 };
180
181 let mut updated_task = task.clone();
182
183 if updated_task.status.state != TaskState::Working {
185 return Err(A2AError::TaskNotCancelable(format!(
186 "Task {} is in state {:?} and cannot be canceled",
187 task_id, updated_task.status.state
188 )));
189 }
190
191 let cancel_message = Message {
193 role: ::buffa::EnumValue::from(crate::domain::Role::Agent),
194 parts: vec![crate::domain::Part::text(format!(
195 "Task {} canceled.",
196 task_id
197 ))],
198 message_id: uuid::Uuid::new_v4().to_string(),
199 task_id: task_id.to_string(),
200 context_id: updated_task.context_id.clone(),
201 ..Default::default()
202 };
203
204 updated_task.update_status(TaskState::Canceled, Some(cancel_message));
206 tasks_guard.insert(task_id.to_string(), updated_task.clone());
207 self.bump_version(task_id).await;
208
209 Ok(updated_task)
212 }
213}
214
215#[async_trait]
216impl AsyncTaskVersioning for InMemoryTaskStorage {
217 async fn version(&self, id: &TaskId) -> Result<u64, A2AError> {
218 let task_id = id.as_str();
219 let tasks_guard = self.tasks.lock().await;
220 if !tasks_guard.contains_key(task_id) {
221 return Err(A2AError::TaskNotFound(task_id.to_string()));
222 }
223 let versions = self.versions.lock().await;
224 Ok(versions.get(task_id).copied().unwrap_or(0))
225 }
226
227 async fn get_versioned(
228 &self,
229 id: &TaskId,
230 history_length: Option<u32>,
231 ) -> Result<VersionedTask, A2AError> {
232 let task_id = id.as_str();
233 let tasks_guard = self.tasks.lock().await;
234 let Some(task) = tasks_guard.get(task_id) else {
235 return Err(A2AError::TaskNotFound(task_id.to_string()));
236 };
237 let task = task.with_limited_history(history_length);
238 let versions = self.versions.lock().await;
239 let version = versions.get(task_id).copied().unwrap_or(0);
240 Ok(VersionedTask::new(task, version))
241 }
242
243 async fn update_status_checked(
244 &self,
245 id: &TaskId,
246 expected: u64,
247 state: TaskState,
248 message: Option<Message>,
249 ) -> Result<VersionedTask, A2AError> {
250 let task_id = id.as_str();
251 let mut tasks_guard = self.tasks.lock().await;
254 let task = tasks_guard
255 .get_mut(task_id)
256 .ok_or_else(|| A2AError::TaskNotFound(task_id.to_string()))?;
257 let mut versions = self.versions.lock().await;
258 let current = versions.get(task_id).copied().unwrap_or(0);
259 if current != expected {
260 return Err(A2AError::VersionConflict {
261 id: task_id.to_string(),
262 expected,
263 actual: current,
264 });
265 }
266 task.update_status(state, message);
267 let new_version = current + 1;
268 versions.insert(task_id.to_string(), new_version);
269 Ok(VersionedTask::new(task.clone(), new_version))
270 }
271}
272
273#[async_trait]
274impl AsyncTaskQuery for InMemoryTaskStorage {
275 async fn list(
276 &self,
277 params: &crate::domain::ListTasksParams,
278 ) -> Result<crate::domain::ListTasksResult, A2AError> {
279 use crate::domain::ListTasksResult;
280
281 let tasks_guard = self.tasks.lock().await;
282
283 let mut filtered_tasks: Vec<_> = tasks_guard
285 .values()
286 .filter(|task| {
287 if let Some(ref context_id) = params.context_id {
289 if &task.context_id != context_id {
290 return false;
291 }
292 }
293
294 if let Some(ref status) = params.status {
296 if &task.status.state != status {
297 return false;
298 }
299 }
300
301 if let Some(status_timestamp_after) = ¶ms.status_timestamp_after {
303 if let Ok(after_dt) =
304 chrono::DateTime::parse_from_rfc3339(status_timestamp_after)
305 {
306 let after_utc = after_dt.with_timezone(&chrono::Utc);
307 if let Some(timestamp) = task.status.timestamp_utc() {
308 if timestamp <= after_utc {
309 return false;
310 }
311 }
312 }
313 }
314
315 true
316 })
317 .cloned()
318 .collect();
319
320 filtered_tasks.sort_by(|a, b| {
322 let a_time = a
323 .status
324 .timestamp_utc()
325 .map(|t| t.timestamp_millis())
326 .unwrap_or(0);
327 let b_time = b
328 .status
329 .timestamp_utc()
330 .map(|t| t.timestamp_millis())
331 .unwrap_or(0);
332 b_time.cmp(&a_time)
333 });
334
335 let total_size = filtered_tasks.len() as i32;
336
337 let page_size = params.page_size.unwrap_or(50).clamp(1, 100) as usize;
339 let page_start = if let Some(ref token) = params.page_token {
340 token.parse::<usize>().unwrap_or(0)
342 } else {
343 0
344 };
345
346 let page_end = (page_start + page_size).min(filtered_tasks.len());
347 let has_more = page_end < filtered_tasks.len();
348
349 let mut page_tasks: Vec<_> = filtered_tasks[page_start..page_end].to_vec();
351
352 let history_length = params.history_length.unwrap_or(0);
354 for task in &mut page_tasks {
355 *task = task.with_limited_history(Some(history_length as u32));
356
357 if !params.include_artifacts.unwrap_or(false) {
359 task.artifacts.clear();
360 }
361 }
362
363 let next_page_token = if has_more {
365 page_end.to_string()
366 } else {
367 String::new()
368 };
369
370 Ok(ListTasksResult {
371 tasks: page_tasks,
372 total_size,
373 page_size: page_size as i32,
374 next_page_token,
375 })
376 }
377}
378
379#[async_trait]
384impl AsyncNotificationManager for InMemoryTaskStorage {
385 async fn set_config(
386 &self,
387 config: &TaskPushNotificationConfig,
388 ) -> Result<TaskPushNotificationConfig, A2AError> {
389 #[cfg(feature = "tracing")]
390 tracing::info!(
391 task_id = %config.task_id,
392 url = %config.url,
393 "🚀 Registering push notification config for task"
394 );
395
396 self.push_notification_registry
398 .register(&config.task_id, config.clone())
399 .await?;
400
401 #[cfg(feature = "tracing")]
402 tracing::info!(
403 task_id = %config.task_id,
404 "✅ Push notification config registered successfully"
405 );
406
407 Ok(config.clone())
408 }
409
410 async fn get_config(
411 &self,
412 params: &crate::domain::GetTaskPushNotificationConfigParams,
413 ) -> Result<TaskPushNotificationConfig, A2AError> {
414 match self
415 .push_notification_registry
416 .get_config(¶ms.id)
417 .await?
418 {
419 Some(config) => Ok(config),
420 None => Err(A2AError::PushNotificationNotSupported),
421 }
422 }
423
424 async fn list_configs(
425 &self,
426 params: &crate::domain::ListTaskPushNotificationConfigsParams,
427 ) -> Result<Vec<TaskPushNotificationConfig>, A2AError> {
428 match self
431 .push_notification_registry
432 .get_config(¶ms.id)
433 .await?
434 {
435 Some(config) => Ok(vec![config]),
436 None => Ok(vec![]),
437 }
438 }
439
440 async fn delete_config(
441 &self,
442 params: &crate::domain::DeleteTaskPushNotificationConfigParams,
443 ) -> Result<(), A2AError> {
444 self.push_notification_registry
447 .unregister(¶ms.id)
448 .await?;
449 Ok(())
450 }
451}
452
453impl Clone for InMemoryTaskStorage {
454 fn clone(&self) -> Self {
455 Self {
456 tasks: self.tasks.clone(),
457 versions: self.versions.clone(),
458 push_notification_registry: self.push_notification_registry.clone(),
459 }
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466 use crate::domain::ContextId;
467
468 fn tid(s: &str) -> TaskId {
469 s.parse().unwrap()
470 }
471 fn cid(s: &str) -> ContextId {
472 s.parse().unwrap()
473 }
474
475 #[tokio::test]
476 async fn versioning_tracks_and_guards_mutations() {
477 let store = InMemoryTaskStorage::new();
478 store.create(&tid("t1"), &cid("c1")).await.unwrap();
479 assert_eq!(store.version(&tid("t1")).await.unwrap(), 1);
480
481 store
483 .update_status(&tid("t1"), TaskState::Working, None)
484 .await
485 .unwrap();
486 let snap = store.get_versioned(&tid("t1"), None).await.unwrap();
487 assert_eq!(snap.version, 2);
488
489 let err = store
491 .update_status_checked(&tid("t1"), 1, TaskState::Completed, None)
492 .await
493 .unwrap_err();
494 assert!(matches!(
495 err,
496 A2AError::VersionConflict {
497 expected: 1,
498 actual: 2,
499 ..
500 }
501 ));
502 assert_eq!(
503 store.get(&tid("t1"), None).await.unwrap().status.state,
504 TaskState::Working
505 );
506
507 let ok = store
509 .update_status_checked(&tid("t1"), 2, TaskState::Completed, None)
510 .await
511 .unwrap();
512 assert_eq!(ok.version, 3);
513 assert_eq!(ok.task.status.state, TaskState::Completed);
514 }
515}