use std::sync::Arc;
use actpub_webfinger::Jrd;
use axum::Router;
use axum::extract::{Query, State};
use axum::http::{StatusCode, header};
use axum::response::IntoResponse;
use axum::routing::get;
use serde::Deserialize;
pub const JRD_CONTENT_TYPE: &str = "application/jrd+json";
const CORS_ALLOW_ANY_ORIGIN: &str = "*";
pub trait WebFingerResolver: Send + Sync + 'static {
fn resolve(&self, resource: String)
-> impl Future<Output = Result<Option<Jrd>, String>> + Send;
}
#[derive(Debug, Deserialize)]
struct WebFingerQuery {
resource: String,
}
pub fn webfinger_router<R>(resolver: R) -> Router
where
R: WebFingerResolver,
{
Router::new()
.route("/.well-known/webfinger", get(handle::<R>))
.with_state(Arc::new(resolver))
}
async fn handle<R>(
State(resolver): State<Arc<R>>,
Query(q): Query<WebFingerQuery>,
) -> impl IntoResponse
where
R: WebFingerResolver,
{
let cors = (header::ACCESS_CONTROL_ALLOW_ORIGIN, CORS_ALLOW_ANY_ORIGIN);
match resolver.resolve(q.resource).await {
Ok(Some(jrd)) => match serde_json::to_vec(&jrd) {
Ok(body) => (
StatusCode::OK,
[(header::CONTENT_TYPE, JRD_CONTENT_TYPE), cors],
body,
)
.into_response(),
Err(err) => {
tracing::error!(target: "actpub::axum::webfinger", %err, "JRD serialise failed");
(StatusCode::INTERNAL_SERVER_ERROR, [cors]).into_response()
}
},
Ok(None) => (StatusCode::NOT_FOUND, [cors]).into_response(),
Err(err) => {
tracing::warn!(target: "actpub::axum::webfinger", reason = %err, "resolver failed");
(StatusCode::INTERNAL_SERVER_ERROR, [cors]).into_response()
}
}
}
#[cfg(test)]
mod tests {
use actpub_webfinger::{Jrd, JrdLink, rels};
use axum::body::Body;
use axum::http::{Method, Request, StatusCode};
use http_body_util::BodyExt;
use serde_json::Value;
use tower::ServiceExt;
use super::*;
struct StaticResolver(Option<Jrd>);
impl WebFingerResolver for StaticResolver {
#[allow(
unknown_lints,
clippy::unused_async_trait_impl,
reason = "trait definition requires async but mock implementation has no await"
)]
async fn resolve(&self, _resource: String) -> Result<Option<Jrd>, String> {
Ok(self.0.clone())
}
}
fn alice_jrd() -> Jrd {
Jrd::builder("acct:alice@example.com")
.alias("https://example.com/@alice")
.link(
JrdLink::builder(rels::ACTIVITYPUB_ACTOR)
.href("https://example.com/users/alice".parse().unwrap())
.media_type("application/activity+json")
.build(),
)
.build()
}
#[tokio::test]
async fn router_returns_jrd_for_known_resource() {
let app = webfinger_router(StaticResolver(Some(alice_jrd())));
let req = Request::builder()
.method(Method::GET)
.uri("/.well-known/webfinger?resource=acct:alice@example.com")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers()
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or(""),
JRD_CONTENT_TYPE,
);
let bytes = resp.into_body().collect().await.unwrap().to_bytes();
let v: Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(v["subject"], serde_json::json!("acct:alice@example.com"));
}
#[tokio::test]
async fn router_returns_404_for_unknown_resource() {
let app = webfinger_router(StaticResolver(None));
let req = Request::builder()
.method(Method::GET)
.uri("/.well-known/webfinger?resource=acct:ghost@example.com")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn router_returns_400_when_resource_query_param_is_missing() {
let app = webfinger_router(StaticResolver(None));
let req = Request::builder()
.method(Method::GET)
.uri("/.well-known/webfinger")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn router_emits_cors_header_on_success() {
let app = webfinger_router(StaticResolver(Some(alice_jrd())));
let req = Request::builder()
.method(Method::GET)
.uri("/.well-known/webfinger?resource=acct:alice@example.com")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.and_then(|v| v.to_str().ok()),
Some("*"),
);
}
#[tokio::test]
async fn router_emits_cors_header_on_404() {
let app = webfinger_router(StaticResolver(None));
let req = Request::builder()
.method(Method::GET)
.uri("/.well-known/webfinger?resource=acct:ghost@example.com")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
assert_eq!(
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.and_then(|v| v.to_str().ok()),
Some("*"),
);
}
}