Skip to main content

oxirs_graphrag/retrieval/
streaming_subgraph.rs

1//! Incremental streaming subgraph retrieval using BFS expansion
2//!
3//! `StreamingSubgraphRetriever` retrieves large subgraphs batch by batch,
4//! expanding outward from a SPARQL query in breadth-first order.
5//! Each batch yields a `SubgraphBatch` with metadata and a `is_final` flag.
6
7use crate::{GraphRAGError, GraphRAGResult, SparqlEngineTrait, Triple};
8use std::collections::{HashSet, VecDeque};
9use std::sync::Arc;
10
11// ── Configuration ─────────────────────────────────────────────────────────────
12
13/// Configuration for incremental subgraph streaming
14#[derive(Debug, Clone)]
15pub struct StreamConfig {
16    /// Maximum triples returned in a single batch
17    pub max_triples_per_batch: usize,
18    /// Timeout budget in milliseconds (0 = no limit, enforced per-batch by caller)
19    pub timeout_ms: u64,
20    /// Maximum BFS expansion depth
21    pub max_depth: u8,
22    /// Deduplicate triples across batches
23    pub deduplicate: bool,
24    /// Maximum total triples to deliver (0 = unlimited)
25    pub max_total_triples: usize,
26}
27
28impl Default for StreamConfig {
29    fn default() -> Self {
30        Self {
31            max_triples_per_batch: 500,
32            timeout_ms: 30_000,
33            max_depth: 3,
34            deduplicate: true,
35            max_total_triples: 50_000,
36        }
37    }
38}
39
40// ── Batch type ─────────────────────────────────────────────────────────────────
41
42/// A single batch of triples from the streaming retriever
43#[derive(Debug, Clone)]
44pub struct SubgraphBatch {
45    /// Triples in this batch
46    pub triples: Vec<Triple>,
47    /// Whether this is the final batch (no more data follows)
48    pub is_final: bool,
49    /// Zero-based batch sequence number
50    pub batch_id: usize,
51    /// Current BFS depth at which these triples were collected
52    pub current_depth: u8,
53}
54
55// ── Stream handle ──────────────────────────────────────────────────────────────
56
57/// Handle returned by `retrieve_stream`. Yields `SubgraphBatch` values synchronously.
58pub struct SubgraphStream {
59    batches: Vec<SubgraphBatch>,
60    next_idx: usize,
61}
62
63impl SubgraphStream {
64    fn new(batches: Vec<SubgraphBatch>) -> Self {
65        Self {
66            batches,
67            next_idx: 0,
68        }
69    }
70
71    /// Retrieve the next batch, or `None` when exhausted.
72    pub fn next_batch(&mut self) -> Option<SubgraphBatch> {
73        if self.next_idx < self.batches.len() {
74            let batch = self.batches[self.next_idx].clone();
75            self.next_idx += 1;
76            Some(batch)
77        } else {
78            None
79        }
80    }
81
82    /// Collect all batches into a flat `Vec<Triple>`.
83    pub fn collect_all(mut self) -> Vec<Triple> {
84        let mut out = Vec::new();
85        while let Some(batch) = self.next_batch() {
86            out.extend(batch.triples);
87        }
88        out
89    }
90
91    /// Total number of batches available
92    pub fn batch_count(&self) -> usize {
93        self.batches.len()
94    }
95}
96
97// ── StreamingSubgraphRetriever ─────────────────────────────────────────────────
98
99/// Incrementally retrieves a subgraph by running SPARQL CONSTRUCT queries
100/// layer by layer (BFS expansion) and packaging results into fixed-size batches.
101pub struct StreamingSubgraphRetriever<S: SparqlEngineTrait> {
102    engine: Arc<S>,
103    config: StreamConfig,
104}
105
106impl<S: SparqlEngineTrait + 'static> StreamingSubgraphRetriever<S> {
107    /// Create a new retriever with the given config.
108    pub fn new(engine: Arc<S>, config: StreamConfig) -> Self {
109        Self { engine, config }
110    }
111
112    /// Create with default config.
113    pub fn with_defaults(engine: Arc<S>) -> Self {
114        Self::new(engine, StreamConfig::default())
115    }
116
117    /// Start streaming for the given SPARQL CONSTRUCT query.
118    ///
119    /// The query is used as the initial seed: its results form depth 0.
120    /// Then each distinct object entity in those results is expanded for
121    /// subsequent depths, up to `config.max_depth`.
122    ///
123    /// This is a synchronous method; it runs a Tokio `block_in_place` internally
124    /// (or the caller must be in a Tokio runtime context).
125    pub fn retrieve_stream(
126        &self,
127        query: &str,
128        config: &StreamConfig,
129    ) -> GraphRAGResult<SubgraphStream> {
130        // We need an async executor; use a blocking thread via tokio::task::block_in_place
131        // or, since tests always run inside Tokio, call a helper.
132        let rt = tokio::runtime::Handle::try_current()
133            .map_err(|_| GraphRAGError::InternalError("No Tokio runtime available".to_string()))?;
134
135        let engine = Arc::clone(&self.engine);
136        let query_owned = query.to_string();
137        let config_owned = config.clone();
138
139        // Run the BFS expansion under the existing Tokio handle
140        let batches = rt.block_on(run_bfs_expansion(engine, &query_owned, &config_owned))?;
141
142        Ok(SubgraphStream::new(batches))
143    }
144}
145
146// ── BFS expansion ─────────────────────────────────────────────────────────────
147
148/// Run BFS expansion up to `config.max_depth`, returning packaged `SubgraphBatch`es.
149async fn run_bfs_expansion<S: SparqlEngineTrait>(
150    engine: Arc<S>,
151    initial_query: &str,
152    config: &StreamConfig,
153) -> GraphRAGResult<Vec<SubgraphBatch>> {
154    let mut batches: Vec<SubgraphBatch> = Vec::new();
155    let mut seen: HashSet<(String, String, String)> = HashSet::new();
156    let mut total_delivered: usize = 0;
157
158    // Depth 0: run the initial CONSTRUCT query
159    let initial_triples = engine.construct(initial_query).await?;
160    let initial_triples = deduplicate_if(initial_triples, config.deduplicate, &mut seen);
161
162    // Collect frontier entities from depth-0 results (objects that could be expanded)
163    let mut frontier: VecDeque<String> = VecDeque::new();
164    for t in &initial_triples {
165        if t.object.starts_with("http") {
166            frontier.push_back(t.object.clone());
167        }
168    }
169
170    // Package depth-0 triples into batches
171    let (new_batches, delivered) = package_into_batches(
172        initial_triples,
173        0,
174        config.max_triples_per_batch,
175        config.max_total_triples,
176        total_delivered,
177        &mut batches,
178    );
179    let _ = new_batches;
180    total_delivered += delivered;
181
182    if config.max_total_triples > 0 && total_delivered >= config.max_total_triples {
183        mark_last_batch(&mut batches);
184        return Ok(batches);
185    }
186
187    // Depths 1..max_depth: expand frontier entities
188    for depth in 1..=config.max_depth {
189        if frontier.is_empty() {
190            break;
191        }
192
193        let current_frontier: Vec<String> = frontier.drain(..).collect();
194        let mut depth_triples: Vec<Triple> = Vec::new();
195
196        for entity_uri in &current_frontier {
197            let expand_query = build_entity_expand_query(entity_uri, 1);
198            let raw = engine.construct(&expand_query).await?;
199            let filtered = deduplicate_if(raw, config.deduplicate, &mut seen);
200            for t in &filtered {
201                if t.object.starts_with("http") {
202                    frontier.push_back(t.object.clone());
203                }
204            }
205            depth_triples.extend(filtered);
206
207            if config.max_total_triples > 0
208                && total_delivered + depth_triples.len() >= config.max_total_triples
209            {
210                break;
211            }
212        }
213
214        let (_, delivered) = package_into_batches(
215            depth_triples,
216            depth,
217            config.max_triples_per_batch,
218            config.max_total_triples,
219            total_delivered,
220            &mut batches,
221        );
222        total_delivered += delivered;
223
224        if config.max_total_triples > 0 && total_delivered >= config.max_total_triples {
225            break;
226        }
227    }
228
229    mark_last_batch(&mut batches);
230    Ok(batches)
231}
232
233/// Mark the last batch in the list as final.
234fn mark_last_batch(batches: &mut [SubgraphBatch]) {
235    if let Some(last) = batches.last_mut() {
236        last.is_final = true;
237    }
238}
239
240/// Deduplicate triples if enabled, updating the seen set in place.
241fn deduplicate_if(
242    triples: Vec<Triple>,
243    dedup: bool,
244    seen: &mut HashSet<(String, String, String)>,
245) -> Vec<Triple> {
246    if !dedup {
247        return triples;
248    }
249    triples
250        .into_iter()
251        .filter(|t| seen.insert((t.subject.clone(), t.predicate.clone(), t.object.clone())))
252        .collect()
253}
254
255/// Pack `triples` into batches of `batch_size`, respecting `max_total`.
256/// Returns (number of batches created, number of triples delivered).
257fn package_into_batches(
258    triples: Vec<Triple>,
259    depth: u8,
260    batch_size: usize,
261    max_total: usize,
262    already_delivered: usize,
263    out: &mut Vec<SubgraphBatch>,
264) -> (usize, usize) {
265    let mut remaining = triples;
266    if max_total > 0 && already_delivered + remaining.len() > max_total {
267        remaining.truncate(max_total - already_delivered);
268    }
269
270    let mut total_delivered = 0usize;
271    let mut batches_created = 0usize;
272    let mut offset = 0usize;
273
274    while offset < remaining.len() {
275        let end = (offset + batch_size).min(remaining.len());
276        let chunk: Vec<Triple> = remaining[offset..end].to_vec();
277        let chunk_len = chunk.len();
278        let batch_id = out.len();
279
280        out.push(SubgraphBatch {
281            triples: chunk,
282            is_final: false, // will be set by mark_last_batch
283            batch_id,
284            current_depth: depth,
285        });
286
287        total_delivered += chunk_len;
288        batches_created += 1;
289        offset = end;
290    }
291
292    (batches_created, total_delivered)
293}
294
295/// Build a 1-hop CONSTRUCT expansion query for a single entity.
296fn build_entity_expand_query(entity_uri: &str, _hops: usize) -> String {
297    format!(
298        r#"CONSTRUCT {{ <{e}> ?p ?o . ?s ?p2 <{e}> . }}
299WHERE {{ {{ <{e}> ?p ?o . }} UNION {{ ?s ?p2 <{e}> . }} }}"#,
300        e = entity_uri,
301    )
302}
303
304// ── Tests ─────────────────────────────────────────────────────────────────────
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309    use crate::{GraphRAGResult, SparqlEngineTrait, Triple};
310    use async_trait::async_trait;
311    use std::collections::HashMap;
312    use std::sync::Arc;
313
314    /// Minimal mock SPARQL engine
315    struct MockEngine {
316        triples: Vec<Triple>,
317    }
318
319    impl MockEngine {
320        fn new(triples: Vec<Triple>) -> Arc<Self> {
321            Arc::new(Self { triples })
322        }
323    }
324
325    #[async_trait]
326    impl SparqlEngineTrait for MockEngine {
327        async fn select(&self, _query: &str) -> GraphRAGResult<Vec<HashMap<String, String>>> {
328            Ok(vec![])
329        }
330        async fn ask(&self, _query: &str) -> GraphRAGResult<bool> {
331            Ok(false)
332        }
333        async fn construct(&self, _query: &str) -> GraphRAGResult<Vec<Triple>> {
334            Ok(self.triples.clone())
335        }
336    }
337
338    fn make_triples(n: usize) -> Vec<Triple> {
339        (0..n)
340            .map(|i| Triple::new(format!("http://s/{i}"), "http://p", format!("http://o/{i}")))
341            .collect()
342    }
343
344    fn run<F: std::future::Future>(f: F) -> F::Output {
345        tokio::runtime::Runtime::new()
346            .expect("should succeed")
347            .block_on(f)
348    }
349
350    // ── StreamConfig default tests ──────────────────────────────────────────
351
352    #[test]
353    fn test_stream_config_defaults() {
354        let cfg = StreamConfig::default();
355        assert_eq!(cfg.max_triples_per_batch, 500);
356        assert_eq!(cfg.timeout_ms, 30_000);
357        assert_eq!(cfg.max_depth, 3);
358        assert!(cfg.deduplicate);
359        assert_eq!(cfg.max_total_triples, 50_000);
360    }
361
362    // ── SubgraphBatch field tests ───────────────────────────────────────────
363
364    #[test]
365    fn test_subgraph_batch_fields() {
366        let batch = SubgraphBatch {
367            triples: make_triples(5),
368            is_final: true,
369            batch_id: 2,
370            current_depth: 1,
371        };
372        assert_eq!(batch.triples.len(), 5);
373        assert!(batch.is_final);
374        assert_eq!(batch.batch_id, 2);
375        assert_eq!(batch.current_depth, 1);
376    }
377
378    // ── SubgraphStream collect_all ──────────────────────────────────────────
379
380    #[test]
381    fn test_stream_collect_all() {
382        let batches = vec![
383            SubgraphBatch {
384                triples: make_triples(3),
385                is_final: false,
386                batch_id: 0,
387                current_depth: 0,
388            },
389            SubgraphBatch {
390                triples: make_triples(2),
391                is_final: true,
392                batch_id: 1,
393                current_depth: 1,
394            },
395        ];
396        let stream = SubgraphStream::new(batches);
397        let all = stream.collect_all();
398        assert_eq!(all.len(), 5);
399    }
400
401    // ── SubgraphStream next_batch ───────────────────────────────────────────
402
403    #[test]
404    fn test_stream_next_batch_exhaustion() {
405        let batches = vec![SubgraphBatch {
406            triples: make_triples(1),
407            is_final: true,
408            batch_id: 0,
409            current_depth: 0,
410        }];
411        let mut stream = SubgraphStream::new(batches);
412        assert!(stream.next_batch().is_some());
413        assert!(stream.next_batch().is_none());
414    }
415
416    // ── SubgraphStream batch_count ──────────────────────────────────────────
417
418    #[test]
419    fn test_stream_batch_count() {
420        let batches = (0..5)
421            .map(|i| SubgraphBatch {
422                triples: make_triples(1),
423                is_final: i == 4,
424                batch_id: i,
425                current_depth: 0,
426            })
427            .collect();
428        let stream = SubgraphStream::new(batches);
429        assert_eq!(stream.batch_count(), 5);
430    }
431
432    // ── BFS expansion: depth 0 ──────────────────────────────────────────────
433
434    #[tokio::test]
435    async fn test_bfs_depth0_basic() {
436        let triples = make_triples(10);
437        let engine: Arc<MockEngine> = MockEngine::new(triples);
438        let config = StreamConfig {
439            max_triples_per_batch: 100,
440            max_depth: 0,
441            deduplicate: false,
442            max_total_triples: 0,
443            ..Default::default()
444        };
445        let batches = run_bfs_expansion(engine, "CONSTRUCT {}", &config)
446            .await
447            .expect("should succeed");
448        // Only the initial query at depth 0
449        let total: usize = batches.iter().map(|b| b.triples.len()).sum();
450        assert_eq!(total, 10);
451        assert!(batches.last().expect("should succeed").is_final);
452    }
453
454    // ── BFS expansion: max_total_triples cap ────────────────────────────────
455
456    #[tokio::test]
457    async fn test_bfs_max_total_cap() {
458        let engine = MockEngine::new(make_triples(200));
459        let config = StreamConfig {
460            max_triples_per_batch: 100,
461            max_depth: 0,
462            deduplicate: false,
463            max_total_triples: 50,
464            ..Default::default()
465        };
466        let batches = run_bfs_expansion(engine, "CONSTRUCT {}", &config)
467            .await
468            .expect("should succeed");
469        let total: usize = batches.iter().map(|b| b.triples.len()).sum();
470        assert!(total <= 50);
471    }
472
473    // ── BFS expansion: deduplication ───────────────────────────────────────
474
475    #[tokio::test]
476    async fn test_bfs_deduplication() {
477        let triple = Triple::new("http://s", "http://p", "http://o");
478        let engine = MockEngine::new(vec![triple; 50]);
479        let config = StreamConfig {
480            max_triples_per_batch: 100,
481            max_depth: 0,
482            deduplicate: true,
483            max_total_triples: 0,
484            ..Default::default()
485        };
486        let batches = run_bfs_expansion(engine, "CONSTRUCT {}", &config)
487            .await
488            .expect("should succeed");
489        let total: usize = batches.iter().map(|b| b.triples.len()).sum();
490        assert_eq!(total, 1);
491    }
492
493    // ── BFS expansion: no deduplication ───────────────────────────────────
494
495    #[tokio::test]
496    async fn test_bfs_no_deduplication_counts_all() {
497        let triple = Triple::new("http://s", "http://p", "http://o");
498        let engine = MockEngine::new(vec![triple; 20]);
499        let config = StreamConfig {
500            max_triples_per_batch: 100,
501            max_depth: 0,
502            deduplicate: false,
503            max_total_triples: 0,
504            ..Default::default()
505        };
506        let batches = run_bfs_expansion(engine, "CONSTRUCT {}", &config)
507            .await
508            .expect("should succeed");
509        let total: usize = batches.iter().map(|b| b.triples.len()).sum();
510        assert_eq!(total, 20);
511    }
512
513    // ── Package into batches: batch_size splits ─────────────────────────────
514
515    #[test]
516    fn test_package_into_batches_splits_correctly() {
517        let triples = make_triples(25);
518        let mut out: Vec<SubgraphBatch> = Vec::new();
519        let (batches_created, delivered) = package_into_batches(triples, 0, 10, 0, 0, &mut out);
520        assert_eq!(batches_created, 3); // 10 + 10 + 5
521        assert_eq!(delivered, 25);
522        assert_eq!(out.len(), 3);
523    }
524
525    // ── Package into batches: max_total truncation ──────────────────────────
526
527    #[test]
528    fn test_package_into_batches_respects_max_total() {
529        let triples = make_triples(100);
530        let mut out: Vec<SubgraphBatch> = Vec::new();
531        let (_, delivered) = package_into_batches(triples, 0, 50, 30, 0, &mut out);
532        assert!(delivered <= 30);
533        let total: usize = out.iter().map(|b| b.triples.len()).sum();
534        assert!(total <= 30);
535    }
536
537    // ── mark_last_batch ─────────────────────────────────────────────────────
538
539    #[test]
540    fn test_mark_last_batch_sets_is_final() {
541        let mut batches = vec![
542            SubgraphBatch {
543                triples: vec![],
544                is_final: false,
545                batch_id: 0,
546                current_depth: 0,
547            },
548            SubgraphBatch {
549                triples: vec![],
550                is_final: false,
551                batch_id: 1,
552                current_depth: 0,
553            },
554        ];
555        mark_last_batch(&mut batches);
556        assert!(!batches[0].is_final);
557        assert!(batches[1].is_final);
558    }
559
560    // ── build_entity_expand_query ───────────────────────────────────────────
561
562    #[test]
563    fn test_build_entity_expand_query_contains_uri() {
564        let q = build_entity_expand_query("http://example.org/e", 1);
565        assert!(q.contains("http://example.org/e"));
566        assert!(q.contains("CONSTRUCT"));
567    }
568
569    // ── deduplicate_if ──────────────────────────────────────────────────────
570
571    #[test]
572    fn test_deduplicate_if_disabled() {
573        let triples = vec![
574            Triple::new("http://s", "http://p", "http://o"),
575            Triple::new("http://s", "http://p", "http://o"),
576        ];
577        let mut seen = HashSet::new();
578        let result = deduplicate_if(triples, false, &mut seen);
579        assert_eq!(result.len(), 2);
580        assert!(seen.is_empty()); // Not updated when disabled
581    }
582
583    #[test]
584    fn test_deduplicate_if_enabled_removes_dupes() {
585        let triples = vec![
586            Triple::new("http://s", "http://p", "http://o"),
587            Triple::new("http://s", "http://p", "http://o"),
588            Triple::new("http://s2", "http://p", "http://o"),
589        ];
590        let mut seen = HashSet::new();
591        let result = deduplicate_if(triples, true, &mut seen);
592        assert_eq!(result.len(), 2);
593    }
594
595    // ── StreamingSubgraphRetriever with_defaults ────────────────────────────
596
597    #[tokio::test]
598    async fn test_retriever_with_defaults() {
599        let engine = MockEngine::new(make_triples(5));
600        let retriever = StreamingSubgraphRetriever::with_defaults(engine);
601        assert_eq!(retriever.config.max_depth, 3);
602    }
603
604    // ── Empty engine returns empty stream ───────────────────────────────────
605
606    #[tokio::test]
607    async fn test_bfs_empty_engine_returns_empty() {
608        let engine = MockEngine::new(vec![]);
609        let config = StreamConfig {
610            max_triples_per_batch: 10,
611            max_depth: 0,
612            deduplicate: false,
613            max_total_triples: 0,
614            ..Default::default()
615        };
616        let batches = run_bfs_expansion(engine, "CONSTRUCT {}", &config)
617            .await
618            .expect("should succeed");
619        let total: usize = batches.iter().map(|b| b.triples.len()).sum();
620        assert_eq!(total, 0);
621    }
622
623    // ── Batch IDs are sequential ────────────────────────────────────────────
624
625    #[tokio::test]
626    async fn test_bfs_batch_ids_sequential() {
627        let engine = MockEngine::new(make_triples(30));
628        let config = StreamConfig {
629            max_triples_per_batch: 10,
630            max_depth: 0,
631            deduplicate: false,
632            max_total_triples: 0,
633            ..Default::default()
634        };
635        let batches = run_bfs_expansion(engine, "CONSTRUCT {}", &config)
636            .await
637            .expect("should succeed");
638        for (expected_id, batch) in batches.iter().enumerate() {
639            assert_eq!(batch.batch_id, expected_id);
640        }
641    }
642
643    // ── Only last batch has is_final = true ─────────────────────────────────
644
645    #[tokio::test]
646    async fn test_bfs_only_last_batch_is_final() {
647        let engine = MockEngine::new(make_triples(25));
648        let config = StreamConfig {
649            max_triples_per_batch: 10,
650            max_depth: 0,
651            deduplicate: false,
652            max_total_triples: 0,
653            ..Default::default()
654        };
655        let batches = run_bfs_expansion(engine, "CONSTRUCT {}", &config)
656            .await
657            .expect("should succeed");
658        for (i, batch) in batches.iter().enumerate() {
659            if i < batches.len() - 1 {
660                assert!(!batch.is_final, "Batch {i} should not be final");
661            } else {
662                assert!(batch.is_final, "Last batch should be final");
663            }
664        }
665    }
666
667    // ── depth-0 batches have current_depth = 0 ──────────────────────────────
668
669    #[tokio::test]
670    async fn test_bfs_depth0_batches_have_depth_zero() {
671        let engine = MockEngine::new(make_triples(10));
672        let config = StreamConfig {
673            max_triples_per_batch: 5,
674            max_depth: 0,
675            deduplicate: false,
676            max_total_triples: 0,
677            ..Default::default()
678        };
679        let batches = run_bfs_expansion(engine, "CONSTRUCT {}", &config)
680            .await
681            .expect("should succeed");
682        for batch in &batches {
683            assert_eq!(batch.current_depth, 0);
684        }
685    }
686
687    // ── StreamingSubgraphRetriever::new sets config correctly ───────────────
688
689    #[test]
690    fn test_retriever_new_config() {
691        let engine = MockEngine::new(vec![]);
692        let config = StreamConfig {
693            max_triples_per_batch: 42,
694            ..Default::default()
695        };
696        let retriever = StreamingSubgraphRetriever::new(engine, config);
697        assert_eq!(retriever.config.max_triples_per_batch, 42);
698    }
699}