Skip to main content

a3s_code_core/context/
static_provider.rs

1//! Static context provider for session-local context items.
2
3use super::{ContextItem, ContextProvider, ContextQuery, ContextResult};
4use crate::text::truncate_utf8;
5
6/// Provides pre-built context items through the same retrieval pipeline as
7/// external providers.
8#[derive(Debug, Clone)]
9pub struct StaticContextProvider {
10    name: String,
11    items: Vec<ContextItem>,
12}
13
14impl StaticContextProvider {
15    pub fn new(name: impl Into<String>) -> Self {
16        Self {
17            name: name.into(),
18            items: Vec::new(),
19        }
20    }
21
22    pub fn from_items(
23        name: impl Into<String>,
24        items: impl IntoIterator<Item = ContextItem>,
25    ) -> Self {
26        Self {
27            name: name.into(),
28            items: items.into_iter().collect(),
29        }
30    }
31
32    pub fn with_item(mut self, item: ContextItem) -> Self {
33        self.items.push(item);
34        self
35    }
36}
37
38#[async_trait::async_trait]
39impl ContextProvider for StaticContextProvider {
40    fn name(&self) -> &str {
41        &self.name
42    }
43
44    async fn query(&self, query: &ContextQuery) -> anyhow::Result<ContextResult> {
45        let mut result = ContextResult::new(&self.name);
46        let max_results = query.max_results.max(1);
47        let max_tokens = query.max_tokens.max(1);
48        let mut total_tokens = 0usize;
49
50        for item in self.items.iter().take(max_results) {
51            let item_tokens = estimated_tokens(item);
52            if total_tokens + item_tokens > max_tokens {
53                result.truncated = true;
54
55                if result.items.is_empty() {
56                    result.add_item(truncate_item(item, max_tokens));
57                }
58
59                break;
60            }
61
62            let mut item = item.clone();
63            if item.token_count == 0 {
64                item.token_count = item_tokens;
65            }
66            total_tokens += item_tokens;
67            result.add_item(item);
68        }
69
70        if self.items.len() > max_results {
71            result.truncated = true;
72        }
73
74        Ok(result)
75    }
76}
77
78fn estimated_tokens(item: &ContextItem) -> usize {
79    if item.token_count > 0 {
80        item.token_count
81    } else {
82        item.content.split_whitespace().count().max(1)
83    }
84}
85
86fn truncate_item(item: &ContextItem, max_tokens: usize) -> ContextItem {
87    let max_bytes = max_tokens.saturating_mul(4).max(1);
88    let mut truncated = item.clone();
89    let shown = truncate_utf8(&item.content, max_bytes).trim_end();
90    truncated.content = format!("{shown}\n\n[context truncated]");
91    truncated.token_count = max_tokens;
92    truncated
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98    use crate::context::ContextType;
99
100    #[tokio::test]
101    async fn returns_static_items() {
102        let provider = StaticContextProvider::new("static").with_item(
103            ContextItem::new("agents_md", ContextType::Resource, "Follow local rules")
104                .with_relevance(0.95),
105        );
106
107        let result = provider.query(&ContextQuery::new("prompt")).await.unwrap();
108
109        assert_eq!(result.provider, "static");
110        assert_eq!(result.items.len(), 1);
111        assert_eq!(result.items[0].id, "agents_md");
112        assert!(result.items[0].token_count > 0);
113        assert!(!result.truncated);
114    }
115
116    #[tokio::test]
117    async fn respects_result_limit() {
118        let provider = StaticContextProvider::from_items(
119            "static",
120            [
121                ContextItem::new("a", ContextType::Resource, "a").with_token_count(1),
122                ContextItem::new("b", ContextType::Resource, "b").with_token_count(1),
123            ],
124        );
125
126        let result = provider
127            .query(&ContextQuery::new("prompt").with_max_results(1))
128            .await
129            .unwrap();
130
131        assert_eq!(result.items.len(), 1);
132        assert!(result.truncated);
133    }
134
135    #[tokio::test]
136    async fn truncates_oversized_single_item() {
137        let provider = StaticContextProvider::new("static").with_item(
138            ContextItem::new(
139                "large",
140                ContextType::Resource,
141                "one two three four five six",
142            )
143            .with_token_count(6),
144        );
145
146        let result = provider
147            .query(&ContextQuery::new("prompt").with_max_tokens(2))
148            .await
149            .unwrap();
150
151        assert_eq!(result.items.len(), 1);
152        assert_eq!(result.items[0].token_count, 2);
153        assert!(result.items[0].content.contains("[context truncated]"));
154        assert!(result.truncated);
155    }
156}