use crate::{A2AError, A2AResult, data::message::Message, data::task::Task};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ConversationContext {
pub context_id: String,
pub created_at: String,
pub last_active: String,
pub task_count: u32,
}
impl ConversationContext {
pub fn new(context_id: String) -> Self {
let now = Self::current_timestamp();
Self {
context_id,
created_at: now.clone(),
last_active: now,
task_count: 0,
}
}
pub fn update_activity(&mut self) {
self.last_active = Self::current_timestamp();
self.task_count += 1;
}
fn current_timestamp() -> String {
#[cfg(feature = "time-stamps")]
{
chrono::Utc::now().to_rfc3339()
}
#[cfg(not(feature = "time-stamps"))]
{
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
.to_string()
}
}
}
pub trait TaskStorage: Send + Sync {
fn store_task(&self, task: Task) -> A2AResult<()>;
fn get_task(&self, task_id: &str) -> A2AResult<Option<Task>>;
fn update_task(&self, task: Task) -> A2AResult<()>;
fn list_tasks(&self) -> A2AResult<Vec<Task>>;
fn remove_task(&self, task_id: &str) -> A2AResult<bool>;
fn task_exists(&self, task_id: &str) -> A2AResult<bool>;
fn get_tasks_by_context(&self, context_id: &str) -> A2AResult<Vec<Task>>;
fn get_latest_task_in_context(&self, context_id: &str) -> A2AResult<Option<Task>>;
fn get_context_history(&self, context_id: &str) -> A2AResult<Vec<Message>>;
fn get_or_create_context(&self, context_id: &str) -> A2AResult<ConversationContext>;
fn update_context_activity(&self, context_id: &str) -> A2AResult<()>;
fn list_contexts(&self) -> A2AResult<Vec<ConversationContext>>;
}
#[derive(Debug, Clone)]
pub struct InMemoryTaskStorage {
tasks: Arc<RwLock<HashMap<String, Task>>>,
contexts: Arc<RwLock<HashMap<String, ConversationContext>>>,
task_order: Arc<RwLock<HashMap<String, Vec<String>>>>,
}
impl InMemoryTaskStorage {
pub fn new() -> Self {
Self {
tasks: Arc::new(RwLock::new(HashMap::new())),
contexts: Arc::new(RwLock::new(HashMap::new())),
task_order: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn task_count(&self) -> usize {
self.tasks.read().unwrap().len()
}
pub fn context_count(&self) -> usize {
self.contexts.read().unwrap().len()
}
pub fn clear(&self) {
self.tasks.write().unwrap().clear();
self.contexts.write().unwrap().clear();
self.task_order.write().unwrap().clear();
}
}
impl Default for InMemoryTaskStorage {
fn default() -> Self {
Self::new()
}
}
impl TaskStorage for InMemoryTaskStorage {
fn store_task(&self, task: Task) -> A2AResult<()> {
let mut tasks = self
.tasks
.write()
.map_err(|_| A2AError::internal("Failed to acquire task storage lock"))?;
let is_new_task = !tasks.contains_key(&task.id);
tasks.insert(task.id.clone(), task.clone());
let mut task_order = self
.task_order
.write()
.map_err(|_| A2AError::internal("Failed to acquire task order lock"))?;
task_order
.entry(task.context_id.clone())
.or_insert_with(Vec::new)
.push(task.id.clone());
drop(tasks);
drop(task_order);
if is_new_task {
let mut contexts = self
.contexts
.write()
.map_err(|_| A2AError::internal("Failed to acquire context storage lock"))?;
if let Some(context) = contexts.get_mut(&task.context_id) {
context.update_activity();
} else {
let mut new_context = ConversationContext::new(task.context_id.clone());
new_context.update_activity();
contexts.insert(task.context_id.clone(), new_context);
}
}
Ok(())
}
fn get_task(&self, task_id: &str) -> A2AResult<Option<Task>> {
let tasks = self
.tasks
.read()
.map_err(|_| A2AError::internal("Failed to acquire task storage lock"))?;
Ok(tasks.get(task_id).cloned())
}
fn update_task(&self, task: Task) -> A2AResult<()> {
let mut tasks = self
.tasks
.write()
.map_err(|_| A2AError::internal("Failed to acquire task storage lock"))?;
if tasks.contains_key(&task.id) {
tasks.insert(task.id.clone(), task);
Ok(())
} else {
Err(A2AError::method_execution_failed(
"tasks/update",
&format!("Task not found: {}", &task.id),
))
}
}
fn list_tasks(&self) -> A2AResult<Vec<Task>> {
let tasks = self
.tasks
.read()
.map_err(|_| A2AError::internal("Failed to acquire task storage lock"))?;
Ok(tasks.values().cloned().collect())
}
fn remove_task(&self, task_id: &str) -> A2AResult<bool> {
let mut tasks = self
.tasks
.write()
.map_err(|_| A2AError::internal("Failed to acquire task storage lock"))?;
Ok(tasks.remove(task_id).is_some())
}
fn task_exists(&self, task_id: &str) -> A2AResult<bool> {
let tasks = self
.tasks
.read()
.map_err(|_| A2AError::internal("Failed to acquire task storage lock"))?;
Ok(tasks.contains_key(task_id))
}
fn get_tasks_by_context(&self, context_id: &str) -> A2AResult<Vec<Task>> {
let tasks = self
.tasks
.read()
.map_err(|_| A2AError::internal("Failed to acquire task storage lock"))?;
Ok(tasks
.values()
.cloned()
.filter(|task| task.context_id == context_id)
.collect())
}
fn get_latest_task_in_context(&self, context_id: &str) -> A2AResult<Option<Task>> {
let task_order = self
.task_order
.read()
.map_err(|_| A2AError::internal("Failed to acquire task order lock"))?;
if let Some(task_ids) = task_order.get(context_id) {
let latest_task_id = task_ids.last().cloned();
if let Some(task_id) = latest_task_id {
let tasks = self
.tasks
.read()
.map_err(|_| A2AError::internal("Failed to acquire task storage lock"))?;
Ok(tasks.get(&task_id).cloned())
} else {
Ok(None)
}
} else {
Ok(None)
}
}
fn get_context_history(&self, context_id: &str) -> A2AResult<Vec<Message>> {
let task_order = self
.task_order
.read()
.map_err(|_| A2AError::internal("Failed to acquire task order lock"))?;
let tasks = self
.tasks
.read()
.map_err(|_| A2AError::internal("Failed to acquire task storage lock"))?;
let mut messages: Vec<Message> = Vec::new();
if let Some(task_ids) = task_order.get(context_id) {
for task_id in task_ids {
if let Some(task) = tasks.get(task_id) {
if let Some(ref task_history) = task.history {
messages.extend(task_history.iter().cloned());
}
}
}
}
Ok(messages)
}
fn get_or_create_context(&self, context_id: &str) -> A2AResult<ConversationContext> {
let mut contexts = self
.contexts
.write()
.map_err(|_| A2AError::internal("Failed to acquire context storage lock"))?;
if let Some(context) = contexts.get(context_id) {
Ok(context.clone())
} else {
let new_context = ConversationContext::new(context_id.to_string());
contexts.insert(context_id.to_string(), new_context.clone());
Ok(new_context)
}
}
fn update_context_activity(&self, context_id: &str) -> A2AResult<()> {
let mut contexts = self
.contexts
.write()
.map_err(|_| A2AError::internal("Failed to acquire context storage lock"))?;
if let Some(context) = contexts.get_mut(context_id) {
context.update_activity();
Ok(())
} else {
Err(A2AError::method_execution_failed(
"update_context_activity",
&format!("Context not found: {}", context_id),
))
}
}
fn list_contexts(&self) -> A2AResult<Vec<ConversationContext>> {
let contexts = self
.contexts
.read()
.map_err(|_| A2AError::internal("Failed to acquire context storage lock"))?;
Ok(contexts.values().cloned().collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data::task::TaskState;
#[test]
fn test_in_memory_storage_basic_operations() {
let storage = InMemoryTaskStorage::new();
let task = Task::new("test-context".to_string());
let task_id = task.id.clone();
storage.store_task(task.clone()).unwrap();
assert_eq!(storage.task_count(), 1);
let retrieved = storage.get_task(&task_id).unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().id, task_id);
assert!(storage.task_exists(&task_id).unwrap());
assert!(!storage.task_exists("non-existent").unwrap());
let tasks = storage.list_tasks().unwrap();
assert_eq!(tasks.len(), 1);
assert_eq!(tasks[0].id, task_id);
assert!(storage.remove_task(&task_id).unwrap());
assert!(!storage.remove_task(&task_id).unwrap());
assert_eq!(storage.task_count(), 0);
}
#[test]
fn test_update_task() {
let storage = InMemoryTaskStorage::new();
let mut task = Task::new("test-context".to_string());
let task_id = task.id.clone();
storage.store_task(task.clone()).unwrap();
task.update_status(TaskState::Working);
storage.update_task(task.clone()).unwrap();
let retrieved = storage.get_task(&task_id).unwrap().unwrap();
assert_eq!(retrieved.status.state, TaskState::Working);
}
#[test]
fn test_update_nonexistent_task() {
let storage = InMemoryTaskStorage::new();
let task = Task::new("test-context".to_string());
let result = storage.update_task(task);
assert!(result.is_err());
}
#[test]
fn test_get_nonexistent_task() {
let storage = InMemoryTaskStorage::new();
let result = storage.get_task("non-existent").unwrap();
assert!(result.is_none());
}
#[test]
fn test_default_constructor() {
let storage = InMemoryTaskStorage::default();
assert_eq!(storage.task_count(), 0);
assert_eq!(storage.context_count(), 0);
let storage_new = InMemoryTaskStorage::new();
assert_eq!(storage.task_count(), storage_new.task_count());
assert_eq!(storage.context_count(), storage_new.context_count());
}
#[test]
fn test_clear_functionality() {
let storage = InMemoryTaskStorage::new();
let task1 = Task::new("context1".to_string());
let task2 = Task::new("context2".to_string());
let task3 = Task::new("context3".to_string());
storage.store_task(task1).unwrap();
storage.store_task(task2).unwrap();
storage.store_task(task3).unwrap();
assert_eq!(storage.task_count(), 3);
assert_eq!(storage.context_count(), 3);
storage.clear();
assert_eq!(storage.task_count(), 0);
assert_eq!(storage.context_count(), 0);
let tasks = storage.list_tasks().unwrap();
assert!(tasks.is_empty());
}
#[test]
fn test_comprehensive_storage_workflow() {
let storage = InMemoryTaskStorage::new();
let mut task1 = Task::new("workflow-test-1".to_string());
let task2 = Task::new("workflow-test-2".to_string());
let task1_id = task1.id.clone();
let task2_id = task2.id.clone();
storage.store_task(task1.clone()).unwrap();
storage.store_task(task2.clone()).unwrap();
assert_eq!(storage.task_count(), 2);
assert_eq!(storage.context_count(), 2);
task1.update_status(TaskState::Working);
storage.update_task(task1.clone()).unwrap();
let retrieved_task1 = storage.get_task(&task1_id).unwrap().unwrap();
assert_eq!(retrieved_task1.status.state, TaskState::Working);
let all_tasks = storage.list_tasks().unwrap();
assert_eq!(all_tasks.len(), 2);
assert!(storage.task_exists(&task1_id).unwrap());
assert!(storage.task_exists(&task2_id).unwrap());
assert!(!storage.task_exists("non-existent-id").unwrap());
assert!(storage.remove_task(&task1_id).unwrap());
assert!(!storage.task_exists(&task1_id).unwrap());
assert_eq!(storage.task_count(), 1);
assert!(!storage.remove_task(&task1_id).unwrap());
storage.clear();
assert_eq!(storage.task_count(), 0);
assert_eq!(storage.context_count(), 0); assert!(!storage.task_exists(&task2_id).unwrap());
}
#[test]
fn test_concurrent_storage_operations() {
use std::sync::Arc;
use std::thread;
let storage = Arc::new(InMemoryTaskStorage::new());
let mut handles = vec![];
for i in 0..5 {
let storage_clone = Arc::clone(&storage);
let handle = thread::spawn(move || {
let task = Task::new(format!("concurrent-task-{}", i));
let task_id = task.id.clone();
storage_clone.store_task(task).unwrap();
assert!(storage_clone.task_exists(&task_id).unwrap());
let retrieved = storage_clone.get_task(&task_id).unwrap();
assert!(retrieved.is_some());
task_id
});
handles.push(handle);
}
let task_ids: Vec<String> = handles.into_iter().map(|h| h.join().unwrap()).collect();
assert_eq!(storage.task_count(), 5);
let all_tasks = storage.list_tasks().unwrap();
assert_eq!(all_tasks.len(), 5);
for task_id in task_ids {
assert!(storage.task_exists(&task_id).unwrap());
}
}
#[test]
fn test_edge_cases_and_error_conditions() {
let storage = InMemoryTaskStorage::new();
let non_existent_task = Task::new("non-existent-context".to_string());
let update_result = storage.update_task(non_existent_task.clone());
assert!(update_result.is_err());
let error = update_result.unwrap_err();
let error_message = format!("{}", error);
assert!(error_message.contains(&non_existent_task.id));
assert!(!storage.remove_task("definitely-not-there").unwrap());
assert_eq!(storage.task_count(), 0);
assert_eq!(storage.context_count(), 0);
assert!(storage.list_tasks().unwrap().is_empty());
assert!(!storage.task_exists("any-id").unwrap());
storage.clear();
assert_eq!(storage.task_count(), 0);
assert_eq!(storage.context_count(), 0);
}
#[test]
fn test_task_replacement_behavior() {
let storage = InMemoryTaskStorage::new();
let mut original_task = Task::new("replacement-test".to_string());
let task_id = original_task.id.clone();
storage.store_task(original_task.clone()).unwrap();
assert_eq!(storage.task_count(), 1);
assert_eq!(storage.context_count(), 1);
original_task.update_status(TaskState::Completed);
storage.store_task(original_task.clone()).unwrap();
assert_eq!(storage.task_count(), 1);
assert_eq!(storage.context_count(), 1);
let retrieved = storage.get_task(&task_id).unwrap().unwrap();
assert_eq!(retrieved.status.state, TaskState::Completed);
}
#[test]
fn test_context_based_operations() {
let storage = InMemoryTaskStorage::new();
let task1_ctx1 = Task::new("context-1".to_string());
let task2_ctx1 = Task::new("context-1".to_string());
let task1_ctx2 = Task::new("context-2".to_string());
let _task1_ctx1_id = task1_ctx1.id.clone();
let task2_ctx1_id = task2_ctx1.id.clone();
let _task1_ctx2_id = task1_ctx2.id.clone();
storage.store_task(task1_ctx1).unwrap();
storage.store_task(task2_ctx1).unwrap();
storage.store_task(task1_ctx2).unwrap();
assert_eq!(storage.task_count(), 3);
assert_eq!(storage.context_count(), 2);
let ctx1_tasks = storage.get_tasks_by_context("context-1").unwrap();
assert_eq!(ctx1_tasks.len(), 2);
let ctx2_tasks = storage.get_tasks_by_context("context-2").unwrap();
assert_eq!(ctx2_tasks.len(), 1);
let latest_ctx1 = storage.get_latest_task_in_context("context-1").unwrap();
assert!(latest_ctx1.is_some());
assert_eq!(latest_ctx1.unwrap().id, task2_ctx1_id);
let contexts = storage.list_contexts().unwrap();
assert_eq!(contexts.len(), 2);
let ctx1_context = contexts
.iter()
.find(|c| c.context_id == "context-1")
.unwrap();
assert_eq!(ctx1_context.task_count, 2);
let ctx2_context = contexts
.iter()
.find(|c| c.context_id == "context-2")
.unwrap();
assert_eq!(ctx2_context.task_count, 1);
}
#[test]
fn test_conversation_history_aggregation() {
use crate::data::message::{Message, MessageRole, Part};
let storage = InMemoryTaskStorage::new();
let mut task1 = Task::new("conversation-test".to_string());
let mut task2 = Task::new("conversation-test".to_string());
let msg1 = Message::with_id(
"msg-1".to_string(),
MessageRole::User,
vec![Part::text("Hello")],
);
let msg2 = Message::with_id(
"msg-2".to_string(),
MessageRole::Agent,
vec![Part::text("Hi there!")],
);
task1.add_to_history(msg1);
task1.add_to_history(msg2);
let msg3 = Message::with_id(
"msg-3".to_string(),
MessageRole::User,
vec![Part::text("How are you?")],
);
let msg4 = Message::with_id(
"msg-4".to_string(),
MessageRole::Agent,
vec![Part::text("I'm doing well!")],
);
task2.add_to_history(msg3);
task2.add_to_history(msg4);
storage.store_task(task1).unwrap();
storage.store_task(task2).unwrap();
let history = storage.get_context_history("conversation-test").unwrap();
assert_eq!(history.len(), 4);
assert_eq!(history[0].message_id, "msg-1");
assert_eq!(history[1].message_id, "msg-2");
assert_eq!(history[2].message_id, "msg-3");
assert_eq!(history[3].message_id, "msg-4");
}
#[test]
fn test_context_lifecycle_management() {
let storage = InMemoryTaskStorage::new();
let context1 = storage.get_or_create_context("new-context").unwrap();
assert_eq!(context1.context_id, "new-context");
assert_eq!(context1.task_count, 0);
assert_eq!(storage.context_count(), 1);
let context1_again = storage.get_or_create_context("new-context").unwrap();
assert_eq!(context1_again.context_id, "new-context");
assert_eq!(storage.context_count(), 1);
let task = Task::new("new-context".to_string());
storage.store_task(task).unwrap();
let updated_context = storage.get_or_create_context("new-context").unwrap();
assert_eq!(updated_context.task_count, 1);
storage.update_context_activity("new-context").unwrap();
let context_after_update = storage.get_or_create_context("new-context").unwrap();
assert_eq!(context_after_update.task_count, 2);
let result = storage.update_context_activity("non-existent");
assert!(result.is_err());
}
#[test]
fn test_empty_context_operations() {
let storage = InMemoryTaskStorage::new();
let tasks = storage.get_tasks_by_context("empty-context").unwrap();
assert!(tasks.is_empty());
let latest = storage.get_latest_task_in_context("empty-context").unwrap();
assert!(latest.is_none());
let history = storage.get_context_history("empty-context").unwrap();
assert!(history.is_empty());
let contexts = storage.list_contexts().unwrap();
assert!(contexts.is_empty());
}
#[test]
fn test_a2a_protocol_conversation_flow() {
let storage = InMemoryTaskStorage::new();
let mut task1 = Task::new("ctx-conversation-abc".to_string());
task1.id = "task-boat-gen-123".to_string();
task1.update_status(TaskState::Completed);
let mut task2 = Task::new("ctx-conversation-abc".to_string());
task2.id = "task-boat-color-456".to_string();
task2.update_status(TaskState::Completed);
storage.store_task(task1).unwrap();
storage.store_task(task2).unwrap();
assert_eq!(storage.task_count(), 2);
assert_eq!(storage.context_count(), 1);
let context_tasks = storage
.get_tasks_by_context("ctx-conversation-abc")
.unwrap();
assert_eq!(context_tasks.len(), 2);
let latest = storage
.get_latest_task_in_context("ctx-conversation-abc")
.unwrap();
assert_eq!(latest.unwrap().id, "task-boat-color-456");
let context = storage
.get_or_create_context("ctx-conversation-abc")
.unwrap();
assert_eq!(context.task_count, 2);
assert_eq!(context.context_id, "ctx-conversation-abc");
}
}