use std::collections::HashMap;
use std::convert::Infallible;
use std::sync::Arc;
use std::time::Duration;
use adler_core::{
ExecutorOptions, Site, Username, render_investigation_report_html,
render_investigation_report_markdown,
};
use async_stream::stream;
use axum::Json;
use axum::extract::{Path as AxumPath, Query, State};
use axum::http::header;
use axum::response::sse::{Event, KeepAlive, KeepAliveStream, Sse};
use axum::response::{IntoResponse, Response};
use futures::Stream;
use serde::Deserialize;
use super::dto::{
AccessSummary, DisabledSiteSummary, Health, RefilterRequest, RefilterResponse, RetryRequest,
RetryResponse, ScanListEntry, ScanSnapshot, SessionName, SiteSummary, SitesResponse,
StartEvent, StartScanRequest, StartScanResponse,
};
use super::error::ApiError;
use super::filter::filter_catalog;
use crate::persist::{PersistedDisabledMatch, PersistedScan, ScanRequestContext};
use crate::scan::{ScanHandle, ScanId};
use crate::state::AppState;
pub(super) async fn health() -> Json<Health> {
Json(Health {
ok: true,
version: env!("CARGO_PKG_VERSION"),
})
}
pub(super) async fn list_sites(State(state): State<AppState>) -> Json<SitesResponse> {
let sites = state
.sites
.iter()
.map(SiteSummary::from)
.collect::<Vec<_>>();
let disabled = state
.catalog
.iter()
.filter(|s| s.disabled)
.map(DisabledSiteSummary::from)
.collect::<Vec<_>>();
Json(SitesResponse { sites, disabled })
}
pub(super) async fn list_access(State(state): State<AppState>) -> Json<AccessSummary> {
let egress = state.client.egress_summary();
let sessions = state
.client
.session_names()
.into_iter()
.map(|name| SessionName { name })
.collect();
Json(AccessSummary { egress, sessions })
}
pub(super) async fn list_scans(State(state): State<AppState>) -> Json<Vec<ScanListEntry>> {
let handles: Vec<(ScanId, ScanHandle)> = {
let scans = state.scans.read().await;
scans
.iter()
.map(|(id, h)| (id.clone(), h.clone()))
.collect()
};
let mut by_id: HashMap<ScanId, ScanListEntry> = HashMap::with_capacity(handles.len());
for (id, handle) in handles {
let finished = handle.finished().await;
by_id.insert(
id.clone(),
ScanListEntry {
scan_id: id,
username: handle.username().to_owned(),
site_count: handle.site_count(),
started_at_ms: handle.created_at_ms(),
elapsed_ms: u64::try_from(handle.elapsed().as_millis()).unwrap_or(u64::MAX),
status: if finished.is_some() {
"finished"
} else {
"running"
},
summary: finished.map(|f| f.summary),
},
);
}
if let Some(dir) = &state.scans_dir {
for ps in crate::persist::load_all(dir).await {
by_id.entry(ps.scan_id.clone()).or_insert(ScanListEntry {
scan_id: ps.scan_id,
username: ps.username,
site_count: ps.site_count,
started_at_ms: ps.created_at_ms,
elapsed_ms: ps.elapsed_ms,
status: "finished",
summary: Some(ps.summary),
});
}
}
let mut entries: Vec<ScanListEntry> = by_id.into_values().collect();
entries.sort_by_key(|e| std::cmp::Reverse(e.started_at_ms));
Json(entries)
}
pub(super) async fn list_scan_timeline(
State(state): State<AppState>,
AxumPath(username): AxumPath<String>,
) -> Result<Json<crate::persist::ScanTimeline>, ApiError> {
let username = Username::new(username)
.map_err(|e| ApiError::bad_request("invalid_username", e.to_string()))?;
let scans = load_finished_scans_for_username(&state, username.as_str()).await;
let mut timeline = crate::persist::build_scan_timeline(&scans);
if timeline.username.is_empty() {
username.as_str().clone_into(&mut timeline.username);
}
Ok(Json(timeline))
}
pub(super) async fn start_scan(
State(state): State<AppState>,
Json(req): Json<StartScanRequest>,
) -> Result<Json<StartScanResponse>, ApiError> {
let username = Username::new(req.username.clone())
.map_err(|e| ApiError::bad_request("invalid_username", e.to_string()))?;
let sites = filter_catalog(&state.sites, &req);
if sites.is_empty() {
let disabled = disabled_matches(&state.catalog, &req);
return Err(ApiError::bad_request(
"empty_site_filter",
empty_filter_message(disabled.is_empty()),
)
.with_disabled_matches(disabled));
}
if !req.egress_names.is_empty() {
let known: std::collections::HashSet<String> =
state.client.egress_names().into_iter().collect();
let bad: Vec<&String> = req
.egress_names
.iter()
.filter(|n| !known.contains(n.as_str()))
.collect();
if !bad.is_empty() {
let names: Vec<&str> = bad.iter().map(|s| s.as_str()).collect();
return Err(ApiError::bad_request(
"unknown_egress",
format!("egress not in pool: {}", names.join(", ")),
));
}
}
let mut options = ExecutorOptions::default();
if let Some(c) = req.concurrency {
options = options.concurrency(c);
}
if let Some(d) = req.deadline_secs {
options = options.deadline(Duration::from_secs(d));
}
let id = ScanId::new();
let site_count = sites.len();
let handle = ScanHandle::new(req.username.clone(), site_count, site_count.max(64));
state.insert_scan(id.clone(), handle.clone()).await;
let persist_ctx = state
.scans_dir
.as_ref()
.map(|dir| crate::scan::PersistContext {
scan_id: id.clone(),
dir: dir.clone(),
request_context: request_context(&req, &state.catalog, None),
});
let scan_client: Arc<adler_core::Client> = if req.egress_names.is_empty() {
state.client.clone()
} else {
Arc::new(state.client.with_egress_subset(&req.egress_names))
};
let task = crate::scan::spawn(
handle,
scan_client,
Arc::from(sites.into_boxed_slice()),
username,
options,
persist_ctx,
);
state.register_scan_task(id.clone(), task).await;
Ok(Json(StartScanResponse {
scan_id: id,
username: req.username,
site_count,
}))
}
pub(super) async fn refilter_scan(
State(state): State<AppState>,
AxumPath(id): AxumPath<String>,
Json(req): Json<RefilterRequest>,
) -> Result<Json<RefilterResponse>, ApiError> {
let prev_id = ScanId::from(id);
let prev_handle = state
.get_scan(&prev_id)
.await
.ok_or_else(|| ApiError::not_found("scan_not_found", "no scan with that ID"))?;
if prev_handle.is_finished_now() {
return Err(ApiError::bad_request(
"scan_finished",
"scan has already finished; start a new one with POST /api/scan",
));
}
if !req.egress_names.is_empty() {
let known: std::collections::HashSet<String> =
state.client.egress_names().into_iter().collect();
let bad: Vec<&String> = req
.egress_names
.iter()
.filter(|n| !known.contains(n.as_str()))
.collect();
if !bad.is_empty() {
let names: Vec<&str> = bad.iter().map(|s| s.as_str()).collect();
return Err(ApiError::bad_request(
"unknown_egress",
format!("egress not in pool: {}", names.join(", ")),
));
}
}
let mut start_shape = StartScanRequest::from(&req);
let new_sites = filter_catalog(&state.sites, &start_shape);
if new_sites.is_empty() {
let disabled = disabled_matches(&state.catalog, &start_shape);
return Err(ApiError::bad_request(
"empty_site_filter",
empty_filter_message(disabled.is_empty()),
)
.with_disabled_matches(disabled));
}
let prev_outcomes = prev_handle.outcomes_snapshot().await;
let new_site_names: std::collections::HashSet<String> =
new_sites.iter().map(|s| s.name.clone()).collect();
let carried: Vec<adler_core::CheckOutcome> = prev_outcomes
.into_iter()
.filter(|o| new_site_names.contains(&o.site))
.collect();
let carried_names: std::collections::HashSet<String> =
carried.iter().map(|o| o.site.clone()).collect();
let sites_to_probe: Vec<Site> = new_sites
.iter()
.filter(|s| !carried_names.contains(&s.name))
.cloned()
.collect();
state.abort_scan(&prev_id).await;
let mut options = ExecutorOptions::default();
if let Some(c) = req.concurrency {
options = options.concurrency(c);
}
if let Some(d) = req.deadline_secs {
options = options.deadline(Duration::from_secs(d));
}
let username_str = prev_handle.username().to_owned();
start_shape.username = username_str.clone();
let username = Username::new(username_str.clone())
.map_err(|e| ApiError::bad_request("invalid_username", e.to_string()))?;
let id = ScanId::new();
let site_count = new_sites.len();
let handle = ScanHandle::new(username_str.clone(), site_count, site_count.max(64));
state.insert_scan(id.clone(), handle.clone()).await;
handle.extend_outcomes(carried.clone()).await;
let persist_ctx = state
.scans_dir
.as_ref()
.map(|dir| crate::scan::PersistContext {
scan_id: id.clone(),
dir: dir.clone(),
request_context: request_context(&start_shape, &state.catalog, Some(prev_id.clone())),
});
let scan_client: Arc<adler_core::Client> = if req.egress_names.is_empty() {
state.client.clone()
} else {
Arc::new(state.client.with_egress_subset(&req.egress_names))
};
let task = crate::scan::spawn(
handle,
scan_client,
Arc::from(sites_to_probe.into_boxed_slice()),
username,
options,
persist_ctx,
);
state.register_scan_task(id.clone(), task).await;
Ok(Json(RefilterResponse {
scan_id: id,
derived_from: prev_id,
carried_outcomes: carried.len(),
site_count,
}))
}
fn disabled_matches(catalog: &[Site], req: &StartScanRequest) -> Vec<DisabledSiteSummary> {
super::filter::disabled_matches(catalog, req)
.iter()
.map(DisabledSiteSummary::from)
.collect()
}
fn request_context(
req: &StartScanRequest,
catalog: &[Site],
derived_from: Option<ScanId>,
) -> ScanRequestContext {
ScanRequestContext {
username: req.username.clone(),
derived_from,
only: req.only.clone(),
exclude: req.exclude.clone(),
tag: req.tag.clone(),
exclude_tag: req.exclude_tag.clone(),
top: req.top,
nsfw: req.nsfw,
concurrency: req.concurrency.map(std::num::NonZeroUsize::get),
deadline_secs: req.deadline_secs,
egress_names: req.egress_names.clone(),
disabled_matches: super::filter::disabled_matches(catalog, req)
.iter()
.map(PersistedDisabledMatch::from)
.collect(),
}
}
fn empty_filter_message(disabled_empty: bool) -> &'static str {
if disabled_empty {
"no sites match the requested filter"
} else {
"no enabled sites match the requested filter"
}
}
pub(super) async fn get_scan(
State(state): State<AppState>,
AxumPath(id): AxumPath<String>,
) -> Result<Json<ScanSnapshot>, ApiError> {
let scan_id = ScanId::from(id);
if let Some(scan) = state.get_scan(&scan_id).await {
return Ok(match scan.finished().await {
Some(finished) => Json(ScanSnapshot::Finished {
username: scan.username().to_owned(),
site_count: scan.site_count(),
finished,
}),
None => Json(ScanSnapshot::Running {
username: scan.username().to_owned(),
site_count: scan.site_count(),
elapsed_ms: u64::try_from(scan.elapsed().as_millis()).unwrap_or(u64::MAX),
partial: scan.outcomes_snapshot().await,
}),
});
}
if let Some(dir) = &state.scans_dir {
if let Some(mut ps) = crate::persist::load(dir, &scan_id).await {
let related_scans = crate::persist::load_all(dir).await;
crate::persist::apply_historical_confidence_overlay(&mut ps, &related_scans);
return Ok(Json(ScanSnapshot::Finished {
username: ps.username,
site_count: ps.site_count,
finished: crate::scan::FinishedScan {
summary: ps.summary,
outcomes: ps.outcomes,
identity_clusters: ps.identity_clusters,
elapsed_ms: ps.elapsed_ms,
},
}));
}
}
Err(ApiError::not_found(
"scan_not_found",
"no scan with that ID",
))
}
#[derive(Debug, Deserialize)]
pub(super) struct ReportQuery {
format: Option<String>,
}
pub(super) async fn get_scan_report(
State(state): State<AppState>,
AxumPath(id): AxumPath<String>,
Query(query): Query<ReportQuery>,
) -> Result<Response, ApiError> {
let format = ReportResponseFormat::parse(query.format.as_deref())?;
let scan_id = ScanId::from(id);
let scan = load_finished_scan(&state, &scan_id).await?;
let related_scans = load_finished_scans_for_username(&state, &scan.username).await;
let report = crate::persist::build_investigation_report(scan, &related_scans);
Ok(match format {
ReportResponseFormat::Json => Json(report).into_response(),
ReportResponseFormat::Markdown => (
[(header::CONTENT_TYPE, "text/markdown; charset=utf-8")],
render_investigation_report_markdown(&report),
)
.into_response(),
ReportResponseFormat::Html => (
[(header::CONTENT_TYPE, "text/html; charset=utf-8")],
render_investigation_report_html(&report),
)
.into_response(),
})
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ReportResponseFormat {
Json,
Markdown,
Html,
}
impl ReportResponseFormat {
fn parse(value: Option<&str>) -> Result<Self, ApiError> {
match value.unwrap_or("json") {
"json" => Ok(Self::Json),
"markdown" => Ok(Self::Markdown),
"html" => Ok(Self::Html),
_ => Err(ApiError::bad_request(
"invalid_report_format",
"report format must be json, markdown, or html",
)),
}
}
}
pub(super) async fn diff_scans(
State(state): State<AppState>,
AxumPath((from, to)): AxumPath<(String, String)>,
) -> Result<Json<crate::persist::ScanDiff>, ApiError> {
let from_id = ScanId::from(from);
let to_id = ScanId::from(to);
let previous = load_finished_scan(&state, &from_id).await?;
let current = load_finished_scan(&state, &to_id).await?;
Ok(Json(crate::persist::diff_scans(&previous, ¤t)))
}
async fn load_finished_scans_for_username(state: &AppState, username: &str) -> Vec<PersistedScan> {
let handles: Vec<(ScanId, ScanHandle)> = {
let scans = state.scans.read().await;
scans
.iter()
.map(|(id, handle)| (id.clone(), handle.clone()))
.collect()
};
let mut by_id: HashMap<ScanId, PersistedScan> = HashMap::with_capacity(handles.len());
for (id, handle) in handles {
if handle.username() != username {
continue;
}
if let Some(finished) = handle.finished().await {
by_id.insert(
id.clone(),
PersistedScan::from_finished(
id,
handle.username().to_owned(),
handle.site_count(),
handle.created_at_ms(),
finished,
),
);
}
}
if let Some(dir) = &state.scans_dir {
for scan in crate::persist::load_all(dir).await {
if scan.username == username {
by_id.entry(scan.scan_id.clone()).or_insert(scan);
}
}
}
by_id.into_values().collect()
}
async fn load_finished_scan(state: &AppState, scan_id: &ScanId) -> Result<PersistedScan, ApiError> {
if let Some(scan) = state.get_scan(scan_id).await {
if let Some(finished) = scan.finished().await {
return Ok(PersistedScan::from_finished(
scan_id.clone(),
scan.username().to_owned(),
scan.site_count(),
scan.created_at_ms(),
finished,
));
}
return Err(ApiError::bad_request(
"scan_not_finished",
"scan is still running",
));
}
if let Some(dir) = &state.scans_dir
&& let Some(scan) = crate::persist::load(dir, scan_id).await
{
return Ok(scan);
}
Err(ApiError::not_found(
"scan_not_found",
"no scan with that ID",
))
}
type SseStream = std::pin::Pin<Box<dyn Stream<Item = Result<Event, Infallible>> + Send>>;
pub(super) async fn retry_site(
State(state): State<AppState>,
AxumPath(id): AxumPath<String>,
Json(req): Json<RetryRequest>,
) -> Result<Json<RetryResponse>, ApiError> {
let scan_id = ScanId::from(id);
let username_raw: String = if let Some(handle) = state.get_scan(&scan_id).await {
handle.username().to_owned()
} else if let Some(dir) = &state.scans_dir {
if let Some(ps) = crate::persist::load(dir, &scan_id).await {
ps.username
} else {
return Err(ApiError::not_found(
"scan_not_found",
"no scan with that ID",
));
}
} else {
return Err(ApiError::not_found(
"scan_not_found",
"no scan with that ID",
));
};
let site = state
.sites
.iter()
.find(|s| s.name.eq_ignore_ascii_case(&req.site))
.cloned()
.ok_or_else(|| {
ApiError::bad_request("site_not_in_catalog", "site not in current catalog")
})?;
let username = Username::new(username_raw.clone())
.map_err(|e| ApiError::bad_request("invalid_username", e.to_string()))?;
let new_outcome = state.client.check(&site, &username).await;
if let Some(handle) = state.get_scan(&scan_id).await {
handle.replace_outcome(new_outcome.clone()).await;
if let (Some(finished), Some(dir)) = (handle.finished().await, &state.scans_dir) {
let existing_context = crate::persist::load(dir, &scan_id)
.await
.and_then(|scan| scan.request_context);
let mut snap = crate::persist::PersistedScan::from_finished(
scan_id.clone(),
handle.username().to_owned(),
handle.site_count(),
handle.created_at_ms(),
finished,
);
snap.request_context = existing_context;
if let Err(err) = crate::persist::save(dir, &snap).await {
tracing::warn!(error = %err, scan_id = %scan_id, "failed to re-persist scan");
}
}
} else if let Some(dir) = &state.scans_dir {
if let Some(mut ps) = crate::persist::load(dir, &scan_id).await {
if let Some(slot) = ps.outcomes.iter_mut().find(|o| o.site == new_outcome.site) {
*slot = new_outcome.clone();
} else {
ps.outcomes.push(new_outcome.clone());
}
ps.refresh_derived_fields();
if let Err(err) = crate::persist::save(dir, &ps).await {
tracing::warn!(error = %err, scan_id = %scan_id, "failed to patch persisted scan");
}
}
}
Ok(Json(RetryResponse {
outcome: new_outcome,
}))
}
pub(super) async fn stream_scan(
State(state): State<AppState>,
AxumPath(id): AxumPath<String>,
) -> Result<Sse<KeepAliveStream<SseStream>>, ApiError> {
let scan_id = ScanId::from(id);
if let Some(scan) = state.get_scan(&scan_id).await {
let stream: SseStream = Box::pin(scan_event_stream(scan));
return Ok(Sse::new(stream).keep_alive(KeepAlive::new()));
}
if let Some(dir) = &state.scans_dir {
if let Some(ps) = crate::persist::load(dir, &scan_id).await {
let stream: SseStream = Box::pin(persisted_event_stream(ps));
return Ok(Sse::new(stream).keep_alive(KeepAlive::new()));
}
}
Err(ApiError::not_found(
"scan_not_found",
"no scan with that ID",
))
}
fn persisted_event_stream(
ps: crate::persist::PersistedScan,
) -> impl Stream<Item = Result<Event, Infallible>> + Send {
let username = ps.username.clone();
let outcomes = ps.outcomes.clone();
let finished = crate::scan::FinishedScan {
summary: ps.summary,
outcomes: ps.outcomes,
identity_clusters: ps.identity_clusters,
elapsed_ms: ps.elapsed_ms,
};
stream! {
yield Ok(Event::default()
.event("start")
.json_data(StartEvent { username })
.unwrap_or_default());
for o in &outcomes {
yield Ok(outcome_event(o));
}
yield Ok(Event::default()
.event("done")
.json_data(&finished)
.unwrap_or_default());
}
}
fn scan_event_stream(scan: ScanHandle) -> impl Stream<Item = Result<Event, Infallible>> {
stream! {
yield Ok(Event::default()
.event("start")
.json_data(StartEvent { username: scan.username().to_owned() })
.unwrap_or_default());
let history = scan.outcomes_snapshot().await;
let mut last_index = history.len();
for outcome in &history {
yield Ok(outcome_event(outcome));
}
if scan.finished().await.is_none() {
let mut rx = scan.subscribe();
loop {
tokio::select! {
biased;
() = scan.wait_done() => break,
recv = rx.recv() => match recv {
Ok(idx) => {
let snap = scan.outcomes_snapshot().await;
for outcome in &snap[last_index..=idx.min(snap.len().saturating_sub(1))] {
yield Ok(outcome_event(outcome));
}
last_index = idx + 1;
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
let snap = scan.outcomes_snapshot().await;
for outcome in &snap[last_index..] {
yield Ok(outcome_event(outcome));
}
last_index = snap.len();
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
}
}
}
}
let final_snap = scan.outcomes_snapshot().await;
for outcome in &final_snap[last_index..] {
yield Ok(outcome_event(outcome));
}
if let Some(finished) = scan.finished().await {
yield Ok(Event::default()
.event("done")
.json_data(&finished)
.unwrap_or_default());
}
}
}
fn outcome_event(outcome: &adler_core::CheckOutcome) -> Event {
Event::default()
.event("outcome")
.json_data(outcome)
.unwrap_or_default()
}