mod send;
pub mod stream;
mod tasks;
use crate::a2a::types::*;
use crate::brain::agent::service::AgentService;
use crate::services::ServiceContext;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
pub type TaskStore = Arc<RwLock<HashMap<String, Task>>>;
pub type CancelStore = Arc<RwLock<HashMap<String, CancellationToken>>>;
pub fn new_task_store() -> TaskStore {
Arc::new(RwLock::new(HashMap::new()))
}
pub fn new_cancel_store() -> CancelStore {
Arc::new(RwLock::new(HashMap::new()))
}
pub async fn dispatch(
req: JsonRpcRequest,
store: TaskStore,
cancel_store: CancelStore,
agent_service: Arc<AgentService>,
service_context: ServiceContext,
) -> JsonRpcResponse {
match req.method.as_str() {
"message/send" => {
send::handle_send_message(
req.id,
req.params,
store,
cancel_store,
agent_service,
service_context,
)
.await
}
"tasks/get" => tasks::handle_get_task(req.id, req.params, store).await,
"tasks/cancel" => {
tasks::handle_cancel_task(
req.id,
req.params,
store,
cancel_store,
&service_context.pool(),
)
.await
}
_ => JsonRpcResponse::error(
req.id,
error_codes::METHOD_NOT_FOUND,
format!("Method not found: {}", req.method),
),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_get_task_not_found() {
let store = new_task_store();
let req = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
method: "tasks/get".to_string(),
params: serde_json::json!({"id": "nonexistent"}),
id: serde_json::json!(2),
};
let resp = tasks::handle_get_task(req.id, req.params, store).await;
assert!(resp.error.is_some());
assert_eq!(resp.error.as_ref().expect("err").code, -32001);
}
#[tokio::test]
async fn test_unknown_method() {
use crate::a2a::test_helpers::helpers;
let store = new_task_store();
let cancel_store = new_cancel_store();
let agent = helpers::placeholder_agent_service().await;
let ctx = helpers::placeholder_service_context().await;
let req = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
method: "unknown/method".to_string(),
params: serde_json::json!({}),
id: serde_json::json!(99),
};
let resp = dispatch(req, store, cancel_store, agent, ctx).await;
assert!(resp.error.is_some());
assert_eq!(resp.error.as_ref().expect("err").code, -32601);
}
}