use axum::{
extract::{Path, Query, State},
response::sse::{Event, KeepAlive, Sse},
};
use futures::stream::{self, Stream};
use std::convert::Infallible;
use std::time::Instant;
use std::sync::Arc;
use crate::AppState;
use super::types::{
StreamDoneEvent, StreamErrorEvent, StreamNodeEvent, StreamStatsEvent, StreamTraverseParams,
TraversalResultItem,
};
const STATS_INTERVAL: usize = 100;
#[utoipa::path(
get,
path = "/collections/{name}/graph/traverse/stream",
tag = "graph",
params(
("name" = String, Path, description = "Collection name"),
StreamTraverseParams
),
responses(
(status = 200, description = "SSE stream of traversal events (node, stats, done, error)")
)
)]
#[allow(clippy::unused_async)]
pub async fn stream_traverse(
State(state): State<Arc<AppState>>,
Path(collection): Path<String>,
Query(params): Query<StreamTraverseParams>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
use velesdb_core::collection::graph::TraversalConfig;
let start_time = Instant::now();
let rel_types: Vec<String> = params
.relationship_types
.map(|s| s.split(',').map(|t| t.trim().to_string()).collect())
.unwrap_or_default();
let traversal_result: Result<Vec<TraversalResultItem>, String> =
match state.db.get_graph_collection(&collection) {
None => Err(format!(
"Collection '{}' not found or is not a graph collection.",
collection
)),
Some(coll) => {
let config = TraversalConfig::with_range(1, params.max_depth)
.with_limit(params.limit)
.with_rel_types(rel_types);
let raw = match params.algorithm.to_lowercase().as_str() {
"dfs" => coll.traverse_dfs(params.start_node, &config),
_ => coll.traverse_bfs(params.start_node, &config),
};
Ok(raw
.into_iter()
.map(|r| TraversalResultItem {
target_id: r.target_id,
depth: r.depth,
path: r.path,
})
.collect())
}
};
let events = build_sse_events(traversal_result, start_time);
Sse::new(stream::iter(events)).keep_alive(KeepAlive::default())
}
fn build_sse_events(
traversal_result: Result<Vec<TraversalResultItem>, String>,
start_time: Instant,
) -> Vec<Result<Event, Infallible>> {
match traversal_result {
Ok(results) => build_success_events(results, start_time),
Err(e) => build_error_events(e),
}
}
fn build_success_events(
results: Vec<TraversalResultItem>,
start_time: Instant,
) -> Vec<Result<Event, Infallible>> {
let total = results.len();
let mut max_depth: u32 = 0;
let mut events: Vec<Result<Event, Infallible>> = Vec::with_capacity(total + 2);
for (i, item) in results.into_iter().enumerate() {
if item.depth > max_depth {
max_depth = item.depth;
}
let node_event = StreamNodeEvent {
id: item.target_id,
depth: item.depth,
path: item.path,
};
let event_data = serde_json::to_string(&node_event).unwrap_or_else(|_| "{}".to_string());
events.push(Ok(Event::default().event("node").data(event_data)));
if (i + 1) % STATS_INTERVAL == 0 {
let stats_event = StreamStatsEvent {
nodes_visited: i + 1,
elapsed_ms: elapsed_ms(start_time),
};
let stats_data =
serde_json::to_string(&stats_event).unwrap_or_else(|_| "{}".to_string());
events.push(Ok(Event::default().event("stats").data(stats_data)));
}
}
let done_event = StreamDoneEvent {
total_nodes: total,
max_depth_reached: max_depth,
elapsed_ms: elapsed_ms(start_time),
};
let done_data = serde_json::to_string(&done_event).unwrap_or_else(|_| "{}".to_string());
events.push(Ok(Event::default().event("done").data(done_data)));
events
}
fn build_error_events(error: String) -> Vec<Result<Event, Infallible>> {
let error_event = StreamErrorEvent { error };
let error_data = serde_json::to_string(&error_event).unwrap_or_else(|_| "{}".to_string());
vec![Ok(Event::default().event("error").data(error_data))]
}
#[inline]
fn elapsed_ms(start_time: Instant) -> u64 {
start_time.elapsed().as_millis() as u64
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stream_node_event_serialize() {
let event = StreamNodeEvent {
id: 123,
depth: 2,
path: vec![1, 2],
};
let json = serde_json::to_string(&event).expect("should serialize");
assert!(json.contains("123"));
assert!(json.contains("\"depth\":2"));
}
#[test]
fn test_stream_done_event_serialize() {
let event = StreamDoneEvent {
total_nodes: 100,
max_depth_reached: 5,
elapsed_ms: 150,
};
let json = serde_json::to_string(&event).expect("should serialize");
assert!(json.contains("100"));
assert!(json.contains("max_depth_reached"));
}
#[test]
fn test_stream_error_event_serialize() {
let event = StreamErrorEvent {
error: "Collection not found".to_string(),
};
let json = serde_json::to_string(&event).expect("should serialize");
assert!(json.contains("Collection not found"));
}
#[test]
fn test_build_error_events_returns_single_error() {
let events = build_error_events("test error".to_string());
assert_eq!(events.len(), 1);
}
#[test]
fn test_elapsed_ms_returns_reasonable_value() {
let start = Instant::now();
std::thread::sleep(std::time::Duration::from_millis(10));
let ms = elapsed_ms(start);
assert!(ms >= 5, "elapsed should be at least 5ms, got {ms}");
}
}