velesdb-server 1.6.0

REST API server for VelesDB vector database
Documentation
//! SSE streaming graph traversal handler (EPIC-058 US-003).
//!
//! Provides a Server-Sent Events endpoint for streaming graph traversal
//! results incrementally, avoiding full buffering for large traversals.

use axum::{
    extract::{Path, Query, State},
    response::sse::{Event, KeepAlive, Sse},
};
use futures::stream::{self, Stream};
use std::convert::Infallible;
use std::time::Instant;

use std::sync::Arc;

use crate::AppState;

use super::types::{
    StreamDoneEvent, StreamErrorEvent, StreamNodeEvent, StreamStatsEvent, StreamTraverseParams,
    TraversalResultItem,
};

/// Interval (in nodes) between periodic stats events.
const STATS_INTERVAL: usize = 100;

/// Stream graph traversal results via SSE.
///
/// Yields events:
/// - `node`: Each node reached during traversal
/// - `stats`: Periodic statistics (every [`STATS_INTERVAL`] nodes)
/// - `done`: Traversal completed
/// - `error`: If an error occurs
#[utoipa::path(
    get,
    path = "/collections/{name}/graph/traverse/stream",
    tag = "graph",
    params(
        ("name" = String, Path, description = "Collection name"),
        StreamTraverseParams
    ),
    responses(
        (status = 200, description = "SSE stream of traversal events (node, stats, done, error)")
    )
)]
#[allow(clippy::unused_async)]
pub async fn stream_traverse(
    State(state): State<Arc<AppState>>,
    Path(collection): Path<String>,
    Query(params): Query<StreamTraverseParams>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
    use velesdb_core::collection::graph::TraversalConfig;

    let start_time = Instant::now();

    let rel_types: Vec<String> = params
        .relationship_types
        .map(|s| s.split(',').map(|t| t.trim().to_string()).collect())
        .unwrap_or_default();

    let traversal_result: Result<Vec<TraversalResultItem>, String> =
        match state.db.get_graph_collection(&collection) {
            None => Err(format!(
                "Collection '{}' not found or is not a graph collection.",
                collection
            )),
            Some(coll) => {
                let config = TraversalConfig::with_range(1, params.max_depth)
                    .with_limit(params.limit)
                    .with_rel_types(rel_types);

                let raw = match params.algorithm.to_lowercase().as_str() {
                    "dfs" => coll.traverse_dfs(params.start_node, &config),
                    _ => coll.traverse_bfs(params.start_node, &config),
                };

                Ok(raw
                    .into_iter()
                    .map(|r| TraversalResultItem {
                        target_id: r.target_id,
                        depth: r.depth,
                        path: r.path,
                    })
                    .collect())
            }
        };

    let events = build_sse_events(traversal_result, start_time);
    Sse::new(stream::iter(events)).keep_alive(KeepAlive::default())
}

/// Converts a traversal result into a sequence of SSE events.
///
/// Extracted to keep the handler thin and the logic testable.
fn build_sse_events(
    traversal_result: Result<Vec<TraversalResultItem>, String>,
    start_time: Instant,
) -> Vec<Result<Event, Infallible>> {
    match traversal_result {
        Ok(results) => build_success_events(results, start_time),
        Err(e) => build_error_events(e),
    }
}

fn build_success_events(
    results: Vec<TraversalResultItem>,
    start_time: Instant,
) -> Vec<Result<Event, Infallible>> {
    let total = results.len();
    let mut max_depth: u32 = 0;
    let mut events: Vec<Result<Event, Infallible>> = Vec::with_capacity(total + 2);

    for (i, item) in results.into_iter().enumerate() {
        if item.depth > max_depth {
            max_depth = item.depth;
        }

        let node_event = StreamNodeEvent {
            id: item.target_id,
            depth: item.depth,
            path: item.path,
        };
        let event_data = serde_json::to_string(&node_event).unwrap_or_else(|_| "{}".to_string());
        events.push(Ok(Event::default().event("node").data(event_data)));

        if (i + 1) % STATS_INTERVAL == 0 {
            let stats_event = StreamStatsEvent {
                nodes_visited: i + 1,
                elapsed_ms: elapsed_ms(start_time),
            };
            let stats_data =
                serde_json::to_string(&stats_event).unwrap_or_else(|_| "{}".to_string());
            events.push(Ok(Event::default().event("stats").data(stats_data)));
        }
    }

    let done_event = StreamDoneEvent {
        total_nodes: total,
        max_depth_reached: max_depth,
        elapsed_ms: elapsed_ms(start_time),
    };
    let done_data = serde_json::to_string(&done_event).unwrap_or_else(|_| "{}".to_string());
    events.push(Ok(Event::default().event("done").data(done_data)));

    events
}

fn build_error_events(error: String) -> Vec<Result<Event, Infallible>> {
    let error_event = StreamErrorEvent { error };
    let error_data = serde_json::to_string(&error_event).unwrap_or_else(|_| "{}".to_string());
    vec![Ok(Event::default().event("error").data(error_data))]
}

/// Returns elapsed milliseconds since `start_time`.
///
/// The cast from `u128` to `u64` is safe because `u64::MAX` milliseconds
/// corresponds to ~584 million years, which no request will ever reach.
#[inline]
fn elapsed_ms(start_time: Instant) -> u64 {
    start_time.elapsed().as_millis() as u64
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_stream_node_event_serialize() {
        let event = StreamNodeEvent {
            id: 123,
            depth: 2,
            path: vec![1, 2],
        };
        let json = serde_json::to_string(&event).expect("should serialize");
        assert!(json.contains("123"));
        assert!(json.contains("\"depth\":2"));
    }

    #[test]
    fn test_stream_done_event_serialize() {
        let event = StreamDoneEvent {
            total_nodes: 100,
            max_depth_reached: 5,
            elapsed_ms: 150,
        };
        let json = serde_json::to_string(&event).expect("should serialize");
        assert!(json.contains("100"));
        assert!(json.contains("max_depth_reached"));
    }

    #[test]
    fn test_stream_error_event_serialize() {
        let event = StreamErrorEvent {
            error: "Collection not found".to_string(),
        };
        let json = serde_json::to_string(&event).expect("should serialize");
        assert!(json.contains("Collection not found"));
    }

    #[test]
    fn test_build_error_events_returns_single_error() {
        let events = build_error_events("test error".to_string());
        assert_eq!(events.len(), 1);
    }

    #[test]
    fn test_elapsed_ms_returns_reasonable_value() {
        let start = Instant::now();
        std::thread::sleep(std::time::Duration::from_millis(10));
        let ms = elapsed_ms(start);
        assert!(ms >= 5, "elapsed should be at least 5ms, got {ms}");
    }
}