Skip to main content

agents/agent/
context.rs

1use std::sync::{Arc, Mutex};
2
3use crate::llm::LlmRunner;
4use crate::llm::completion::{InputContent, InputItem, Role};
5use async_trait::async_trait;
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::agent::error::AgentResult;
11
12/// Strategy hint for how a context chunk should be retained.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
14pub enum ContextStrategy {
15    Pinnable,
16    Compactable,
17}
18
19/// Role attached to a context message chunk.
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
21pub enum ContextRole {
22    System,
23    User,
24    Assistant,
25}
26
27/// One item in an agent context window.
28#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
29pub enum ContextChunk {
30    Message {
31        strategy: ContextStrategy,
32        role: ContextRole,
33        content: String,
34    },
35    ToolCall {
36        strategy: ContextStrategy,
37        id: String,
38        name: String,
39        args: Value,
40    },
41    ToolResult {
42        strategy: ContextStrategy,
43        id: String,
44        result: Value,
45    },
46}
47
48impl ContextChunk {
49    pub fn system_text(strategy: ContextStrategy, content: impl Into<String>) -> Self {
50        Self::Message {
51            strategy,
52            role: ContextRole::System,
53            content: content.into(),
54        }
55    }
56
57    pub fn user_text(strategy: ContextStrategy, content: impl Into<String>) -> Self {
58        Self::Message {
59            strategy,
60            role: ContextRole::User,
61            content: content.into(),
62        }
63    }
64
65    pub fn assistant_text(strategy: ContextStrategy, content: impl Into<String>) -> Self {
66        Self::Message {
67            strategy,
68            role: ContextRole::Assistant,
69            content: content.into(),
70        }
71    }
72
73    pub fn from_input_item(
74        strategy: ContextStrategy,
75        item: InputItem,
76    ) -> Option<AgentResult<Self>> {
77        match item {
78            InputItem::Message { role, content } => {
79                let text = flatten_input_content(content);
80                let role = match role {
81                    Role::System => ContextRole::System,
82                    Role::User => ContextRole::User,
83                    Role::Assistant => ContextRole::Assistant,
84                };
85                Some(Ok(Self::Message {
86                    strategy,
87                    role,
88                    content: text,
89                }))
90            }
91            InputItem::ToolCall { call } => Some(Ok(Self::ToolCall {
92                strategy,
93                id: call.id,
94                name: call.name,
95                args: call.arguments,
96            })),
97            InputItem::ToolResult {
98                tool_use_id,
99                content,
100            } => Some(match serde_json::from_str::<Value>(&content) {
101                Ok(result) => Ok(Self::ToolResult {
102                    strategy,
103                    id: tool_use_id,
104                    result,
105                }),
106                Err(_) => Ok(Self::ToolResult {
107                    strategy,
108                    id: tool_use_id,
109                    result: Value::String(content),
110                }),
111            }),
112        }
113    }
114
115    pub fn to_input_item(&self) -> Option<AgentResult<InputItem>> {
116        match self {
117            ContextChunk::Message { role, content, .. } => Some(Ok(match role {
118                ContextRole::System => InputItem::system_text(content.clone()),
119                ContextRole::User => InputItem::user_text(content.clone()),
120                ContextRole::Assistant => InputItem::assistant_text(content.clone()),
121            })),
122            ContextChunk::ToolCall { id, name, args, .. } => Some(Ok(InputItem::tool_call(
123                id.clone(),
124                name.clone(),
125                args.clone(),
126            ))),
127            ContextChunk::ToolResult { id, result, .. } => Some(
128                serde_json::to_string(result)
129                    .map(|content| InputItem::tool_result(id.clone(), content))
130                    .map_err(|error| crate::agent::error::AgentError::Internal {
131                        message: error.to_string(),
132                    }),
133            ),
134        }
135    }
136}
137
138/// Materialized context window ready to be lowered into model input items.
139#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
140pub struct ContextWindow {
141    pub chunks: Vec<ContextChunk>,
142}
143
144impl ContextWindow {
145    pub fn new(chunks: Vec<ContextChunk>) -> Self {
146        Self { chunks }
147    }
148
149    pub fn to_input_items(&self) -> AgentResult<Vec<InputItem>> {
150        self.chunks
151            .iter()
152            .filter_map(|chunk| chunk.to_input_item())
153            .collect()
154    }
155}
156
157/// Source of additional context chunks for an agent.
158#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
159#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
160pub trait ContextProvider: Send + Sync {
161    async fn provide(&self) -> AgentResult<Vec<ContextChunk>>;
162}
163
164/// Builder for [`ContextManager`].
165pub struct ContextManagerBuilder {
166    providers: Vec<Arc<dyn ContextProvider>>,
167}
168
169impl ContextManagerBuilder {
170    pub fn new() -> Self {
171        Self {
172            providers: Vec::new(),
173        }
174    }
175
176    pub fn add_provider<P>(mut self, provider: P) -> Self
177    where
178        P: ContextProvider + 'static,
179    {
180        self.providers.push(Arc::new(provider));
181        self
182    }
183
184    pub fn build(self) -> ContextManager {
185        ContextManager {
186            providers: self.providers,
187            history: Mutex::new(Vec::new()),
188            llm: Mutex::new(None),
189        }
190    }
191}
192
193impl Default for ContextManagerBuilder {
194    fn default() -> Self {
195        Self::new()
196    }
197}
198
199/// Composes static providers and conversation history into a context window.
200pub struct ContextManager {
201    providers: Vec<Arc<dyn ContextProvider>>,
202    history: Mutex<Vec<ContextChunk>>,
203    llm: Mutex<Option<Arc<LlmRunner>>>,
204}
205
206impl ContextManager {
207    pub fn builder() -> ContextManagerBuilder {
208        ContextManagerBuilder::new()
209    }
210
211    pub fn static_text(text: impl Into<String>) -> Self {
212        Self::builder()
213            .add_provider(StaticContextProvider::system_text(text))
214            .build()
215    }
216
217    pub fn new() -> Self {
218        Self::builder().build()
219    }
220
221    pub fn with_provider_arc(mut self, provider: Arc<dyn ContextProvider>) -> Self {
222        self.providers.push(provider);
223        self
224    }
225
226    pub fn attach_llm_runner(&self, llm: Arc<LlmRunner>) {
227        *self.llm.lock().expect("context llm") = Some(llm);
228    }
229
230    pub async fn push(&self, chunk: ContextChunk) -> AgentResult<()> {
231        self.history.lock().expect("context history").push(chunk);
232        Ok(())
233    }
234
235    pub async fn window(&self) -> AgentResult<ContextWindow> {
236        let mut chunks = Vec::new();
237        for provider in &self.providers {
238            chunks.extend(provider.provide().await?);
239        }
240        chunks.extend(self.history.lock().expect("context history").clone());
241        Ok(ContextWindow::new(chunks))
242    }
243
244    pub async fn history(&self) -> AgentResult<Vec<ContextChunk>> {
245        Ok(self.history.lock().expect("context history").clone())
246    }
247}
248
249impl Default for ContextManager {
250    fn default() -> Self {
251        Self::new()
252    }
253}
254
255/// Simple context provider backed by a fixed list of chunks.
256pub struct StaticContextProvider {
257    chunks: Vec<ContextChunk>,
258}
259
260impl StaticContextProvider {
261    pub fn new(chunks: Vec<ContextChunk>) -> Self {
262        Self { chunks }
263    }
264
265    pub fn system_text(text: impl Into<String>) -> Self {
266        Self::new(vec![ContextChunk::system_text(
267            ContextStrategy::Pinnable,
268            text,
269        )])
270    }
271}
272
273#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
274#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
275impl ContextProvider for StaticContextProvider {
276    async fn provide(&self) -> AgentResult<Vec<ContextChunk>> {
277        Ok(self.chunks.clone())
278    }
279}
280
281fn flatten_input_content(content: Vec<InputContent>) -> String {
282    content
283        .into_iter()
284        .filter_map(|part| match part {
285            InputContent::Text { text } => Some(text),
286            InputContent::ImageUrl { .. } => None,
287        })
288        .collect::<Vec<_>>()
289        .join("\n")
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use crate::agent::error::AgentError;
296
297    struct FailingProvider;
298
299    #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
300    #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
301    impl ContextProvider for FailingProvider {
302        async fn provide(&self) -> AgentResult<Vec<ContextChunk>> {
303            Err(AgentError::Internal {
304                message: "provider failed".to_string(),
305            })
306        }
307    }
308
309    #[test]
310    fn from_input_item_maps_message_roles_and_flattens_text_parts() {
311        let item = InputItem::Message {
312            role: Role::Assistant,
313            content: vec![
314                InputContent::Text {
315                    text: "hello".to_string(),
316                },
317                InputContent::ImageUrl {
318                    url: "https://example.com/cat.png".to_string(),
319                },
320                InputContent::Text {
321                    text: "world".to_string(),
322                },
323            ],
324        };
325
326        let chunk = ContextChunk::from_input_item(ContextStrategy::Compactable, item)
327            .expect("chunk")
328            .expect("valid chunk");
329
330        assert_eq!(
331            chunk,
332            ContextChunk::assistant_text(ContextStrategy::Compactable, "hello\nworld")
333        );
334    }
335
336    #[test]
337    fn from_input_item_parses_json_tool_results() {
338        let chunk = ContextChunk::from_input_item(
339            ContextStrategy::Compactable,
340            InputItem::tool_result("call_1", r#"{"status":"ok"}"#),
341        )
342        .expect("chunk")
343        .expect("valid chunk");
344
345        assert_eq!(
346            chunk,
347            ContextChunk::ToolResult {
348                strategy: ContextStrategy::Compactable,
349                id: "call_1".to_string(),
350                result: serde_json::json!({ "status": "ok" }),
351            }
352        );
353    }
354
355    #[test]
356    fn from_input_item_falls_back_to_string_for_non_json_tool_results() {
357        let chunk = ContextChunk::from_input_item(
358            ContextStrategy::Compactable,
359            InputItem::tool_result("call_1", "plain text error"),
360        )
361        .expect("chunk")
362        .expect("valid chunk");
363
364        assert_eq!(
365            chunk,
366            ContextChunk::ToolResult {
367                strategy: ContextStrategy::Compactable,
368                id: "call_1".to_string(),
369                result: Value::String("plain text error".to_string()),
370            }
371        );
372    }
373
374    #[test]
375    fn tool_result_chunk_round_trips_back_to_input_item() {
376        let item = ContextChunk::ToolResult {
377            strategy: ContextStrategy::Compactable,
378            id: "call_1".to_string(),
379            result: serde_json::json!({ "status": "ok" }),
380        }
381        .to_input_item()
382        .expect("tool result lowers")
383        .expect("valid item");
384
385        assert!(matches!(
386            item,
387            InputItem::ToolResult { tool_use_id, content }
388                if tool_use_id == "call_1" && content == r#"{"status":"ok"}"#
389        ));
390    }
391
392    #[tokio::test]
393    async fn static_provider_chunks_precede_history_in_window() {
394        let manager = ContextManager::builder()
395            .add_provider(StaticContextProvider::system_text("system prompt"))
396            .build();
397
398        manager
399            .push(ContextChunk::user_text(
400                ContextStrategy::Compactable,
401                "hello from user",
402            ))
403            .await
404            .expect("push");
405
406        let window = manager.window().await.expect("window");
407        assert_eq!(
408            window.chunks,
409            vec![
410                ContextChunk::system_text(ContextStrategy::Pinnable, "system prompt"),
411                ContextChunk::user_text(ContextStrategy::Compactable, "hello from user"),
412            ]
413        );
414    }
415
416    #[test]
417    fn context_window_lowers_messages_tool_calls_and_tool_results() {
418        let window = ContextWindow::new(vec![
419            ContextChunk::system_text(ContextStrategy::Pinnable, "system"),
420            ContextChunk::ToolCall {
421                strategy: ContextStrategy::Compactable,
422                id: "call_1".to_string(),
423                name: "ping".to_string(),
424                args: serde_json::json!({ "value": "hello" }),
425            },
426            ContextChunk::ToolResult {
427                strategy: ContextStrategy::Compactable,
428                id: "call_1".to_string(),
429                result: serde_json::json!({ "status": "ok" }),
430            },
431        ]);
432
433        let items = window.to_input_items().expect("input items");
434        assert_eq!(items.len(), 3);
435        assert!(matches!(
436            &items[0],
437            InputItem::Message {
438                role: Role::System,
439                ..
440            }
441        ));
442        assert!(matches!(
443            &items[1],
444            InputItem::ToolCall { call }
445                if call.id == "call_1"
446                    && call.name == "ping"
447                    && call.arguments == serde_json::json!({ "value": "hello" })
448        ));
449        assert!(matches!(
450            &items[2],
451            InputItem::ToolResult { tool_use_id, .. } if tool_use_id == "call_1"
452        ));
453    }
454
455    #[tokio::test]
456    async fn multiple_providers_preserve_builder_order_before_history() {
457        let manager = ContextManager::builder()
458            .add_provider(StaticContextProvider::new(vec![ContextChunk::system_text(
459                ContextStrategy::Pinnable,
460                "system one",
461            )]))
462            .add_provider(StaticContextProvider::new(vec![ContextChunk::system_text(
463                ContextStrategy::Pinnable,
464                "system two",
465            )]))
466            .build();
467
468        manager
469            .push(ContextChunk::user_text(
470                ContextStrategy::Compactable,
471                "hello from user",
472            ))
473            .await
474            .expect("push");
475
476        let window = manager.window().await.expect("window");
477        assert_eq!(
478            window.chunks,
479            vec![
480                ContextChunk::system_text(ContextStrategy::Pinnable, "system one"),
481                ContextChunk::system_text(ContextStrategy::Pinnable, "system two"),
482                ContextChunk::user_text(ContextStrategy::Compactable, "hello from user"),
483            ]
484        );
485    }
486
487    #[tokio::test]
488    async fn push_preserves_history_order_and_window_is_non_destructive() {
489        let manager = ContextManager::new();
490        let first = ContextChunk::user_text(ContextStrategy::Compactable, "first");
491        let second = ContextChunk::assistant_text(ContextStrategy::Compactable, "second");
492
493        manager.push(first.clone()).await.expect("push first");
494        manager.push(second.clone()).await.expect("push second");
495
496        let history = manager.history().await.expect("history");
497        assert_eq!(history, vec![first.clone(), second.clone()]);
498
499        let window = manager.window().await.expect("window");
500        assert_eq!(window.chunks, vec![first.clone(), second.clone()]);
501
502        let history_again = manager.history().await.expect("history again");
503        assert_eq!(history_again, vec![first, second]);
504    }
505
506    #[tokio::test]
507    async fn static_text_builds_a_pinnable_system_message() {
508        let manager = ContextManager::static_text("hello system");
509        let window = manager.window().await.expect("window");
510
511        assert_eq!(
512            window.chunks,
513            vec![ContextChunk::system_text(
514                ContextStrategy::Pinnable,
515                "hello system",
516            )]
517        );
518    }
519
520    #[tokio::test]
521    async fn history_returns_only_session_history_not_provider_chunks() {
522        let manager = ContextManager::builder()
523            .add_provider(StaticContextProvider::system_text("system prompt"))
524            .build();
525
526        manager
527            .push(ContextChunk::user_text(
528                ContextStrategy::Compactable,
529                "hello from user",
530            ))
531            .await
532            .expect("push");
533
534        let history = manager.history().await.expect("history");
535        assert_eq!(
536            history,
537            vec![ContextChunk::user_text(
538                ContextStrategy::Compactable,
539                "hello from user",
540            )]
541        );
542    }
543
544    #[tokio::test]
545    async fn failing_provider_errors_window() {
546        let manager = ContextManager::builder()
547            .add_provider(FailingProvider)
548            .build();
549
550        let error = manager.window().await.expect_err("provider should fail");
551        assert!(matches!(error, AgentError::Internal { message } if message == "provider failed"));
552    }
553
554    #[tokio::test]
555    async fn tool_calls_are_preserved_in_history_and_lowered_into_window() {
556        let manager = ContextManager::new();
557
558        manager
559            .push(ContextChunk::ToolCall {
560                strategy: ContextStrategy::Compactable,
561                id: "call_1".to_string(),
562                name: "ping".to_string(),
563                args: serde_json::json!({ "value": "hello" }),
564            })
565            .await
566            .expect("push");
567
568        let history = manager.history().await.expect("history");
569        assert_eq!(history.len(), 1);
570        assert!(matches!(history[0], ContextChunk::ToolCall { .. }));
571
572        let input_items = manager
573            .window()
574            .await
575            .expect("window")
576            .to_input_items()
577            .expect("items");
578        assert_eq!(input_items.len(), 1);
579        assert!(matches!(
580            &input_items[0],
581            InputItem::ToolCall { call }
582                if call.id == "call_1"
583                    && call.name == "ping"
584                    && call.arguments == serde_json::json!({ "value": "hello" })
585        ));
586    }
587}