#![doc = include_str!("../../docs/07_fetching_data.md")]
use crate::{
config::Data,
error::{Error, Error::ParseFetchedObject},
extract_id,
http_signatures::sign_request,
reqwest_shim::ResponseExt,
FEDERATION_CONTENT_TYPE,
};
use bytes::Bytes;
use http::{header::LOCATION, HeaderValue, StatusCode};
use serde::de::DeserializeOwned;
use std::sync::atomic::Ordering;
use tracing::info;
use url::Url;
pub mod collection_id;
pub mod object_id;
pub mod webfinger;
pub struct FetchObjectResponse<Kind> {
pub object: Kind,
pub url: Url,
content_type: Option<HeaderValue>,
object_id: Option<Url>,
}
pub async fn fetch_object_http<T: Clone, Kind: DeserializeOwned>(
url: &Url,
data: &Data<T>,
) -> Result<FetchObjectResponse<Kind>, Error> {
static FETCH_CONTENT_TYPE: HeaderValue = HeaderValue::from_static(FEDERATION_CONTENT_TYPE);
const VALID_RESPONSE_CONTENT_TYPES: [&str; 3] = [
FEDERATION_CONTENT_TYPE, r#"application/ld+json; profile="https://www.w3.org/ns/activitystreams""#, r#"application/activity+json; charset=utf-8"#, ];
let res = fetch_object_http_with_accept(url, data, &FETCH_CONTENT_TYPE, false).await?;
let content_type = res
.content_type
.as_ref()
.and_then(|c| Some(c.to_str().ok()?.to_lowercase()))
.ok_or(Error::FetchInvalidContentType(res.url.clone()))?;
if !VALID_RESPONSE_CONTENT_TYPES.contains(&content_type.as_str()) {
return Err(Error::FetchInvalidContentType(res.url));
}
if res.object_id.as_ref() != Some(&res.url) {
if let Some(res_object_id) = res.object_id {
data.config.verify_url_valid(&res_object_id).await?;
if res_object_id.domain() == res.url.domain() {
return Box::pin(fetch_object_http(&res_object_id, data)).await;
}
}
return Err(Error::FetchWrongId(res.url));
}
if data.config.is_local_url(&res.url) {
return Err(Error::NotFound);
}
Ok(res)
}
async fn fetch_object_http_with_accept<T: Clone, Kind: DeserializeOwned>(
url: &Url,
data: &Data<T>,
content_type: &HeaderValue,
recursive: bool,
) -> Result<FetchObjectResponse<Kind>, Error> {
let config = &data.config;
config.verify_url_valid(url).await?;
info!("Fetching remote object {}", url.to_string());
let mut counter = data.request_counter.fetch_add(1, Ordering::SeqCst);
counter += 1;
if counter > config.http_fetch_limit {
return Err(Error::RequestLimit);
}
let req = config
.client
.get(url.as_str())
.header("Accept", content_type)
.timeout(config.request_timeout);
let res = if let Some((actor_id, private_key_pem)) = config.signed_fetch_actor.as_deref() {
let req = sign_request(
req,
actor_id,
Bytes::new(),
private_key_pem.clone(),
data.config.http_signature_compat,
)
.await?;
config.client.execute(req).await?
} else {
req.send().await?
};
let location = res.headers().get(LOCATION).and_then(|l| l.to_str().ok());
if let (Some(location), false) = (location, recursive) {
let location = location.parse()?;
return Box::pin(fetch_object_http_with_accept(
&location,
data,
content_type,
true,
))
.await;
}
if res.status() == StatusCode::GONE {
return Err(Error::ObjectDeleted(url.clone()));
}
let url = res.url().clone();
let content_type = res.headers().get("Content-Type").cloned();
let text = res.bytes_limited().await?;
let object_id = extract_id(&text).ok();
match serde_json::from_slice(&text) {
Ok(object) => Ok(FetchObjectResponse {
object,
url,
content_type,
object_id,
}),
Err(e) => Err(ParseFetchedObject(
e,
url,
String::from_utf8(Vec::from(text))?,
)),
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use crate::{
config::FederationConfig,
traits::tests::{DbConnection, Person},
};
#[tokio::test]
async fn test_request_limit() -> Result<(), Error> {
let config = FederationConfig::builder()
.domain("example.com")
.app_data(DbConnection)
.http_fetch_limit(0)
.build()
.await
.unwrap();
let data = config.to_request_data();
let fetch_url = "https://example.net/".to_string();
let res: Result<FetchObjectResponse<Person>, Error> =
fetch_object_http(&Url::parse(&fetch_url).map_err(Error::UrlParse)?, &data).await;
assert_eq!(res.err(), Some(Error::RequestLimit));
Ok(())
}
}