#[cfg(feature = "sqlx-storage")]
mod sqlx_tests {
use a2a_rs::adapter::storage::{DatabaseConfig, SqlxTaskStorage};
use a2a_rs::domain::TaskState;
use a2a_rs::port::{AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskManager};
use a2a_rs::{A2AError, TaskPushNotificationConfig};
use std::sync::Arc;
use uuid::Uuid;
async fn create_test_storage() -> Result<SqlxTaskStorage, A2AError> {
let config = DatabaseConfig::builder()
.url("sqlite::memory:".to_string())
.max_connections(1)
.build();
SqlxTaskStorage::new(&config.url).await
}
#[tokio::test]
async fn test_task_lifecycle() -> Result<(), Box<dyn std::error::Error>> {
let storage = create_test_storage().await?;
let task_id = Uuid::new_v4().to_string();
let context_id = "test-context";
let task = storage.create_task(&task_id, context_id).await?;
assert_eq!(task.id, task_id);
assert_eq!(task.context_id, context_id);
assert_eq!(task.status.state, TaskState::Submitted);
assert!(storage.task_exists(&task_id).await?);
assert!(!storage.task_exists("non-existent").await?);
let working_task = storage
.update_task_status(&task_id, TaskState::Working, None)
.await?;
assert_eq!(working_task.status.state, TaskState::Working);
let completed_task = storage
.update_task_status(&task_id, TaskState::Completed, None)
.await?;
assert_eq!(completed_task.status.state, TaskState::Completed);
let retrieved_task = storage.get_task(&task_id, Some(10)).await?;
assert_eq!(retrieved_task.id, task_id);
assert_eq!(retrieved_task.status.state, TaskState::Completed);
Ok(())
}
#[tokio::test]
async fn test_task_cancellation() -> Result<(), Box<dyn std::error::Error>> {
let storage = create_test_storage().await?;
let task_id = Uuid::new_v4().to_string();
storage.create_task(&task_id, "test-context").await?;
storage
.update_task_status(&task_id, TaskState::Working, None)
.await?;
let canceled_task = storage.cancel_task(&task_id).await?;
assert_eq!(canceled_task.status.state, TaskState::Canceled);
let task_with_history = storage.get_task(&task_id, None).await?;
assert_eq!(task_with_history.status.state, TaskState::Canceled);
Ok(())
}
#[tokio::test]
async fn test_cannot_cancel_completed_task() -> Result<(), Box<dyn std::error::Error>> {
let storage = create_test_storage().await?;
let task_id = Uuid::new_v4().to_string();
storage.create_task(&task_id, "test-context").await?;
storage
.update_task_status(&task_id, TaskState::Working, None)
.await?;
storage
.update_task_status(&task_id, TaskState::Completed, None)
.await?;
let result = storage.cancel_task(&task_id).await;
assert!(result.is_err());
if let Err(A2AError::TaskNotCancelable(_)) = result {
} else {
panic!("Expected TaskNotCancelable error, got: {:?}", result);
}
Ok(())
}
#[tokio::test]
async fn test_duplicate_task_creation() -> Result<(), Box<dyn std::error::Error>> {
let storage = create_test_storage().await?;
let task_id = Uuid::new_v4().to_string();
storage.create_task(&task_id, "test-context").await?;
let result = storage.create_task(&task_id, "test-context").await;
assert!(result.is_err());
if let Err(A2AError::TaskNotFound(_)) = result {
} else {
panic!(
"Expected TaskNotFound error for duplicate, got: {:?}",
result
);
}
Ok(())
}
#[tokio::test]
async fn test_task_history_limit() -> Result<(), Box<dyn std::error::Error>> {
let storage = create_test_storage().await?;
let task_id = Uuid::new_v4().to_string();
storage.create_task(&task_id, "test-context").await?;
storage
.update_task_status(&task_id, TaskState::Working, None)
.await?;
storage
.update_task_status(&task_id, TaskState::InputRequired, None)
.await?;
storage
.update_task_status(&task_id, TaskState::Working, None)
.await?;
storage
.update_task_status(&task_id, TaskState::Completed, None)
.await?;
let _task_limited = storage.get_task(&task_id, Some(3)).await?;
let _task_full = storage.get_task(&task_id, None).await?;
Ok(())
}
#[tokio::test]
async fn test_push_notifications() -> Result<(), Box<dyn std::error::Error>> {
let storage = create_test_storage().await?;
let task_id = Uuid::new_v4().to_string();
storage.create_task(&task_id, "test-context").await?;
let config = TaskPushNotificationConfig {
tenant: String::new(),
task_id: task_id.clone(),
id: String::new(),
url: "https://example.com/webhook".to_string(),
token: String::new(),
authentication: None.into(),
..Default::default()
};
let set_config = storage.set_task_notification(&config).await?;
assert_eq!(set_config.task_id, task_id);
assert_eq!(set_config.url, "https://example.com/webhook");
let retrieved_config = storage.get_task_notification(&task_id).await?;
assert_eq!(retrieved_config.task_id, task_id);
assert_eq!(retrieved_config.url, "https://example.com/webhook");
storage.remove_task_notification(&task_id).await?;
let result = storage.get_task_notification(&task_id).await;
assert!(result.is_err());
Ok(())
}
#[tokio::test]
async fn test_database_config() -> Result<(), Box<dyn std::error::Error>> {
let valid_config = DatabaseConfig::builder()
.url("sqlite:test.db".to_string())
.max_connections(5)
.timeout_seconds(10)
.build();
assert!(valid_config.validate().is_ok());
let invalid_config = DatabaseConfig::builder().url("".to_string()).build();
assert!(invalid_config.validate().is_err());
use a2a_rs::adapter::storage::DatabaseType;
assert_eq!(valid_config.database_type(), Some(DatabaseType::Sqlite));
let postgres_config = DatabaseConfig::builder()
.url("postgres://localhost/test".to_string())
.build();
assert_eq!(
postgres_config.database_type(),
Some(DatabaseType::Postgres)
);
Ok(())
}
#[tokio::test]
async fn test_streaming_subscribers() -> Result<(), Box<dyn std::error::Error>> {
let storage = create_test_storage().await?;
let task_id = Uuid::new_v4().to_string();
storage.create_task(&task_id, "test-context").await?;
let count = storage.get_subscriber_count(&task_id).await?;
assert_eq!(count, 0);
storage.remove_task_subscribers(&task_id).await?;
let result = storage.remove_subscription("fake-id").await;
assert!(matches!(result, Err(A2AError::UnsupportedOperation(_))));
Ok(())
}
#[tokio::test]
async fn test_concurrent_operations() -> Result<(), Box<dyn std::error::Error>> {
let storage = Arc::new(create_test_storage().await?);
let mut handles = Vec::new();
for i in 0..10 {
let storage_clone = storage.clone();
let handle = tokio::spawn(async move {
let task_id = format!("concurrent-task-{}", i);
let task = storage_clone
.create_task(&task_id, "concurrent-context")
.await?;
storage_clone
.update_task_status(&task_id, TaskState::Working, None)
.await?;
storage_clone
.update_task_status(&task_id, TaskState::Completed, None)
.await?;
Ok::<_, A2AError>(task)
});
handles.push(handle);
}
for handle in handles {
let result = handle.await??;
assert_eq!(result.status.state, TaskState::Submitted); }
for i in 0..10 {
let task_id = format!("concurrent-task-{}", i);
assert!(storage.task_exists(&task_id).await?);
let task = storage.get_task(&task_id, None).await?;
assert_eq!(task.status.state, TaskState::Completed);
}
Ok(())
}
#[tokio::test]
async fn test_database_migrations() -> Result<(), Box<dyn std::error::Error>> {
let config = DatabaseConfig::builder()
.url("sqlite::memory:".to_string())
.build();
let _storage = SqlxTaskStorage::new(&config.url).await?;
let _storage2 = SqlxTaskStorage::new(&config.url).await?;
Ok(())
}
#[tokio::test]
async fn test_list_tasks_v3_basic() -> Result<(), Box<dyn std::error::Error>> {
let storage = create_test_storage().await?;
for i in 0..5 {
let task_id = format!("task-{}", i);
storage.create_task(&task_id, "test-context").await?;
}
let params = a2a_rs::domain::ListTasksParams::default();
let result = storage.list_tasks_v3(¶ms).await?;
assert_eq!(result.total_size, 5, "Should have 5 tasks");
assert_eq!(result.tasks.len(), 5, "Should return 5 tasks");
assert_eq!(result.page_size, 50, "Default page size should be 50");
Ok(())
}
#[tokio::test]
async fn test_list_tasks_v3_filtering() -> Result<(), Box<dyn std::error::Error>> {
let storage = create_test_storage().await?;
storage.create_task("task-a-1", "context-a").await?;
storage.create_task("task-a-2", "context-a").await?;
storage.create_task("task-b-1", "context-b").await?;
storage
.update_task_status("task-a-1", TaskState::Working, None)
.await?;
storage
.update_task_status("task-a-2", TaskState::Completed, None)
.await?;
let params = a2a_rs::domain::ListTasksParams {
context_id: Some("context-a".to_string()),
..Default::default()
};
let result = storage.list_tasks_v3(¶ms).await?;
assert_eq!(result.total_size, 2, "Should have 2 tasks in context-a");
let params = a2a_rs::domain::ListTasksParams {
status: Some(TaskState::Working),
..Default::default()
};
let result = storage.list_tasks_v3(¶ms).await?;
assert_eq!(result.total_size, 1, "Should have 1 working task");
Ok(())
}
#[tokio::test]
async fn test_list_tasks_v3_pagination() -> Result<(), Box<dyn std::error::Error>> {
let storage = create_test_storage().await?;
for i in 0..10 {
storage
.create_task(&format!("task-{}", i), "test-context")
.await?;
}
let params = a2a_rs::domain::ListTasksParams {
page_size: Some(3),
..Default::default()
};
let page1 = storage.list_tasks_v3(¶ms).await?;
assert_eq!(page1.tasks.len(), 3, "Should return 3 tasks");
assert!(
!page1.next_page_token.is_empty(),
"Should have next page token"
);
let params = a2a_rs::domain::ListTasksParams {
page_size: Some(3),
page_token: Some(page1.next_page_token.clone()),
..Default::default()
};
let page2 = storage.list_tasks_v3(¶ms).await?;
assert_eq!(page2.tasks.len(), 3, "Should return 3 tasks");
Ok(())
}
#[tokio::test]
async fn test_push_notification_config_v3_crud() -> Result<(), Box<dyn std::error::Error>> {
let storage = create_test_storage().await?;
let task_id = Uuid::new_v4().to_string();
storage.create_task(&task_id, "test-context").await?;
let config = TaskPushNotificationConfig {
tenant: String::new(),
task_id: task_id.clone(),
id: "config-1".to_string(),
url: "https://example.com/webhook".to_string(),
token: "test-token".to_string(),
authentication: None.into(),
..Default::default()
};
storage.set_task_notification(&config).await?;
let get_params = a2a_rs::domain::GetTaskPushNotificationConfigParams {
id: task_id.clone(),
push_notification_config_id: Some("config-1".to_string()),
metadata: None,
};
let retrieved = storage.get_push_notification_config(&get_params).await?;
assert_eq!(retrieved.url, "https://example.com/webhook");
assert_eq!(retrieved.token, "test-token");
let list_params = a2a_rs::domain::ListTaskPushNotificationConfigsParams {
id: task_id.clone(),
metadata: None,
};
let configs = storage.list_push_notification_configs(&list_params).await?;
assert_eq!(configs.len(), 1, "Should have 1 config");
let delete_params = a2a_rs::domain::DeleteTaskPushNotificationConfigParams {
id: task_id.clone(),
push_notification_config_id: "config-1".to_string(),
metadata: None,
};
storage
.delete_push_notification_config(&delete_params)
.await?;
let configs = storage.list_push_notification_configs(&list_params).await?;
assert_eq!(configs.len(), 0, "Config should be deleted");
Ok(())
}
#[tokio::test]
async fn test_push_notification_config_v3_multiple() -> Result<(), Box<dyn std::error::Error>> {
let storage = create_test_storage().await?;
let task_id = Uuid::new_v4().to_string();
storage.create_task(&task_id, "test-context").await?;
let config1 = TaskPushNotificationConfig {
tenant: String::new(),
task_id: task_id.clone(),
id: "config-1".to_string(),
url: "https://example.com/webhook1".to_string(),
token: String::new(),
authentication: None.into(),
..Default::default()
};
let config2 = TaskPushNotificationConfig {
tenant: String::new(),
task_id: task_id.clone(),
id: "config-2".to_string(),
url: "https://example.com/webhook2".to_string(),
token: "token-2".to_string(),
authentication: None.into(),
..Default::default()
};
storage.set_task_notification(&config1).await?;
storage.set_task_notification(&config2).await?;
let list_params = a2a_rs::domain::ListTaskPushNotificationConfigsParams {
id: task_id.clone(),
metadata: None,
};
let configs = storage.list_push_notification_configs(&list_params).await?;
assert_eq!(configs.len(), 2, "Should have 2 configs");
Ok(())
}
}
#[cfg(not(feature = "sqlx-storage"))]
#[tokio::test]
async fn test_sqlx_not_available() {
println!("SQLx storage tests skipped - feature not enabled");
}