1use std::sync::{Arc, Mutex};
2
3use crate::llm::LlmRunner;
4use crate::llm::completion::{InputContent, InputItem, Role};
5use async_trait::async_trait;
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::agent::error::AgentResult;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
14pub enum ContextStrategy {
15 Pinnable,
16 Compactable,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
21pub enum ContextRole {
22 System,
23 User,
24 Assistant,
25}
26
27#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
29pub enum ContextChunk {
30 Message {
31 strategy: ContextStrategy,
32 role: ContextRole,
33 content: String,
34 },
35 ToolCall {
36 strategy: ContextStrategy,
37 id: String,
38 name: String,
39 args: Value,
40 },
41 ToolResult {
42 strategy: ContextStrategy,
43 id: String,
44 result: Value,
45 },
46}
47
48impl ContextChunk {
49 pub fn system_text(strategy: ContextStrategy, content: impl Into<String>) -> Self {
50 Self::Message {
51 strategy,
52 role: ContextRole::System,
53 content: content.into(),
54 }
55 }
56
57 pub fn user_text(strategy: ContextStrategy, content: impl Into<String>) -> Self {
58 Self::Message {
59 strategy,
60 role: ContextRole::User,
61 content: content.into(),
62 }
63 }
64
65 pub fn assistant_text(strategy: ContextStrategy, content: impl Into<String>) -> Self {
66 Self::Message {
67 strategy,
68 role: ContextRole::Assistant,
69 content: content.into(),
70 }
71 }
72
73 pub fn from_input_item(
74 strategy: ContextStrategy,
75 item: InputItem,
76 ) -> Option<AgentResult<Self>> {
77 match item {
78 InputItem::Message { role, content } => {
79 let text = flatten_input_content(content);
80 let role = match role {
81 Role::System => ContextRole::System,
82 Role::User => ContextRole::User,
83 Role::Assistant => ContextRole::Assistant,
84 };
85 Some(Ok(Self::Message {
86 strategy,
87 role,
88 content: text,
89 }))
90 }
91 InputItem::ToolCall { call } => Some(Ok(Self::ToolCall {
92 strategy,
93 id: call.id,
94 name: call.name,
95 args: call.arguments,
96 })),
97 InputItem::ToolResult {
98 tool_use_id,
99 content,
100 } => Some(match serde_json::from_str::<Value>(&content) {
101 Ok(result) => Ok(Self::ToolResult {
102 strategy,
103 id: tool_use_id,
104 result,
105 }),
106 Err(_) => Ok(Self::ToolResult {
107 strategy,
108 id: tool_use_id,
109 result: Value::String(content),
110 }),
111 }),
112 }
113 }
114
115 pub fn to_input_item(&self) -> Option<AgentResult<InputItem>> {
116 match self {
117 ContextChunk::Message { role, content, .. } => Some(Ok(match role {
118 ContextRole::System => InputItem::system_text(content.clone()),
119 ContextRole::User => InputItem::user_text(content.clone()),
120 ContextRole::Assistant => InputItem::assistant_text(content.clone()),
121 })),
122 ContextChunk::ToolCall { id, name, args, .. } => Some(Ok(InputItem::tool_call(
123 id.clone(),
124 name.clone(),
125 args.clone(),
126 ))),
127 ContextChunk::ToolResult { id, result, .. } => Some(
128 serde_json::to_string(result)
129 .map(|content| InputItem::tool_result(id.clone(), content))
130 .map_err(|error| crate::agent::error::AgentError::Internal {
131 message: error.to_string(),
132 }),
133 ),
134 }
135 }
136}
137
138#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
140pub struct ContextWindow {
141 pub chunks: Vec<ContextChunk>,
142}
143
144impl ContextWindow {
145 pub fn new(chunks: Vec<ContextChunk>) -> Self {
146 Self { chunks }
147 }
148
149 pub fn to_input_items(&self) -> AgentResult<Vec<InputItem>> {
150 self.chunks
151 .iter()
152 .filter_map(|chunk| chunk.to_input_item())
153 .collect()
154 }
155}
156
157#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
159#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
160pub trait ContextProvider: Send + Sync {
161 async fn provide(&self) -> AgentResult<Vec<ContextChunk>>;
162}
163
164pub struct ContextManagerBuilder {
166 providers: Vec<Arc<dyn ContextProvider>>,
167}
168
169impl ContextManagerBuilder {
170 pub fn new() -> Self {
171 Self {
172 providers: Vec::new(),
173 }
174 }
175
176 pub fn add_provider<P>(mut self, provider: P) -> Self
177 where
178 P: ContextProvider + 'static,
179 {
180 self.providers.push(Arc::new(provider));
181 self
182 }
183
184 pub fn build(self) -> ContextManager {
185 ContextManager {
186 providers: self.providers,
187 history: Mutex::new(Vec::new()),
188 llm: Mutex::new(None),
189 }
190 }
191}
192
193impl Default for ContextManagerBuilder {
194 fn default() -> Self {
195 Self::new()
196 }
197}
198
199pub struct ContextManager {
201 providers: Vec<Arc<dyn ContextProvider>>,
202 history: Mutex<Vec<ContextChunk>>,
203 llm: Mutex<Option<Arc<LlmRunner>>>,
204}
205
206impl ContextManager {
207 pub fn builder() -> ContextManagerBuilder {
208 ContextManagerBuilder::new()
209 }
210
211 pub fn static_text(text: impl Into<String>) -> Self {
212 Self::builder()
213 .add_provider(StaticContextProvider::system_text(text))
214 .build()
215 }
216
217 pub fn new() -> Self {
218 Self::builder().build()
219 }
220
221 pub fn with_provider_arc(mut self, provider: Arc<dyn ContextProvider>) -> Self {
222 self.providers.push(provider);
223 self
224 }
225
226 pub fn attach_llm_runner(&self, llm: Arc<LlmRunner>) {
227 *self.llm.lock().expect("context llm") = Some(llm);
228 }
229
230 pub async fn push(&self, chunk: ContextChunk) -> AgentResult<()> {
231 self.history.lock().expect("context history").push(chunk);
232 Ok(())
233 }
234
235 pub async fn window(&self) -> AgentResult<ContextWindow> {
236 let mut chunks = Vec::new();
237 for provider in &self.providers {
238 chunks.extend(provider.provide().await?);
239 }
240 chunks.extend(self.history.lock().expect("context history").clone());
241 Ok(ContextWindow::new(chunks))
242 }
243
244 pub async fn history(&self) -> AgentResult<Vec<ContextChunk>> {
245 Ok(self.history.lock().expect("context history").clone())
246 }
247}
248
249impl Default for ContextManager {
250 fn default() -> Self {
251 Self::new()
252 }
253}
254
255pub struct StaticContextProvider {
257 chunks: Vec<ContextChunk>,
258}
259
260impl StaticContextProvider {
261 pub fn new(chunks: Vec<ContextChunk>) -> Self {
262 Self { chunks }
263 }
264
265 pub fn system_text(text: impl Into<String>) -> Self {
266 Self::new(vec![ContextChunk::system_text(
267 ContextStrategy::Pinnable,
268 text,
269 )])
270 }
271}
272
273#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
274#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
275impl ContextProvider for StaticContextProvider {
276 async fn provide(&self) -> AgentResult<Vec<ContextChunk>> {
277 Ok(self.chunks.clone())
278 }
279}
280
281fn flatten_input_content(content: Vec<InputContent>) -> String {
282 content
283 .into_iter()
284 .filter_map(|part| match part {
285 InputContent::Text { text } => Some(text),
286 InputContent::ImageUrl { .. } => None,
287 })
288 .collect::<Vec<_>>()
289 .join("\n")
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use crate::agent::error::AgentError;
296
297 struct FailingProvider;
298
299 #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
300 #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
301 impl ContextProvider for FailingProvider {
302 async fn provide(&self) -> AgentResult<Vec<ContextChunk>> {
303 Err(AgentError::Internal {
304 message: "provider failed".to_string(),
305 })
306 }
307 }
308
309 #[test]
310 fn from_input_item_maps_message_roles_and_flattens_text_parts() {
311 let item = InputItem::Message {
312 role: Role::Assistant,
313 content: vec![
314 InputContent::Text {
315 text: "hello".to_string(),
316 },
317 InputContent::ImageUrl {
318 url: "https://example.com/cat.png".to_string(),
319 },
320 InputContent::Text {
321 text: "world".to_string(),
322 },
323 ],
324 };
325
326 let chunk = ContextChunk::from_input_item(ContextStrategy::Compactable, item)
327 .expect("chunk")
328 .expect("valid chunk");
329
330 assert_eq!(
331 chunk,
332 ContextChunk::assistant_text(ContextStrategy::Compactable, "hello\nworld")
333 );
334 }
335
336 #[test]
337 fn from_input_item_parses_json_tool_results() {
338 let chunk = ContextChunk::from_input_item(
339 ContextStrategy::Compactable,
340 InputItem::tool_result("call_1", r#"{"status":"ok"}"#),
341 )
342 .expect("chunk")
343 .expect("valid chunk");
344
345 assert_eq!(
346 chunk,
347 ContextChunk::ToolResult {
348 strategy: ContextStrategy::Compactable,
349 id: "call_1".to_string(),
350 result: serde_json::json!({ "status": "ok" }),
351 }
352 );
353 }
354
355 #[test]
356 fn from_input_item_falls_back_to_string_for_non_json_tool_results() {
357 let chunk = ContextChunk::from_input_item(
358 ContextStrategy::Compactable,
359 InputItem::tool_result("call_1", "plain text error"),
360 )
361 .expect("chunk")
362 .expect("valid chunk");
363
364 assert_eq!(
365 chunk,
366 ContextChunk::ToolResult {
367 strategy: ContextStrategy::Compactable,
368 id: "call_1".to_string(),
369 result: Value::String("plain text error".to_string()),
370 }
371 );
372 }
373
374 #[test]
375 fn tool_result_chunk_round_trips_back_to_input_item() {
376 let item = ContextChunk::ToolResult {
377 strategy: ContextStrategy::Compactable,
378 id: "call_1".to_string(),
379 result: serde_json::json!({ "status": "ok" }),
380 }
381 .to_input_item()
382 .expect("tool result lowers")
383 .expect("valid item");
384
385 assert!(matches!(
386 item,
387 InputItem::ToolResult { tool_use_id, content }
388 if tool_use_id == "call_1" && content == r#"{"status":"ok"}"#
389 ));
390 }
391
392 #[tokio::test]
393 async fn static_provider_chunks_precede_history_in_window() {
394 let manager = ContextManager::builder()
395 .add_provider(StaticContextProvider::system_text("system prompt"))
396 .build();
397
398 manager
399 .push(ContextChunk::user_text(
400 ContextStrategy::Compactable,
401 "hello from user",
402 ))
403 .await
404 .expect("push");
405
406 let window = manager.window().await.expect("window");
407 assert_eq!(
408 window.chunks,
409 vec![
410 ContextChunk::system_text(ContextStrategy::Pinnable, "system prompt"),
411 ContextChunk::user_text(ContextStrategy::Compactable, "hello from user"),
412 ]
413 );
414 }
415
416 #[test]
417 fn context_window_lowers_messages_tool_calls_and_tool_results() {
418 let window = ContextWindow::new(vec![
419 ContextChunk::system_text(ContextStrategy::Pinnable, "system"),
420 ContextChunk::ToolCall {
421 strategy: ContextStrategy::Compactable,
422 id: "call_1".to_string(),
423 name: "ping".to_string(),
424 args: serde_json::json!({ "value": "hello" }),
425 },
426 ContextChunk::ToolResult {
427 strategy: ContextStrategy::Compactable,
428 id: "call_1".to_string(),
429 result: serde_json::json!({ "status": "ok" }),
430 },
431 ]);
432
433 let items = window.to_input_items().expect("input items");
434 assert_eq!(items.len(), 3);
435 assert!(matches!(
436 &items[0],
437 InputItem::Message {
438 role: Role::System,
439 ..
440 }
441 ));
442 assert!(matches!(
443 &items[1],
444 InputItem::ToolCall { call }
445 if call.id == "call_1"
446 && call.name == "ping"
447 && call.arguments == serde_json::json!({ "value": "hello" })
448 ));
449 assert!(matches!(
450 &items[2],
451 InputItem::ToolResult { tool_use_id, .. } if tool_use_id == "call_1"
452 ));
453 }
454
455 #[tokio::test]
456 async fn multiple_providers_preserve_builder_order_before_history() {
457 let manager = ContextManager::builder()
458 .add_provider(StaticContextProvider::new(vec![ContextChunk::system_text(
459 ContextStrategy::Pinnable,
460 "system one",
461 )]))
462 .add_provider(StaticContextProvider::new(vec![ContextChunk::system_text(
463 ContextStrategy::Pinnable,
464 "system two",
465 )]))
466 .build();
467
468 manager
469 .push(ContextChunk::user_text(
470 ContextStrategy::Compactable,
471 "hello from user",
472 ))
473 .await
474 .expect("push");
475
476 let window = manager.window().await.expect("window");
477 assert_eq!(
478 window.chunks,
479 vec![
480 ContextChunk::system_text(ContextStrategy::Pinnable, "system one"),
481 ContextChunk::system_text(ContextStrategy::Pinnable, "system two"),
482 ContextChunk::user_text(ContextStrategy::Compactable, "hello from user"),
483 ]
484 );
485 }
486
487 #[tokio::test]
488 async fn push_preserves_history_order_and_window_is_non_destructive() {
489 let manager = ContextManager::new();
490 let first = ContextChunk::user_text(ContextStrategy::Compactable, "first");
491 let second = ContextChunk::assistant_text(ContextStrategy::Compactable, "second");
492
493 manager.push(first.clone()).await.expect("push first");
494 manager.push(second.clone()).await.expect("push second");
495
496 let history = manager.history().await.expect("history");
497 assert_eq!(history, vec![first.clone(), second.clone()]);
498
499 let window = manager.window().await.expect("window");
500 assert_eq!(window.chunks, vec![first.clone(), second.clone()]);
501
502 let history_again = manager.history().await.expect("history again");
503 assert_eq!(history_again, vec![first, second]);
504 }
505
506 #[tokio::test]
507 async fn static_text_builds_a_pinnable_system_message() {
508 let manager = ContextManager::static_text("hello system");
509 let window = manager.window().await.expect("window");
510
511 assert_eq!(
512 window.chunks,
513 vec![ContextChunk::system_text(
514 ContextStrategy::Pinnable,
515 "hello system",
516 )]
517 );
518 }
519
520 #[tokio::test]
521 async fn history_returns_only_session_history_not_provider_chunks() {
522 let manager = ContextManager::builder()
523 .add_provider(StaticContextProvider::system_text("system prompt"))
524 .build();
525
526 manager
527 .push(ContextChunk::user_text(
528 ContextStrategy::Compactable,
529 "hello from user",
530 ))
531 .await
532 .expect("push");
533
534 let history = manager.history().await.expect("history");
535 assert_eq!(
536 history,
537 vec![ContextChunk::user_text(
538 ContextStrategy::Compactable,
539 "hello from user",
540 )]
541 );
542 }
543
544 #[tokio::test]
545 async fn failing_provider_errors_window() {
546 let manager = ContextManager::builder()
547 .add_provider(FailingProvider)
548 .build();
549
550 let error = manager.window().await.expect_err("provider should fail");
551 assert!(matches!(error, AgentError::Internal { message } if message == "provider failed"));
552 }
553
554 #[tokio::test]
555 async fn tool_calls_are_preserved_in_history_and_lowered_into_window() {
556 let manager = ContextManager::new();
557
558 manager
559 .push(ContextChunk::ToolCall {
560 strategy: ContextStrategy::Compactable,
561 id: "call_1".to_string(),
562 name: "ping".to_string(),
563 args: serde_json::json!({ "value": "hello" }),
564 })
565 .await
566 .expect("push");
567
568 let history = manager.history().await.expect("history");
569 assert_eq!(history.len(), 1);
570 assert!(matches!(history[0], ContextChunk::ToolCall { .. }));
571
572 let input_items = manager
573 .window()
574 .await
575 .expect("window")
576 .to_input_items()
577 .expect("items");
578 assert_eq!(input_items.len(), 1);
579 assert!(matches!(
580 &input_items[0],
581 InputItem::ToolCall { call }
582 if call.id == "call_1"
583 && call.name == "ping"
584 && call.arguments == serde_json::json!({ "value": "hello" })
585 ));
586 }
587}