use std::collections::HashMap;
use a2a_protocol_types::params::ListTasksParams;
use a2a_protocol_types::task::Task;
use crate::call_context::CallContext;
use crate::error::{ServerError, ServerResult};
use super::RequestHandler;
pub(super) fn validate_id(raw: &str, name: &str, max_length: usize) -> ServerResult<()> {
let trimmed = raw.trim();
if trimmed.is_empty() {
return Err(ServerError::InvalidParams(format!(
"{name} must not be empty or whitespace-only"
)));
}
if trimmed.len() > max_length {
return Err(ServerError::InvalidParams(format!(
"{name} exceeds maximum length (got {}, max {max_length})",
trimmed.len()
)));
}
Ok(())
}
pub(super) fn build_call_context(
method: &str,
headers: Option<&HashMap<String, String>>,
) -> CallContext {
let mut ctx = CallContext::new(method);
if let Some(h) = headers {
ctx = ctx.with_http_headers(h.clone());
}
ctx
}
impl RequestHandler {
const CONTEXT_LOOKUP_PAGE_SIZE: u32 = 10;
pub(crate) async fn find_task_by_context(
&self,
context_id: &str,
) -> ServerResult<Option<Task>> {
if context_id.len() > self.limits.max_id_length {
return Ok(None);
}
let tenant = crate::store::tenant::TenantContext::current();
let tenant_param = if tenant.is_empty() {
None
} else {
Some(tenant)
};
let params = ListTasksParams {
tenant: tenant_param,
context_id: Some(context_id.to_owned()),
status: None,
page_size: Some(Self::CONTEXT_LOOKUP_PAGE_SIZE),
page_token: None,
status_timestamp_after: None,
include_artifacts: None,
history_length: None,
};
let resp = self.task_store.list(¶ms).await?;
let mut terminal_fallback: Option<Task> = None;
for task in resp.tasks {
if !task.status.state.is_terminal() {
return Ok(Some(task));
}
if terminal_fallback.is_none() {
terminal_fallback = Some(task);
}
}
Ok(terminal_fallback)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn validate_id_accepts_normal_id() {
assert!(
validate_id("task-123", "task_id", 1024).is_ok(),
"a normal short ID should be accepted"
);
}
#[test]
fn validate_id_rejects_empty_string() {
let err = validate_id("", "task_id", 1024).unwrap_err();
assert!(
matches!(err, ServerError::InvalidParams(ref msg) if msg.contains("empty")),
"empty string should be rejected with InvalidParams: {err:?}"
);
}
#[test]
fn validate_id_rejects_whitespace_only() {
let err = validate_id(" \t\n ", "context_id", 1024).unwrap_err();
assert!(
matches!(err, ServerError::InvalidParams(ref msg) if msg.contains("empty")),
"whitespace-only string should be rejected: {err:?}"
);
}
#[test]
fn validate_id_rejects_exceeding_max_length() {
let long_id = "a".repeat(2000);
let err = validate_id(&long_id, "task_id", 1024).unwrap_err();
assert!(
matches!(err, ServerError::InvalidParams(ref msg) if msg.contains("maximum length")),
"overly long ID should be rejected: {err:?}"
);
}
#[test]
fn validate_id_accepts_exactly_max_length() {
let exact = "b".repeat(128);
assert!(
validate_id(&exact, "task_id", 128).is_ok(),
"ID at exactly max length should be accepted"
);
}
#[test]
fn validate_id_trims_before_length_check() {
assert!(
validate_id(" abc ", "id", 3).is_ok(),
"trimmed length (3) should pass a max of 3"
);
}
#[test]
fn validate_id_includes_field_name_in_error() {
let err = validate_id("", "my_field", 1024).unwrap_err();
assert!(
matches!(err, ServerError::InvalidParams(ref msg) if msg.contains("my_field")),
"error message should contain the field name: {err:?}"
);
}
#[test]
fn build_call_context_without_headers() {
let ctx = build_call_context("message/send", None);
assert_eq!(ctx.method(), "message/send", "method should be set");
assert!(
ctx.http_headers().is_empty(),
"headers should be empty when None is passed"
);
}
#[test]
fn build_call_context_with_headers() {
let mut headers = HashMap::new();
headers.insert("authorization".to_owned(), "Bearer tok".to_owned());
headers.insert("x-request-id".to_owned(), "req-99".to_owned());
let ctx = build_call_context("tasks/get", Some(&headers));
assert_eq!(ctx.method(), "tasks/get");
assert_eq!(
ctx.http_headers().get("authorization").map(String::as_str),
Some("Bearer tok"),
"headers should be cloned into the context"
);
assert_eq!(
ctx.http_headers().get("x-request-id").map(String::as_str),
Some("req-99"),
);
}
#[test]
fn build_call_context_with_empty_headers_map() {
let headers = HashMap::new();
let ctx = build_call_context("test", Some(&headers));
assert!(
ctx.http_headers().is_empty(),
"an empty map should result in empty headers"
);
}
mod find_task_by_context_tests {
use a2a_protocol_types::task::{ContextId, Task, TaskId, TaskState, TaskStatus};
use crate::agent_executor;
use crate::builder::RequestHandlerBuilder;
use crate::handler::limits::HandlerLimits;
struct DummyExecutor;
agent_executor!(DummyExecutor, |_ctx, _queue| async { Ok(()) });
#[tokio::test]
async fn context_id_too_long_returns_none() {
let handler = RequestHandlerBuilder::new(DummyExecutor)
.with_handler_limits(HandlerLimits::default().with_max_id_length(10))
.build()
.unwrap();
let long_id = "a".repeat(11);
let result = handler.find_task_by_context(&long_id).await.unwrap();
assert!(
result.is_none(),
"find_task_by_context should return None for context_id exceeding max length"
);
}
#[tokio::test]
async fn context_id_within_limit_returns_none_for_missing() {
let handler = RequestHandlerBuilder::new(DummyExecutor)
.with_handler_limits(HandlerLimits::default().with_max_id_length(100))
.build()
.unwrap();
let result = handler
.find_task_by_context("no-such-context")
.await
.unwrap();
assert!(
result.is_none(),
"find_task_by_context should return None when no task matches the context"
);
}
fn make_task(id: &str, context_id: &str, state: TaskState) -> Task {
Task {
id: TaskId::new(id.to_owned()),
context_id: ContextId::new(context_id),
status: TaskStatus::new(state),
history: None,
artifacts: None,
metadata: None,
}
}
#[tokio::test]
async fn prefers_non_terminal_over_terminal_task() {
let handler = RequestHandlerBuilder::new(DummyExecutor)
.with_handler_limits(HandlerLimits::default().with_max_id_length(100))
.build()
.unwrap();
handler
.task_store
.save(make_task("aaa-completed", "ctx-1", TaskState::Completed))
.await
.unwrap();
handler
.task_store
.save(make_task("bbb-working", "ctx-1", TaskState::Working))
.await
.unwrap();
let result = handler.find_task_by_context("ctx-1").await.unwrap();
assert!(result.is_some(), "should find a task");
let task = result.unwrap();
assert_eq!(
task.id.0, "bbb-working",
"should prefer the non-terminal (Working) task over the terminal (Completed) one"
);
}
#[tokio::test]
async fn returns_terminal_task_when_no_non_terminal_exists() {
let handler = RequestHandlerBuilder::new(DummyExecutor)
.with_handler_limits(HandlerLimits::default().with_max_id_length(100))
.build()
.unwrap();
handler
.task_store
.save(make_task("task-done", "ctx-2", TaskState::Completed))
.await
.unwrap();
let result = handler.find_task_by_context("ctx-2").await.unwrap();
assert!(result.is_some(), "should still return a terminal task");
assert_eq!(result.unwrap().id.0, "task-done");
}
#[tokio::test]
async fn returns_first_non_terminal_when_multiple_exist() {
let handler = RequestHandlerBuilder::new(DummyExecutor)
.with_handler_limits(HandlerLimits::default().with_max_id_length(100))
.build()
.unwrap();
handler
.task_store
.save(make_task("aaa-failed", "ctx-3", TaskState::Failed))
.await
.unwrap();
handler
.task_store
.save(make_task("bbb-submitted", "ctx-3", TaskState::Submitted))
.await
.unwrap();
handler
.task_store
.save(make_task("ccc-working", "ctx-3", TaskState::Working))
.await
.unwrap();
let result = handler.find_task_by_context("ctx-3").await.unwrap();
let task = result.unwrap();
assert!(
!task.status.state.is_terminal(),
"should return a non-terminal task, got {:?}",
task.status.state
);
assert_eq!(
task.id.0, "bbb-submitted",
"should return the first non-terminal task in store order"
);
}
}
}