use std::pin::Pin;
use async_trait::async_trait;
use futures::Stream;
use tonic::{Request, Response, Status};
use turul_a2a_proto as pb;
use crate::grpc::error::a2a_to_status;
use crate::middleware::context::RequestContext;
use crate::router::{self, AppState, ListTasksQuery, PushConfigQuery};
pub const TENANT_METADATA: &str = "x-tenant-id";
pub struct GrpcService {
pub(crate) state: AppState,
}
impl GrpcService {
pub fn new(state: AppState) -> Self {
Self { state }
}
}
pub(crate) fn tenant_from<T>(req: &Request<T>, proto_tenant: &str) -> String {
if !proto_tenant.is_empty() {
return proto_tenant.to_string();
}
if let Some(val) = req.metadata().get(TENANT_METADATA) {
if let Ok(s) = val.to_str() {
if !s.is_empty() {
return s.to_string();
}
}
}
String::new()
}
fn owner_from<T>(req: &Request<T>) -> String {
req.extensions()
.get::<RequestContext>()
.map(|ctx| ctx.identity.owner().to_string())
.unwrap_or_else(|| "anonymous".to_string())
}
fn internal_from_json(err: serde_json::Error) -> Status {
Status::internal(format!("grpc adapter: proto/json mismatch: {err}"))
}
pub type BoxedStreamResponseStream =
Pin<Box<dyn Stream<Item = Result<pb::StreamResponse, Status>> + Send + 'static>>;
#[async_trait]
impl pb::grpc::A2aService for GrpcService {
async fn send_message(
&self,
request: Request<pb::SendMessageRequest>,
) -> Result<Response<pb::SendMessageResponse>, Status> {
let owner = owner_from(&request);
let tenant = tenant_from(&request, &request.get_ref().tenant);
let body = serde_json::to_string(request.get_ref()).map_err(internal_from_json)?;
let value = router::core_send_message(self.state.clone(), &tenant, &owner, None, body)
.await
.map_err(a2a_to_status)?
.0;
let response: pb::SendMessageResponse =
serde_json::from_value(value).map_err(internal_from_json)?;
Ok(Response::new(response))
}
async fn get_task(
&self,
request: Request<pb::GetTaskRequest>,
) -> Result<Response<pb::Task>, Status> {
let owner = owner_from(&request);
let tenant = tenant_from(&request, &request.get_ref().tenant);
let task_id = request.get_ref().id.clone();
let history_length = request.get_ref().history_length;
let value = router::core_get_task(
self.state.clone(),
&tenant,
&owner,
&task_id,
history_length,
)
.await
.map_err(a2a_to_status)?
.0;
let task: pb::Task = serde_json::from_value(value).map_err(internal_from_json)?;
Ok(Response::new(task))
}
async fn list_tasks(
&self,
request: Request<pb::ListTasksRequest>,
) -> Result<Response<pb::ListTasksResponse>, Status> {
let owner = owner_from(&request);
let req = request.get_ref();
let tenant = tenant_from(&request, &req.tenant);
let status = pb::TaskState::try_from(req.status)
.ok()
.filter(|s| *s != pb::TaskState::Unspecified)
.map(|s| s.as_str_name().to_string());
let query = ListTasksQuery {
context_id: Some(req.context_id.clone()).filter(|s| !s.is_empty()),
status,
page_size: req.page_size,
page_token: Some(req.page_token.clone()).filter(|s| !s.is_empty()),
history_length: req.history_length,
include_artifacts: None,
};
let value = router::core_list_tasks(self.state.clone(), &tenant, &owner, &query)
.await
.map_err(a2a_to_status)?
.0;
let response: pb::ListTasksResponse =
serde_json::from_value(value).map_err(internal_from_json)?;
Ok(Response::new(response))
}
async fn cancel_task(
&self,
request: Request<pb::CancelTaskRequest>,
) -> Result<Response<pb::Task>, Status> {
let owner = owner_from(&request);
let tenant = tenant_from(&request, &request.get_ref().tenant);
let task_id = request.get_ref().id.clone();
let value = router::core_cancel_task(self.state.clone(), &tenant, &owner, &task_id)
.await
.map_err(a2a_to_status)?
.0;
let task: pb::Task = serde_json::from_value(value).map_err(internal_from_json)?;
Ok(Response::new(task))
}
async fn create_task_push_notification_config(
&self,
request: Request<pb::TaskPushNotificationConfig>,
) -> Result<Response<pb::TaskPushNotificationConfig>, Status> {
let owner = owner_from(&request);
let tenant = tenant_from(&request, &request.get_ref().tenant);
let task_id = request.get_ref().task_id.clone();
if task_id.is_empty() {
return Err(Status::invalid_argument("push config task_id is required"));
}
let body = serde_json::to_string(request.get_ref()).map_err(internal_from_json)?;
let value =
router::core_create_push_config(self.state.clone(), &tenant, &owner, &task_id, body)
.await
.map_err(a2a_to_status)?
.0;
let config: pb::TaskPushNotificationConfig =
serde_json::from_value(value).map_err(internal_from_json)?;
Ok(Response::new(config))
}
async fn get_task_push_notification_config(
&self,
request: Request<pb::GetTaskPushNotificationConfigRequest>,
) -> Result<Response<pb::TaskPushNotificationConfig>, Status> {
let owner = owner_from(&request);
let tenant = tenant_from(&request, &request.get_ref().tenant);
let task_id = request.get_ref().task_id.clone();
let config_id = request.get_ref().id.clone();
let value =
router::core_get_push_config(self.state.clone(), &tenant, &owner, &task_id, &config_id)
.await
.map_err(a2a_to_status)?
.0;
let config: pb::TaskPushNotificationConfig =
serde_json::from_value(value).map_err(internal_from_json)?;
Ok(Response::new(config))
}
async fn list_task_push_notification_configs(
&self,
request: Request<pb::ListTaskPushNotificationConfigsRequest>,
) -> Result<Response<pb::ListTaskPushNotificationConfigsResponse>, Status> {
let owner = owner_from(&request);
let req = request.get_ref();
let tenant = tenant_from(&request, &req.tenant);
let task_id = req.task_id.clone();
let query = PushConfigQuery {
page_size: Some(req.page_size).filter(|n| *n > 0),
page_token: Some(req.page_token.clone()).filter(|s| !s.is_empty()),
};
let value =
router::core_list_push_configs(self.state.clone(), &tenant, &owner, &task_id, &query)
.await
.map_err(a2a_to_status)?
.0;
let response: pb::ListTaskPushNotificationConfigsResponse =
serde_json::from_value(value).map_err(internal_from_json)?;
Ok(Response::new(response))
}
async fn delete_task_push_notification_config(
&self,
request: Request<pb::DeleteTaskPushNotificationConfigRequest>,
) -> Result<Response<pb::pbjson_types::Empty>, Status> {
let owner = owner_from(&request);
let tenant = tenant_from(&request, &request.get_ref().tenant);
let task_id = request.get_ref().task_id.clone();
let config_id = request.get_ref().id.clone();
let _ = router::core_delete_push_config(
self.state.clone(),
&tenant,
&owner,
&task_id,
&config_id,
)
.await
.map_err(a2a_to_status)?;
Ok(Response::new(pb::pbjson_types::Empty {}))
}
async fn get_extended_agent_card(
&self,
_request: Request<pb::GetExtendedAgentCardRequest>,
) -> Result<Response<pb::AgentCard>, Status> {
match self.state.executor.extended_agent_card(None) {
Some(card) => Ok(Response::new(card)),
None => Err(a2a_to_status(
crate::error::A2aError::ExtendedAgentCardNotConfigured,
)),
}
}
type SendStreamingMessageStream = BoxedStreamResponseStream;
async fn send_streaming_message(
&self,
request: Request<pb::SendMessageRequest>,
) -> Result<Response<Self::SendStreamingMessageStream>, Status> {
let owner = owner_from(&request);
let tenant = tenant_from(&request, &request.get_ref().tenant);
let body = serde_json::to_string(request.get_ref()).map_err(internal_from_json)?;
let stream = crate::grpc::streaming::handle_send_streaming_message(
self.state.clone(),
tenant,
owner,
body,
)
.await?;
Ok(Response::new(stream))
}
type SubscribeToTaskStream = BoxedStreamResponseStream;
async fn subscribe_to_task(
&self,
request: Request<pb::SubscribeToTaskRequest>,
) -> Result<Response<Self::SubscribeToTaskStream>, Status> {
let owner = owner_from(&request);
let tenant = tenant_from(&request, &request.get_ref().tenant);
let task_id = request.get_ref().id.clone();
let last_event_id = request
.metadata()
.get(crate::grpc::streaming::LAST_EVENT_ID_METADATA)
.and_then(|v| v.to_str().ok())
.map(str::to_string);
let stream = crate::grpc::streaming::handle_subscribe_to_task(
self.state.clone(),
tenant,
owner,
task_id,
last_event_id,
)
.await?;
Ok(Response::new(stream))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_request<T>(value: T, metadata_tenant: Option<&str>) -> Request<T> {
let mut req = Request::new(value);
if let Some(t) = metadata_tenant {
req.metadata_mut().insert(
TENANT_METADATA,
tonic::metadata::MetadataValue::try_from(t).expect("ascii metadata"),
);
}
req
}
#[test]
fn proto_tenant_wins_over_metadata() {
let req = make_request((), Some("tenant-from-metadata"));
assert_eq!(tenant_from(&req, "tenant-from-proto"), "tenant-from-proto");
}
#[test]
fn metadata_fallback_when_proto_empty() {
let req = make_request((), Some("tenant-from-metadata"));
assert_eq!(tenant_from(&req, ""), "tenant-from-metadata");
}
#[test]
fn empty_when_neither_set() {
let req = make_request((), None);
assert_eq!(tenant_from(&req, ""), "");
}
#[test]
fn metadata_ignored_when_proto_non_empty_even_on_conflict() {
let req = make_request((), Some("tenant-B"));
assert_eq!(tenant_from(&req, "tenant-A"), "tenant-A");
}
#[test]
fn empty_proto_and_empty_metadata_yields_empty() {
let req = make_request((), Some(""));
assert_eq!(tenant_from(&req, ""), "");
}
}