use axum::Json;
use axum::extract::FromRequestParts;
use axum::response::{IntoResponse, Response as AxumResponse};
use axum_extra::extract::{Query, QueryRejection};
use http::header::{self, HOST};
use http::request::Parts;
use http::uri::InvalidUri;
use http::{HeaderValue, StatusCode};
use tracing::trace;
use crate::{Rel, WebFingerRequest, WebFingerResponse};
const JRD_CONTENT_TYPE: HeaderValue = HeaderValue::from_static("application/jrd+json");
impl IntoResponse for WebFingerResponse {
fn into_response(self) -> AxumResponse {
([(header::CONTENT_TYPE, JRD_CONTENT_TYPE)], Json(self)).into_response()
}
}
#[derive(Debug, serde::Deserialize)]
struct RequestParams {
resource: String,
#[serde(default)]
rel: Vec<String>,
}
pub enum Rejection {
InvalidQueryString(String),
MissingHost,
InvalidResource(InvalidUri),
}
impl IntoResponse for Rejection {
fn into_response(self) -> AxumResponse {
let message = match self {
Rejection::MissingHost => "missing host".to_string(),
Rejection::InvalidQueryString(e) => format!("{e}"),
Rejection::InvalidResource(e) => format!("invalid resource: {e}"),
};
(StatusCode::BAD_REQUEST, message).into_response()
}
}
impl From<QueryRejection> for Rejection {
fn from(rejection: QueryRejection) -> Self {
Rejection::InvalidQueryString(rejection.to_string())
}
}
impl<S: Send + Sync> FromRequestParts<S> for WebFingerRequest {
type Rejection = Rejection;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
trace!("request parts: {:?}", parts);
let host = parts
.uri
.host()
.or_else(|| parts.headers.get(HOST).and_then(|host| host.to_str().ok()))
.map(str::to_string)
.ok_or(Rejection::MissingHost)?;
let query = Query::<RequestParams>::from_request_parts(parts, state).await?;
let resource = query.resource.parse().map_err(Rejection::InvalidResource)?;
let rels = query.rel.clone().into_iter().map(Rel::from).collect();
Ok(WebFingerRequest {
host,
resource,
rels,
})
}
}
#[cfg(test)]
mod tests {
use axum::body::Body;
use axum::routing::get;
use http::{Request, Response};
use http_body_util::BodyExt;
use tower::ServiceExt;
use super::*;
use crate::WELL_KNOWN_PATH;
type Result<T = (), E = Box<dyn std::error::Error>> = std::result::Result<T, E>;
trait IntoText {
async fn into_text(self) -> Result<String>;
}
impl IntoText for Response<Body> {
async fn into_text(self) -> Result<String> {
let body = self.into_body().collect().await?.to_bytes();
let string = String::from_utf8(body.to_vec())?;
Ok(string)
}
}
fn app() -> axum::Router {
axum::Router::new().route(WELL_KNOWN_PATH, get(webfinger))
}
async fn webfinger(request: WebFingerRequest) -> impl IntoResponse {
WebFingerResponse::builder(request.resource.to_string()).build()
}
const VALID_RESOURCE: &str = "acct:carol@example.com";
#[tokio::test]
async fn valid_request() -> Result {
let uri = format!("https://example.com{WELL_KNOWN_PATH}?resource={VALID_RESOURCE}");
let request = Request::builder().uri(uri).body(Body::empty())?;
let response = app().oneshot(request).await?;
assert_eq!(response.status(), StatusCode::OK, "{response:?}");
let body = response.into_text().await?;
assert_eq!(body, r#"{"subject":"acct:carol@example.com","links":[]}"#);
Ok(())
}
#[tokio::test]
async fn valid_request_with_host_header() -> Result {
let request = Request::builder()
.uri(format!("{WELL_KNOWN_PATH}?resource={VALID_RESOURCE}"))
.header(HOST, "example.com")
.body(Body::empty())?;
let response = app().oneshot(request).await?;
assert_eq!(response.status(), StatusCode::OK, "{response:?}");
let body = response.into_text().await?;
assert_eq!(body, r#"{"subject":"acct:carol@example.com","links":[]}"#);
Ok(())
}
#[tokio::test]
async fn request_with_no_host() -> Result {
let uri = format!("{WELL_KNOWN_PATH}?resource={VALID_RESOURCE}");
let request = Request::builder().uri(uri).body(Body::empty())?;
let response = app().oneshot(request).await?;
assert_eq!(response.status(), StatusCode::BAD_REQUEST, "{response:?}");
let body = response.into_text().await?;
assert_eq!(body, "missing host");
Ok(())
}
#[tokio::test]
async fn request_with_missing_resource() -> Result {
let request = Request::builder()
.uri(WELL_KNOWN_PATH)
.header(HOST, "example.com")
.body(Body::empty())?;
let response = app().oneshot(request).await?;
assert_eq!(response.status(), StatusCode::BAD_REQUEST, "{response:?}");
let body = response.into_text().await?;
assert_eq!(
body,
"Failed to deserialize query string: missing field `resource`",
);
Ok(())
}
#[tokio::test]
async fn request_with_invalid_resource() -> Result {
let uri = format!("https://example.com{WELL_KNOWN_PATH}?resource=%");
let request = Request::builder().uri(uri).body(Body::empty())?;
let response = app().oneshot(request).await?;
assert_eq!(response.status(), StatusCode::BAD_REQUEST, "{response:?}");
let body = response.into_text().await?;
assert_eq!(body, "invalid resource: invalid authority");
Ok(())
}
}