use std::collections::HashMap;
use std::convert::Infallible;
use std::sync::Arc;
use std::time::Duration;
use adler_core::{CheckOutcome, ExecutorOptions, Site, Username};
use async_stream::stream;
use axum::Json;
use axum::Router;
use axum::extract::{Path as AxumPath, State};
use axum::http::StatusCode;
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use futures::Stream;
use serde::{Deserialize, Serialize};
use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer;
use crate::scan::{FinishedScan, ScanHandle, ScanId};
use crate::state::AppState;
pub fn router(state: AppState) -> Router {
Router::new()
.route("/api/health", get(health))
.route("/api/sites", get(list_sites))
.route("/api/scans", get(list_scans))
.route("/api/scan", post(start_scan))
.route("/api/scan/:id", get(get_scan))
.route("/api/scan/:id/stream", get(stream_scan))
.route("/api/scan/:id/retry", post(retry_site))
.layer(
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any),
)
.layer(TraceLayer::new_for_http())
.with_state(state)
}
#[derive(Serialize)]
struct Health {
ok: bool,
version: &'static str,
}
async fn health() -> Json<Health> {
Json(Health {
ok: true,
version: env!("CARGO_PKG_VERSION"),
})
}
#[derive(Serialize)]
struct SiteSummary {
name: String,
url: String,
tags: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
popularity: Option<u32>,
}
impl From<&Site> for SiteSummary {
fn from(s: &Site) -> Self {
Self {
name: s.name.clone(),
url: s.url.as_str().to_owned(),
tags: s.tags.clone(),
popularity: s.popularity,
}
}
}
async fn list_sites(State(state): State<AppState>) -> Json<Vec<SiteSummary>> {
Json(state.sites.iter().map(SiteSummary::from).collect())
}
#[derive(Serialize)]
struct ScanListEntry {
scan_id: ScanId,
username: String,
site_count: usize,
started_at_ms: u64,
elapsed_ms: u64,
status: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
summary: Option<crate::scan::Summary>,
}
async fn list_scans(State(state): State<AppState>) -> Json<Vec<ScanListEntry>> {
let handles: Vec<(ScanId, ScanHandle)> = {
let scans = state.scans.read().await;
scans
.iter()
.map(|(id, h)| (id.clone(), h.clone()))
.collect()
};
let mut by_id: HashMap<ScanId, ScanListEntry> = HashMap::with_capacity(handles.len());
for (id, handle) in handles {
let finished = handle.finished().await;
by_id.insert(
id.clone(),
ScanListEntry {
scan_id: id,
username: handle.username().to_owned(),
site_count: handle.site_count(),
started_at_ms: handle.created_at_ms(),
elapsed_ms: u64::try_from(handle.elapsed().as_millis()).unwrap_or(u64::MAX),
status: if finished.is_some() {
"finished"
} else {
"running"
},
summary: finished.map(|f| f.summary),
},
);
}
if let Some(dir) = &state.scans_dir {
for ps in crate::persist::load_all(dir).await {
by_id.entry(ps.scan_id.clone()).or_insert(ScanListEntry {
scan_id: ps.scan_id,
username: ps.username,
site_count: ps.site_count,
started_at_ms: ps.created_at_ms,
elapsed_ms: ps.elapsed_ms,
status: "finished",
summary: Some(ps.summary),
});
}
}
let mut entries: Vec<ScanListEntry> = by_id.into_values().collect();
entries.sort_by_key(|e| std::cmp::Reverse(e.started_at_ms));
Json(entries)
}
#[derive(Debug, Deserialize, Default)]
struct StartScanRequest {
username: String,
#[serde(default)]
only: Vec<String>,
#[serde(default)]
exclude: Vec<String>,
#[serde(default)]
tag: Vec<String>,
#[serde(default)]
exclude_tag: Vec<String>,
#[serde(default)]
top: Option<u32>,
#[serde(default)]
nsfw: bool,
#[serde(default)]
concurrency: Option<std::num::NonZeroUsize>,
#[serde(default)]
deadline_secs: Option<u64>,
}
#[derive(Serialize)]
struct StartScanResponse {
scan_id: ScanId,
username: String,
site_count: usize,
}
fn filter_catalog(catalog: &[Site], req: &StartScanRequest) -> Vec<Site> {
let only_lc: Vec<String> = req.only.iter().map(|s| s.to_lowercase()).collect();
let exclude_lc: Vec<String> = req.exclude.iter().map(|s| s.to_lowercase()).collect();
let tag_set: std::collections::HashSet<&str> = req.tag.iter().map(String::as_str).collect();
let exclude_tag_set: std::collections::HashSet<&str> =
req.exclude_tag.iter().map(String::as_str).collect();
let mut filtered: Vec<Site> = catalog
.iter()
.filter(|s| {
let name_lc = s.name.to_lowercase();
if !only_lc.is_empty() && !only_lc.iter().any(|n| name_lc.contains(n)) {
return false;
}
if exclude_lc.iter().any(|n| name_lc.contains(n)) {
return false;
}
if !tag_set.is_empty() {
if s.tags.is_empty() {
return false;
}
if !s.tags.iter().any(|t| tag_set.contains(t.as_str())) {
return false;
}
}
if s.tags.iter().any(|t| exclude_tag_set.contains(t.as_str())) {
return false;
}
if !req.nsfw && s.tags.iter().any(|t| t == "nsfw") {
return false;
}
true
})
.cloned()
.collect();
if let Some(n) = req.top {
filtered.retain(|s| s.popularity.is_some_and(|p| p <= n));
filtered.sort_by_key(|s| s.popularity.unwrap_or(u32::MAX));
}
filtered
}
async fn start_scan(
State(state): State<AppState>,
Json(req): Json<StartScanRequest>,
) -> Result<Json<StartScanResponse>, ApiError> {
let username = Username::new(req.username.clone())
.map_err(|e| ApiError::bad_request("invalid_username", e.to_string()))?;
let sites = filter_catalog(&state.sites, &req);
if sites.is_empty() {
return Err(ApiError::bad_request(
"empty_site_filter",
"no sites match the requested filter",
));
}
let mut options = ExecutorOptions::default();
if let Some(c) = req.concurrency {
options = options.concurrency(c);
}
if let Some(d) = req.deadline_secs {
options = options.deadline(Duration::from_secs(d));
}
let id = ScanId::new();
let site_count = sites.len();
let handle = ScanHandle::new(req.username.clone(), site_count, site_count.max(64));
state.insert_scan(id.clone(), handle.clone()).await;
let persist_ctx = state
.scans_dir
.as_ref()
.map(|dir| crate::scan::PersistContext {
scan_id: id.clone(),
dir: dir.clone(),
});
crate::scan::spawn(
handle,
state.client.clone(),
Arc::from(sites.into_boxed_slice()),
username,
options,
persist_ctx,
);
Ok(Json(StartScanResponse {
scan_id: id,
username: req.username,
site_count,
}))
}
#[derive(Serialize)]
#[serde(tag = "status", rename_all = "snake_case")]
enum ScanSnapshot {
Running {
username: String,
site_count: usize,
elapsed_ms: u64,
partial: Vec<adler_core::CheckOutcome>,
},
Finished {
username: String,
site_count: usize,
#[serde(flatten)]
finished: FinishedScan,
},
}
async fn get_scan(
State(state): State<AppState>,
AxumPath(id): AxumPath<String>,
) -> Result<Json<ScanSnapshot>, ApiError> {
let scan_id = ScanId::from(id);
if let Some(scan) = state.get_scan(&scan_id).await {
return Ok(match scan.finished().await {
Some(finished) => Json(ScanSnapshot::Finished {
username: scan.username().to_owned(),
site_count: scan.site_count(),
finished,
}),
None => Json(ScanSnapshot::Running {
username: scan.username().to_owned(),
site_count: scan.site_count(),
elapsed_ms: u64::try_from(scan.elapsed().as_millis()).unwrap_or(u64::MAX),
partial: scan.outcomes_snapshot().await,
}),
});
}
if let Some(dir) = &state.scans_dir {
if let Some(ps) = crate::persist::load(dir, &scan_id).await {
return Ok(Json(ScanSnapshot::Finished {
username: ps.username,
site_count: ps.site_count,
finished: crate::scan::FinishedScan {
summary: ps.summary,
outcomes: ps.outcomes,
elapsed_ms: ps.elapsed_ms,
},
}));
}
}
Err(ApiError::not_found(
"scan_not_found",
"no scan with that ID",
))
}
type SseStream = std::pin::Pin<Box<dyn Stream<Item = Result<Event, Infallible>> + Send>>;
#[derive(Debug, Deserialize)]
struct RetryRequest {
site: String,
}
#[derive(Serialize)]
struct RetryResponse {
outcome: CheckOutcome,
}
async fn retry_site(
State(state): State<AppState>,
AxumPath(id): AxumPath<String>,
Json(req): Json<RetryRequest>,
) -> Result<Json<RetryResponse>, ApiError> {
let scan_id = ScanId::from(id);
let username_raw: String = if let Some(handle) = state.get_scan(&scan_id).await {
handle.username().to_owned()
} else if let Some(dir) = &state.scans_dir {
if let Some(ps) = crate::persist::load(dir, &scan_id).await {
ps.username
} else {
return Err(ApiError::not_found(
"scan_not_found",
"no scan with that ID",
));
}
} else {
return Err(ApiError::not_found(
"scan_not_found",
"no scan with that ID",
));
};
let site = state
.sites
.iter()
.find(|s| s.name.eq_ignore_ascii_case(&req.site))
.cloned()
.ok_or_else(|| {
ApiError::bad_request("site_not_in_catalog", "site not in current catalog")
})?;
let username = Username::new(username_raw.clone())
.map_err(|e| ApiError::bad_request("invalid_username", e.to_string()))?;
let new_outcome = state.client.check(&site, &username).await;
if let Some(handle) = state.get_scan(&scan_id).await {
handle.replace_outcome(new_outcome.clone()).await;
if let (Some(finished), Some(dir)) = (handle.finished().await, &state.scans_dir) {
let snap = crate::persist::PersistedScan::from_finished(
scan_id.clone(),
handle.username().to_owned(),
handle.site_count(),
handle.created_at_ms(),
finished,
);
if let Err(err) = crate::persist::save(dir, &snap).await {
tracing::warn!(error = %err, scan_id = %scan_id, "failed to re-persist scan");
}
}
} else if let Some(dir) = &state.scans_dir {
if let Some(mut ps) = crate::persist::load(dir, &scan_id).await {
if let Some(slot) = ps.outcomes.iter_mut().find(|o| o.site == new_outcome.site) {
*slot = new_outcome.clone();
} else {
ps.outcomes.push(new_outcome.clone());
}
ps.summary = crate::scan::Summary::from_outcomes(&ps.outcomes);
if let Err(err) = crate::persist::save(dir, &ps).await {
tracing::warn!(error = %err, scan_id = %scan_id, "failed to patch persisted scan");
}
}
}
Ok(Json(RetryResponse {
outcome: new_outcome,
}))
}
async fn stream_scan(
State(state): State<AppState>,
AxumPath(id): AxumPath<String>,
) -> Result<Sse<SseStream>, ApiError> {
let scan_id = ScanId::from(id);
if let Some(scan) = state.get_scan(&scan_id).await {
let stream: SseStream = Box::pin(scan_event_stream(scan));
return Ok(Sse::new(stream).keep_alive(KeepAlive::new()));
}
if let Some(dir) = &state.scans_dir {
if let Some(ps) = crate::persist::load(dir, &scan_id).await {
let stream: SseStream = Box::pin(persisted_event_stream(ps));
return Ok(Sse::new(stream).keep_alive(KeepAlive::new()));
}
}
Err(ApiError::not_found(
"scan_not_found",
"no scan with that ID",
))
}
fn persisted_event_stream(
ps: crate::persist::PersistedScan,
) -> impl Stream<Item = Result<Event, Infallible>> + Send {
let username = ps.username.clone();
let outcomes = ps.outcomes.clone();
let finished = crate::scan::FinishedScan {
summary: ps.summary,
outcomes: ps.outcomes,
elapsed_ms: ps.elapsed_ms,
};
stream! {
yield Ok(Event::default()
.event("start")
.json_data(StartEvent { username })
.unwrap_or_default());
for o in &outcomes {
yield Ok(outcome_event(o));
}
yield Ok(Event::default()
.event("done")
.json_data(&finished)
.unwrap_or_default());
}
}
fn scan_event_stream(scan: ScanHandle) -> impl Stream<Item = Result<Event, Infallible>> {
stream! {
yield Ok(Event::default()
.event("start")
.json_data(StartEvent { username: scan.username().to_owned() })
.unwrap_or_default());
let history = scan.outcomes_snapshot().await;
let mut last_index = history.len();
for outcome in &history {
yield Ok(outcome_event(outcome));
}
if scan.finished().await.is_none() {
let mut rx = scan.subscribe();
loop {
tokio::select! {
biased;
() = scan.wait_done() => break,
recv = rx.recv() => match recv {
Ok(idx) => {
let snap = scan.outcomes_snapshot().await;
for outcome in &snap[last_index..=idx.min(snap.len().saturating_sub(1))] {
yield Ok(outcome_event(outcome));
}
last_index = idx + 1;
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
let snap = scan.outcomes_snapshot().await;
for outcome in &snap[last_index..] {
yield Ok(outcome_event(outcome));
}
last_index = snap.len();
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
}
}
}
}
let final_snap = scan.outcomes_snapshot().await;
for outcome in &final_snap[last_index..] {
yield Ok(outcome_event(outcome));
}
if let Some(finished) = scan.finished().await {
yield Ok(Event::default()
.event("done")
.json_data(&finished)
.unwrap_or_default());
}
}
}
fn outcome_event(outcome: &adler_core::CheckOutcome) -> Event {
Event::default()
.event("outcome")
.json_data(outcome)
.unwrap_or_default()
}
#[derive(Serialize)]
struct StartEvent {
username: String,
}
#[derive(Debug, Serialize)]
struct ApiError {
#[serde(skip)]
status: StatusCode,
error: &'static str,
message: String,
}
impl ApiError {
fn bad_request(code: &'static str, msg: impl Into<String>) -> Self {
Self {
status: StatusCode::BAD_REQUEST,
error: code,
message: msg.into(),
}
}
fn not_found(code: &'static str, msg: impl Into<String>) -> Self {
Self {
status: StatusCode::NOT_FOUND,
error: code,
message: msg.into(),
}
}
}
impl IntoResponse for ApiError {
fn into_response(self) -> Response {
let status = self.status;
(status, Json(self)).into_response()
}
}
#[cfg(test)]
mod tests {
use super::*;
use adler_core::{Client, KnownPresent, Signal, UrlTemplate};
use axum::body::{Body, to_bytes};
use axum::http::{Request, header};
use tower::ServiceExt;
use wiremock::matchers::{any, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn site(name: &str, base: &str, segment: &str) -> Site {
Site {
name: name.into(),
url: UrlTemplate::new(format!("{base}/{segment}/{{username}}")).unwrap(),
signals: vec![
Signal::StatusFound { codes: vec![200] },
Signal::StatusNotFound { codes: vec![404] },
],
known_present: None::<KnownPresent>,
known_absent: None,
extract: Vec::new(),
tags: Vec::new(),
request_headers: std::collections::BTreeMap::new(),
regex_check: None,
engine: None,
strip_bad_char: None,
request_method: adler_core::HttpMethod::Get,
request_body: None,
protection: Vec::new(),
disabled: false,
source: None,
popularity: None,
access: adler_core::AccessPolicy::default(),
}
}
async fn test_app() -> (Router, MockServer) {
let mock = MockServer::start().await;
Mock::given(any())
.and(path("/a/alice"))
.respond_with(ResponseTemplate::new(200))
.mount(&mock)
.await;
Mock::given(any())
.and(path("/b/alice"))
.respond_with(ResponseTemplate::new(404))
.mount(&mock)
.await;
let sites = vec![site("A", &mock.uri(), "a"), site("B", &mock.uri(), "b")];
let client = Client::builder()
.timeout(Duration::from_secs(2))
.min_request_interval(Duration::ZERO)
.build()
.unwrap();
let state = AppState::new(sites, client, 16);
(router(state), mock)
}
#[tokio::test]
async fn health_returns_ok() {
let (app, _mock) = test_app().await;
let resp = app
.oneshot(
Request::builder()
.uri("/api/health")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = to_bytes(resp.into_body(), 1024).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(v["ok"], true);
}
#[tokio::test]
async fn list_sites_returns_summary() {
let (app, _mock) = test_app().await;
let resp = app
.oneshot(
Request::builder()
.uri("/api/sites")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = to_bytes(resp.into_body(), 4096).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(v.as_array().unwrap().len(), 2);
assert_eq!(v[0]["name"], "A");
assert!(v[0]["url"].as_str().unwrap().contains("{username}"));
}
#[tokio::test]
async fn start_scan_rejects_invalid_username() {
let (app, _mock) = test_app().await;
let resp = app
.oneshot(
Request::builder()
.method("POST")
.uri("/api/scan")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(r#"{"username":" bad "}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body = to_bytes(resp.into_body(), 1024).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(v["error"], "invalid_username");
}
#[tokio::test]
async fn start_then_poll_finishes_with_expected_counts() {
let (app, _mock) = test_app().await;
let resp = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/api/scan")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(r#"{"username":"alice"}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = to_bytes(resp.into_body(), 4096).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
let scan_id = v["scan_id"].as_str().unwrap().to_owned();
assert_eq!(v["site_count"], 2);
for _ in 0..50 {
tokio::time::sleep(Duration::from_millis(100)).await;
let r = app
.clone()
.oneshot(
Request::builder()
.uri(format!("/api/scan/{scan_id}"))
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(r.status(), StatusCode::OK);
let body = to_bytes(r.into_body(), 16384).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
if v["status"] == "finished" {
assert_eq!(v["summary"]["found"], 1);
assert_eq!(v["summary"]["not_found"], 1);
assert_eq!(v["outcomes"].as_array().unwrap().len(), 2);
return;
}
}
panic!("scan did not finish within 5s");
}
#[tokio::test]
async fn get_scan_404s_on_unknown_id() {
let (app, _mock) = test_app().await;
let resp = app
.oneshot(
Request::builder()
.uri("/api/scan/does-not-exist")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
let body = to_bytes(resp.into_body(), 1024).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(v["error"], "scan_not_found");
}
fn tagged_site(name: &str, base: &str, segment: &str, tags: &[&str]) -> Site {
let mut s = site(name, base, segment);
s.tags = tags.iter().map(|t| (*t).to_owned()).collect();
s
}
#[test]
fn filter_catalog_honours_only_exclude() {
let sites = vec![
site("GitHub", "http://x", "gh"),
site("GitLab", "http://x", "gl"),
site("Bitbucket", "http://x", "bb"),
];
let only = StartScanRequest {
only: vec!["git".into()],
..Default::default()
};
let names: Vec<_> = filter_catalog(&sites, &only)
.into_iter()
.map(|s| s.name)
.collect();
assert_eq!(names, vec!["GitHub", "GitLab"]);
let exclude = StartScanRequest {
exclude: vec!["lab".into()],
..Default::default()
};
let names: Vec<_> = filter_catalog(&sites, &exclude)
.into_iter()
.map(|s| s.name)
.collect();
assert_eq!(names, vec!["GitHub", "Bitbucket"]);
}
#[test]
fn filter_catalog_honours_tags_and_nsfw() {
let sites = vec![
tagged_site("A", "http://x", "a", &["social"]),
tagged_site("B", "http://x", "b", &["dev"]),
tagged_site("C", "http://x", "c", &["social", "nsfw"]),
tagged_site("D", "http://x", "d", &[]),
];
let only_social = StartScanRequest {
tag: vec!["social".into()],
..Default::default()
};
let names: Vec<_> = filter_catalog(&sites, &only_social)
.into_iter()
.map(|s| s.name)
.collect();
assert_eq!(names, vec!["A"]);
let with_nsfw = StartScanRequest {
tag: vec!["social".into()],
nsfw: true,
..Default::default()
};
let names: Vec<_> = filter_catalog(&sites, &with_nsfw)
.into_iter()
.map(|s| s.name)
.collect();
assert_eq!(names, vec!["A", "C"]);
let exclude_dev = StartScanRequest {
exclude_tag: vec!["dev".into()],
..Default::default()
};
let names: Vec<_> = filter_catalog(&sites, &exclude_dev)
.into_iter()
.map(|s| s.name)
.collect();
assert_eq!(names, vec!["A", "D"]);
}
#[test]
fn filter_catalog_top_sorts_by_popularity() {
let mut a = site("A", "http://x", "a");
a.popularity = Some(3);
let mut b = site("B", "http://x", "b");
b.popularity = Some(1);
let mut c = site("C", "http://x", "c");
c.popularity = Some(2);
let d = site("D", "http://x", "d"); let sites = vec![a, b, c, d];
let req = StartScanRequest {
top: Some(2),
..Default::default()
};
let names: Vec<_> = filter_catalog(&sites, &req)
.into_iter()
.map(|s| s.name)
.collect();
assert_eq!(names, vec!["B", "C"]);
}
#[tokio::test]
async fn start_scan_with_tag_filter_only_runs_matching_sites() {
let mock = MockServer::start().await;
Mock::given(any())
.and(path("/a/alice"))
.respond_with(ResponseTemplate::new(200))
.mount(&mock)
.await;
Mock::given(any())
.and(path("/b/alice"))
.respond_with(ResponseTemplate::new(404))
.mount(&mock)
.await;
let sites = vec![
tagged_site("A", &mock.uri(), "a", &["social"]),
tagged_site("B", &mock.uri(), "b", &["dev"]),
];
let client = Client::builder()
.timeout(Duration::from_secs(2))
.min_request_interval(Duration::ZERO)
.build()
.unwrap();
let state = AppState::new(sites, client, 16);
let app = router(state);
let resp = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/api/scan")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(r#"{"username":"alice","tag":["social"]}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = to_bytes(resp.into_body(), 4096).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(v["site_count"], 1);
}
#[tokio::test]
async fn empty_filter_returns_bad_request() {
let (app, _mock) = test_app().await;
let resp = app
.oneshot(
Request::builder()
.method("POST")
.uri("/api/scan")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(
r#"{"username":"alice","only":["definitely-not-a-site"]}"#,
))
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body = to_bytes(resp.into_body(), 1024).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(v["error"], "empty_site_filter");
}
#[tokio::test]
async fn retry_flips_outcome_when_response_changes() {
let mock = MockServer::start().await;
Mock::given(any())
.and(path("/a/alice"))
.respond_with(ResponseTemplate::new(404))
.up_to_n_times(1)
.mount(&mock)
.await;
Mock::given(any())
.and(path("/a/alice"))
.respond_with(ResponseTemplate::new(200))
.mount(&mock)
.await;
let sites = vec![site("A", &mock.uri(), "a")];
let client = Client::builder()
.timeout(Duration::from_secs(2))
.min_request_interval(Duration::ZERO)
.build()
.unwrap();
let state = AppState::new(sites, client, 16);
let app = router(state);
let r = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/api/scan")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(r#"{"username":"alice"}"#))
.unwrap(),
)
.await
.unwrap();
let body = to_bytes(r.into_body(), 4096).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
let scan_id = v["scan_id"].as_str().unwrap().to_owned();
let mut finished = false;
for _ in 0..60 {
tokio::time::sleep(Duration::from_millis(60)).await;
let r = app
.clone()
.oneshot(
Request::builder()
.uri(format!("/api/scan/{scan_id}"))
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
let body = to_bytes(r.into_body(), 8192).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
if v["status"] == "finished" {
assert_eq!(v["summary"]["not_found"], 1);
finished = true;
break;
}
}
assert!(finished, "scan did not finish");
let r = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri(format!("/api/scan/{scan_id}/retry"))
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(r#"{"site":"A"}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(r.status(), StatusCode::OK);
let body = to_bytes(r.into_body(), 4096).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(v["outcome"]["site"], "A");
assert_eq!(v["outcome"]["kind"], "found");
let r = app
.oneshot(
Request::builder()
.uri(format!("/api/scan/{scan_id}"))
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
let body = to_bytes(r.into_body(), 16384).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(v["summary"]["found"], 1);
assert_eq!(v["summary"]["not_found"], 0);
}
#[tokio::test]
async fn retry_404s_unknown_site_or_scan() {
let (app, _mock) = test_app().await;
let r = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/api/scan/nope/retry")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(r#"{"site":"A"}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(r.status(), StatusCode::NOT_FOUND);
let r = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/api/scan")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(r#"{"username":"alice"}"#))
.unwrap(),
)
.await
.unwrap();
let body = to_bytes(r.into_body(), 4096).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
let scan_id = v["scan_id"].as_str().unwrap().to_owned();
let r = app
.oneshot(
Request::builder()
.method("POST")
.uri(format!("/api/scan/{scan_id}/retry"))
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(r#"{"site":"NoSuch"}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(r.status(), StatusCode::BAD_REQUEST);
let body = to_bytes(r.into_body(), 1024).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(v["error"], "site_not_in_catalog");
}
#[tokio::test]
async fn list_scans_returns_newest_first() {
let (app, _mock) = test_app().await;
for _ in 0..2 {
let r = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/api/scan")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(r#"{"username":"alice"}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(r.status(), StatusCode::OK);
tokio::time::sleep(Duration::from_millis(5)).await;
}
let resp = app
.oneshot(
Request::builder()
.uri("/api/scans")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = to_bytes(resp.into_body(), 4096).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
let arr = v.as_array().unwrap();
assert_eq!(arr.len(), 2);
assert!(
arr[0]["started_at_ms"].as_u64() >= arr[1]["started_at_ms"].as_u64(),
"scans must be newest-first",
);
}
}