a3s_code_core/context/
static_provider.rs1use super::{ContextItem, ContextProvider, ContextQuery, ContextResult};
4use crate::text::truncate_utf8;
5
6#[derive(Debug, Clone)]
9pub struct StaticContextProvider {
10 name: String,
11 items: Vec<ContextItem>,
12}
13
14impl StaticContextProvider {
15 pub fn new(name: impl Into<String>) -> Self {
16 Self {
17 name: name.into(),
18 items: Vec::new(),
19 }
20 }
21
22 pub fn from_items(
23 name: impl Into<String>,
24 items: impl IntoIterator<Item = ContextItem>,
25 ) -> Self {
26 Self {
27 name: name.into(),
28 items: items.into_iter().collect(),
29 }
30 }
31
32 pub fn with_item(mut self, item: ContextItem) -> Self {
33 self.items.push(item);
34 self
35 }
36}
37
38#[async_trait::async_trait]
39impl ContextProvider for StaticContextProvider {
40 fn name(&self) -> &str {
41 &self.name
42 }
43
44 async fn query(&self, query: &ContextQuery) -> anyhow::Result<ContextResult> {
45 let mut result = ContextResult::new(&self.name);
46 let max_results = query.max_results.max(1);
47 let max_tokens = query.max_tokens.max(1);
48 let mut total_tokens = 0usize;
49
50 for item in self.items.iter().take(max_results) {
51 let item_tokens = estimated_tokens(item);
52 if total_tokens + item_tokens > max_tokens {
53 result.truncated = true;
54
55 if result.items.is_empty() {
56 result.add_item(truncate_item(item, max_tokens));
57 }
58
59 break;
60 }
61
62 let mut item = item.clone();
63 if item.token_count == 0 {
64 item.token_count = item_tokens;
65 }
66 total_tokens += item_tokens;
67 result.add_item(item);
68 }
69
70 if self.items.len() > max_results {
71 result.truncated = true;
72 }
73
74 Ok(result)
75 }
76}
77
78fn estimated_tokens(item: &ContextItem) -> usize {
79 if item.token_count > 0 {
80 item.token_count
81 } else {
82 item.content.split_whitespace().count().max(1)
83 }
84}
85
86fn truncate_item(item: &ContextItem, max_tokens: usize) -> ContextItem {
87 let max_bytes = max_tokens.saturating_mul(4).max(1);
88 let mut truncated = item.clone();
89 let shown = truncate_utf8(&item.content, max_bytes).trim_end();
90 truncated.content = format!("{shown}\n\n[context truncated]");
91 truncated.token_count = max_tokens;
92 truncated
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use crate::context::ContextType;
99
100 #[tokio::test]
101 async fn returns_static_items() {
102 let provider = StaticContextProvider::new("static").with_item(
103 ContextItem::new("agents_md", ContextType::Resource, "Follow local rules")
104 .with_relevance(0.95),
105 );
106
107 let result = provider.query(&ContextQuery::new("prompt")).await.unwrap();
108
109 assert_eq!(result.provider, "static");
110 assert_eq!(result.items.len(), 1);
111 assert_eq!(result.items[0].id, "agents_md");
112 assert!(result.items[0].token_count > 0);
113 assert!(!result.truncated);
114 }
115
116 #[tokio::test]
117 async fn respects_result_limit() {
118 let provider = StaticContextProvider::from_items(
119 "static",
120 [
121 ContextItem::new("a", ContextType::Resource, "a").with_token_count(1),
122 ContextItem::new("b", ContextType::Resource, "b").with_token_count(1),
123 ],
124 );
125
126 let result = provider
127 .query(&ContextQuery::new("prompt").with_max_results(1))
128 .await
129 .unwrap();
130
131 assert_eq!(result.items.len(), 1);
132 assert!(result.truncated);
133 }
134
135 #[tokio::test]
136 async fn truncates_oversized_single_item() {
137 let provider = StaticContextProvider::new("static").with_item(
138 ContextItem::new(
139 "large",
140 ContextType::Resource,
141 "one two three four five six",
142 )
143 .with_token_count(6),
144 );
145
146 let result = provider
147 .query(&ContextQuery::new("prompt").with_max_tokens(2))
148 .await
149 .unwrap();
150
151 assert_eq!(result.items.len(), 1);
152 assert_eq!(result.items[0].token_count, 2);
153 assert!(result.items[0].content.contains("[context truncated]"));
154 assert!(result.truncated);
155 }
156}