1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
8pub enum CompactionStrategy {
9 #[default]
11 Drop,
12 Summarize,
17}
18
19#[derive(Debug, Clone)]
24pub struct ContextNeeds {
25 pub recall: bool,
27 pub pending_tasks: bool,
29 pub profile: bool,
31 pub summaries: bool,
33 pub outcomes: bool,
35 pub compact: CompactionStrategy,
37}
38
39impl Default for ContextNeeds {
40 fn default() -> Self {
41 Self {
42 recall: true,
43 pending_tasks: true,
44 profile: true,
45 summaries: true,
46 outcomes: true,
47 compact: CompactionStrategy::default(),
48 }
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct ContextEntry {
55 pub role: String,
57 pub content: String,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize, Default)]
63pub struct McpServer {
64 pub name: String,
66 pub command: String,
68 pub args: Vec<String>,
70 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
72 pub env: HashMap<String, String>,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct Toolbox {
81 pub name: String,
83 pub description: String,
85 #[serde(default = "default_object_schema")]
87 pub parameters: serde_json::Value,
88 pub command: String,
90 #[serde(default)]
92 pub args: Vec<String>,
93 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
95 pub env: HashMap<String, String>,
96 #[serde(default)]
101 pub network: bool,
102 #[serde(default, skip_serializing_if = "Vec::is_empty")]
107 pub env_passthrough: Vec<String>,
108 #[serde(default, skip_serializing_if = "Vec::is_empty")]
113 pub allowed_commands: Vec<String>,
114 #[serde(default, skip_serializing_if = "Vec::is_empty")]
116 pub search_hints: Vec<String>,
117}
118
119pub fn command_matches_allowlist(allowed: &[String], command: &str) -> bool {
130 if allowed.is_empty() {
131 return true;
132 }
133 let basename = command.rsplit('/').next().unwrap_or(command);
134 allowed.iter().any(|entry| {
135 if entry.contains('/') {
136 entry == command
137 } else {
138 entry == basename
139 }
140 })
141}
142
143fn default_object_schema() -> serde_json::Value {
144 serde_json::json!({"type": "object"})
145}
146
147fn is_false(b: &bool) -> bool {
148 !b
149}
150
151#[derive(Clone, Serialize, Deserialize)]
153pub struct Context {
154 pub system_prompt: String,
156 pub history: Vec<ContextEntry>,
158 pub current_message: String,
160 #[serde(default)]
162 pub mcp_servers: Vec<McpServer>,
163 #[serde(default, skip_serializing_if = "Vec::is_empty")]
165 pub toolboxes: Vec<Toolbox>,
166 #[serde(default, skip_serializing_if = "Option::is_none")]
168 pub max_turns: Option<u32>,
169 #[serde(default, skip_serializing_if = "Option::is_none")]
171 pub allowed_tools: Option<Vec<String>>,
172 #[serde(default, skip_serializing_if = "Option::is_none")]
174 pub model: Option<String>,
175 #[serde(default, skip_serializing_if = "Option::is_none")]
177 pub session_id: Option<String>,
178 #[serde(default, skip_serializing_if = "Option::is_none")]
181 pub agent_name: Option<String>,
182 #[serde(skip)]
184 pub hook_runner: Option<std::sync::Arc<dyn crate::hooks::HookRunner>>,
185 #[serde(skip)]
188 pub permission_rules: Option<std::sync::Arc<crate::permissions::PermissionRules>>,
189 #[serde(default, skip_serializing_if = "is_false")]
195 pub extended_thinking: bool,
196}
197
198impl std::fmt::Debug for Context {
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 f.debug_struct("Context")
205 .field("system_prompt", &self.system_prompt)
206 .field("history", &self.history)
207 .field("current_message", &self.current_message)
208 .field("mcp_servers", &self.mcp_servers)
209 .field("toolboxes", &self.toolboxes)
210 .field("max_turns", &self.max_turns)
211 .field("allowed_tools", &self.allowed_tools)
212 .field("model", &self.model)
213 .field("session_id", &self.session_id)
214 .field("agent_name", &self.agent_name)
215 .field(
216 "hook_runner",
217 &self.hook_runner.as_ref().map(|_| "<runner>"),
218 )
219 .field(
220 "permission_rules",
221 &self.permission_rules.as_ref().map(|_| "<rules>"),
222 )
223 .field("extended_thinking", &self.extended_thinking)
224 .finish()
225 }
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct ApiMessage {
231 pub role: String,
233 pub content: String,
235}
236
237impl Context {
238 pub fn new(message: &str) -> Self {
240 Self {
241 system_prompt: String::new(),
242 history: Vec::new(),
243 current_message: message.to_string(),
244 mcp_servers: Vec::new(),
245 toolboxes: Vec::new(),
246 max_turns: None,
247 allowed_tools: None,
248 model: None,
249 session_id: None,
250 agent_name: None,
251 hook_runner: None,
252 permission_rules: None,
253 extended_thinking: false,
254 }
255 }
256
257 pub fn with_hooks(mut self, runner: std::sync::Arc<dyn crate::hooks::HookRunner>) -> Self {
259 self.hook_runner = Some(runner);
260 self
261 }
262
263 pub fn to_prompt_string(&self) -> String {
269 if self.agent_name.is_some() {
270 return self.current_message.clone();
271 }
272
273 let mut parts = Vec::new();
274
275 if self.session_id.is_none() {
276 if !self.system_prompt.is_empty() {
277 parts.push(format!("[System]\n{}", self.system_prompt));
278 }
279 for entry in &self.history {
280 let role = if entry.role == "user" {
281 "User"
282 } else {
283 "Assistant"
284 };
285 parts.push(format!("[{}]\n{}", role, entry.content));
286 }
287 parts.push(format!("[User]\n{}", self.current_message));
288 } else {
289 if !self.system_prompt.is_empty() {
290 parts.push(format!(
291 "[User]\n{}\n\n{}",
292 self.system_prompt, self.current_message
293 ));
294 } else {
295 parts.push(format!("[User]\n{}", self.current_message));
296 }
297 }
298
299 parts.join("\n\n")
300 }
301
302 pub fn to_api_messages(&self) -> (String, Vec<ApiMessage>) {
307 let mut messages = Vec::with_capacity(self.history.len() + 1);
308
309 for entry in &self.history {
310 messages.push(ApiMessage {
311 role: entry.role.clone(),
312 content: entry.content.clone(),
313 });
314 }
315
316 messages.push(ApiMessage {
317 role: "user".to_string(),
318 content: self.current_message.clone(),
319 });
320
321 (self.system_prompt.clone(), messages)
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn test_context_new_defaults() {
331 let ctx = Context::new("hello");
332 assert!(ctx.system_prompt.is_empty());
333 assert!(ctx.history.is_empty());
334 assert!(ctx.mcp_servers.is_empty());
335 assert!(ctx.toolboxes.is_empty());
336 assert_eq!(ctx.current_message, "hello");
337 assert!(ctx.session_id.is_none());
338 assert!(ctx.agent_name.is_none());
339 }
340
341 #[test]
342 fn test_mcp_server_serde_round_trip() {
343 let server = McpServer {
344 name: "playwright".into(),
345 command: "npx".into(),
346 args: vec!["@playwright/mcp".into(), "--headless".into()],
347 env: HashMap::new(),
348 };
349 let json = serde_json::to_string(&server).unwrap();
350 let deserialized: McpServer = serde_json::from_str(&json).unwrap();
351 assert_eq!(deserialized.name, "playwright");
352 assert_eq!(deserialized.args, vec!["@playwright/mcp", "--headless"]);
353 }
354
355 #[test]
356 fn test_context_serde_without_optional_fields() {
357 let json = r#"{"system_prompt":"test","history":[],"current_message":"hi"}"#;
358 let ctx: Context = serde_json::from_str(json).unwrap();
359 assert!(ctx.mcp_servers.is_empty());
360 assert!(ctx.session_id.is_none());
361 assert!(ctx.agent_name.is_none());
362 }
363
364 #[test]
365 fn test_to_api_messages_basic() {
366 let ctx = Context::new("hello");
367 let (system, messages) = ctx.to_api_messages();
368 assert!(system.is_empty());
369 assert_eq!(messages.len(), 1);
370 assert_eq!(messages[0].role, "user");
371 assert_eq!(messages[0].content, "hello");
372 }
373
374 #[test]
375 fn test_to_api_messages_with_history() {
376 let ctx = Context {
377 system_prompt: "Be helpful.".into(),
378 history: vec![
379 ContextEntry {
380 role: "user".into(),
381 content: "Hi".into(),
382 },
383 ContextEntry {
384 role: "assistant".into(),
385 content: "Hello!".into(),
386 },
387 ],
388 current_message: "How are you?".into(),
389 mcp_servers: Vec::new(),
390 toolboxes: Vec::new(),
391 max_turns: None,
392 allowed_tools: None,
393 model: None,
394 session_id: None,
395 agent_name: None,
396 hook_runner: None,
397 permission_rules: None,
398 extended_thinking: false,
399 };
400 let (system, messages) = ctx.to_api_messages();
401 assert_eq!(system, "Be helpful.");
402 assert_eq!(messages.len(), 3);
403 }
404
405 #[test]
406 fn test_to_prompt_string_no_session() {
407 let ctx = Context {
408 system_prompt: "Be helpful.".into(),
409 history: vec![ContextEntry {
410 role: "user".into(),
411 content: "Hi".into(),
412 }],
413 current_message: "How are you?".into(),
414 mcp_servers: Vec::new(),
415 toolboxes: Vec::new(),
416 max_turns: None,
417 allowed_tools: None,
418 model: None,
419 session_id: None,
420 agent_name: None,
421 hook_runner: None,
422 permission_rules: None,
423 extended_thinking: false,
424 };
425 let prompt = ctx.to_prompt_string();
426 assert!(prompt.contains("[System]\nBe helpful."));
427 assert!(prompt.contains("[User]\nHi"));
428 assert!(prompt.contains("[User]\nHow are you?"));
429 }
430
431 #[test]
432 fn test_to_prompt_string_with_session() {
433 let ctx = Context {
434 system_prompt: "Current time: 2026-03-06".into(),
435 history: vec![ContextEntry {
436 role: "user".into(),
437 content: "Hi".into(),
438 }],
439 current_message: "How are you?".into(),
440 mcp_servers: Vec::new(),
441 toolboxes: Vec::new(),
442 max_turns: None,
443 allowed_tools: None,
444 model: None,
445 session_id: Some("sess-abc".into()),
446 agent_name: None,
447 hook_runner: None,
448 permission_rules: None,
449 extended_thinking: false,
450 };
451 let prompt = ctx.to_prompt_string();
452 assert!(!prompt.contains("[System]"));
453 assert!(prompt.contains("[User]\nCurrent time: 2026-03-06\n\nHow are you?"));
454 }
455
456 #[test]
457 fn test_to_prompt_string_with_agent_name() {
458 let ctx = Context {
459 system_prompt: "You are a build analyst...".into(),
460 history: vec![ContextEntry {
461 role: "user".into(),
462 content: "prev".into(),
463 }],
464 current_message: "Build me a task tracker.".into(),
465 mcp_servers: Vec::new(),
466 toolboxes: Vec::new(),
467 max_turns: None,
468 allowed_tools: None,
469 model: None,
470 session_id: None,
471 agent_name: Some("build-analyst".into()),
472 hook_runner: None,
473 permission_rules: None,
474 extended_thinking: false,
475 };
476 let prompt = ctx.to_prompt_string();
477 assert_eq!(prompt, "Build me a task tracker.");
478 }
479
480 #[test]
481 fn test_agent_name_takes_precedence_over_session_id() {
482 let ctx = Context {
483 system_prompt: "system".into(),
484 history: Vec::new(),
485 current_message: "Build something.".into(),
486 mcp_servers: Vec::new(),
487 toolboxes: Vec::new(),
488 max_turns: None,
489 allowed_tools: None,
490 model: None,
491 session_id: Some("sess-456".into()),
492 agent_name: Some("build-architect".into()),
493 hook_runner: None,
494 permission_rules: None,
495 extended_thinking: false,
496 };
497 assert_eq!(ctx.to_prompt_string(), "Build something.");
498 }
499
500 #[test]
501 fn test_session_id_serde_round_trip() {
502 let ctx = Context {
503 system_prompt: "test".into(),
504 history: Vec::new(),
505 current_message: "hi".into(),
506 mcp_servers: Vec::new(),
507 toolboxes: Vec::new(),
508 max_turns: None,
509 allowed_tools: None,
510 model: None,
511 session_id: Some("sess-123".into()),
512 agent_name: None,
513 hook_runner: None,
514 permission_rules: None,
515 extended_thinking: false,
516 };
517 let json = serde_json::to_string(&ctx).unwrap();
518 let deserialized: Context = serde_json::from_str(&json).unwrap();
519 assert_eq!(deserialized.session_id, Some("sess-123".into()));
520 }
521
522 #[test]
523 fn test_optional_fields_skipped_in_serialization() {
524 let ctx = Context::new("hello");
525 let json = serde_json::to_string(&ctx).unwrap();
526 assert!(!json.contains("session_id"));
527 assert!(!json.contains("agent_name"));
528 assert!(!json.contains("max_turns"));
529 assert!(!json.contains("toolboxes"));
530 }
531
532 #[test]
533 fn test_toolbox_serde_round_trip() {
534 let tb = Toolbox {
535 name: "lint".into(),
536 description: "Run linter on a file.".into(),
537 parameters: serde_json::json!({
538 "type": "object",
539 "properties": {"file": {"type": "string"}},
540 "required": ["file"]
541 }),
542 command: "bash".into(),
543 args: vec!["scripts/lint.sh".into()],
544 env: HashMap::new(),
545 network: false,
546 env_passthrough: Vec::new(),
547 allowed_commands: Vec::new(),
548 search_hints: Vec::new(),
549 };
550 let json = serde_json::to_string(&tb).unwrap();
551 let deserialized: Toolbox = serde_json::from_str(&json).unwrap();
552 assert_eq!(deserialized.name, "lint");
553 assert_eq!(deserialized.command, "bash");
554 assert_eq!(deserialized.args, vec!["scripts/lint.sh"]);
555 }
556
557 #[test]
558 fn test_toolbox_default_parameters() {
559 let json = r#"{"name":"test","description":"Test tool.","command":"echo"}"#;
560 let tb: Toolbox = serde_json::from_str(json).unwrap();
561 assert_eq!(tb.parameters, serde_json::json!({"type": "object"}));
562 assert!(tb.args.is_empty());
563 assert!(tb.env.is_empty());
564 }
565
566 #[test]
567 fn test_context_serde_with_toolboxes() {
568 let mut ctx = Context::new("run lint");
569 ctx.toolboxes.push(Toolbox {
570 name: "lint".into(),
571 description: "Lint a file.".into(),
572 parameters: serde_json::json!({"type": "object"}),
573 command: "bash".into(),
574 args: vec!["lint.sh".into()],
575 env: HashMap::new(),
576 network: false,
577 env_passthrough: Vec::new(),
578 allowed_commands: Vec::new(),
579 search_hints: Vec::new(),
580 });
581 let json = serde_json::to_string(&ctx).unwrap();
582 assert!(json.contains("toolboxes"));
583 let deserialized: Context = serde_json::from_str(&json).unwrap();
584 assert_eq!(deserialized.toolboxes.len(), 1);
585 assert_eq!(deserialized.toolboxes[0].name, "lint");
586 }
587}