use std::collections::HashMap;
use std::sync::Arc;
use axum::body::Bytes;
use axum::extract::State;
use axum::http::{header, HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use axum::routing::post;
use axum::Router;
use serde::Serialize;
use serde_json::Value;
use trust_tasks_rs::{
discovery::DiscoveryRegistry, specs::trust_task_discovery::v0_1 as discovery, ErrorPayload,
ErrorResponse, Payload, RejectReason, StandardCode, TransportHandler, TrustTask,
};
use uuid::Uuid;
use crate::auth::{Auth, BearerAuth};
use crate::handler::HttpsHandler;
use crate::status::status_for_code;
#[derive(Debug, Clone)]
pub struct RequestContext {
pub authenticated_sender: Option<String>,
pub local: Option<String>,
}
type DispatchFn =
Box<dyn Fn(TrustTask<Value>, &RequestContext) -> Result<Value, RejectReason> + Send + Sync>;
struct Route {
dispatch: DispatchFn,
}
struct ServerState {
local_vid: Option<String>,
auth: Box<dyn Auth>,
routes: HashMap<String, Route>,
}
pub struct HttpsServerBuilder {
local_vid: Option<String>,
auth: Option<Box<dyn Auth>>,
routes: HashMap<String, Route>,
}
impl HttpsServerBuilder {
pub fn local_vid(mut self, vid: impl Into<String>) -> Self {
self.local_vid = Some(vid.into());
self
}
pub fn with_auth(mut self, auth: impl Auth) -> Self {
self.auth = Some(Box::new(auth));
self
}
pub fn on<P, Resp, F>(mut self, handler: F) -> Self
where
P: Payload + 'static,
Resp: Payload + Serialize + 'static,
F: Fn(&TrustTask<P>, &RequestContext) -> Result<Resp, RejectReason> + Send + Sync + 'static,
{
let dispatch: DispatchFn = Box::new(move |doc: TrustTask<Value>, ctx: &RequestContext| {
let typed = downcast::<P>(doc)?;
typed.enforce_audience_binding()?;
let response_payload = handler(&typed, ctx)?;
let new_id = format!("urn:uuid:{}", Uuid::new_v4());
let response_doc = typed.respond_with(new_id, response_payload);
Ok(serde_json::to_value(&response_doc).expect("response serialises (typed structs)"))
});
let key = P::type_uri().for_routing().to_string();
self.routes.insert(key, Route { dispatch });
self
}
pub fn with_discovery(self, registry: DiscoveryRegistry) -> Self {
self.on::<discovery::Payload, discovery::Response, _>(move |req, _ctx| {
Ok(registry.respond_to(&req.payload))
})
}
pub fn enable_discovery(self) -> Self {
let mut registry: DiscoveryRegistry = self.routes.keys().cloned().collect();
registry.register_payload::<discovery::Payload>();
self.with_discovery(registry)
}
pub fn build(self) -> HttpsServer {
let auth = self.auth.unwrap_or_else(|| Box::new(BearerAuth::new()));
HttpsServer {
state: Arc::new(ServerState {
local_vid: self.local_vid,
auth,
routes: self.routes,
}),
}
}
}
pub struct HttpsServer {
state: Arc<ServerState>,
}
impl HttpsServer {
pub fn builder() -> HttpsServerBuilder {
HttpsServerBuilder {
local_vid: None,
auth: None,
routes: HashMap::new(),
}
}
pub fn into_router(self) -> Router {
Router::new()
.route("/trust-tasks", post(dispatch_handler))
.with_state(self.state)
}
pub async fn serve(self, addr: impl tokio::net::ToSocketAddrs) -> std::io::Result<()> {
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, self.into_router()).await
}
}
async fn dispatch_handler(
State(state): State<Arc<ServerState>>,
headers: HeaderMap,
body: Bytes,
) -> Response {
let doc: TrustTask<Value> = match serde_json::from_slice(&body) {
Ok(d) => d,
Err(e) => {
return reject_response(
None,
None,
RejectReason::MalformedRequest {
reason: format!("body did not parse as a Trust Task document: {e}"),
},
);
}
};
let peer_vid = extract_bearer(&headers).and_then(|tok| state.auth.resolve(tok));
let handler = HttpsHandler::new(state.local_vid.clone(), peer_vid);
let _resolved = match handler.resolve_parties(&doc) {
Ok(r) => r,
Err(consistency) => {
let reason: RejectReason = consistency.into();
return reject_response(Some(&handler), Some(&doc), reason);
}
};
let now = chrono::Utc::now();
let my_vid = state.local_vid.as_deref().unwrap_or("");
if let Err(reason) = doc.validate_basic(now, my_vid) {
return reject_response(Some(&handler), Some(&doc), reason);
}
let routing_key = doc.type_uri.for_routing().to_string();
let Some(route) = state.routes.get(&routing_key) else {
return reject_response(
Some(&handler),
Some(&doc),
RejectReason::UnsupportedType {
type_uri: routing_key,
},
);
};
let ctx = RequestContext {
authenticated_sender: handler.peer().map(str::to_string),
local: handler.local().map(str::to_string),
};
let dispatch_result = (route.dispatch)(doc.clone(), &ctx);
match dispatch_result {
Ok(success_body) => success_response(success_body),
Err(reason) => reject_response(Some(&handler), Some(&doc), reason),
}
}
fn extract_bearer(headers: &HeaderMap) -> Option<&str> {
let value = headers.get(header::AUTHORIZATION)?.to_str().ok()?;
let token = value
.strip_prefix("Bearer ")
.or_else(|| value.strip_prefix("bearer "))?;
Some(token.trim())
}
fn downcast<P: Payload>(doc: TrustTask<Value>) -> Result<TrustTask<P>, RejectReason> {
let TrustTask {
id,
thread_id,
type_uri,
issuer,
recipient,
issued_at,
expires_at,
payload,
context,
proof,
extra,
} = doc;
let payload: P =
serde_json::from_value(payload).map_err(|e| RejectReason::MalformedRequest {
reason: format!("payload does not match {}: {e}", P::TYPE_URI),
})?;
Ok(TrustTask {
id,
thread_id,
type_uri,
issuer,
recipient,
issued_at,
expires_at,
payload,
context,
proof,
extra,
})
}
fn success_response(body: Value) -> Response {
let bytes = serde_json::to_vec(&body).expect("serialise success body");
(
StatusCode::OK,
[(header::CONTENT_TYPE, "application/json")],
bytes,
)
.into_response()
}
fn reject_response(
handler: Option<&HttpsHandler>,
request: Option<&TrustTask<Value>>,
reason: RejectReason,
) -> Response {
let status = status_for_code(&reason.code().into());
let status = StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let error_doc = build_error_response(handler, request, reason);
let body = serde_json::to_vec(&error_doc).expect("serialise error response");
(status, [(header::CONTENT_TYPE, "application/json")], body).into_response()
}
fn build_error_response(
handler: Option<&HttpsHandler>,
request: Option<&TrustTask<Value>>,
reason: RejectReason,
) -> ErrorResponse {
let new_id = format!("urn:uuid:{}", Uuid::new_v4());
match (handler, request) {
(Some(h), Some(req)) => {
match h.reject(req, new_id.clone(), reason.clone()) {
Some(resp) => resp,
None => suppressed_error_response(&new_id, reason),
}
}
(_, Some(req)) => req.reject_with(new_id, reason),
_ => {
let mut doc = TrustTask::new(
new_id,
trust_tasks_rs::TypeUri::canonical("trust-task-error", 0, 1)
.expect("framework type URI"),
ErrorPayload::from(reason),
);
doc.issued_at = Some(chrono::Utc::now());
doc
}
}
}
fn suppressed_error_response(new_id: &str, reason: RejectReason) -> ErrorResponse {
let mut doc = TrustTask::new(
new_id.to_string(),
trust_tasks_rs::TypeUri::canonical("trust-task-error", 0, 1).expect("framework type URI"),
ErrorPayload::from(reason).with_message(
"identity_mismatch with no transport-authenticated sender — \
response not addressed (SPEC §8.1)",
),
);
doc.issued_at = Some(chrono::Utc::now());
doc
}
fn _verify_standard_code_into() {
let _: trust_tasks_rs::TrustTaskCode = StandardCode::Expired.into();
}