use std::any::Any;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
tokio::task_local! {
pub static REQUEST_META: RequestMeta;
}
#[must_use]
pub fn request_meta() -> RequestMeta {
REQUEST_META.try_with(Clone::clone).unwrap_or_default()
}
#[derive(Debug, Clone, Default)]
pub struct RequestMeta {
kv: HashMap<String, Vec<String>>,
}
impl RequestMeta {
#[must_use]
pub fn new(src: HashMap<String, Vec<String>>) -> Self {
let kv = src
.into_iter()
.map(|(k, v)| (k.to_lowercase(), v))
.collect();
Self { kv }
}
#[cfg(feature = "server")]
#[must_use]
pub fn from_header_map(headers: &axum::http::HeaderMap) -> Self {
let mut kv: HashMap<String, Vec<String>> = HashMap::new();
for (name, value) in headers {
if let Ok(v) = value.to_str() {
kv.entry(name.as_str().to_owned())
.or_default()
.push(v.to_owned());
}
}
Self { kv }
}
#[must_use]
pub fn empty() -> Self {
Self::default()
}
#[must_use]
pub fn get(&self, key: &str) -> Option<&[String]> {
self.kv.get(&key.to_lowercase()).map(Vec::as_slice)
}
pub fn set(&mut self, key: &str, value: impl Into<String>) {
self.kv.insert(key.to_lowercase(), vec![value.into()]);
}
pub fn iter(&self) -> impl Iterator<Item = (&str, &[String])> {
self.kv.iter().map(|(k, v)| (k.as_str(), v.as_slice()))
}
#[must_use]
pub fn with(&self, additional: HashMap<String, Vec<String>>) -> Self {
if additional.is_empty() {
return self.clone();
}
let mut merged = self.kv.clone();
for (k, v) in additional {
merged.insert(k.to_lowercase(), v);
}
Self { kv: merged }
}
}
pub trait User: Send + Sync + std::fmt::Debug {
fn name(&self) -> &str;
fn authenticated(&self) -> bool;
}
#[derive(Debug, Clone)]
pub struct AuthenticatedUser {
pub username: String,
}
impl AuthenticatedUser {
pub fn new(username: impl Into<String>) -> Self {
Self {
username: username.into(),
}
}
}
impl User for AuthenticatedUser {
fn name(&self) -> &str {
&self.username
}
fn authenticated(&self) -> bool {
true
}
}
#[derive(Debug, Clone, Copy)]
pub struct UnauthenticatedUser;
impl User for UnauthenticatedUser {
fn name(&self) -> &'static str {
""
}
fn authenticated(&self) -> bool {
false
}
}
#[derive(Debug)]
pub struct CallContext {
method: String,
request_meta: RequestMeta,
activated_extensions: Vec<String>,
pub user: Arc<dyn User>,
}
impl CallContext {
pub fn new(method: impl Into<String>, meta: RequestMeta) -> Self {
Self {
method: method.into(),
request_meta: meta,
activated_extensions: Vec::new(),
user: Arc::new(UnauthenticatedUser),
}
}
#[must_use]
pub fn method(&self) -> &str {
&self.method
}
#[must_use]
pub const fn request_meta(&self) -> &RequestMeta {
&self.request_meta
}
#[must_use]
pub fn activated_extensions(&self) -> &[String] {
&self.activated_extensions
}
pub fn activate_extension(&mut self, uri: impl Into<String>) {
self.activated_extensions.push(uri.into());
}
#[must_use]
pub fn is_extension_active(&self, uri: &str) -> bool {
self.activated_extensions.iter().any(|e| e == uri)
}
#[must_use]
pub fn requested_extension_uris(&self) -> Vec<String> {
self.request_meta
.get(crate::SVC_PARAM_EXTENSIONS)
.map(<[String]>::to_vec)
.unwrap_or_default()
}
#[must_use]
pub fn is_extension_requested(&self, uri: &str) -> bool {
self.requested_extension_uris().iter().any(|e| e == uri)
}
}
pub struct Request {
pub payload: Box<dyn Any + Send>,
}
impl Request {
pub fn new<T: Send + 'static>(payload: T) -> Self {
Self {
payload: Box::new(payload),
}
}
#[must_use]
pub fn downcast_ref<T: 'static>(&self) -> Option<&T> {
self.payload.downcast_ref()
}
pub fn downcast<T: 'static>(self) -> Result<T, Self> {
match self.payload.downcast::<T>() {
Ok(t) => Ok(*t),
Err(payload) => Err(Self { payload }),
}
}
}
impl std::fmt::Debug for Request {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Request")
.field("payload_type", &(*self.payload).type_id())
.finish()
}
}
pub struct Response {
pub payload: Option<Box<dyn Any + Send>>,
pub err: Option<crate::error::A2AError>,
}
impl Response {
pub fn ok<T: Send + 'static>(payload: T) -> Self {
Self {
payload: Some(Box::new(payload)),
err: None,
}
}
#[must_use]
pub fn error(err: crate::error::A2AError) -> Self {
Self {
payload: None,
err: Some(err),
}
}
#[must_use]
pub fn downcast_ref<T: 'static>(&self) -> Option<&T> {
self.payload.as_ref()?.downcast_ref()
}
}
impl std::fmt::Debug for Response {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Response")
.field("has_payload", &self.payload.is_some())
.field("has_error", &self.err.is_some())
.finish()
}
}
pub trait CallInterceptor: Send + Sync {
fn before<'a>(
&'a self,
ctx: &'a mut CallContext,
req: &'a mut Request,
) -> Pin<Box<dyn Future<Output = Result<(), crate::error::A2AError>> + Send + 'a>>;
fn after<'a>(
&'a self,
ctx: &'a CallContext,
resp: &'a mut Response,
) -> Pin<Box<dyn Future<Output = Result<(), crate::error::A2AError>> + Send + 'a>>;
}
#[derive(Debug, Clone, Copy)]
pub struct PassthroughInterceptor;
impl CallInterceptor for PassthroughInterceptor {
fn before<'a>(
&'a self,
_ctx: &'a mut CallContext,
_req: &'a mut Request,
) -> Pin<Box<dyn Future<Output = Result<(), crate::error::A2AError>> + Send + 'a>> {
Box::pin(async { Ok(()) })
}
fn after<'a>(
&'a self,
_ctx: &'a CallContext,
_resp: &'a mut Response,
) -> Pin<Box<dyn Future<Output = Result<(), crate::error::A2AError>> + Send + 'a>> {
Box::pin(async { Ok(()) })
}
}