1#![allow(unexpected_cfgs)]
19
20use std::collections::BinaryHeap;
53use std::sync::Arc;
54use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
55
56use crate::context_query::{
57 ContextSection, ContextSelectQuery, OutputFormat, SectionContent, SimilarityQuery,
58 TruncationStrategy, VectorIndex,
59};
60use crate::soch_ql::SochValue;
61use crate::token_budget::TokenEstimator;
62
63#[derive(Debug, Clone)]
69pub enum SectionChunk {
70 SectionHeader {
72 name: String,
73 priority: i32,
74 estimated_tokens: usize,
75 },
76
77 RowBlock {
79 section_name: String,
80 rows: Vec<Vec<SochValue>>,
81 columns: Vec<String>,
82 tokens: usize,
83 },
84
85 SearchResultBlock {
87 section_name: String,
88 results: Vec<StreamingSearchResult>,
89 tokens: usize,
90 },
91
92 ContentBlock {
94 section_name: String,
95 content: String,
96 tokens: usize,
97 },
98
99 SectionComplete {
101 name: String,
102 total_tokens: usize,
103 truncated: bool,
104 },
105
106 StreamComplete {
108 total_tokens: usize,
109 sections_included: Vec<String>,
110 sections_dropped: Vec<String>,
111 },
112
113 Error {
115 section_name: Option<String>,
116 message: String,
117 },
118}
119
120#[derive(Debug, Clone)]
122pub struct StreamingSearchResult {
123 pub id: String,
124 pub score: f32,
125 pub content: String,
126}
127
128#[derive(Debug, Clone)]
130pub struct StreamingConfig {
131 pub token_limit: usize,
133
134 pub chunk_size: usize,
136
137 pub include_headers: bool,
139
140 pub format: OutputFormat,
142
143 pub truncation: TruncationStrategy,
145
146 pub parallel_execution: bool,
148
149 pub exact_tokens: bool,
151}
152
153impl Default for StreamingConfig {
154 fn default() -> Self {
155 Self {
156 token_limit: 4096,
157 chunk_size: 256,
158 include_headers: true,
159 format: OutputFormat::Soch,
160 truncation: TruncationStrategy::TailDrop,
161 parallel_execution: false,
162 exact_tokens: false,
163 }
164 }
165}
166
167#[derive(Debug)]
173pub struct RollingBudget {
174 limit: usize,
176
177 used: AtomicUsize,
179
180 exhausted: AtomicBool,
182}
183
184impl RollingBudget {
185 pub fn new(limit: usize) -> Self {
187 Self {
188 limit,
189 used: AtomicUsize::new(0),
190 exhausted: AtomicBool::new(false),
191 }
192 }
193
194 pub fn try_consume(&self, tokens: usize) -> usize {
197 if self.exhausted.load(Ordering::Acquire) {
198 return 0;
199 }
200
201 let mut current = self.used.load(Ordering::Acquire);
202 loop {
203 let remaining = self.limit.saturating_sub(current);
204 if remaining == 0 {
205 self.exhausted.store(true, Ordering::Release);
206 return 0;
207 }
208
209 let to_consume = tokens.min(remaining);
210 match self.used.compare_exchange_weak(
211 current,
212 current + to_consume,
213 Ordering::AcqRel,
214 Ordering::Acquire,
215 ) {
216 Ok(_) => {
217 if current + to_consume >= self.limit {
218 self.exhausted.store(true, Ordering::Release);
219 }
220 return to_consume;
221 }
222 Err(actual) => current = actual,
223 }
224 }
225 }
226
227 pub fn remaining(&self) -> usize {
229 self.limit.saturating_sub(self.used.load(Ordering::Acquire))
230 }
231
232 pub fn is_exhausted(&self) -> bool {
234 self.exhausted.load(Ordering::Acquire)
235 }
236
237 pub fn used(&self) -> usize {
239 self.used.load(Ordering::Acquire)
240 }
241}
242
243#[derive(Debug, Clone)]
249struct ScheduledSection {
250 priority: i32,
252
253 index: usize,
255
256 section: ContextSection,
258}
259
260impl Eq for ScheduledSection {}
261
262impl PartialEq for ScheduledSection {
263 fn eq(&self, other: &Self) -> bool {
264 self.priority == other.priority && self.index == other.index
265 }
266}
267
268impl Ord for ScheduledSection {
269 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
270 other
272 .priority
273 .cmp(&self.priority)
274 .then_with(|| other.index.cmp(&self.index))
275 }
276}
277
278impl PartialOrd for ScheduledSection {
279 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
280 Some(self.cmp(other))
281 }
282}
283
284pub struct StreamingContextExecutor<V: VectorIndex> {
290 estimator: TokenEstimator,
292
293 vector_index: Arc<V>,
295
296 budget: Arc<RollingBudget>,
298
299 config: StreamingConfig,
301}
302
303impl<V: VectorIndex> StreamingContextExecutor<V> {
304 pub fn new(vector_index: Arc<V>, config: StreamingConfig) -> Self {
306 let budget = Arc::new(RollingBudget::new(config.token_limit));
307 Self {
308 estimator: TokenEstimator::new(),
309 vector_index,
310 budget,
311 config,
312 }
313 }
314
315 pub fn execute_streaming(&self, query: &ContextSelectQuery) -> StreamingContextIter<'_, V> {
320 let mut priority_queue = BinaryHeap::new();
322 for (index, section) in query.sections.iter().enumerate() {
323 priority_queue.push(ScheduledSection {
324 priority: section.priority,
325 index,
326 section: section.clone(),
327 });
328 }
329
330 StreamingContextIter {
331 executor: self,
332 priority_queue,
333 current_section: None,
334 current_section_tokens: 0,
335 sections_included: Vec::new(),
336 sections_dropped: Vec::new(),
337 completed: false,
338 }
339 }
340
341 fn execute_section(&self, section: &ContextSection) -> Vec<SectionChunk> {
343 let mut chunks = Vec::new();
344
345 if self.config.include_headers {
347 let header_tokens = self.estimator.estimate_text(&format!(
348 "## {} [priority={}]\n",
349 section.name, section.priority
350 ));
351
352 if self.budget.try_consume(header_tokens) > 0 {
353 chunks.push(SectionChunk::SectionHeader {
354 name: section.name.clone(),
355 priority: section.priority,
356 estimated_tokens: header_tokens,
357 });
358 } else {
359 return chunks; }
361 }
362
363 match §ion.content {
365 SectionContent::Literal { value } => {
366 self.execute_literal_section(section, value, &mut chunks);
367 }
368 SectionContent::Search {
369 collection,
370 query,
371 top_k,
372 min_score,
373 } => {
374 self.execute_search_section(
375 section,
376 collection,
377 query,
378 *top_k,
379 *min_score,
380 &mut chunks,
381 );
382 }
383 SectionContent::Get { path } => {
384 let content = format!("{}:**", path.to_path_string());
386 self.execute_literal_section(section, &content, &mut chunks);
387 }
388 SectionContent::Last {
389 count,
390 table,
391 where_clause: _,
392 } => {
393 let content = format!("{}[{}]:\n (recent entries)", table, count);
395 self.execute_literal_section(section, &content, &mut chunks);
396 }
397 SectionContent::Select {
398 columns,
399 table,
400 where_clause: _,
401 limit,
402 } => {
403 let content = format!(
405 "{}[{}]{{{}}}:\n (query results)",
406 table,
407 limit.unwrap_or(10),
408 columns.join(",")
409 );
410 self.execute_literal_section(section, &content, &mut chunks);
411 }
412 SectionContent::Variable { name } => {
413 let content = format!("${}", name);
414 self.execute_literal_section(section, &content, &mut chunks);
415 }
416 SectionContent::ToolRegistry {
417 include,
418 exclude: _,
419 include_schema,
420 } => {
421 let content = if include.is_empty() {
422 format!("tools[*]{{schema={}}}", include_schema)
423 } else {
424 format!("tools[{}]{{schema={}}}", include.join(","), include_schema)
425 };
426 self.execute_literal_section(section, &content, &mut chunks);
427 }
428 SectionContent::ToolCalls {
429 count,
430 tool_filter,
431 status_filter: _,
432 include_outputs,
433 } => {
434 let filter_str = tool_filter.as_deref().unwrap_or("*");
435 let content = format!(
436 "tool_calls[{}]{{tool={},outputs={}}}",
437 count, filter_str, include_outputs
438 );
439 self.execute_literal_section(section, &content, &mut chunks);
440 }
441 }
442
443 chunks
444 }
445
446 fn execute_literal_section(
448 &self,
449 section: &ContextSection,
450 content: &str,
451 chunks: &mut Vec<SectionChunk>,
452 ) {
453 let _total_tokens = self.estimator.estimate_text(content);
455 let mut consumed = 0;
456 let mut offset = 0;
457 let content_bytes = content.as_bytes();
458
459 while offset < content_bytes.len() && !self.budget.is_exhausted() {
460 let approx_bytes = (self.config.chunk_size as f32 * 4.0) as usize;
462 let end = (offset + approx_bytes).min(content_bytes.len());
463
464 let break_point = if end < content_bytes.len() {
466 content[offset..end]
467 .rfind('\n')
468 .or_else(|| content[offset..end].rfind(' '))
469 .map(|p| offset + p + 1)
470 .unwrap_or(end)
471 } else {
472 end
473 };
474
475 let chunk_content = &content[offset..break_point];
476 let chunk_tokens = self.estimator.estimate_text(chunk_content);
477
478 let actual = self.budget.try_consume(chunk_tokens);
479 if actual == 0 {
480 break;
481 }
482
483 consumed += actual;
484 chunks.push(SectionChunk::ContentBlock {
485 section_name: section.name.clone(),
486 content: chunk_content.to_string(),
487 tokens: actual,
488 });
489
490 offset = break_point;
491 }
492
493 chunks.push(SectionChunk::SectionComplete {
495 name: section.name.clone(),
496 total_tokens: consumed,
497 truncated: offset < content_bytes.len(),
498 });
499 }
500
501 fn execute_search_section(
503 &self,
504 section: &ContextSection,
505 collection: &str,
506 query: &SimilarityQuery,
507 top_k: usize,
508 min_score: Option<f32>,
509 chunks: &mut Vec<SectionChunk>,
510 ) {
511 let results = match query {
513 SimilarityQuery::Embedding(embedding) => self
514 .vector_index
515 .search_by_embedding(collection, embedding, top_k, min_score),
516 SimilarityQuery::Text(text) => self
517 .vector_index
518 .search_by_text(collection, text, top_k, min_score),
519 SimilarityQuery::Variable(_) => {
520 Ok(Vec::new())
522 }
523 };
524
525 match results {
526 Ok(results) => {
527 let mut section_tokens = 0;
528 let mut batch = Vec::new();
529
530 for result in results {
531 if self.budget.is_exhausted() {
532 break;
533 }
534
535 let result_content =
536 format!("[{:.3}] {}: {}\n", result.score, result.id, result.content);
537 let tokens = self.estimator.estimate_text(&result_content);
538
539 let actual = self.budget.try_consume(tokens);
540 if actual == 0 {
541 break;
542 }
543
544 section_tokens += actual;
545 batch.push(StreamingSearchResult {
546 id: result.id,
547 score: result.score,
548 content: result.content,
549 });
550
551 if batch.len() >= 5 {
553 chunks.push(SectionChunk::SearchResultBlock {
554 section_name: section.name.clone(),
555 results: std::mem::take(&mut batch),
556 tokens: section_tokens,
557 });
558 section_tokens = 0;
559 }
560 }
561
562 if !batch.is_empty() {
564 chunks.push(SectionChunk::SearchResultBlock {
565 section_name: section.name.clone(),
566 results: batch,
567 tokens: section_tokens,
568 });
569 }
570
571 chunks.push(SectionChunk::SectionComplete {
572 name: section.name.clone(),
573 total_tokens: section_tokens,
574 truncated: self.budget.is_exhausted(),
575 });
576 }
577 Err(e) => {
578 chunks.push(SectionChunk::Error {
579 section_name: Some(section.name.clone()),
580 message: e,
581 });
582 }
583 }
584 }
585}
586
587pub struct StreamingContextIter<'a, V: VectorIndex> {
593 executor: &'a StreamingContextExecutor<V>,
594 priority_queue: BinaryHeap<ScheduledSection>,
595 current_section: Option<(ScheduledSection, Vec<SectionChunk>, usize)>,
596 #[allow(dead_code)]
597 current_section_tokens: usize,
598 sections_included: Vec<String>,
599 sections_dropped: Vec<String>,
600 completed: bool,
601}
602
603impl<'a, V: VectorIndex> Iterator for StreamingContextIter<'a, V> {
604 type Item = SectionChunk;
605
606 fn next(&mut self) -> Option<Self::Item> {
607 if self.completed {
608 return None;
609 }
610
611 if self.executor.budget.is_exhausted() && self.current_section.is_none() {
613 while let Some(scheduled) = self.priority_queue.pop() {
615 self.sections_dropped.push(scheduled.section.name.clone());
616 }
617
618 self.completed = true;
619 return Some(SectionChunk::StreamComplete {
620 total_tokens: self.executor.budget.used(),
621 sections_included: std::mem::take(&mut self.sections_included),
622 sections_dropped: std::mem::take(&mut self.sections_dropped),
623 });
624 }
625
626 if let Some((_section, chunks, index)) = &mut self.current_section {
628 if *index < chunks.len() {
629 let chunk = chunks[*index].clone();
630 *index += 1;
631
632 if let SectionChunk::SectionComplete { name, .. } = &chunk {
634 self.sections_included.push(name.clone());
635 self.current_section = None;
636 }
637
638 return Some(chunk);
639 }
640 self.current_section = None;
641 }
642
643 if let Some(scheduled) = self.priority_queue.pop() {
645 let chunks = self.executor.execute_section(&scheduled.section);
646 if !chunks.is_empty() {
647 let first_chunk = chunks[0].clone();
648 self.current_section = Some((scheduled, chunks, 1));
649 return Some(first_chunk);
650 }
651 self.sections_dropped.push(scheduled.section.name.clone());
653 return self.next();
654 }
655
656 self.completed = true;
658 Some(SectionChunk::StreamComplete {
659 total_tokens: self.executor.budget.used(),
660 sections_included: std::mem::take(&mut self.sections_included),
661 sections_dropped: std::mem::take(&mut self.sections_dropped),
662 })
663 }
664}
665
666#[cfg(feature = "async")]
671pub mod async_stream {
672 use super::*;
673 use futures::Stream;
674
675 pub struct AsyncStreamingContext<V: VectorIndex> {
677 iter: StreamingContextIter<'static, V>,
678 }
679
680 impl<V: VectorIndex> Stream for AsyncStreamingContext<V> {
681 type Item = SectionChunk;
682
683 fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
684 Poll::Ready(self.iter.next())
685 }
686 }
687}
688
689pub fn create_streaming_executor<V: VectorIndex>(
695 vector_index: Arc<V>,
696 token_limit: usize,
697) -> StreamingContextExecutor<V> {
698 let config = StreamingConfig {
699 token_limit,
700 ..Default::default()
701 };
702 StreamingContextExecutor::new(vector_index, config)
703}
704
705pub fn collect_streaming_chunks<V: VectorIndex>(
707 executor: &StreamingContextExecutor<V>,
708 query: &ContextSelectQuery,
709) -> Vec<SectionChunk> {
710 executor.execute_streaming(query).collect()
711}
712
713pub fn materialize_context(chunks: &[SectionChunk], format: OutputFormat) -> String {
715 let mut output = String::new();
716
717 for chunk in chunks {
718 match chunk {
719 SectionChunk::SectionHeader { name, priority, .. } => {
720 match format {
721 OutputFormat::Soch => {
722 output.push_str(&format!("# {} [p={}]\n", name, priority));
723 }
724 OutputFormat::Markdown => {
725 output.push_str(&format!("## {}\n\n", name));
726 }
727 OutputFormat::Json => {
728 }
730 }
731 }
732 SectionChunk::ContentBlock { content, .. } => {
733 output.push_str(content);
734 }
735 SectionChunk::RowBlock { columns, rows, .. } => {
736 output.push_str(&format!("{{{}}}:\n", columns.join(",")));
738 for row in rows {
739 let values: Vec<String> = row.iter().map(|v| format!("{:?}", v)).collect();
740 output.push_str(&format!(" {}\n", values.join(",")));
741 }
742 }
743 SectionChunk::SearchResultBlock { results, .. } => {
744 for result in results {
745 output.push_str(&format!(
746 "[{:.3}] {}: {}\n",
747 result.score, result.id, result.content
748 ));
749 }
750 }
751 SectionChunk::SectionComplete { .. } => {
752 output.push('\n');
753 }
754 SectionChunk::StreamComplete { .. } => {
755 }
757 SectionChunk::Error {
758 section_name,
759 message,
760 } => {
761 let section = section_name.as_deref().unwrap_or("unknown");
762 output.push_str(&format!("# Error in {}: {}\n", section, message));
763 }
764 }
765 }
766
767 output
768}
769
770#[cfg(test)]
775mod tests {
776 use super::*;
777 use crate::context_query::{
778 ContextQueryOptions, PathExpression, SessionReference, VectorIndexStats, VectorSearchResult,
779 };
780 use std::collections::HashMap;
781
782 struct MockVectorIndex {
784 results: Vec<VectorSearchResult>,
785 }
786
787 impl VectorIndex for MockVectorIndex {
788 fn search_by_embedding(
789 &self,
790 _collection: &str,
791 _embedding: &[f32],
792 k: usize,
793 _min_score: Option<f32>,
794 ) -> Result<Vec<VectorSearchResult>, String> {
795 Ok(self.results.iter().take(k).cloned().collect())
796 }
797
798 fn search_by_text(
799 &self,
800 _collection: &str,
801 _text: &str,
802 k: usize,
803 _min_score: Option<f32>,
804 ) -> Result<Vec<VectorSearchResult>, String> {
805 Ok(self.results.iter().take(k).cloned().collect())
806 }
807
808 fn stats(&self, _collection: &str) -> Option<VectorIndexStats> {
809 Some(VectorIndexStats {
810 vector_count: self.results.len(),
811 dimension: 128,
812 metric: "cosine".to_string(),
813 })
814 }
815 }
816
817 #[test]
818 fn test_rolling_budget() {
819 let budget = RollingBudget::new(100);
820
821 assert_eq!(budget.try_consume(30), 30);
822 assert_eq!(budget.remaining(), 70);
823
824 assert_eq!(budget.try_consume(50), 50);
825 assert_eq!(budget.remaining(), 20);
826
827 assert_eq!(budget.try_consume(30), 20);
829 assert!(budget.is_exhausted());
830
831 assert_eq!(budget.try_consume(10), 0);
833 }
834
835 #[test]
836 fn test_streaming_context_basic() {
837 let mock_index = Arc::new(MockVectorIndex {
838 results: vec![
839 VectorSearchResult {
840 id: "doc1".to_string(),
841 score: 0.95,
842 content: "First document".to_string(),
843 metadata: HashMap::new(),
844 },
845 VectorSearchResult {
846 id: "doc2".to_string(),
847 score: 0.85,
848 content: "Second document".to_string(),
849 metadata: HashMap::new(),
850 },
851 ],
852 });
853
854 let executor = StreamingContextExecutor::new(
855 mock_index,
856 StreamingConfig {
857 token_limit: 1000,
858 ..Default::default()
859 },
860 );
861
862 let query = ContextSelectQuery {
863 output_name: "test".to_string(),
864 session: SessionReference::None,
865 options: ContextQueryOptions::default(),
866 sections: vec![ContextSection {
867 name: "INTRO".to_string(),
868 priority: 0,
869 content: SectionContent::Literal {
870 value: "Welcome to the test context.".to_string(),
871 },
872 transform: None,
873 }],
874 };
875
876 let chunks: Vec<_> = executor.execute_streaming(&query).collect();
877
878 assert!(chunks.len() >= 3);
880
881 if let Some(SectionChunk::StreamComplete {
883 sections_included, ..
884 }) = chunks.last()
885 {
886 assert!(sections_included.contains(&"INTRO".to_string()));
887 } else {
888 panic!("Expected StreamComplete as last chunk");
889 }
890 }
891
892 #[test]
893 fn test_priority_ordering() {
894 let mock_index = Arc::new(MockVectorIndex { results: vec![] });
895
896 let executor = StreamingContextExecutor::new(
897 mock_index,
898 StreamingConfig {
899 token_limit: 10000,
900 ..Default::default()
901 },
902 );
903
904 let query = ContextSelectQuery {
905 output_name: "test".to_string(),
906 session: SessionReference::None,
907 options: ContextQueryOptions::default(),
908 sections: vec![
909 ContextSection {
910 name: "LOW_PRIORITY".to_string(),
911 priority: 10,
912 content: SectionContent::Literal {
913 value: "Low priority content".to_string(),
914 },
915 transform: None,
916 },
917 ContextSection {
918 name: "HIGH_PRIORITY".to_string(),
919 priority: 0,
920 content: SectionContent::Literal {
921 value: "High priority content".to_string(),
922 },
923 transform: None,
924 },
925 ContextSection {
926 name: "MID_PRIORITY".to_string(),
927 priority: 5,
928 content: SectionContent::Literal {
929 value: "Mid priority content".to_string(),
930 },
931 transform: None,
932 },
933 ],
934 };
935
936 let chunks: Vec<_> = executor.execute_streaming(&query).collect();
937
938 let headers: Vec<_> = chunks
940 .iter()
941 .filter_map(|c| match c {
942 SectionChunk::SectionHeader { name, .. } => Some(name.clone()),
943 _ => None,
944 })
945 .collect();
946
947 assert_eq!(
948 headers,
949 vec!["HIGH_PRIORITY", "MID_PRIORITY", "LOW_PRIORITY"]
950 );
951 }
952
953 #[test]
954 fn test_budget_exhaustion() {
955 let mock_index = Arc::new(MockVectorIndex { results: vec![] });
956
957 let executor = StreamingContextExecutor::new(
958 mock_index,
959 StreamingConfig {
960 token_limit: 50, ..Default::default()
962 },
963 );
964
965 let query = ContextSelectQuery {
966 output_name: "test".to_string(),
967 session: SessionReference::None,
968 options: ContextQueryOptions::default(),
969 sections: vec![
970 ContextSection {
971 name: "FIRST".to_string(),
972 priority: 0,
973 content: SectionContent::Literal {
974 value: "This is a somewhat longer content that will consume budget."
975 .to_string(),
976 },
977 transform: None,
978 },
979 ContextSection {
980 name: "SECOND".to_string(),
981 priority: 1,
982 content: SectionContent::Literal {
983 value: "This should be dropped.".to_string(),
984 },
985 transform: None,
986 },
987 ],
988 };
989
990 let chunks: Vec<_> = executor.execute_streaming(&query).collect();
991
992 assert!(matches!(
996 chunks.last(),
997 Some(SectionChunk::StreamComplete { .. })
998 ));
999 }
1000}