use reqwest::redirect::Policy;
use reqwest::{Client, ClientBuilder, header};
use tracing::debug;
use url::Url;
use crate::{Discovery, Error, NodeInfo, Version};
pub const DEFAULT_MAX_BODY_BYTES: u64 = 64 * 1024;
pub fn recommended_client() -> Result<Client, Error> {
Ok(ClientBuilder::new()
.redirect(Policy::custom(|attempt| {
const MAX_REDIRECTS: usize = 2;
if attempt.previous().len() >= MAX_REDIRECTS {
return attempt.error("too many redirects");
}
let origin = attempt.previous().first().unwrap_or_else(|| attempt.url());
if origin.host_str() == attempt.url().host_str()
&& origin.scheme() == attempt.url().scheme()
{
attempt.follow()
} else {
attempt.error("cross-origin redirect forbidden for NodeInfo")
}
}))
.build()?)
}
pub async fn fetch_discovery(host: &Url, client: &Client) -> Result<Discovery, Error> {
fetch_discovery_with_limit(host, client, DEFAULT_MAX_BODY_BYTES).await
}
pub async fn fetch_discovery_with_limit(
host: &Url,
client: &Client,
max_body_bytes: u64,
) -> Result<Discovery, Error> {
let url = host.join(crate::WELL_KNOWN_PATH)?;
debug!(%url, max_body_bytes, "fetching NodeInfo discovery");
let body = request_capped(client, url, max_body_bytes).await?;
Ok(serde_json::from_slice(&body)?)
}
pub async fn fetch(host: &Url, version: Version, client: &Client) -> Result<NodeInfo, Error> {
fetch_with_limit(host, version, client, DEFAULT_MAX_BODY_BYTES).await
}
pub async fn fetch_with_limit(
host: &Url,
version: Version,
client: &Client,
max_body_bytes: u64,
) -> Result<NodeInfo, Error> {
let discovery = fetch_discovery_with_limit(host, client, max_body_bytes).await?;
let link = discovery
.find_link(version)
.ok_or(Error::VersionNotAdvertised {
requested: version.as_str(),
})?;
if !same_origin(host, &link.href) {
return Err(Error::CrossOriginHref {
discovery: host.clone(),
href: link.href.clone(),
});
}
debug!(url = %link.href, max_body_bytes, "fetching NodeInfo document");
let body = request_capped(client, link.href.clone(), max_body_bytes).await?;
Ok(serde_json::from_slice(&body)?)
}
fn same_origin(a: &Url, b: &Url) -> bool {
a.scheme().eq_ignore_ascii_case(b.scheme())
&& match (a.host_str(), b.host_str()) {
(Some(ha), Some(hb)) => ha.eq_ignore_ascii_case(hb),
_ => false,
}
&& a.port_or_known_default() == b.port_or_known_default()
}
async fn request_capped(client: &Client, url: Url, max_body_bytes: u64) -> Result<Vec<u8>, Error> {
let response = client
.get(url)
.header(header::ACCEPT, "application/json")
.send()
.await?;
let status = response.status();
if !status.is_success() {
return Err(Error::BadStatus(status.as_u16()));
}
read_capped(response, max_body_bytes).await
}
async fn read_capped(
mut response: reqwest::Response,
max_body_bytes: u64,
) -> Result<Vec<u8>, Error> {
let mut acc: Vec<u8> = Vec::new();
while let Some(chunk) = response.chunk().await? {
if max_body_bytes > 0 && (acc.len() as u64 + chunk.len() as u64) > max_body_bytes {
return Err(Error::ResponseTooLarge(max_body_bytes));
}
acc.extend_from_slice(&chunk);
}
Ok(acc)
}
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use serde_json::json;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
use super::*;
#[tokio::test]
async fn end_to_end_fetch_via_mock_server() {
let server = MockServer::start().await;
let base: Url = server.uri().parse().unwrap();
let nodeinfo_url = format!("{base}nodeinfo/2.1");
Mock::given(method("GET"))
.and(path("/.well-known/nodeinfo"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"links": [{
"rel": "http://nodeinfo.diaspora.software/ns/schema/2.1",
"href": nodeinfo_url
}]
})))
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/nodeinfo/2.1"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"version": "2.1",
"software": { "name": "mock-server", "version": "0.1.0" },
"protocols": ["activitypub"],
"openRegistrations": false,
"usage": {}
})))
.mount(&server)
.await;
let client = Client::new();
let info = fetch(&base, Version::V2_1, &client).await.unwrap();
assert_eq!(info.version, Version::V2_1);
assert_eq!(info.software.name, "mock-server");
}
#[tokio::test]
async fn fetch_refuses_cross_origin_href_advertised_by_discovery() {
let primary = MockServer::start().await;
let attacker = MockServer::start().await;
let attacker_nodeinfo = format!("{}/nodeinfo/2.1", attacker.uri());
Mock::given(method("GET"))
.and(path("/.well-known/nodeinfo"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"links": [{
"rel": "http://nodeinfo.diaspora.software/ns/schema/2.1",
"href": attacker_nodeinfo
}]
})))
.mount(&primary)
.await;
let client = Client::new();
let base: Url = primary.uri().parse().unwrap();
let err = fetch(&base, Version::V2_1, &client)
.await
.expect_err("cross-origin href must be refused");
assert!(
matches!(err, Error::CrossOriginHref { .. }),
"expected CrossOriginHref, got {err:?}",
);
}
#[tokio::test]
async fn oversized_discovery_body_is_rejected() {
let server = MockServer::start().await;
let base: Url = server.uri().parse().unwrap();
let padding = "x".repeat(128 * 1024);
Mock::given(method("GET"))
.and(path("/.well-known/nodeinfo"))
.respond_with(ResponseTemplate::new(200).set_body_raw(
format!(r#"{{"links":[],"padding":"{padding}"}}"#).into_bytes(),
"application/json",
))
.mount(&server)
.await;
let client = Client::new();
let err = fetch_discovery(&base, &client)
.await
.expect_err("oversized body must be rejected");
assert!(matches!(
err,
Error::ResponseTooLarge(DEFAULT_MAX_BODY_BYTES)
));
}
#[tokio::test]
async fn recommended_client_rejects_cross_origin_redirect() {
let primary = MockServer::start().await;
let attacker = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/.well-known/nodeinfo"))
.respond_with(ResponseTemplate::new(302).insert_header(
"Location",
format!("{}/.well-known/nodeinfo", attacker.uri()),
))
.mount(&primary)
.await;
let client = recommended_client().expect("client builds");
let base: Url = primary.uri().parse().unwrap();
fetch_discovery(&base, &client)
.await
.expect_err("cross-origin redirect must fail");
}
#[tokio::test]
async fn missing_version_returns_specific_error() {
let server = MockServer::start().await;
let base: Url = server.uri().parse().unwrap();
Mock::given(method("GET"))
.and(path("/.well-known/nodeinfo"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"links": [{
"rel": "http://nodeinfo.diaspora.software/ns/schema/2.0",
"href": format!("{base}nodeinfo/2.0")
}]
})))
.mount(&server)
.await;
let client = Client::new();
let err = fetch(&base, Version::V2_1, &client).await.unwrap_err();
assert!(matches!(
err,
Error::VersionNotAdvertised { requested: "2.1" }
));
}
}