use crate::{
activity_queue::{create_activity_queue, ActivityQueue},
error::Error,
http_signatures::sign_request,
protocol::verification::verify_domains_match,
traits::{Activity, Actor},
utils::validate_ip,
};
use async_trait::async_trait;
use bytes::Bytes;
use derive_builder::Builder;
use dyn_clone::{clone_trait_object, DynClone};
use moka::future::Cache;
use regex::Regex;
use reqwest::{redirect::Policy, Client, Request};
use reqwest_middleware::{ClientWithMiddleware, RequestBuilder};
use rsa::{pkcs8::DecodePrivateKey, RsaPrivateKey};
use serde::de::DeserializeOwned;
use std::{
ops::Deref,
sync::{
atomic::{AtomicU32, Ordering},
Arc,
OnceLock,
},
time::Duration,
};
use url::Url;
#[derive(Builder, Clone)]
#[builder(build_fn(private, name = "partial_build"))]
pub struct FederationConfig<T: Clone> {
#[builder(setter(into))]
pub(crate) domain: String,
pub(crate) app_data: T,
#[builder(default = "20")]
pub(crate) http_fetch_limit: u32,
#[builder(default = "default_client()")]
pub(crate) client: ClientWithMiddleware,
#[builder(default = "false")]
pub(crate) debug: bool,
#[builder(default = "self.debug.unwrap_or(false)")]
pub(crate) allow_http_urls: bool,
#[builder(default = "Duration::from_secs(10)")]
pub(crate) request_timeout: Duration,
#[builder(default = "Box::new(DefaultUrlVerifier())")]
pub(crate) url_verifier: Box<dyn UrlVerifier + Sync>,
#[builder(default = "false")]
pub(crate) http_signature_compat: bool,
#[builder(default = "None", setter(custom))]
pub(crate) signed_fetch_actor: Option<Arc<(Url, RsaPrivateKey)>>,
#[builder(
default = "Cache::builder().max_capacity(10000).build()",
setter(custom)
)]
pub(crate) actor_pkey_cache: Cache<Url, RsaPrivateKey>,
#[builder(setter(skip))]
pub(crate) activity_queue: Option<Arc<ActivityQueue>>,
#[builder(default = "0")]
pub(crate) queue_worker_count: usize,
#[builder(default = "0")]
pub(crate) queue_retry_count: usize,
}
pub(crate) fn domain_regex() -> &'static Regex {
static DOMAIN_REGEX: OnceLock<Regex> = OnceLock::new();
DOMAIN_REGEX.get_or_init(|| Regex::new(r"^[a-zA-Z0-9.-]*$").expect("compile regex"))
}
impl<T: Clone> FederationConfig<T> {
pub fn builder() -> FederationConfigBuilder<T> {
FederationConfigBuilder::default()
}
pub(crate) async fn verify_url_and_domain<A, Datatype>(&self, activity: &A) -> Result<(), Error>
where
A: Activity<DataType = Datatype> + DeserializeOwned + Send + 'static,
{
verify_domains_match(activity.id(), activity.actor())?;
self.verify_url_valid(activity.id()).await?;
if self.is_local_url(activity.id()) {
return Err(Error::UrlVerificationError(
"Activity was sent from local instance",
));
}
Ok(())
}
pub fn to_request_data(&self) -> Data<T> {
Data {
config: self.clone(),
request_counter: Default::default(),
}
}
pub(crate) async fn verify_url_valid(&self, url: &Url) -> Result<(), Error> {
match url.scheme() {
"https" => {}
"http" => {
if !self.allow_http_urls {
return Err(Error::UrlVerificationError(
"Http urls are only allowed in debug mode",
));
}
}
_ => return Err(Error::UrlVerificationError("Invalid url scheme")),
};
if self.is_local_url(url) {
return Ok(());
}
let Some(domain) = url.domain() else {
return Err(Error::UrlVerificationError("Url must have a domain"));
};
if !domain_regex().is_match(domain) {
return Err(Error::UrlVerificationError("Invalid characters in domain"));
}
if !self.debug {
if url.port().is_some() {
return Err(Error::UrlVerificationError("Explicit port is not allowed"));
}
let allow_local = std::env::var("DANGER_FEDERATION_ALLOW_LOCAL_IP").is_ok();
if !allow_local && validate_ip(&url).await.is_err() {
return Err(Error::DomainResolveError(domain.to_string()));
}
}
if domain.ends_with('.') {
let mut url = url.clone();
let domain = &domain[0..domain.len() - 1];
url.set_host(Some(domain))?;
self.url_verifier.verify(&url).await?;
} else {
self.url_verifier.verify(url).await?;
}
Ok(())
}
pub(crate) fn is_local_url(&self, url: &Url) -> bool {
match url.host_str() {
Some(domain) => {
let domain = if let Some(port) = url.port() {
format!("{}:{}", domain, port)
} else {
domain.to_string()
};
domain == self.domain
}
None => false,
}
}
pub fn domain(&self) -> &str {
&self.domain
}
}
impl<T: Clone> FederationConfigBuilder<T> {
pub fn signed_fetch_actor<A: Actor>(&mut self, actor: &A) -> &mut Self {
let private_key_pem = actor
.private_key_pem()
.expect("actor does not have a private key to sign with");
let private_key =
RsaPrivateKey::from_pkcs8_pem(&private_key_pem).expect("Could not decode PEM data");
self.signed_fetch_actor = Some(Some(Arc::new((actor.id().clone(), private_key))));
self
}
pub fn actor_pkey_cache(&mut self, cache_size: u64) -> &mut Self {
self.actor_pkey_cache = Some(Cache::builder().max_capacity(cache_size).build());
self
}
pub async fn build(&mut self) -> Result<FederationConfig<T>, FederationConfigBuilderError> {
let mut config = self.partial_build()?;
let queue = create_activity_queue(
config.client.clone(),
config.queue_worker_count,
config.queue_retry_count,
config.request_timeout,
);
config.activity_queue = Some(Arc::new(queue));
Ok(config)
}
}
impl<T: Clone> Deref for FederationConfig<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.app_data
}
}
#[async_trait]
pub trait UrlVerifier: DynClone + Send {
async fn verify(&self, url: &Url) -> Result<(), Error>;
}
#[derive(Clone)]
struct DefaultUrlVerifier();
#[async_trait]
impl UrlVerifier for DefaultUrlVerifier {
async fn verify(&self, _url: &Url) -> Result<(), Error> {
Ok(())
}
}
clone_trait_object!(UrlVerifier);
#[derive(Clone)]
pub struct Data<T: Clone> {
pub(crate) config: FederationConfig<T>,
pub(crate) request_counter: RequestCounter,
}
impl<T: Clone> Data<T> {
pub fn app_data(&self) -> &T {
&self.config.app_data
}
pub fn domain(&self) -> &str {
&self.config.domain
}
pub fn reset_request_count(&self) -> Self {
Data {
config: self.config.clone(),
request_counter: Default::default(),
}
}
pub fn request_count(&self) -> u32 {
self.request_counter.0.load(Ordering::Relaxed)
}
pub async fn sign_request(&self, req: RequestBuilder, body: Bytes) -> Result<Request, Error> {
let (actor_id, private_key_pem) =
self.config
.signed_fetch_actor
.as_deref()
.ok_or(Error::Other(
"config value signed_fetch_actor is none".to_string(),
))?;
sign_request(
req,
actor_id,
body,
private_key_pem.clone(),
self.config.http_signature_compat,
)
.await
}
pub async fn is_valid_ip(&self, url: &Url) -> Result<(), Error> {
if self.config.debug {
return Ok(());
}
validate_ip(url).await
}
}
impl<T: Clone> Deref for Data<T> {
type Target = T;
fn deref(&self) -> &T {
&self.config.app_data
}
}
#[derive(Default)]
pub(crate) struct RequestCounter(pub(crate) AtomicU32);
impl Clone for RequestCounter {
fn clone(&self) -> Self {
RequestCounter(self.0.load(Ordering::Relaxed).into())
}
}
#[derive(Clone)]
pub struct FederationMiddleware<T: Clone>(pub(crate) FederationConfig<T>);
impl<T: Clone> FederationMiddleware<T> {
pub fn new(config: FederationConfig<T>) -> Self {
FederationMiddleware(config)
}
}
fn default_client() -> ClientWithMiddleware {
let timeout = Duration::from_secs(10);
Client::builder()
.redirect(Policy::none())
.timeout(timeout)
.connect_timeout(timeout)
.build()
.unwrap_or_else(|_| Client::default())
.into()
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod test {
use super::*;
async fn config() -> FederationConfig<i32> {
FederationConfig::builder()
.domain("example.com")
.app_data(1)
.build()
.await
.unwrap()
}
#[tokio::test]
async fn test_url_is_local() -> Result<(), Error> {
let config = config().await;
assert!(config.is_local_url(&Url::parse("http://example.com")?));
assert!(!config.is_local_url(&Url::parse("http://other.com")?));
assert!(!config.is_local_url(&Url::parse("http://127.0.0.1")?));
Ok(())
}
#[tokio::test]
async fn test_get_domain() {
let config = config().await;
assert_eq!("example.com", config.domain());
}
}