use std::collections::HashMap;
use std::convert::Infallible;
use std::sync::Arc;
use std::time::Duration;
use adler_core::{ExecutorOptions, Site, Username};
use async_stream::stream;
use axum::Json;
use axum::extract::{Path as AxumPath, State};
use axum::response::sse::{Event, KeepAlive, KeepAliveStream, Sse};
use futures::Stream;
use super::dto::{
AccessSummary, DisabledSiteSummary, Health, RefilterRequest, RefilterResponse, RetryRequest,
RetryResponse, ScanListEntry, ScanSnapshot, SessionName, SiteSummary, SitesResponse,
StartEvent, StartScanRequest, StartScanResponse,
};
use super::error::ApiError;
use super::filter::filter_catalog;
use crate::scan::{ScanHandle, ScanId};
use crate::state::AppState;
pub(super) async fn health() -> Json<Health> {
Json(Health {
ok: true,
version: env!("CARGO_PKG_VERSION"),
})
}
pub(super) async fn list_sites(State(state): State<AppState>) -> Json<SitesResponse> {
let sites = state
.sites
.iter()
.map(SiteSummary::from)
.collect::<Vec<_>>();
let disabled = state
.catalog
.iter()
.filter(|s| s.disabled)
.map(DisabledSiteSummary::from)
.collect::<Vec<_>>();
Json(SitesResponse { sites, disabled })
}
pub(super) async fn list_access(State(state): State<AppState>) -> Json<AccessSummary> {
let egress = state.client.egress_summary();
let sessions = state
.client
.session_names()
.into_iter()
.map(|name| SessionName { name })
.collect();
Json(AccessSummary { egress, sessions })
}
pub(super) async fn list_scans(State(state): State<AppState>) -> Json<Vec<ScanListEntry>> {
let handles: Vec<(ScanId, ScanHandle)> = {
let scans = state.scans.read().await;
scans
.iter()
.map(|(id, h)| (id.clone(), h.clone()))
.collect()
};
let mut by_id: HashMap<ScanId, ScanListEntry> = HashMap::with_capacity(handles.len());
for (id, handle) in handles {
let finished = handle.finished().await;
by_id.insert(
id.clone(),
ScanListEntry {
scan_id: id,
username: handle.username().to_owned(),
site_count: handle.site_count(),
started_at_ms: handle.created_at_ms(),
elapsed_ms: u64::try_from(handle.elapsed().as_millis()).unwrap_or(u64::MAX),
status: if finished.is_some() {
"finished"
} else {
"running"
},
summary: finished.map(|f| f.summary),
},
);
}
if let Some(dir) = &state.scans_dir {
for ps in crate::persist::load_all(dir).await {
by_id.entry(ps.scan_id.clone()).or_insert(ScanListEntry {
scan_id: ps.scan_id,
username: ps.username,
site_count: ps.site_count,
started_at_ms: ps.created_at_ms,
elapsed_ms: ps.elapsed_ms,
status: "finished",
summary: Some(ps.summary),
});
}
}
let mut entries: Vec<ScanListEntry> = by_id.into_values().collect();
entries.sort_by_key(|e| std::cmp::Reverse(e.started_at_ms));
Json(entries)
}
pub(super) async fn start_scan(
State(state): State<AppState>,
Json(req): Json<StartScanRequest>,
) -> Result<Json<StartScanResponse>, ApiError> {
let username = Username::new(req.username.clone())
.map_err(|e| ApiError::bad_request("invalid_username", e.to_string()))?;
let sites = filter_catalog(&state.sites, &req);
if sites.is_empty() {
let disabled = disabled_matches(&state.catalog, &req);
return Err(ApiError::bad_request(
"empty_site_filter",
empty_filter_message(disabled.is_empty()),
)
.with_disabled_matches(disabled));
}
if !req.egress_names.is_empty() {
let known: std::collections::HashSet<String> =
state.client.egress_names().into_iter().collect();
let bad: Vec<&String> = req
.egress_names
.iter()
.filter(|n| !known.contains(n.as_str()))
.collect();
if !bad.is_empty() {
let names: Vec<&str> = bad.iter().map(|s| s.as_str()).collect();
return Err(ApiError::bad_request(
"unknown_egress",
format!("egress not in pool: {}", names.join(", ")),
));
}
}
let mut options = ExecutorOptions::default();
if let Some(c) = req.concurrency {
options = options.concurrency(c);
}
if let Some(d) = req.deadline_secs {
options = options.deadline(Duration::from_secs(d));
}
let id = ScanId::new();
let site_count = sites.len();
let handle = ScanHandle::new(req.username.clone(), site_count, site_count.max(64));
state.insert_scan(id.clone(), handle.clone()).await;
let persist_ctx = state
.scans_dir
.as_ref()
.map(|dir| crate::scan::PersistContext {
scan_id: id.clone(),
dir: dir.clone(),
});
let scan_client: Arc<adler_core::Client> = if req.egress_names.is_empty() {
state.client.clone()
} else {
Arc::new(state.client.with_egress_subset(&req.egress_names))
};
let task = crate::scan::spawn(
handle,
scan_client,
Arc::from(sites.into_boxed_slice()),
username,
options,
persist_ctx,
);
state.register_scan_task(id.clone(), task).await;
Ok(Json(StartScanResponse {
scan_id: id,
username: req.username,
site_count,
}))
}
pub(super) async fn refilter_scan(
State(state): State<AppState>,
AxumPath(id): AxumPath<String>,
Json(req): Json<RefilterRequest>,
) -> Result<Json<RefilterResponse>, ApiError> {
let prev_id = ScanId::from(id);
let prev_handle = state
.get_scan(&prev_id)
.await
.ok_or_else(|| ApiError::not_found("scan_not_found", "no scan with that ID"))?;
if prev_handle.is_finished_now() {
return Err(ApiError::bad_request(
"scan_finished",
"scan has already finished; start a new one with POST /api/scan",
));
}
if !req.egress_names.is_empty() {
let known: std::collections::HashSet<String> =
state.client.egress_names().into_iter().collect();
let bad: Vec<&String> = req
.egress_names
.iter()
.filter(|n| !known.contains(n.as_str()))
.collect();
if !bad.is_empty() {
let names: Vec<&str> = bad.iter().map(|s| s.as_str()).collect();
return Err(ApiError::bad_request(
"unknown_egress",
format!("egress not in pool: {}", names.join(", ")),
));
}
}
let start_shape = StartScanRequest::from(&req);
let new_sites = filter_catalog(&state.sites, &start_shape);
if new_sites.is_empty() {
let disabled = disabled_matches(&state.catalog, &start_shape);
return Err(ApiError::bad_request(
"empty_site_filter",
empty_filter_message(disabled.is_empty()),
)
.with_disabled_matches(disabled));
}
let prev_outcomes = prev_handle.outcomes_snapshot().await;
let new_site_names: std::collections::HashSet<String> =
new_sites.iter().map(|s| s.name.clone()).collect();
let carried: Vec<adler_core::CheckOutcome> = prev_outcomes
.into_iter()
.filter(|o| new_site_names.contains(&o.site))
.collect();
let carried_names: std::collections::HashSet<String> =
carried.iter().map(|o| o.site.clone()).collect();
let sites_to_probe: Vec<Site> = new_sites
.iter()
.filter(|s| !carried_names.contains(&s.name))
.cloned()
.collect();
state.abort_scan(&prev_id).await;
let mut options = ExecutorOptions::default();
if let Some(c) = req.concurrency {
options = options.concurrency(c);
}
if let Some(d) = req.deadline_secs {
options = options.deadline(Duration::from_secs(d));
}
let username_str = prev_handle.username().to_owned();
let username = Username::new(username_str.clone())
.map_err(|e| ApiError::bad_request("invalid_username", e.to_string()))?;
let id = ScanId::new();
let site_count = new_sites.len();
let handle = ScanHandle::new(username_str.clone(), site_count, site_count.max(64));
state.insert_scan(id.clone(), handle.clone()).await;
handle.extend_outcomes(carried.clone()).await;
let persist_ctx = state
.scans_dir
.as_ref()
.map(|dir| crate::scan::PersistContext {
scan_id: id.clone(),
dir: dir.clone(),
});
let scan_client: Arc<adler_core::Client> = if req.egress_names.is_empty() {
state.client.clone()
} else {
Arc::new(state.client.with_egress_subset(&req.egress_names))
};
let task = crate::scan::spawn(
handle,
scan_client,
Arc::from(sites_to_probe.into_boxed_slice()),
username,
options,
persist_ctx,
);
state.register_scan_task(id.clone(), task).await;
Ok(Json(RefilterResponse {
scan_id: id,
derived_from: prev_id,
carried_outcomes: carried.len(),
site_count,
}))
}
fn disabled_matches(catalog: &[Site], req: &StartScanRequest) -> Vec<DisabledSiteSummary> {
super::filter::disabled_matches(catalog, req)
.iter()
.map(DisabledSiteSummary::from)
.collect()
}
fn empty_filter_message(disabled_empty: bool) -> &'static str {
if disabled_empty {
"no sites match the requested filter"
} else {
"no enabled sites match the requested filter"
}
}
pub(super) async fn get_scan(
State(state): State<AppState>,
AxumPath(id): AxumPath<String>,
) -> Result<Json<ScanSnapshot>, ApiError> {
let scan_id = ScanId::from(id);
if let Some(scan) = state.get_scan(&scan_id).await {
return Ok(match scan.finished().await {
Some(finished) => Json(ScanSnapshot::Finished {
username: scan.username().to_owned(),
site_count: scan.site_count(),
finished,
}),
None => Json(ScanSnapshot::Running {
username: scan.username().to_owned(),
site_count: scan.site_count(),
elapsed_ms: u64::try_from(scan.elapsed().as_millis()).unwrap_or(u64::MAX),
partial: scan.outcomes_snapshot().await,
}),
});
}
if let Some(dir) = &state.scans_dir {
if let Some(ps) = crate::persist::load(dir, &scan_id).await {
return Ok(Json(ScanSnapshot::Finished {
username: ps.username,
site_count: ps.site_count,
finished: crate::scan::FinishedScan {
summary: ps.summary,
outcomes: ps.outcomes,
elapsed_ms: ps.elapsed_ms,
},
}));
}
}
Err(ApiError::not_found(
"scan_not_found",
"no scan with that ID",
))
}
type SseStream = std::pin::Pin<Box<dyn Stream<Item = Result<Event, Infallible>> + Send>>;
pub(super) async fn retry_site(
State(state): State<AppState>,
AxumPath(id): AxumPath<String>,
Json(req): Json<RetryRequest>,
) -> Result<Json<RetryResponse>, ApiError> {
let scan_id = ScanId::from(id);
let username_raw: String = if let Some(handle) = state.get_scan(&scan_id).await {
handle.username().to_owned()
} else if let Some(dir) = &state.scans_dir {
if let Some(ps) = crate::persist::load(dir, &scan_id).await {
ps.username
} else {
return Err(ApiError::not_found(
"scan_not_found",
"no scan with that ID",
));
}
} else {
return Err(ApiError::not_found(
"scan_not_found",
"no scan with that ID",
));
};
let site = state
.sites
.iter()
.find(|s| s.name.eq_ignore_ascii_case(&req.site))
.cloned()
.ok_or_else(|| {
ApiError::bad_request("site_not_in_catalog", "site not in current catalog")
})?;
let username = Username::new(username_raw.clone())
.map_err(|e| ApiError::bad_request("invalid_username", e.to_string()))?;
let new_outcome = state.client.check(&site, &username).await;
if let Some(handle) = state.get_scan(&scan_id).await {
handle.replace_outcome(new_outcome.clone()).await;
if let (Some(finished), Some(dir)) = (handle.finished().await, &state.scans_dir) {
let snap = crate::persist::PersistedScan::from_finished(
scan_id.clone(),
handle.username().to_owned(),
handle.site_count(),
handle.created_at_ms(),
finished,
);
if let Err(err) = crate::persist::save(dir, &snap).await {
tracing::warn!(error = %err, scan_id = %scan_id, "failed to re-persist scan");
}
}
} else if let Some(dir) = &state.scans_dir {
if let Some(mut ps) = crate::persist::load(dir, &scan_id).await {
if let Some(slot) = ps.outcomes.iter_mut().find(|o| o.site == new_outcome.site) {
*slot = new_outcome.clone();
} else {
ps.outcomes.push(new_outcome.clone());
}
ps.summary = crate::scan::Summary::from_outcomes(&ps.outcomes);
if let Err(err) = crate::persist::save(dir, &ps).await {
tracing::warn!(error = %err, scan_id = %scan_id, "failed to patch persisted scan");
}
}
}
Ok(Json(RetryResponse {
outcome: new_outcome,
}))
}
pub(super) async fn stream_scan(
State(state): State<AppState>,
AxumPath(id): AxumPath<String>,
) -> Result<Sse<KeepAliveStream<SseStream>>, ApiError> {
let scan_id = ScanId::from(id);
if let Some(scan) = state.get_scan(&scan_id).await {
let stream: SseStream = Box::pin(scan_event_stream(scan));
return Ok(Sse::new(stream).keep_alive(KeepAlive::new()));
}
if let Some(dir) = &state.scans_dir {
if let Some(ps) = crate::persist::load(dir, &scan_id).await {
let stream: SseStream = Box::pin(persisted_event_stream(ps));
return Ok(Sse::new(stream).keep_alive(KeepAlive::new()));
}
}
Err(ApiError::not_found(
"scan_not_found",
"no scan with that ID",
))
}
fn persisted_event_stream(
ps: crate::persist::PersistedScan,
) -> impl Stream<Item = Result<Event, Infallible>> + Send {
let username = ps.username.clone();
let outcomes = ps.outcomes.clone();
let finished = crate::scan::FinishedScan {
summary: ps.summary,
outcomes: ps.outcomes,
elapsed_ms: ps.elapsed_ms,
};
stream! {
yield Ok(Event::default()
.event("start")
.json_data(StartEvent { username })
.unwrap_or_default());
for o in &outcomes {
yield Ok(outcome_event(o));
}
yield Ok(Event::default()
.event("done")
.json_data(&finished)
.unwrap_or_default());
}
}
fn scan_event_stream(scan: ScanHandle) -> impl Stream<Item = Result<Event, Infallible>> {
stream! {
yield Ok(Event::default()
.event("start")
.json_data(StartEvent { username: scan.username().to_owned() })
.unwrap_or_default());
let history = scan.outcomes_snapshot().await;
let mut last_index = history.len();
for outcome in &history {
yield Ok(outcome_event(outcome));
}
if scan.finished().await.is_none() {
let mut rx = scan.subscribe();
loop {
tokio::select! {
biased;
() = scan.wait_done() => break,
recv = rx.recv() => match recv {
Ok(idx) => {
let snap = scan.outcomes_snapshot().await;
for outcome in &snap[last_index..=idx.min(snap.len().saturating_sub(1))] {
yield Ok(outcome_event(outcome));
}
last_index = idx + 1;
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
let snap = scan.outcomes_snapshot().await;
for outcome in &snap[last_index..] {
yield Ok(outcome_event(outcome));
}
last_index = snap.len();
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
}
}
}
}
let final_snap = scan.outcomes_snapshot().await;
for outcome in &final_snap[last_index..] {
yield Ok(outcome_event(outcome));
}
if let Some(finished) = scan.finished().await {
yield Ok(Event::default()
.event("done")
.json_data(&finished)
.unwrap_or_default());
}
}
}
fn outcome_event(outcome: &adler_core::CheckOutcome) -> Event {
Event::default()
.event("outcome")
.json_data(outcome)
.unwrap_or_default()
}