1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
8pub enum CompactionStrategy {
9 #[default]
11 Drop,
12 Summarize,
17}
18
19pub struct ContextNeeds {
24 pub recall: bool,
26 pub pending_tasks: bool,
28 pub profile: bool,
30 pub summaries: bool,
32 pub outcomes: bool,
34 pub compact: CompactionStrategy,
36}
37
38impl Default for ContextNeeds {
39 fn default() -> Self {
40 Self {
41 recall: true,
42 pending_tasks: true,
43 profile: true,
44 summaries: true,
45 outcomes: true,
46 compact: CompactionStrategy::default(),
47 }
48 }
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct ContextEntry {
54 pub role: String,
56 pub content: String,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize, Default)]
62pub struct McpServer {
63 pub name: String,
65 pub command: String,
67 pub args: Vec<String>,
69 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
71 pub env: HashMap<String, String>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct Toolbox {
80 pub name: String,
82 pub description: String,
84 #[serde(default = "default_object_schema")]
86 pub parameters: serde_json::Value,
87 pub command: String,
89 #[serde(default)]
91 pub args: Vec<String>,
92 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
94 pub env: HashMap<String, String>,
95 #[serde(default, skip_serializing_if = "Vec::is_empty")]
97 pub search_hints: Vec<String>,
98}
99
100fn default_object_schema() -> serde_json::Value {
101 serde_json::json!({"type": "object"})
102}
103
104fn is_false(b: &bool) -> bool {
105 !b
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct Context {
111 pub system_prompt: String,
113 pub history: Vec<ContextEntry>,
115 pub current_message: String,
117 #[serde(default)]
119 pub mcp_servers: Vec<McpServer>,
120 #[serde(default, skip_serializing_if = "Vec::is_empty")]
122 pub toolboxes: Vec<Toolbox>,
123 #[serde(default, skip_serializing_if = "Option::is_none")]
125 pub max_turns: Option<u32>,
126 #[serde(default, skip_serializing_if = "Option::is_none")]
128 pub allowed_tools: Option<Vec<String>>,
129 #[serde(default, skip_serializing_if = "Option::is_none")]
131 pub model: Option<String>,
132 #[serde(default, skip_serializing_if = "Option::is_none")]
134 pub session_id: Option<String>,
135 #[serde(default, skip_serializing_if = "Option::is_none")]
138 pub agent_name: Option<String>,
139 #[serde(skip)]
141 pub hook_runner: Option<std::sync::Arc<dyn crate::hooks::HookRunner>>,
142 #[serde(skip)]
145 pub permission_rules: Option<std::sync::Arc<crate::permissions::PermissionRules>>,
146 #[serde(default, skip_serializing_if = "is_false")]
149 pub extended_thinking: bool,
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct ApiMessage {
155 pub role: String,
157 pub content: String,
159}
160
161impl Context {
162 pub fn new(message: &str) -> Self {
164 Self {
165 system_prompt: String::new(),
166 history: Vec::new(),
167 current_message: message.to_string(),
168 mcp_servers: Vec::new(),
169 toolboxes: Vec::new(),
170 max_turns: None,
171 allowed_tools: None,
172 model: None,
173 session_id: None,
174 agent_name: None,
175 hook_runner: None,
176 permission_rules: None,
177 extended_thinking: false,
178 }
179 }
180
181 pub fn with_hooks(mut self, runner: std::sync::Arc<dyn crate::hooks::HookRunner>) -> Self {
183 self.hook_runner = Some(runner);
184 self
185 }
186
187 pub fn to_prompt_string(&self) -> String {
193 if self.agent_name.is_some() {
194 return self.current_message.clone();
195 }
196
197 let mut parts = Vec::new();
198
199 if self.session_id.is_none() {
200 if !self.system_prompt.is_empty() {
201 parts.push(format!("[System]\n{}", self.system_prompt));
202 }
203 for entry in &self.history {
204 let role = if entry.role == "user" {
205 "User"
206 } else {
207 "Assistant"
208 };
209 parts.push(format!("[{}]\n{}", role, entry.content));
210 }
211 parts.push(format!("[User]\n{}", self.current_message));
212 } else {
213 if !self.system_prompt.is_empty() {
214 parts.push(format!(
215 "[User]\n{}\n\n{}",
216 self.system_prompt, self.current_message
217 ));
218 } else {
219 parts.push(format!("[User]\n{}", self.current_message));
220 }
221 }
222
223 parts.join("\n\n")
224 }
225
226 pub fn to_api_messages(&self) -> (String, Vec<ApiMessage>) {
231 let mut messages = Vec::with_capacity(self.history.len() + 1);
232
233 for entry in &self.history {
234 messages.push(ApiMessage {
235 role: entry.role.clone(),
236 content: entry.content.clone(),
237 });
238 }
239
240 messages.push(ApiMessage {
241 role: "user".to_string(),
242 content: self.current_message.clone(),
243 });
244
245 (self.system_prompt.clone(), messages)
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_context_new_defaults() {
255 let ctx = Context::new("hello");
256 assert!(ctx.system_prompt.is_empty());
257 assert!(ctx.history.is_empty());
258 assert!(ctx.mcp_servers.is_empty());
259 assert!(ctx.toolboxes.is_empty());
260 assert_eq!(ctx.current_message, "hello");
261 assert!(ctx.session_id.is_none());
262 assert!(ctx.agent_name.is_none());
263 }
264
265 #[test]
266 fn test_mcp_server_serde_round_trip() {
267 let server = McpServer {
268 name: "playwright".into(),
269 command: "npx".into(),
270 args: vec!["@playwright/mcp".into(), "--headless".into()],
271 env: HashMap::new(),
272 };
273 let json = serde_json::to_string(&server).unwrap();
274 let deserialized: McpServer = serde_json::from_str(&json).unwrap();
275 assert_eq!(deserialized.name, "playwright");
276 assert_eq!(deserialized.args, vec!["@playwright/mcp", "--headless"]);
277 }
278
279 #[test]
280 fn test_context_serde_without_optional_fields() {
281 let json = r#"{"system_prompt":"test","history":[],"current_message":"hi"}"#;
282 let ctx: Context = serde_json::from_str(json).unwrap();
283 assert!(ctx.mcp_servers.is_empty());
284 assert!(ctx.session_id.is_none());
285 assert!(ctx.agent_name.is_none());
286 }
287
288 #[test]
289 fn test_to_api_messages_basic() {
290 let ctx = Context::new("hello");
291 let (system, messages) = ctx.to_api_messages();
292 assert!(system.is_empty());
293 assert_eq!(messages.len(), 1);
294 assert_eq!(messages[0].role, "user");
295 assert_eq!(messages[0].content, "hello");
296 }
297
298 #[test]
299 fn test_to_api_messages_with_history() {
300 let ctx = Context {
301 system_prompt: "Be helpful.".into(),
302 history: vec![
303 ContextEntry {
304 role: "user".into(),
305 content: "Hi".into(),
306 },
307 ContextEntry {
308 role: "assistant".into(),
309 content: "Hello!".into(),
310 },
311 ],
312 current_message: "How are you?".into(),
313 mcp_servers: Vec::new(),
314 toolboxes: Vec::new(),
315 max_turns: None,
316 allowed_tools: None,
317 model: None,
318 session_id: None,
319 agent_name: None,
320 hook_runner: None,
321 permission_rules: None,
322 extended_thinking: false,
323 };
324 let (system, messages) = ctx.to_api_messages();
325 assert_eq!(system, "Be helpful.");
326 assert_eq!(messages.len(), 3);
327 }
328
329 #[test]
330 fn test_to_prompt_string_no_session() {
331 let ctx = Context {
332 system_prompt: "Be helpful.".into(),
333 history: vec![ContextEntry {
334 role: "user".into(),
335 content: "Hi".into(),
336 }],
337 current_message: "How are you?".into(),
338 mcp_servers: Vec::new(),
339 toolboxes: Vec::new(),
340 max_turns: None,
341 allowed_tools: None,
342 model: None,
343 session_id: None,
344 agent_name: None,
345 hook_runner: None,
346 permission_rules: None,
347 extended_thinking: false,
348 };
349 let prompt = ctx.to_prompt_string();
350 assert!(prompt.contains("[System]\nBe helpful."));
351 assert!(prompt.contains("[User]\nHi"));
352 assert!(prompt.contains("[User]\nHow are you?"));
353 }
354
355 #[test]
356 fn test_to_prompt_string_with_session() {
357 let ctx = Context {
358 system_prompt: "Current time: 2026-03-06".into(),
359 history: vec![ContextEntry {
360 role: "user".into(),
361 content: "Hi".into(),
362 }],
363 current_message: "How are you?".into(),
364 mcp_servers: Vec::new(),
365 toolboxes: Vec::new(),
366 max_turns: None,
367 allowed_tools: None,
368 model: None,
369 session_id: Some("sess-abc".into()),
370 agent_name: None,
371 hook_runner: None,
372 permission_rules: None,
373 extended_thinking: false,
374 };
375 let prompt = ctx.to_prompt_string();
376 assert!(!prompt.contains("[System]"));
377 assert!(prompt.contains("[User]\nCurrent time: 2026-03-06\n\nHow are you?"));
378 }
379
380 #[test]
381 fn test_to_prompt_string_with_agent_name() {
382 let ctx = Context {
383 system_prompt: "You are a build analyst...".into(),
384 history: vec![ContextEntry {
385 role: "user".into(),
386 content: "prev".into(),
387 }],
388 current_message: "Build me a task tracker.".into(),
389 mcp_servers: Vec::new(),
390 toolboxes: Vec::new(),
391 max_turns: None,
392 allowed_tools: None,
393 model: None,
394 session_id: None,
395 agent_name: Some("build-analyst".into()),
396 hook_runner: None,
397 permission_rules: None,
398 extended_thinking: false,
399 };
400 let prompt = ctx.to_prompt_string();
401 assert_eq!(prompt, "Build me a task tracker.");
402 }
403
404 #[test]
405 fn test_agent_name_takes_precedence_over_session_id() {
406 let ctx = Context {
407 system_prompt: "system".into(),
408 history: Vec::new(),
409 current_message: "Build something.".into(),
410 mcp_servers: Vec::new(),
411 toolboxes: Vec::new(),
412 max_turns: None,
413 allowed_tools: None,
414 model: None,
415 session_id: Some("sess-456".into()),
416 agent_name: Some("build-architect".into()),
417 hook_runner: None,
418 permission_rules: None,
419 extended_thinking: false,
420 };
421 assert_eq!(ctx.to_prompt_string(), "Build something.");
422 }
423
424 #[test]
425 fn test_session_id_serde_round_trip() {
426 let ctx = Context {
427 system_prompt: "test".into(),
428 history: Vec::new(),
429 current_message: "hi".into(),
430 mcp_servers: Vec::new(),
431 toolboxes: Vec::new(),
432 max_turns: None,
433 allowed_tools: None,
434 model: None,
435 session_id: Some("sess-123".into()),
436 agent_name: None,
437 hook_runner: None,
438 permission_rules: None,
439 extended_thinking: false,
440 };
441 let json = serde_json::to_string(&ctx).unwrap();
442 let deserialized: Context = serde_json::from_str(&json).unwrap();
443 assert_eq!(deserialized.session_id, Some("sess-123".into()));
444 }
445
446 #[test]
447 fn test_optional_fields_skipped_in_serialization() {
448 let ctx = Context::new("hello");
449 let json = serde_json::to_string(&ctx).unwrap();
450 assert!(!json.contains("session_id"));
451 assert!(!json.contains("agent_name"));
452 assert!(!json.contains("max_turns"));
453 assert!(!json.contains("toolboxes"));
454 }
455
456 #[test]
457 fn test_toolbox_serde_round_trip() {
458 let tb = Toolbox {
459 name: "lint".into(),
460 description: "Run linter on a file.".into(),
461 parameters: serde_json::json!({
462 "type": "object",
463 "properties": {"file": {"type": "string"}},
464 "required": ["file"]
465 }),
466 command: "bash".into(),
467 args: vec!["scripts/lint.sh".into()],
468 env: HashMap::new(),
469 search_hints: Vec::new(),
470 };
471 let json = serde_json::to_string(&tb).unwrap();
472 let deserialized: Toolbox = serde_json::from_str(&json).unwrap();
473 assert_eq!(deserialized.name, "lint");
474 assert_eq!(deserialized.command, "bash");
475 assert_eq!(deserialized.args, vec!["scripts/lint.sh"]);
476 }
477
478 #[test]
479 fn test_toolbox_default_parameters() {
480 let json = r#"{"name":"test","description":"Test tool.","command":"echo"}"#;
481 let tb: Toolbox = serde_json::from_str(json).unwrap();
482 assert_eq!(tb.parameters, serde_json::json!({"type": "object"}));
483 assert!(tb.args.is_empty());
484 assert!(tb.env.is_empty());
485 }
486
487 #[test]
488 fn test_context_serde_with_toolboxes() {
489 let mut ctx = Context::new("run lint");
490 ctx.toolboxes.push(Toolbox {
491 name: "lint".into(),
492 description: "Lint a file.".into(),
493 parameters: serde_json::json!({"type": "object"}),
494 command: "bash".into(),
495 args: vec!["lint.sh".into()],
496 env: HashMap::new(),
497 search_hints: Vec::new(),
498 });
499 let json = serde_json::to_string(&ctx).unwrap();
500 assert!(json.contains("toolboxes"));
501 let deserialized: Context = serde_json::from_str(&json).unwrap();
502 assert_eq!(deserialized.toolboxes.len(), 1);
503 assert_eq!(deserialized.toolboxes[0].name, "lint");
504 }
505}