use crate::database::MonocleDatabase;
use crate::lens::inspect::{
DataRefreshSummary, InspectDataSection, InspectLens, InspectQueryOptions, InspectResult,
};
use crate::server::handler::{WsContext, WsError, WsMethod, WsRequest, WsResult};
use crate::server::op_sink::WsOpSink;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::sync::Arc;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct InspectQueryParams {
pub queries: Vec<String>,
#[serde(default)]
pub query_type: Option<String>,
#[serde(default)]
pub select: Option<Vec<String>>,
#[serde(default)]
pub max_roas: Option<usize>,
#[serde(default)]
pub max_prefixes: Option<usize>,
#[serde(default)]
pub max_neighbors: Option<usize>,
#[serde(default)]
pub max_search_results: Option<usize>,
#[serde(default)]
pub country: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct InspectDataRefreshProgress {
pub stage: String,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub source: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub count: Option<usize>,
}
#[derive(Debug, Clone, Serialize)]
pub struct InspectQueryResponse {
pub data_refreshed: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_summary: Option<DataRefreshSummary>,
pub result: InspectResult,
}
pub struct InspectQueryHandler;
#[async_trait]
impl WsMethod for InspectQueryHandler {
const METHOD: &'static str = "inspect.query";
const IS_STREAMING: bool = true;
type Params = InspectQueryParams;
fn validate(params: &Self::Params) -> WsResult<()> {
if params.queries.is_empty() && params.country.is_none() {
return Err(WsError::invalid_params(
"At least one query or country filter is required",
));
}
if let Some(ref qt) = params.query_type {
match qt.to_lowercase().as_str() {
"asn" | "prefix" | "name" => {}
_ => {
return Err(WsError::invalid_params(format!(
"Invalid query_type: {}. Use 'asn', 'prefix', or 'name'",
qt
)));
}
}
}
if let Some(ref sections) = params.select {
for section in sections {
let s_lower = section.to_lowercase();
if s_lower != "all" && InspectDataSection::from_str(&s_lower).is_none() {
return Err(WsError::invalid_params(format!(
"Invalid section: {}. Available: {}",
section,
InspectDataSection::all_names().join(", ")
)));
}
}
}
Ok(())
}
async fn handle(
ctx: Arc<WsContext>,
_req: WsRequest,
params: Self::Params,
sink: WsOpSink,
) -> WsResult<()> {
let options = build_query_options(¶ms);
let (refresh_summary, result): (DataRefreshSummary, InspectResult) = {
let db = MonocleDatabase::open_in_dir(ctx.data_dir())
.map_err(|e| WsError::internal(format!("Failed to open database: {}", e)))?;
let lens = InspectLens::new(&db, &ctx.config);
let refresh_summary = lens
.ensure_data_available()
.map_err(|e| WsError::operation_failed(format!("Failed to ensure data: {}", e)))?;
let result = if let Some(ref country) = params.country {
lens.query_by_country(country, &options)
.map_err(|e| WsError::operation_failed(e.to_string()))?
} else if let Some(ref query_type) = params.query_type {
match query_type.to_lowercase().as_str() {
"asn" => lens
.query_as_asn(¶ms.queries, &options)
.map_err(|e| WsError::operation_failed(e.to_string()))?,
"prefix" => lens
.query_as_prefix(¶ms.queries, &options)
.map_err(|e| WsError::operation_failed(e.to_string()))?,
"name" => lens
.query_as_name(¶ms.queries, &options)
.map_err(|e| WsError::operation_failed(e.to_string()))?,
_ => {
return Err(WsError::invalid_params(format!(
"Invalid query_type: {}",
query_type
)));
}
}
} else {
lens.query(¶ms.queries, &options)
.map_err(|e| WsError::operation_failed(e.to_string()))?
};
(refresh_summary, result)
};
if refresh_summary.any_refreshed {
for refresh in &refresh_summary.sources {
if refresh.refreshed {
sink.send_progress(InspectDataRefreshProgress {
stage: "refreshed".to_string(),
message: refresh.message.clone(),
source: Some(refresh.source.clone()),
count: refresh.count,
})
.await
.map_err(|e| WsError::internal(e.to_string()))?;
}
}
}
let response = InspectQueryResponse {
data_refreshed: refresh_summary.any_refreshed,
refresh_summary: if refresh_summary.any_refreshed {
Some(refresh_summary)
} else {
None
},
result,
};
sink.send_result(response)
.await
.map_err(|e| WsError::internal(e.to_string()))?;
Ok(())
}
}
fn build_query_options(params: &InspectQueryParams) -> InspectQueryOptions {
let mut options = InspectQueryOptions::default();
if let Some(ref sections) = params.select {
let mut selected = HashSet::new();
for s in sections {
let s_lower = s.to_lowercase();
if s_lower == "all" {
selected.extend(InspectDataSection::all());
} else if let Some(section) = InspectDataSection::from_str(&s_lower) {
selected.insert(section);
}
}
if !selected.is_empty() {
options.select = Some(selected);
}
}
if let Some(max_roas) = params.max_roas {
options.max_roas = max_roas;
}
if let Some(max_prefixes) = params.max_prefixes {
options.max_prefixes = max_prefixes;
}
if let Some(max_neighbors) = params.max_neighbors {
options.max_neighbors = max_neighbors;
}
if let Some(max_search_results) = params.max_search_results {
options.max_search_results = max_search_results;
}
options
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct InspectRefreshParams {
#[serde(default)]
pub force: bool,
}
#[derive(Debug, Clone, Serialize)]
pub struct InspectRefreshResponse {
pub summary: DataRefreshSummary,
}
pub struct InspectRefreshHandler;
#[async_trait]
impl WsMethod for InspectRefreshHandler {
const METHOD: &'static str = "inspect.refresh";
const IS_STREAMING: bool = true;
type Params = InspectRefreshParams;
async fn handle(
ctx: Arc<WsContext>,
_req: WsRequest,
_params: Self::Params,
sink: WsOpSink,
) -> WsResult<()> {
let summary: DataRefreshSummary = {
let db = MonocleDatabase::open_in_dir(ctx.data_dir())
.map_err(|e| WsError::internal(format!("Failed to open database: {}", e)))?;
let lens = InspectLens::new(&db, &ctx.config);
lens.ensure_data_available()
.map_err(|e| WsError::operation_failed(format!("Failed to refresh data: {}", e)))?
};
for refresh in &summary.sources {
sink.send_progress(InspectDataRefreshProgress {
stage: if refresh.refreshed {
"refreshed"
} else {
"skipped"
}
.to_string(),
message: refresh.message.clone(),
source: Some(refresh.source.clone()),
count: refresh.count,
})
.await
.map_err(|e| WsError::internal(e.to_string()))?;
}
let response = InspectRefreshResponse { summary };
sink.send_result(response)
.await
.map_err(|e| WsError::internal(e.to_string()))?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_inspect_query_params_default() {
let params: InspectQueryParams = serde_json::from_str(r#"{"queries": ["13335"]}"#).unwrap();
assert_eq!(params.queries, vec!["13335"]);
assert!(params.query_type.is_none());
assert!(params.select.is_none());
assert!(params.max_roas.is_none());
}
#[test]
fn test_inspect_query_params_full() {
let params: InspectQueryParams = serde_json::from_str(
r#"{
"queries": ["13335", "1.1.1.0/24"],
"query_type": "asn",
"select": ["core", "connectivity"],
"max_roas": 5,
"max_neighbors": 10
}"#,
)
.unwrap();
assert_eq!(params.queries.len(), 2);
assert_eq!(params.query_type, Some("asn".to_string()));
assert_eq!(
params.select,
Some(vec!["core".to_string(), "connectivity".to_string()])
);
assert_eq!(params.max_roas, Some(5));
assert_eq!(params.max_neighbors, Some(10));
}
#[test]
fn test_inspect_query_validation_empty_queries() {
let params = InspectQueryParams {
queries: vec![],
query_type: None,
select: None,
max_roas: None,
max_prefixes: None,
max_neighbors: None,
max_search_results: None,
country: None,
};
let result = InspectQueryHandler::validate(¶ms);
assert!(result.is_err());
}
#[test]
fn test_inspect_query_validation_with_country() {
let params = InspectQueryParams {
queries: vec![],
query_type: None,
select: None,
max_roas: None,
max_prefixes: None,
max_neighbors: None,
max_search_results: None,
country: Some("US".to_string()),
};
let result = InspectQueryHandler::validate(¶ms);
assert!(result.is_ok());
}
#[test]
fn test_inspect_query_validation_invalid_query_type() {
let params = InspectQueryParams {
queries: vec!["13335".to_string()],
query_type: Some("invalid".to_string()),
select: None,
max_roas: None,
max_prefixes: None,
max_neighbors: None,
max_search_results: None,
country: None,
};
let result = InspectQueryHandler::validate(¶ms);
assert!(result.is_err());
}
#[test]
fn test_inspect_query_validation_invalid_section() {
let params = InspectQueryParams {
queries: vec!["13335".to_string()],
query_type: None,
select: Some(vec!["invalid_section".to_string()]),
max_roas: None,
max_prefixes: None,
max_neighbors: None,
max_search_results: None,
country: None,
};
let result = InspectQueryHandler::validate(¶ms);
assert!(result.is_err());
}
#[test]
fn test_build_query_options_defaults() {
let params = InspectQueryParams {
queries: vec!["13335".to_string()],
query_type: None,
select: None,
max_roas: None,
max_prefixes: None,
max_neighbors: None,
max_search_results: None,
country: None,
};
let options = build_query_options(¶ms);
assert!(options.select.is_none());
assert_eq!(options.max_roas, 10); assert_eq!(options.max_prefixes, 10); assert_eq!(options.max_neighbors, 5); assert_eq!(options.max_search_results, 20); }
#[test]
fn test_build_query_options_with_select() {
let params = InspectQueryParams {
queries: vec!["13335".to_string()],
query_type: None,
select: Some(vec!["basic".to_string(), "rpki".to_string()]),
max_roas: Some(100),
max_prefixes: None,
max_neighbors: None,
max_search_results: None,
country: None,
};
let options = build_query_options(¶ms);
assert!(options.select.is_some());
let select = options.select.unwrap();
assert!(select.contains(&InspectDataSection::Basic));
assert!(select.contains(&InspectDataSection::Rpki));
assert_eq!(options.max_roas, 100);
}
#[test]
fn test_inspect_refresh_params_default() {
let params: InspectRefreshParams = serde_json::from_str(r#"{}"#).unwrap();
assert!(!params.force);
}
#[test]
fn test_data_refresh_progress_serialization() {
let progress = InspectDataRefreshProgress {
stage: "refreshing".to_string(),
message: "Refreshing RPKI data...".to_string(),
source: Some("rpki".to_string()),
count: Some(1000),
};
let json = serde_json::to_string(&progress).unwrap();
assert!(json.contains("\"stage\":\"refreshing\""));
assert!(json.contains("\"source\":\"rpki\""));
assert!(json.contains("\"count\":1000"));
}
}