1use crate::{GraphRAGError, GraphRAGResult, SparqlEngineTrait, Triple};
8use std::collections::{HashSet, VecDeque};
9use std::sync::Arc;
10
11#[derive(Debug, Clone)]
15pub struct StreamConfig {
16 pub max_triples_per_batch: usize,
18 pub timeout_ms: u64,
20 pub max_depth: u8,
22 pub deduplicate: bool,
24 pub max_total_triples: usize,
26}
27
28impl Default for StreamConfig {
29 fn default() -> Self {
30 Self {
31 max_triples_per_batch: 500,
32 timeout_ms: 30_000,
33 max_depth: 3,
34 deduplicate: true,
35 max_total_triples: 50_000,
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
44pub struct SubgraphBatch {
45 pub triples: Vec<Triple>,
47 pub is_final: bool,
49 pub batch_id: usize,
51 pub current_depth: u8,
53}
54
55pub struct SubgraphStream {
59 batches: Vec<SubgraphBatch>,
60 next_idx: usize,
61}
62
63impl SubgraphStream {
64 fn new(batches: Vec<SubgraphBatch>) -> Self {
65 Self {
66 batches,
67 next_idx: 0,
68 }
69 }
70
71 pub fn next_batch(&mut self) -> Option<SubgraphBatch> {
73 if self.next_idx < self.batches.len() {
74 let batch = self.batches[self.next_idx].clone();
75 self.next_idx += 1;
76 Some(batch)
77 } else {
78 None
79 }
80 }
81
82 pub fn collect_all(mut self) -> Vec<Triple> {
84 let mut out = Vec::new();
85 while let Some(batch) = self.next_batch() {
86 out.extend(batch.triples);
87 }
88 out
89 }
90
91 pub fn batch_count(&self) -> usize {
93 self.batches.len()
94 }
95}
96
97pub struct StreamingSubgraphRetriever<S: SparqlEngineTrait> {
102 engine: Arc<S>,
103 config: StreamConfig,
104}
105
106impl<S: SparqlEngineTrait + 'static> StreamingSubgraphRetriever<S> {
107 pub fn new(engine: Arc<S>, config: StreamConfig) -> Self {
109 Self { engine, config }
110 }
111
112 pub fn with_defaults(engine: Arc<S>) -> Self {
114 Self::new(engine, StreamConfig::default())
115 }
116
117 pub fn retrieve_stream(
126 &self,
127 query: &str,
128 config: &StreamConfig,
129 ) -> GraphRAGResult<SubgraphStream> {
130 let rt = tokio::runtime::Handle::try_current()
133 .map_err(|_| GraphRAGError::InternalError("No Tokio runtime available".to_string()))?;
134
135 let engine = Arc::clone(&self.engine);
136 let query_owned = query.to_string();
137 let config_owned = config.clone();
138
139 let batches = rt.block_on(run_bfs_expansion(engine, &query_owned, &config_owned))?;
141
142 Ok(SubgraphStream::new(batches))
143 }
144}
145
146async fn run_bfs_expansion<S: SparqlEngineTrait>(
150 engine: Arc<S>,
151 initial_query: &str,
152 config: &StreamConfig,
153) -> GraphRAGResult<Vec<SubgraphBatch>> {
154 let mut batches: Vec<SubgraphBatch> = Vec::new();
155 let mut seen: HashSet<(String, String, String)> = HashSet::new();
156 let mut total_delivered: usize = 0;
157
158 let initial_triples = engine.construct(initial_query).await?;
160 let initial_triples = deduplicate_if(initial_triples, config.deduplicate, &mut seen);
161
162 let mut frontier: VecDeque<String> = VecDeque::new();
164 for t in &initial_triples {
165 if t.object.starts_with("http") {
166 frontier.push_back(t.object.clone());
167 }
168 }
169
170 let (new_batches, delivered) = package_into_batches(
172 initial_triples,
173 0,
174 config.max_triples_per_batch,
175 config.max_total_triples,
176 total_delivered,
177 &mut batches,
178 );
179 let _ = new_batches;
180 total_delivered += delivered;
181
182 if config.max_total_triples > 0 && total_delivered >= config.max_total_triples {
183 mark_last_batch(&mut batches);
184 return Ok(batches);
185 }
186
187 for depth in 1..=config.max_depth {
189 if frontier.is_empty() {
190 break;
191 }
192
193 let current_frontier: Vec<String> = frontier.drain(..).collect();
194 let mut depth_triples: Vec<Triple> = Vec::new();
195
196 for entity_uri in ¤t_frontier {
197 let expand_query = build_entity_expand_query(entity_uri, 1);
198 let raw = engine.construct(&expand_query).await?;
199 let filtered = deduplicate_if(raw, config.deduplicate, &mut seen);
200 for t in &filtered {
201 if t.object.starts_with("http") {
202 frontier.push_back(t.object.clone());
203 }
204 }
205 depth_triples.extend(filtered);
206
207 if config.max_total_triples > 0
208 && total_delivered + depth_triples.len() >= config.max_total_triples
209 {
210 break;
211 }
212 }
213
214 let (_, delivered) = package_into_batches(
215 depth_triples,
216 depth,
217 config.max_triples_per_batch,
218 config.max_total_triples,
219 total_delivered,
220 &mut batches,
221 );
222 total_delivered += delivered;
223
224 if config.max_total_triples > 0 && total_delivered >= config.max_total_triples {
225 break;
226 }
227 }
228
229 mark_last_batch(&mut batches);
230 Ok(batches)
231}
232
233fn mark_last_batch(batches: &mut [SubgraphBatch]) {
235 if let Some(last) = batches.last_mut() {
236 last.is_final = true;
237 }
238}
239
240fn deduplicate_if(
242 triples: Vec<Triple>,
243 dedup: bool,
244 seen: &mut HashSet<(String, String, String)>,
245) -> Vec<Triple> {
246 if !dedup {
247 return triples;
248 }
249 triples
250 .into_iter()
251 .filter(|t| seen.insert((t.subject.clone(), t.predicate.clone(), t.object.clone())))
252 .collect()
253}
254
255fn package_into_batches(
258 triples: Vec<Triple>,
259 depth: u8,
260 batch_size: usize,
261 max_total: usize,
262 already_delivered: usize,
263 out: &mut Vec<SubgraphBatch>,
264) -> (usize, usize) {
265 let mut remaining = triples;
266 if max_total > 0 && already_delivered + remaining.len() > max_total {
267 remaining.truncate(max_total - already_delivered);
268 }
269
270 let mut total_delivered = 0usize;
271 let mut batches_created = 0usize;
272 let mut offset = 0usize;
273
274 while offset < remaining.len() {
275 let end = (offset + batch_size).min(remaining.len());
276 let chunk: Vec<Triple> = remaining[offset..end].to_vec();
277 let chunk_len = chunk.len();
278 let batch_id = out.len();
279
280 out.push(SubgraphBatch {
281 triples: chunk,
282 is_final: false, batch_id,
284 current_depth: depth,
285 });
286
287 total_delivered += chunk_len;
288 batches_created += 1;
289 offset = end;
290 }
291
292 (batches_created, total_delivered)
293}
294
295fn build_entity_expand_query(entity_uri: &str, _hops: usize) -> String {
297 format!(
298 r#"CONSTRUCT {{ <{e}> ?p ?o . ?s ?p2 <{e}> . }}
299WHERE {{ {{ <{e}> ?p ?o . }} UNION {{ ?s ?p2 <{e}> . }} }}"#,
300 e = entity_uri,
301 )
302}
303
304#[cfg(test)]
307mod tests {
308 use super::*;
309 use crate::{GraphRAGResult, SparqlEngineTrait, Triple};
310 use async_trait::async_trait;
311 use std::collections::HashMap;
312 use std::sync::Arc;
313
314 struct MockEngine {
316 triples: Vec<Triple>,
317 }
318
319 impl MockEngine {
320 fn new(triples: Vec<Triple>) -> Arc<Self> {
321 Arc::new(Self { triples })
322 }
323 }
324
325 #[async_trait]
326 impl SparqlEngineTrait for MockEngine {
327 async fn select(&self, _query: &str) -> GraphRAGResult<Vec<HashMap<String, String>>> {
328 Ok(vec![])
329 }
330 async fn ask(&self, _query: &str) -> GraphRAGResult<bool> {
331 Ok(false)
332 }
333 async fn construct(&self, _query: &str) -> GraphRAGResult<Vec<Triple>> {
334 Ok(self.triples.clone())
335 }
336 }
337
338 fn make_triples(n: usize) -> Vec<Triple> {
339 (0..n)
340 .map(|i| Triple::new(format!("http://s/{i}"), "http://p", format!("http://o/{i}")))
341 .collect()
342 }
343
344 fn run<F: std::future::Future>(f: F) -> F::Output {
345 tokio::runtime::Runtime::new()
346 .expect("should succeed")
347 .block_on(f)
348 }
349
350 #[test]
353 fn test_stream_config_defaults() {
354 let cfg = StreamConfig::default();
355 assert_eq!(cfg.max_triples_per_batch, 500);
356 assert_eq!(cfg.timeout_ms, 30_000);
357 assert_eq!(cfg.max_depth, 3);
358 assert!(cfg.deduplicate);
359 assert_eq!(cfg.max_total_triples, 50_000);
360 }
361
362 #[test]
365 fn test_subgraph_batch_fields() {
366 let batch = SubgraphBatch {
367 triples: make_triples(5),
368 is_final: true,
369 batch_id: 2,
370 current_depth: 1,
371 };
372 assert_eq!(batch.triples.len(), 5);
373 assert!(batch.is_final);
374 assert_eq!(batch.batch_id, 2);
375 assert_eq!(batch.current_depth, 1);
376 }
377
378 #[test]
381 fn test_stream_collect_all() {
382 let batches = vec![
383 SubgraphBatch {
384 triples: make_triples(3),
385 is_final: false,
386 batch_id: 0,
387 current_depth: 0,
388 },
389 SubgraphBatch {
390 triples: make_triples(2),
391 is_final: true,
392 batch_id: 1,
393 current_depth: 1,
394 },
395 ];
396 let stream = SubgraphStream::new(batches);
397 let all = stream.collect_all();
398 assert_eq!(all.len(), 5);
399 }
400
401 #[test]
404 fn test_stream_next_batch_exhaustion() {
405 let batches = vec![SubgraphBatch {
406 triples: make_triples(1),
407 is_final: true,
408 batch_id: 0,
409 current_depth: 0,
410 }];
411 let mut stream = SubgraphStream::new(batches);
412 assert!(stream.next_batch().is_some());
413 assert!(stream.next_batch().is_none());
414 }
415
416 #[test]
419 fn test_stream_batch_count() {
420 let batches = (0..5)
421 .map(|i| SubgraphBatch {
422 triples: make_triples(1),
423 is_final: i == 4,
424 batch_id: i,
425 current_depth: 0,
426 })
427 .collect();
428 let stream = SubgraphStream::new(batches);
429 assert_eq!(stream.batch_count(), 5);
430 }
431
432 #[tokio::test]
435 async fn test_bfs_depth0_basic() {
436 let triples = make_triples(10);
437 let engine: Arc<MockEngine> = MockEngine::new(triples);
438 let config = StreamConfig {
439 max_triples_per_batch: 100,
440 max_depth: 0,
441 deduplicate: false,
442 max_total_triples: 0,
443 ..Default::default()
444 };
445 let batches = run_bfs_expansion(engine, "CONSTRUCT {}", &config)
446 .await
447 .expect("should succeed");
448 let total: usize = batches.iter().map(|b| b.triples.len()).sum();
450 assert_eq!(total, 10);
451 assert!(batches.last().expect("should succeed").is_final);
452 }
453
454 #[tokio::test]
457 async fn test_bfs_max_total_cap() {
458 let engine = MockEngine::new(make_triples(200));
459 let config = StreamConfig {
460 max_triples_per_batch: 100,
461 max_depth: 0,
462 deduplicate: false,
463 max_total_triples: 50,
464 ..Default::default()
465 };
466 let batches = run_bfs_expansion(engine, "CONSTRUCT {}", &config)
467 .await
468 .expect("should succeed");
469 let total: usize = batches.iter().map(|b| b.triples.len()).sum();
470 assert!(total <= 50);
471 }
472
473 #[tokio::test]
476 async fn test_bfs_deduplication() {
477 let triple = Triple::new("http://s", "http://p", "http://o");
478 let engine = MockEngine::new(vec![triple; 50]);
479 let config = StreamConfig {
480 max_triples_per_batch: 100,
481 max_depth: 0,
482 deduplicate: true,
483 max_total_triples: 0,
484 ..Default::default()
485 };
486 let batches = run_bfs_expansion(engine, "CONSTRUCT {}", &config)
487 .await
488 .expect("should succeed");
489 let total: usize = batches.iter().map(|b| b.triples.len()).sum();
490 assert_eq!(total, 1);
491 }
492
493 #[tokio::test]
496 async fn test_bfs_no_deduplication_counts_all() {
497 let triple = Triple::new("http://s", "http://p", "http://o");
498 let engine = MockEngine::new(vec![triple; 20]);
499 let config = StreamConfig {
500 max_triples_per_batch: 100,
501 max_depth: 0,
502 deduplicate: false,
503 max_total_triples: 0,
504 ..Default::default()
505 };
506 let batches = run_bfs_expansion(engine, "CONSTRUCT {}", &config)
507 .await
508 .expect("should succeed");
509 let total: usize = batches.iter().map(|b| b.triples.len()).sum();
510 assert_eq!(total, 20);
511 }
512
513 #[test]
516 fn test_package_into_batches_splits_correctly() {
517 let triples = make_triples(25);
518 let mut out: Vec<SubgraphBatch> = Vec::new();
519 let (batches_created, delivered) = package_into_batches(triples, 0, 10, 0, 0, &mut out);
520 assert_eq!(batches_created, 3); assert_eq!(delivered, 25);
522 assert_eq!(out.len(), 3);
523 }
524
525 #[test]
528 fn test_package_into_batches_respects_max_total() {
529 let triples = make_triples(100);
530 let mut out: Vec<SubgraphBatch> = Vec::new();
531 let (_, delivered) = package_into_batches(triples, 0, 50, 30, 0, &mut out);
532 assert!(delivered <= 30);
533 let total: usize = out.iter().map(|b| b.triples.len()).sum();
534 assert!(total <= 30);
535 }
536
537 #[test]
540 fn test_mark_last_batch_sets_is_final() {
541 let mut batches = vec![
542 SubgraphBatch {
543 triples: vec![],
544 is_final: false,
545 batch_id: 0,
546 current_depth: 0,
547 },
548 SubgraphBatch {
549 triples: vec![],
550 is_final: false,
551 batch_id: 1,
552 current_depth: 0,
553 },
554 ];
555 mark_last_batch(&mut batches);
556 assert!(!batches[0].is_final);
557 assert!(batches[1].is_final);
558 }
559
560 #[test]
563 fn test_build_entity_expand_query_contains_uri() {
564 let q = build_entity_expand_query("http://example.org/e", 1);
565 assert!(q.contains("http://example.org/e"));
566 assert!(q.contains("CONSTRUCT"));
567 }
568
569 #[test]
572 fn test_deduplicate_if_disabled() {
573 let triples = vec![
574 Triple::new("http://s", "http://p", "http://o"),
575 Triple::new("http://s", "http://p", "http://o"),
576 ];
577 let mut seen = HashSet::new();
578 let result = deduplicate_if(triples, false, &mut seen);
579 assert_eq!(result.len(), 2);
580 assert!(seen.is_empty()); }
582
583 #[test]
584 fn test_deduplicate_if_enabled_removes_dupes() {
585 let triples = vec![
586 Triple::new("http://s", "http://p", "http://o"),
587 Triple::new("http://s", "http://p", "http://o"),
588 Triple::new("http://s2", "http://p", "http://o"),
589 ];
590 let mut seen = HashSet::new();
591 let result = deduplicate_if(triples, true, &mut seen);
592 assert_eq!(result.len(), 2);
593 }
594
595 #[tokio::test]
598 async fn test_retriever_with_defaults() {
599 let engine = MockEngine::new(make_triples(5));
600 let retriever = StreamingSubgraphRetriever::with_defaults(engine);
601 assert_eq!(retriever.config.max_depth, 3);
602 }
603
604 #[tokio::test]
607 async fn test_bfs_empty_engine_returns_empty() {
608 let engine = MockEngine::new(vec![]);
609 let config = StreamConfig {
610 max_triples_per_batch: 10,
611 max_depth: 0,
612 deduplicate: false,
613 max_total_triples: 0,
614 ..Default::default()
615 };
616 let batches = run_bfs_expansion(engine, "CONSTRUCT {}", &config)
617 .await
618 .expect("should succeed");
619 let total: usize = batches.iter().map(|b| b.triples.len()).sum();
620 assert_eq!(total, 0);
621 }
622
623 #[tokio::test]
626 async fn test_bfs_batch_ids_sequential() {
627 let engine = MockEngine::new(make_triples(30));
628 let config = StreamConfig {
629 max_triples_per_batch: 10,
630 max_depth: 0,
631 deduplicate: false,
632 max_total_triples: 0,
633 ..Default::default()
634 };
635 let batches = run_bfs_expansion(engine, "CONSTRUCT {}", &config)
636 .await
637 .expect("should succeed");
638 for (expected_id, batch) in batches.iter().enumerate() {
639 assert_eq!(batch.batch_id, expected_id);
640 }
641 }
642
643 #[tokio::test]
646 async fn test_bfs_only_last_batch_is_final() {
647 let engine = MockEngine::new(make_triples(25));
648 let config = StreamConfig {
649 max_triples_per_batch: 10,
650 max_depth: 0,
651 deduplicate: false,
652 max_total_triples: 0,
653 ..Default::default()
654 };
655 let batches = run_bfs_expansion(engine, "CONSTRUCT {}", &config)
656 .await
657 .expect("should succeed");
658 for (i, batch) in batches.iter().enumerate() {
659 if i < batches.len() - 1 {
660 assert!(!batch.is_final, "Batch {i} should not be final");
661 } else {
662 assert!(batch.is_final, "Last batch should be final");
663 }
664 }
665 }
666
667 #[tokio::test]
670 async fn test_bfs_depth0_batches_have_depth_zero() {
671 let engine = MockEngine::new(make_triples(10));
672 let config = StreamConfig {
673 max_triples_per_batch: 5,
674 max_depth: 0,
675 deduplicate: false,
676 max_total_triples: 0,
677 ..Default::default()
678 };
679 let batches = run_bfs_expansion(engine, "CONSTRUCT {}", &config)
680 .await
681 .expect("should succeed");
682 for batch in &batches {
683 assert_eq!(batch.current_depth, 0);
684 }
685 }
686
687 #[test]
690 fn test_retriever_new_config() {
691 let engine = MockEngine::new(vec![]);
692 let config = StreamConfig {
693 max_triples_per_batch: 42,
694 ..Default::default()
695 };
696 let retriever = StreamingSubgraphRetriever::new(engine, config);
697 assert_eq!(retriever.config.max_triples_per_batch, 42);
698 }
699}