1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
8pub enum CompactionStrategy {
9 #[default]
11 Drop,
12 Summarize,
17}
18
19#[derive(Debug, Clone)]
24pub struct ContextNeeds {
25 pub recall: bool,
27 pub pending_tasks: bool,
29 pub profile: bool,
31 pub summaries: bool,
33 pub outcomes: bool,
35 pub compact: CompactionStrategy,
37}
38
39impl Default for ContextNeeds {
40 fn default() -> Self {
41 Self {
42 recall: true,
43 pending_tasks: true,
44 profile: true,
45 summaries: true,
46 outcomes: true,
47 compact: CompactionStrategy::default(),
48 }
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct ContextEntry {
55 pub role: String,
57 pub content: String,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize, Default)]
63pub struct McpServer {
64 pub name: String,
66 pub command: String,
68 pub args: Vec<String>,
70 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
72 pub env: HashMap<String, String>,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct Toolbox {
81 pub name: String,
83 pub description: String,
85 #[serde(default = "default_object_schema")]
87 pub parameters: serde_json::Value,
88 pub command: String,
90 #[serde(default)]
92 pub args: Vec<String>,
93 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
95 pub env: HashMap<String, String>,
96 #[serde(default)]
101 pub network: bool,
102 #[serde(default, skip_serializing_if = "Vec::is_empty")]
107 pub env_passthrough: Vec<String>,
108 #[serde(default, skip_serializing_if = "Vec::is_empty")]
113 pub allowed_commands: Vec<String>,
114 #[serde(default, skip_serializing_if = "Vec::is_empty")]
116 pub search_hints: Vec<String>,
117}
118
119pub fn command_matches_allowlist(allowed: &[String], command: &str) -> bool {
130 if allowed.is_empty() {
131 return true;
132 }
133 let basename = command.rsplit('/').next().unwrap_or(command);
134 allowed.iter().any(|entry| {
135 if entry.contains('/') {
136 entry == command
137 } else {
138 entry == basename
139 }
140 })
141}
142
143fn default_object_schema() -> serde_json::Value {
144 serde_json::json!({"type": "object"})
145}
146
147fn is_false(b: &bool) -> bool {
148 !b
149}
150
151#[derive(Clone, Serialize, Deserialize)]
153pub struct Context {
154 pub system_prompt: String,
156 pub history: Vec<ContextEntry>,
158 pub current_message: String,
160 #[serde(default)]
162 pub mcp_servers: Vec<McpServer>,
163 #[serde(default, skip_serializing_if = "Vec::is_empty")]
165 pub toolboxes: Vec<Toolbox>,
166 #[serde(default, skip_serializing_if = "Option::is_none")]
168 pub max_turns: Option<u32>,
169 #[serde(default, skip_serializing_if = "Option::is_none")]
174 pub token_budget: Option<u64>,
175 #[serde(default, skip_serializing_if = "Option::is_none")]
177 pub allowed_tools: Option<Vec<String>>,
178 #[serde(default, skip_serializing_if = "Option::is_none")]
180 pub model: Option<String>,
181 #[serde(default, skip_serializing_if = "Option::is_none")]
183 pub session_id: Option<String>,
184 #[serde(default, skip_serializing_if = "Option::is_none")]
187 pub agent_name: Option<String>,
188 #[serde(skip)]
190 pub hook_runner: Option<std::sync::Arc<dyn crate::hooks::HookRunner>>,
191 #[serde(skip)]
194 pub permission_rules: Option<std::sync::Arc<crate::permissions::PermissionRules>>,
195 #[serde(default, skip_serializing_if = "is_false")]
201 pub extended_thinking: bool,
202}
203
204impl std::fmt::Debug for Context {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 f.debug_struct("Context")
211 .field("system_prompt", &self.system_prompt)
212 .field("history", &self.history)
213 .field("current_message", &self.current_message)
214 .field("mcp_servers", &self.mcp_servers)
215 .field("toolboxes", &self.toolboxes)
216 .field("max_turns", &self.max_turns)
217 .field("token_budget", &self.token_budget)
218 .field("allowed_tools", &self.allowed_tools)
219 .field("model", &self.model)
220 .field("session_id", &self.session_id)
221 .field("agent_name", &self.agent_name)
222 .field(
223 "hook_runner",
224 &self.hook_runner.as_ref().map(|_| "<runner>"),
225 )
226 .field(
227 "permission_rules",
228 &self.permission_rules.as_ref().map(|_| "<rules>"),
229 )
230 .field("extended_thinking", &self.extended_thinking)
231 .finish()
232 }
233}
234
235#[derive(Debug, Clone, Serialize, Deserialize)]
237pub struct ApiMessage {
238 pub role: String,
240 pub content: String,
242}
243
244impl Context {
245 pub fn new(message: &str) -> Self {
247 Self {
248 system_prompt: String::new(),
249 history: Vec::new(),
250 current_message: message.to_string(),
251 mcp_servers: Vec::new(),
252 toolboxes: Vec::new(),
253 max_turns: None,
254 token_budget: None,
255 allowed_tools: None,
256 model: None,
257 session_id: None,
258 agent_name: None,
259 hook_runner: None,
260 permission_rules: None,
261 extended_thinking: false,
262 }
263 }
264
265 pub fn with_hooks(mut self, runner: std::sync::Arc<dyn crate::hooks::HookRunner>) -> Self {
267 self.hook_runner = Some(runner);
268 self
269 }
270
271 pub fn to_prompt_string(&self) -> String {
277 if self.agent_name.is_some() {
278 return self.current_message.clone();
279 }
280
281 let mut parts = Vec::new();
282
283 if self.session_id.is_none() {
284 if !self.system_prompt.is_empty() {
285 parts.push(format!("[System]\n{}", self.system_prompt));
286 }
287 for entry in &self.history {
288 let role = if entry.role == "user" {
289 "User"
290 } else {
291 "Assistant"
292 };
293 parts.push(format!("[{}]\n{}", role, entry.content));
294 }
295 parts.push(format!("[User]\n{}", self.current_message));
296 } else {
297 if !self.system_prompt.is_empty() {
298 parts.push(format!(
299 "[User]\n{}\n\n{}",
300 self.system_prompt, self.current_message
301 ));
302 } else {
303 parts.push(format!("[User]\n{}", self.current_message));
304 }
305 }
306
307 parts.join("\n\n")
308 }
309
310 pub fn to_api_messages(&self) -> (String, Vec<ApiMessage>) {
315 let mut messages = Vec::with_capacity(self.history.len() + 1);
316
317 for entry in &self.history {
318 messages.push(ApiMessage {
319 role: entry.role.clone(),
320 content: entry.content.clone(),
321 });
322 }
323
324 messages.push(ApiMessage {
325 role: "user".to_string(),
326 content: self.current_message.clone(),
327 });
328
329 (self.system_prompt.clone(), messages)
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336
337 #[test]
338 fn test_context_new_defaults() {
339 let ctx = Context::new("hello");
340 assert!(ctx.system_prompt.is_empty());
341 assert!(ctx.history.is_empty());
342 assert!(ctx.mcp_servers.is_empty());
343 assert!(ctx.toolboxes.is_empty());
344 assert_eq!(ctx.current_message, "hello");
345 assert!(ctx.session_id.is_none());
346 assert!(ctx.agent_name.is_none());
347 }
348
349 #[test]
350 fn test_mcp_server_serde_round_trip() {
351 let server = McpServer {
352 name: "playwright".into(),
353 command: "npx".into(),
354 args: vec!["@playwright/mcp".into(), "--headless".into()],
355 env: HashMap::new(),
356 };
357 let json = serde_json::to_string(&server).unwrap();
358 let deserialized: McpServer = serde_json::from_str(&json).unwrap();
359 assert_eq!(deserialized.name, "playwright");
360 assert_eq!(deserialized.args, vec!["@playwright/mcp", "--headless"]);
361 }
362
363 #[test]
364 fn test_context_serde_without_optional_fields() {
365 let json = r#"{"system_prompt":"test","history":[],"current_message":"hi"}"#;
366 let ctx: Context = serde_json::from_str(json).unwrap();
367 assert!(ctx.mcp_servers.is_empty());
368 assert!(ctx.session_id.is_none());
369 assert!(ctx.agent_name.is_none());
370 }
371
372 #[test]
373 fn test_to_api_messages_basic() {
374 let ctx = Context::new("hello");
375 let (system, messages) = ctx.to_api_messages();
376 assert!(system.is_empty());
377 assert_eq!(messages.len(), 1);
378 assert_eq!(messages[0].role, "user");
379 assert_eq!(messages[0].content, "hello");
380 }
381
382 #[test]
383 fn test_to_api_messages_with_history() {
384 let ctx = Context {
385 system_prompt: "Be helpful.".into(),
386 history: vec![
387 ContextEntry {
388 role: "user".into(),
389 content: "Hi".into(),
390 },
391 ContextEntry {
392 role: "assistant".into(),
393 content: "Hello!".into(),
394 },
395 ],
396 current_message: "How are you?".into(),
397 mcp_servers: Vec::new(),
398 toolboxes: Vec::new(),
399 max_turns: None,
400 token_budget: None,
401 allowed_tools: None,
402 model: None,
403 session_id: None,
404 agent_name: None,
405 hook_runner: None,
406 permission_rules: None,
407 extended_thinking: false,
408 };
409 let (system, messages) = ctx.to_api_messages();
410 assert_eq!(system, "Be helpful.");
411 assert_eq!(messages.len(), 3);
412 }
413
414 #[test]
415 fn test_to_prompt_string_no_session() {
416 let ctx = Context {
417 system_prompt: "Be helpful.".into(),
418 history: vec![ContextEntry {
419 role: "user".into(),
420 content: "Hi".into(),
421 }],
422 current_message: "How are you?".into(),
423 mcp_servers: Vec::new(),
424 toolboxes: Vec::new(),
425 max_turns: None,
426 token_budget: None,
427 allowed_tools: None,
428 model: None,
429 session_id: None,
430 agent_name: None,
431 hook_runner: None,
432 permission_rules: None,
433 extended_thinking: false,
434 };
435 let prompt = ctx.to_prompt_string();
436 assert!(prompt.contains("[System]\nBe helpful."));
437 assert!(prompt.contains("[User]\nHi"));
438 assert!(prompt.contains("[User]\nHow are you?"));
439 }
440
441 #[test]
442 fn test_to_prompt_string_with_session() {
443 let ctx = Context {
444 system_prompt: "Current time: 2026-03-06".into(),
445 history: vec![ContextEntry {
446 role: "user".into(),
447 content: "Hi".into(),
448 }],
449 current_message: "How are you?".into(),
450 mcp_servers: Vec::new(),
451 toolboxes: Vec::new(),
452 max_turns: None,
453 token_budget: None,
454 allowed_tools: None,
455 model: None,
456 session_id: Some("sess-abc".into()),
457 agent_name: None,
458 hook_runner: None,
459 permission_rules: None,
460 extended_thinking: false,
461 };
462 let prompt = ctx.to_prompt_string();
463 assert!(!prompt.contains("[System]"));
464 assert!(prompt.contains("[User]\nCurrent time: 2026-03-06\n\nHow are you?"));
465 }
466
467 #[test]
468 fn test_to_prompt_string_with_agent_name() {
469 let ctx = Context {
470 system_prompt: "You are a build analyst...".into(),
471 history: vec![ContextEntry {
472 role: "user".into(),
473 content: "prev".into(),
474 }],
475 current_message: "Build me a task tracker.".into(),
476 mcp_servers: Vec::new(),
477 toolboxes: Vec::new(),
478 max_turns: None,
479 token_budget: None,
480 allowed_tools: None,
481 model: None,
482 session_id: None,
483 agent_name: Some("build-analyst".into()),
484 hook_runner: None,
485 permission_rules: None,
486 extended_thinking: false,
487 };
488 let prompt = ctx.to_prompt_string();
489 assert_eq!(prompt, "Build me a task tracker.");
490 }
491
492 #[test]
493 fn test_agent_name_takes_precedence_over_session_id() {
494 let ctx = Context {
495 system_prompt: "system".into(),
496 history: Vec::new(),
497 current_message: "Build something.".into(),
498 mcp_servers: Vec::new(),
499 toolboxes: Vec::new(),
500 max_turns: None,
501 token_budget: None,
502 allowed_tools: None,
503 model: None,
504 session_id: Some("sess-456".into()),
505 agent_name: Some("build-architect".into()),
506 hook_runner: None,
507 permission_rules: None,
508 extended_thinking: false,
509 };
510 assert_eq!(ctx.to_prompt_string(), "Build something.");
511 }
512
513 #[test]
514 fn test_session_id_serde_round_trip() {
515 let ctx = Context {
516 system_prompt: "test".into(),
517 history: Vec::new(),
518 current_message: "hi".into(),
519 mcp_servers: Vec::new(),
520 toolboxes: Vec::new(),
521 max_turns: None,
522 token_budget: None,
523 allowed_tools: None,
524 model: None,
525 session_id: Some("sess-123".into()),
526 agent_name: None,
527 hook_runner: None,
528 permission_rules: None,
529 extended_thinking: false,
530 };
531 let json = serde_json::to_string(&ctx).unwrap();
532 let deserialized: Context = serde_json::from_str(&json).unwrap();
533 assert_eq!(deserialized.session_id, Some("sess-123".into()));
534 }
535
536 #[test]
537 fn test_optional_fields_skipped_in_serialization() {
538 let ctx = Context::new("hello");
539 let json = serde_json::to_string(&ctx).unwrap();
540 assert!(!json.contains("session_id"));
541 assert!(!json.contains("agent_name"));
542 assert!(!json.contains("max_turns"));
543 assert!(!json.contains("toolboxes"));
544 }
545
546 #[test]
547 fn test_toolbox_serde_round_trip() {
548 let tb = Toolbox {
549 name: "lint".into(),
550 description: "Run linter on a file.".into(),
551 parameters: serde_json::json!({
552 "type": "object",
553 "properties": {"file": {"type": "string"}},
554 "required": ["file"]
555 }),
556 command: "bash".into(),
557 args: vec!["scripts/lint.sh".into()],
558 env: HashMap::new(),
559 network: false,
560 env_passthrough: Vec::new(),
561 allowed_commands: Vec::new(),
562 search_hints: Vec::new(),
563 };
564 let json = serde_json::to_string(&tb).unwrap();
565 let deserialized: Toolbox = serde_json::from_str(&json).unwrap();
566 assert_eq!(deserialized.name, "lint");
567 assert_eq!(deserialized.command, "bash");
568 assert_eq!(deserialized.args, vec!["scripts/lint.sh"]);
569 }
570
571 #[test]
572 fn test_toolbox_default_parameters() {
573 let json = r#"{"name":"test","description":"Test tool.","command":"echo"}"#;
574 let tb: Toolbox = serde_json::from_str(json).unwrap();
575 assert_eq!(tb.parameters, serde_json::json!({"type": "object"}));
576 assert!(tb.args.is_empty());
577 assert!(tb.env.is_empty());
578 }
579
580 #[test]
581 fn test_context_serde_with_toolboxes() {
582 let mut ctx = Context::new("run lint");
583 ctx.toolboxes.push(Toolbox {
584 name: "lint".into(),
585 description: "Lint a file.".into(),
586 parameters: serde_json::json!({"type": "object"}),
587 command: "bash".into(),
588 args: vec!["lint.sh".into()],
589 env: HashMap::new(),
590 network: false,
591 env_passthrough: Vec::new(),
592 allowed_commands: Vec::new(),
593 search_hints: Vec::new(),
594 });
595 let json = serde_json::to_string(&ctx).unwrap();
596 assert!(json.contains("toolboxes"));
597 let deserialized: Context = serde_json::from_str(&json).unwrap();
598 assert_eq!(deserialized.toolboxes.len(), 1);
599 assert_eq!(deserialized.toolboxes[0].name, "lint");
600 }
601}