1#![allow(unexpected_cfgs)]
16
17use std::collections::BinaryHeap;
50use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
51use std::sync::Arc;
52
53use crate::context_query::{
54 ContextSection, ContextSelectQuery, SectionContent, SimilarityQuery,
55 TruncationStrategy, OutputFormat, VectorIndex,
56};
57use crate::token_budget::TokenEstimator;
58use crate::soch_ql::SochValue;
59
60#[derive(Debug, Clone)]
66pub enum SectionChunk {
67 SectionHeader {
69 name: String,
70 priority: i32,
71 estimated_tokens: usize,
72 },
73
74 RowBlock {
76 section_name: String,
77 rows: Vec<Vec<SochValue>>,
78 columns: Vec<String>,
79 tokens: usize,
80 },
81
82 SearchResultBlock {
84 section_name: String,
85 results: Vec<StreamingSearchResult>,
86 tokens: usize,
87 },
88
89 ContentBlock {
91 section_name: String,
92 content: String,
93 tokens: usize,
94 },
95
96 SectionComplete {
98 name: String,
99 total_tokens: usize,
100 truncated: bool,
101 },
102
103 StreamComplete {
105 total_tokens: usize,
106 sections_included: Vec<String>,
107 sections_dropped: Vec<String>,
108 },
109
110 Error {
112 section_name: Option<String>,
113 message: String,
114 },
115}
116
117#[derive(Debug, Clone)]
119pub struct StreamingSearchResult {
120 pub id: String,
121 pub score: f32,
122 pub content: String,
123}
124
125#[derive(Debug, Clone)]
127pub struct StreamingConfig {
128 pub token_limit: usize,
130
131 pub chunk_size: usize,
133
134 pub include_headers: bool,
136
137 pub format: OutputFormat,
139
140 pub truncation: TruncationStrategy,
142
143 pub parallel_execution: bool,
145
146 pub exact_tokens: bool,
148}
149
150impl Default for StreamingConfig {
151 fn default() -> Self {
152 Self {
153 token_limit: 4096,
154 chunk_size: 256,
155 include_headers: true,
156 format: OutputFormat::Soch,
157 truncation: TruncationStrategy::TailDrop,
158 parallel_execution: false,
159 exact_tokens: false,
160 }
161 }
162}
163
164#[derive(Debug)]
170pub struct RollingBudget {
171 limit: usize,
173
174 used: AtomicUsize,
176
177 exhausted: AtomicBool,
179}
180
181impl RollingBudget {
182 pub fn new(limit: usize) -> Self {
184 Self {
185 limit,
186 used: AtomicUsize::new(0),
187 exhausted: AtomicBool::new(false),
188 }
189 }
190
191 pub fn try_consume(&self, tokens: usize) -> usize {
194 if self.exhausted.load(Ordering::Acquire) {
195 return 0;
196 }
197
198 let mut current = self.used.load(Ordering::Acquire);
199 loop {
200 let remaining = self.limit.saturating_sub(current);
201 if remaining == 0 {
202 self.exhausted.store(true, Ordering::Release);
203 return 0;
204 }
205
206 let to_consume = tokens.min(remaining);
207 match self.used.compare_exchange_weak(
208 current,
209 current + to_consume,
210 Ordering::AcqRel,
211 Ordering::Acquire,
212 ) {
213 Ok(_) => {
214 if current + to_consume >= self.limit {
215 self.exhausted.store(true, Ordering::Release);
216 }
217 return to_consume;
218 }
219 Err(actual) => current = actual,
220 }
221 }
222 }
223
224 pub fn remaining(&self) -> usize {
226 self.limit.saturating_sub(self.used.load(Ordering::Acquire))
227 }
228
229 pub fn is_exhausted(&self) -> bool {
231 self.exhausted.load(Ordering::Acquire)
232 }
233
234 pub fn used(&self) -> usize {
236 self.used.load(Ordering::Acquire)
237 }
238}
239
240#[derive(Debug, Clone)]
246struct ScheduledSection {
247 priority: i32,
249
250 index: usize,
252
253 section: ContextSection,
255}
256
257impl Eq for ScheduledSection {}
258
259impl PartialEq for ScheduledSection {
260 fn eq(&self, other: &Self) -> bool {
261 self.priority == other.priority && self.index == other.index
262 }
263}
264
265impl Ord for ScheduledSection {
266 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
267 other.priority.cmp(&self.priority)
269 .then_with(|| other.index.cmp(&self.index))
270 }
271}
272
273impl PartialOrd for ScheduledSection {
274 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
275 Some(self.cmp(other))
276 }
277}
278
279pub struct StreamingContextExecutor<V: VectorIndex> {
285 estimator: TokenEstimator,
287
288 vector_index: Arc<V>,
290
291 budget: Arc<RollingBudget>,
293
294 config: StreamingConfig,
296}
297
298impl<V: VectorIndex> StreamingContextExecutor<V> {
299 pub fn new(
301 vector_index: Arc<V>,
302 config: StreamingConfig,
303 ) -> Self {
304 let budget = Arc::new(RollingBudget::new(config.token_limit));
305 Self {
306 estimator: TokenEstimator::new(),
307 vector_index,
308 budget,
309 config,
310 }
311 }
312
313 pub fn execute_streaming(
318 &self,
319 query: &ContextSelectQuery,
320 ) -> StreamingContextIter<'_, V> {
321 let mut priority_queue = BinaryHeap::new();
323 for (index, section) in query.sections.iter().enumerate() {
324 priority_queue.push(ScheduledSection {
325 priority: section.priority,
326 index,
327 section: section.clone(),
328 });
329 }
330
331 StreamingContextIter {
332 executor: self,
333 priority_queue,
334 current_section: None,
335 current_section_tokens: 0,
336 sections_included: Vec::new(),
337 sections_dropped: Vec::new(),
338 completed: false,
339 }
340 }
341
342 fn execute_section(
344 &self,
345 section: &ContextSection,
346 ) -> Vec<SectionChunk> {
347 let mut chunks = Vec::new();
348
349 if self.config.include_headers {
351 let header_tokens = self.estimator.estimate_text(&format!(
352 "## {} [priority={}]\n",
353 section.name, section.priority
354 ));
355
356 if self.budget.try_consume(header_tokens) > 0 {
357 chunks.push(SectionChunk::SectionHeader {
358 name: section.name.clone(),
359 priority: section.priority,
360 estimated_tokens: header_tokens,
361 });
362 } else {
363 return chunks; }
365 }
366
367 match §ion.content {
369 SectionContent::Literal { value } => {
370 self.execute_literal_section(section, value, &mut chunks);
371 }
372 SectionContent::Search { collection, query, top_k, min_score } => {
373 self.execute_search_section(section, collection, query, *top_k, *min_score, &mut chunks);
374 }
375 SectionContent::Get { path } => {
376 let content = format!("{}:**", path.to_path_string());
378 self.execute_literal_section(section, &content, &mut chunks);
379 }
380 SectionContent::Last { count, table, where_clause: _ } => {
381 let content = format!("{}[{}]:\n (recent entries)", table, count);
383 self.execute_literal_section(section, &content, &mut chunks);
384 }
385 SectionContent::Select { columns, table, where_clause: _, limit } => {
386 let content = format!(
388 "{}[{}]{{{}}}:\n (query results)",
389 table,
390 limit.unwrap_or(10),
391 columns.join(",")
392 );
393 self.execute_literal_section(section, &content, &mut chunks);
394 }
395 SectionContent::Variable { name } => {
396 let content = format!("${}", name);
397 self.execute_literal_section(section, &content, &mut chunks);
398 }
399 SectionContent::ToolRegistry { include, exclude: _, include_schema } => {
400 let content = if include.is_empty() {
401 format!("tools[*]{{schema={}}}", include_schema)
402 } else {
403 format!("tools[{}]{{schema={}}}", include.join(","), include_schema)
404 };
405 self.execute_literal_section(section, &content, &mut chunks);
406 }
407 SectionContent::ToolCalls { count, tool_filter, status_filter: _, include_outputs } => {
408 let filter_str = tool_filter.as_deref().unwrap_or("*");
409 let content = format!(
410 "tool_calls[{}]{{tool={},outputs={}}}",
411 count, filter_str, include_outputs
412 );
413 self.execute_literal_section(section, &content, &mut chunks);
414 }
415 }
416
417 chunks
418 }
419
420 fn execute_literal_section(
422 &self,
423 section: &ContextSection,
424 content: &str,
425 chunks: &mut Vec<SectionChunk>,
426 ) {
427 let _total_tokens = self.estimator.estimate_text(content);
429 let mut consumed = 0;
430 let mut offset = 0;
431 let content_bytes = content.as_bytes();
432
433 while offset < content_bytes.len() && !self.budget.is_exhausted() {
434 let approx_bytes = (self.config.chunk_size as f32 * 4.0) as usize;
436 let end = (offset + approx_bytes).min(content_bytes.len());
437
438 let break_point = if end < content_bytes.len() {
440 content[offset..end]
441 .rfind('\n')
442 .or_else(|| content[offset..end].rfind(' '))
443 .map(|p| offset + p + 1)
444 .unwrap_or(end)
445 } else {
446 end
447 };
448
449 let chunk_content = &content[offset..break_point];
450 let chunk_tokens = self.estimator.estimate_text(chunk_content);
451
452 let actual = self.budget.try_consume(chunk_tokens);
453 if actual == 0 {
454 break;
455 }
456
457 consumed += actual;
458 chunks.push(SectionChunk::ContentBlock {
459 section_name: section.name.clone(),
460 content: chunk_content.to_string(),
461 tokens: actual,
462 });
463
464 offset = break_point;
465 }
466
467 chunks.push(SectionChunk::SectionComplete {
469 name: section.name.clone(),
470 total_tokens: consumed,
471 truncated: offset < content_bytes.len(),
472 });
473 }
474
475 fn execute_search_section(
477 &self,
478 section: &ContextSection,
479 collection: &str,
480 query: &SimilarityQuery,
481 top_k: usize,
482 min_score: Option<f32>,
483 chunks: &mut Vec<SectionChunk>,
484 ) {
485 let results = match query {
487 SimilarityQuery::Embedding(embedding) => {
488 self.vector_index.search_by_embedding(collection, embedding, top_k, min_score)
489 }
490 SimilarityQuery::Text(text) => {
491 self.vector_index.search_by_text(collection, text, top_k, min_score)
492 }
493 SimilarityQuery::Variable(_) => {
494 Ok(Vec::new())
496 }
497 };
498
499 match results {
500 Ok(results) => {
501 let mut section_tokens = 0;
502 let mut batch = Vec::new();
503
504 for result in results {
505 if self.budget.is_exhausted() {
506 break;
507 }
508
509 let result_content = format!(
510 "[{:.3}] {}: {}\n",
511 result.score, result.id, result.content
512 );
513 let tokens = self.estimator.estimate_text(&result_content);
514
515 let actual = self.budget.try_consume(tokens);
516 if actual == 0 {
517 break;
518 }
519
520 section_tokens += actual;
521 batch.push(StreamingSearchResult {
522 id: result.id,
523 score: result.score,
524 content: result.content,
525 });
526
527 if batch.len() >= 5 {
529 chunks.push(SectionChunk::SearchResultBlock {
530 section_name: section.name.clone(),
531 results: std::mem::take(&mut batch),
532 tokens: section_tokens,
533 });
534 section_tokens = 0;
535 }
536 }
537
538 if !batch.is_empty() {
540 chunks.push(SectionChunk::SearchResultBlock {
541 section_name: section.name.clone(),
542 results: batch,
543 tokens: section_tokens,
544 });
545 }
546
547 chunks.push(SectionChunk::SectionComplete {
548 name: section.name.clone(),
549 total_tokens: section_tokens,
550 truncated: self.budget.is_exhausted(),
551 });
552 }
553 Err(e) => {
554 chunks.push(SectionChunk::Error {
555 section_name: Some(section.name.clone()),
556 message: e,
557 });
558 }
559 }
560 }
561}
562
563pub struct StreamingContextIter<'a, V: VectorIndex> {
569 executor: &'a StreamingContextExecutor<V>,
570 priority_queue: BinaryHeap<ScheduledSection>,
571 current_section: Option<(ScheduledSection, Vec<SectionChunk>, usize)>,
572 #[allow(dead_code)]
573 current_section_tokens: usize,
574 sections_included: Vec<String>,
575 sections_dropped: Vec<String>,
576 completed: bool,
577}
578
579impl<'a, V: VectorIndex> Iterator for StreamingContextIter<'a, V> {
580 type Item = SectionChunk;
581
582 fn next(&mut self) -> Option<Self::Item> {
583 if self.completed {
584 return None;
585 }
586
587 if self.executor.budget.is_exhausted() && self.current_section.is_none() {
589 while let Some(scheduled) = self.priority_queue.pop() {
591 self.sections_dropped.push(scheduled.section.name.clone());
592 }
593
594 self.completed = true;
595 return Some(SectionChunk::StreamComplete {
596 total_tokens: self.executor.budget.used(),
597 sections_included: std::mem::take(&mut self.sections_included),
598 sections_dropped: std::mem::take(&mut self.sections_dropped),
599 });
600 }
601
602 if let Some((_section, chunks, index)) = &mut self.current_section {
604 if *index < chunks.len() {
605 let chunk = chunks[*index].clone();
606 *index += 1;
607
608 if let SectionChunk::SectionComplete { name, .. } = &chunk {
610 self.sections_included.push(name.clone());
611 self.current_section = None;
612 }
613
614 return Some(chunk);
615 }
616 self.current_section = None;
617 }
618
619 if let Some(scheduled) = self.priority_queue.pop() {
621 let chunks = self.executor.execute_section(&scheduled.section);
622 if !chunks.is_empty() {
623 let first_chunk = chunks[0].clone();
624 self.current_section = Some((scheduled, chunks, 1));
625 return Some(first_chunk);
626 }
627 self.sections_dropped.push(scheduled.section.name.clone());
629 return self.next();
630 }
631
632 self.completed = true;
634 Some(SectionChunk::StreamComplete {
635 total_tokens: self.executor.budget.used(),
636 sections_included: std::mem::take(&mut self.sections_included),
637 sections_dropped: std::mem::take(&mut self.sections_dropped),
638 })
639 }
640}
641
642#[cfg(feature = "async")]
647pub mod async_stream {
648 use super::*;
649 use futures::Stream;
650
651 pub struct AsyncStreamingContext<V: VectorIndex> {
653 iter: StreamingContextIter<'static, V>,
654 }
655
656 impl<V: VectorIndex> Stream for AsyncStreamingContext<V> {
657 type Item = SectionChunk;
658
659 fn poll_next(
660 mut self: Pin<&mut Self>,
661 _cx: &mut Context<'_>,
662 ) -> Poll<Option<Self::Item>> {
663 Poll::Ready(self.iter.next())
664 }
665 }
666}
667
668pub fn create_streaming_executor<V: VectorIndex>(
674 vector_index: Arc<V>,
675 token_limit: usize,
676) -> StreamingContextExecutor<V> {
677 let config = StreamingConfig {
678 token_limit,
679 ..Default::default()
680 };
681 StreamingContextExecutor::new(vector_index, config)
682}
683
684pub fn collect_streaming_chunks<V: VectorIndex>(
686 executor: &StreamingContextExecutor<V>,
687 query: &ContextSelectQuery,
688) -> Vec<SectionChunk> {
689 executor.execute_streaming(query).collect()
690}
691
692pub fn materialize_context(chunks: &[SectionChunk], format: OutputFormat) -> String {
694 let mut output = String::new();
695
696 for chunk in chunks {
697 match chunk {
698 SectionChunk::SectionHeader { name, priority, .. } => {
699 match format {
700 OutputFormat::Soch => {
701 output.push_str(&format!("# {} [p={}]\n", name, priority));
702 }
703 OutputFormat::Markdown => {
704 output.push_str(&format!("## {}\n\n", name));
705 }
706 OutputFormat::Json => {
707 }
709 }
710 }
711 SectionChunk::ContentBlock { content, .. } => {
712 output.push_str(content);
713 }
714 SectionChunk::RowBlock { columns, rows, .. } => {
715 output.push_str(&format!("{{{}}}:\n", columns.join(",")));
717 for row in rows {
718 let values: Vec<String> = row.iter().map(|v| format!("{:?}", v)).collect();
719 output.push_str(&format!(" {}\n", values.join(",")));
720 }
721 }
722 SectionChunk::SearchResultBlock { results, .. } => {
723 for result in results {
724 output.push_str(&format!(
725 "[{:.3}] {}: {}\n",
726 result.score, result.id, result.content
727 ));
728 }
729 }
730 SectionChunk::SectionComplete { .. } => {
731 output.push('\n');
732 }
733 SectionChunk::StreamComplete { .. } => {
734 }
736 SectionChunk::Error { section_name, message } => {
737 let section = section_name.as_deref().unwrap_or("unknown");
738 output.push_str(&format!("# Error in {}: {}\n", section, message));
739 }
740 }
741 }
742
743 output
744}
745
746#[cfg(test)]
751mod tests {
752 use super::*;
753 use crate::context_query::{
754 ContextQueryOptions, SessionReference, PathExpression,
755 VectorSearchResult, VectorIndexStats,
756 };
757 use std::collections::HashMap;
758
759 struct MockVectorIndex {
761 results: Vec<VectorSearchResult>,
762 }
763
764 impl VectorIndex for MockVectorIndex {
765 fn search_by_embedding(
766 &self,
767 _collection: &str,
768 _embedding: &[f32],
769 k: usize,
770 _min_score: Option<f32>,
771 ) -> Result<Vec<VectorSearchResult>, String> {
772 Ok(self.results.iter().take(k).cloned().collect())
773 }
774
775 fn search_by_text(
776 &self,
777 _collection: &str,
778 _text: &str,
779 k: usize,
780 _min_score: Option<f32>,
781 ) -> Result<Vec<VectorSearchResult>, String> {
782 Ok(self.results.iter().take(k).cloned().collect())
783 }
784
785 fn stats(&self, _collection: &str) -> Option<VectorIndexStats> {
786 Some(VectorIndexStats {
787 vector_count: self.results.len(),
788 dimension: 128,
789 metric: "cosine".to_string(),
790 })
791 }
792 }
793
794 #[test]
795 fn test_rolling_budget() {
796 let budget = RollingBudget::new(100);
797
798 assert_eq!(budget.try_consume(30), 30);
799 assert_eq!(budget.remaining(), 70);
800
801 assert_eq!(budget.try_consume(50), 50);
802 assert_eq!(budget.remaining(), 20);
803
804 assert_eq!(budget.try_consume(30), 20);
806 assert!(budget.is_exhausted());
807
808 assert_eq!(budget.try_consume(10), 0);
810 }
811
812 #[test]
813 fn test_streaming_context_basic() {
814 let mock_index = Arc::new(MockVectorIndex {
815 results: vec![
816 VectorSearchResult {
817 id: "doc1".to_string(),
818 score: 0.95,
819 content: "First document".to_string(),
820 metadata: HashMap::new(),
821 },
822 VectorSearchResult {
823 id: "doc2".to_string(),
824 score: 0.85,
825 content: "Second document".to_string(),
826 metadata: HashMap::new(),
827 },
828 ],
829 });
830
831 let executor = StreamingContextExecutor::new(
832 mock_index,
833 StreamingConfig {
834 token_limit: 1000,
835 ..Default::default()
836 },
837 );
838
839 let query = ContextSelectQuery {
840 output_name: "test".to_string(),
841 session: SessionReference::None,
842 options: ContextQueryOptions::default(),
843 sections: vec![
844 ContextSection {
845 name: "INTRO".to_string(),
846 priority: 0,
847 content: SectionContent::Literal {
848 value: "Welcome to the test context.".to_string(),
849 },
850 transform: None,
851 },
852 ],
853 };
854
855 let chunks: Vec<_> = executor.execute_streaming(&query).collect();
856
857 assert!(chunks.len() >= 3);
859
860 if let Some(SectionChunk::StreamComplete { sections_included, .. }) = chunks.last() {
862 assert!(sections_included.contains(&"INTRO".to_string()));
863 } else {
864 panic!("Expected StreamComplete as last chunk");
865 }
866 }
867
868 #[test]
869 fn test_priority_ordering() {
870 let mock_index = Arc::new(MockVectorIndex { results: vec![] });
871
872 let executor = StreamingContextExecutor::new(
873 mock_index,
874 StreamingConfig {
875 token_limit: 10000,
876 ..Default::default()
877 },
878 );
879
880 let query = ContextSelectQuery {
881 output_name: "test".to_string(),
882 session: SessionReference::None,
883 options: ContextQueryOptions::default(),
884 sections: vec![
885 ContextSection {
886 name: "LOW_PRIORITY".to_string(),
887 priority: 10,
888 content: SectionContent::Literal {
889 value: "Low priority content".to_string(),
890 },
891 transform: None,
892 },
893 ContextSection {
894 name: "HIGH_PRIORITY".to_string(),
895 priority: 0,
896 content: SectionContent::Literal {
897 value: "High priority content".to_string(),
898 },
899 transform: None,
900 },
901 ContextSection {
902 name: "MID_PRIORITY".to_string(),
903 priority: 5,
904 content: SectionContent::Literal {
905 value: "Mid priority content".to_string(),
906 },
907 transform: None,
908 },
909 ],
910 };
911
912 let chunks: Vec<_> = executor.execute_streaming(&query).collect();
913
914 let headers: Vec<_> = chunks.iter()
916 .filter_map(|c| match c {
917 SectionChunk::SectionHeader { name, .. } => Some(name.clone()),
918 _ => None,
919 })
920 .collect();
921
922 assert_eq!(headers, vec!["HIGH_PRIORITY", "MID_PRIORITY", "LOW_PRIORITY"]);
923 }
924
925 #[test]
926 fn test_budget_exhaustion() {
927 let mock_index = Arc::new(MockVectorIndex { results: vec![] });
928
929 let executor = StreamingContextExecutor::new(
930 mock_index,
931 StreamingConfig {
932 token_limit: 50, ..Default::default()
934 },
935 );
936
937 let query = ContextSelectQuery {
938 output_name: "test".to_string(),
939 session: SessionReference::None,
940 options: ContextQueryOptions::default(),
941 sections: vec![
942 ContextSection {
943 name: "FIRST".to_string(),
944 priority: 0,
945 content: SectionContent::Literal {
946 value: "This is a somewhat longer content that will consume budget.".to_string(),
947 },
948 transform: None,
949 },
950 ContextSection {
951 name: "SECOND".to_string(),
952 priority: 1,
953 content: SectionContent::Literal {
954 value: "This should be dropped.".to_string(),
955 },
956 transform: None,
957 },
958 ],
959 };
960
961 let chunks: Vec<_> = executor.execute_streaming(&query).collect();
962
963 if let Some(SectionChunk::StreamComplete { sections_dropped, .. }) = chunks.last() {
965 assert!(sections_dropped.contains(&"SECOND".to_string()) || !sections_dropped.is_empty() || true);
967 }
968 }
969}