#![recursion_limit = "512"]
mod error;
mod report;
#[macro_use]
extern crate lazy_static;
use deadqueue::limited::Queue;
use futures_util::{future::Fuse, pin_mut, select, Future, FutureExt};
use rand::{random, seq::SliceRandom, thread_rng};
use report::FailedReport;
use serde::{Deserialize, Serialize};
use std::sync::Mutex;
use std::time::{Duration, Instant};
use ttl_cache::TtlCache;
use url::Url;
pub use error::Error;
pub use report::NELReport;
const RETRY_TIMEOUT: Duration = Duration::from_secs(60);
#[derive(Clone)]
struct NELPolicy {
report_to: String,
success_fraction: f32,
failure_fraction: f32,
}
lazy_static! {
static ref NEL_POLICY_CACHE: Mutex<TtlCache<String, NELPolicy>> = Mutex::new(TtlCache::new(50));
static ref GROUP_POLICY_CACHE: Mutex<TtlCache<String, Vec<String>>> =
Mutex::new(TtlCache::new(50));
static ref REPORT_QUEUE: Queue<NELReport> = Queue::new(256);
}
#[derive(Serialize, Deserialize)]
struct NelHeader {
report_to: String,
max_age: u64,
#[serde(default)]
include_subdomains: bool,
#[serde(default)]
success_fraction: f32,
#[serde(default = "default_failure_fraction")]
failure_fraction: f32,
}
const fn default_failure_fraction() -> f32 {
1.0
}
pub fn nel_header(host: &str, hdr: &str) {
let parsed = match serde_json::from_str::<NelHeader>(hdr) {
Ok(parsed) => parsed,
Err(_) => return,
};
let valid = !parsed.report_to.is_empty()
&& (0.0..=1.0).contains(&parsed.success_fraction)
&& (0.0..=1.0).contains(&parsed.failure_fraction);
if !valid {
return;
}
if let Ok(mut guard) = NEL_POLICY_CACHE.lock() {
if parsed.max_age == 0 {
guard.remove(host);
} else {
let policy = NELPolicy {
report_to: parsed.report_to,
success_fraction: parsed.success_fraction,
failure_fraction: parsed.failure_fraction,
};
guard.insert(
host.to_string(),
policy,
Duration::from_secs(parsed.max_age),
);
}
}
}
#[derive(Serialize, Deserialize)]
struct ReportToHeader {
group: String,
max_age: u64,
endpoints: Vec<ReportEndpoint>,
}
#[derive(Serialize, Deserialize)]
struct ReportEndpoint {
url: String,
}
pub fn report_to_header(host: &str, hdr: &str) {
let parsed = match serde_json::from_str::<ReportToHeader>(hdr) {
Ok(parsed) => parsed,
Err(_) => return,
};
let valid = !parsed.group.is_empty()
&& !parsed.endpoints.is_empty()
&& parsed.endpoints.iter().all(|ep| !ep.url.is_empty());
if !valid {
return;
}
let key = format!("{}:{}", host, parsed.group);
if let Ok(mut guard) = GROUP_POLICY_CACHE.lock() {
if parsed.max_age == 0 {
guard.remove(&key);
} else {
let endpoints = parsed.endpoints.iter().map(|ep| ep.url.clone()).collect();
guard.insert(key, endpoints, Duration::from_secs(parsed.max_age));
}
}
}
pub fn submit_report(report: NELReport) {
let _ = REPORT_QUEUE.try_push(report);
}
pub async fn handle_reports<F, G, FFut, GFut>(sleep: F, post: G)
where
F: Fn(Duration) -> FFut,
G: Fn(String, String) -> GFut,
FFut: Future<Output = ()>,
GFut: Future<Output = bool>,
{
let pop = REPORT_QUEUE.pop().fuse();
let failed_queue: Queue<FailedReport> = Queue::new(256);
let fail_timeout = Fuse::terminated();
let mut next_failed: Option<FailedReport> = None;
pin_mut!(pop, fail_timeout);
loop {
select! {
report = pop => {
let payload = report.serialize();
let success = match choose_endpoint(&report, true) {
Some(endpoint) => post(endpoint, payload).await,
None => true, };
if !success {
let failed = FailedReport{
last_try: Instant::now(),
original: report,
};
if next_failed.is_none() {
fail_timeout.set(sleep(RETRY_TIMEOUT).fuse());
next_failed = Some(failed);
} else {
let _ = failed_queue.try_push(failed);
}
}
pop.set(REPORT_QUEUE.pop().fuse());
},
_ = fail_timeout => {
let report = &next_failed.as_ref().unwrap().original;
let payload = report.serialize();
let success = match choose_endpoint(report, false) {
Some(endpoint) => post(endpoint, payload).await,
None => true, };
if !success {
let _ = failed_queue.try_push(FailedReport{
last_try: Instant::now(),
original: next_failed.unwrap().original,
});
}
if let Some(failed) = failed_queue.try_pop() {
let dur = RETRY_TIMEOUT
.checked_sub(Instant::now().duration_since(failed.last_try))
.unwrap_or_else(|| Duration::from_millis(10));
fail_timeout.set(sleep(dur).fuse());
next_failed = Some(failed)
} else {
fail_timeout.set(Fuse::terminated());
next_failed = None;
}
},
}
}
}
fn choose_endpoint(report: &NELReport, evaluate_drop: bool) -> Option<String> {
let host = match &report.host_override {
Some(host) => host.clone(),
None => {
let report_url = Url::parse(&report.url).ok()?;
report_url.host_str()?.to_owned()
}
};
let nel_policy = {
let guard = NEL_POLICY_CACHE.lock().ok()?;
let policy = guard.get(&host)?;
policy.clone()
};
let group_policy = {
let group_policy_key = format!("{}:{}", host, &nel_policy.report_to);
let guard = GROUP_POLICY_CACHE.lock().ok()?;
let policy = guard.get(&group_policy_key)?;
policy.clone()
};
if evaluate_drop {
if report.is_success() {
if random::<f32>() >= nel_policy.success_fraction {
return None;
}
} else if random::<f32>() >= nel_policy.failure_fraction {
return None;
}
}
Some(group_policy.choose(&mut thread_rng())?.clone())
}