1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
8pub enum CompactionStrategy {
9 #[default]
11 Drop,
12 Summarize,
17}
18
19#[derive(Debug, Clone)]
24pub struct ContextNeeds {
25 pub recall: bool,
27 pub pending_tasks: bool,
29 pub profile: bool,
31 pub summaries: bool,
33 pub outcomes: bool,
35 pub compact: CompactionStrategy,
37}
38
39impl Default for ContextNeeds {
40 fn default() -> Self {
41 Self {
42 recall: true,
43 pending_tasks: true,
44 profile: true,
45 summaries: true,
46 outcomes: true,
47 compact: CompactionStrategy::default(),
48 }
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct ContextEntry {
55 pub role: String,
57 pub content: String,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize, Default)]
63pub struct McpServer {
64 pub name: String,
66 pub command: String,
68 pub args: Vec<String>,
70 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
72 pub env: HashMap<String, String>,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct Toolbox {
81 pub name: String,
83 pub description: String,
85 #[serde(default = "default_object_schema")]
87 pub parameters: serde_json::Value,
88 pub command: String,
90 #[serde(default)]
92 pub args: Vec<String>,
93 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
95 pub env: HashMap<String, String>,
96 #[serde(default, skip_serializing_if = "Vec::is_empty")]
98 pub search_hints: Vec<String>,
99}
100
101fn default_object_schema() -> serde_json::Value {
102 serde_json::json!({"type": "object"})
103}
104
105fn is_false(b: &bool) -> bool {
106 !b
107}
108
109#[derive(Clone, Serialize, Deserialize)]
111pub struct Context {
112 pub system_prompt: String,
114 pub history: Vec<ContextEntry>,
116 pub current_message: String,
118 #[serde(default)]
120 pub mcp_servers: Vec<McpServer>,
121 #[serde(default, skip_serializing_if = "Vec::is_empty")]
123 pub toolboxes: Vec<Toolbox>,
124 #[serde(default, skip_serializing_if = "Option::is_none")]
126 pub max_turns: Option<u32>,
127 #[serde(default, skip_serializing_if = "Option::is_none")]
129 pub allowed_tools: Option<Vec<String>>,
130 #[serde(default, skip_serializing_if = "Option::is_none")]
132 pub model: Option<String>,
133 #[serde(default, skip_serializing_if = "Option::is_none")]
135 pub session_id: Option<String>,
136 #[serde(default, skip_serializing_if = "Option::is_none")]
139 pub agent_name: Option<String>,
140 #[serde(skip)]
142 pub hook_runner: Option<std::sync::Arc<dyn crate::hooks::HookRunner>>,
143 #[serde(skip)]
146 pub permission_rules: Option<std::sync::Arc<crate::permissions::PermissionRules>>,
147 #[serde(default, skip_serializing_if = "is_false")]
150 pub extended_thinking: bool,
151}
152
153impl std::fmt::Debug for Context {
158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159 f.debug_struct("Context")
160 .field("system_prompt", &self.system_prompt)
161 .field("history", &self.history)
162 .field("current_message", &self.current_message)
163 .field("mcp_servers", &self.mcp_servers)
164 .field("toolboxes", &self.toolboxes)
165 .field("max_turns", &self.max_turns)
166 .field("allowed_tools", &self.allowed_tools)
167 .field("model", &self.model)
168 .field("session_id", &self.session_id)
169 .field("agent_name", &self.agent_name)
170 .field(
171 "hook_runner",
172 &self.hook_runner.as_ref().map(|_| "<runner>"),
173 )
174 .field(
175 "permission_rules",
176 &self.permission_rules.as_ref().map(|_| "<rules>"),
177 )
178 .field("extended_thinking", &self.extended_thinking)
179 .finish()
180 }
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct ApiMessage {
186 pub role: String,
188 pub content: String,
190}
191
192impl Context {
193 pub fn new(message: &str) -> Self {
195 Self {
196 system_prompt: String::new(),
197 history: Vec::new(),
198 current_message: message.to_string(),
199 mcp_servers: Vec::new(),
200 toolboxes: Vec::new(),
201 max_turns: None,
202 allowed_tools: None,
203 model: None,
204 session_id: None,
205 agent_name: None,
206 hook_runner: None,
207 permission_rules: None,
208 extended_thinking: false,
209 }
210 }
211
212 pub fn with_hooks(mut self, runner: std::sync::Arc<dyn crate::hooks::HookRunner>) -> Self {
214 self.hook_runner = Some(runner);
215 self
216 }
217
218 pub fn to_prompt_string(&self) -> String {
224 if self.agent_name.is_some() {
225 return self.current_message.clone();
226 }
227
228 let mut parts = Vec::new();
229
230 if self.session_id.is_none() {
231 if !self.system_prompt.is_empty() {
232 parts.push(format!("[System]\n{}", self.system_prompt));
233 }
234 for entry in &self.history {
235 let role = if entry.role == "user" {
236 "User"
237 } else {
238 "Assistant"
239 };
240 parts.push(format!("[{}]\n{}", role, entry.content));
241 }
242 parts.push(format!("[User]\n{}", self.current_message));
243 } else {
244 if !self.system_prompt.is_empty() {
245 parts.push(format!(
246 "[User]\n{}\n\n{}",
247 self.system_prompt, self.current_message
248 ));
249 } else {
250 parts.push(format!("[User]\n{}", self.current_message));
251 }
252 }
253
254 parts.join("\n\n")
255 }
256
257 pub fn to_api_messages(&self) -> (String, Vec<ApiMessage>) {
262 let mut messages = Vec::with_capacity(self.history.len() + 1);
263
264 for entry in &self.history {
265 messages.push(ApiMessage {
266 role: entry.role.clone(),
267 content: entry.content.clone(),
268 });
269 }
270
271 messages.push(ApiMessage {
272 role: "user".to_string(),
273 content: self.current_message.clone(),
274 });
275
276 (self.system_prompt.clone(), messages)
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn test_context_new_defaults() {
286 let ctx = Context::new("hello");
287 assert!(ctx.system_prompt.is_empty());
288 assert!(ctx.history.is_empty());
289 assert!(ctx.mcp_servers.is_empty());
290 assert!(ctx.toolboxes.is_empty());
291 assert_eq!(ctx.current_message, "hello");
292 assert!(ctx.session_id.is_none());
293 assert!(ctx.agent_name.is_none());
294 }
295
296 #[test]
297 fn test_mcp_server_serde_round_trip() {
298 let server = McpServer {
299 name: "playwright".into(),
300 command: "npx".into(),
301 args: vec!["@playwright/mcp".into(), "--headless".into()],
302 env: HashMap::new(),
303 };
304 let json = serde_json::to_string(&server).unwrap();
305 let deserialized: McpServer = serde_json::from_str(&json).unwrap();
306 assert_eq!(deserialized.name, "playwright");
307 assert_eq!(deserialized.args, vec!["@playwright/mcp", "--headless"]);
308 }
309
310 #[test]
311 fn test_context_serde_without_optional_fields() {
312 let json = r#"{"system_prompt":"test","history":[],"current_message":"hi"}"#;
313 let ctx: Context = serde_json::from_str(json).unwrap();
314 assert!(ctx.mcp_servers.is_empty());
315 assert!(ctx.session_id.is_none());
316 assert!(ctx.agent_name.is_none());
317 }
318
319 #[test]
320 fn test_to_api_messages_basic() {
321 let ctx = Context::new("hello");
322 let (system, messages) = ctx.to_api_messages();
323 assert!(system.is_empty());
324 assert_eq!(messages.len(), 1);
325 assert_eq!(messages[0].role, "user");
326 assert_eq!(messages[0].content, "hello");
327 }
328
329 #[test]
330 fn test_to_api_messages_with_history() {
331 let ctx = Context {
332 system_prompt: "Be helpful.".into(),
333 history: vec![
334 ContextEntry {
335 role: "user".into(),
336 content: "Hi".into(),
337 },
338 ContextEntry {
339 role: "assistant".into(),
340 content: "Hello!".into(),
341 },
342 ],
343 current_message: "How are you?".into(),
344 mcp_servers: Vec::new(),
345 toolboxes: Vec::new(),
346 max_turns: None,
347 allowed_tools: None,
348 model: None,
349 session_id: None,
350 agent_name: None,
351 hook_runner: None,
352 permission_rules: None,
353 extended_thinking: false,
354 };
355 let (system, messages) = ctx.to_api_messages();
356 assert_eq!(system, "Be helpful.");
357 assert_eq!(messages.len(), 3);
358 }
359
360 #[test]
361 fn test_to_prompt_string_no_session() {
362 let ctx = Context {
363 system_prompt: "Be helpful.".into(),
364 history: vec![ContextEntry {
365 role: "user".into(),
366 content: "Hi".into(),
367 }],
368 current_message: "How are you?".into(),
369 mcp_servers: Vec::new(),
370 toolboxes: Vec::new(),
371 max_turns: None,
372 allowed_tools: None,
373 model: None,
374 session_id: None,
375 agent_name: None,
376 hook_runner: None,
377 permission_rules: None,
378 extended_thinking: false,
379 };
380 let prompt = ctx.to_prompt_string();
381 assert!(prompt.contains("[System]\nBe helpful."));
382 assert!(prompt.contains("[User]\nHi"));
383 assert!(prompt.contains("[User]\nHow are you?"));
384 }
385
386 #[test]
387 fn test_to_prompt_string_with_session() {
388 let ctx = Context {
389 system_prompt: "Current time: 2026-03-06".into(),
390 history: vec![ContextEntry {
391 role: "user".into(),
392 content: "Hi".into(),
393 }],
394 current_message: "How are you?".into(),
395 mcp_servers: Vec::new(),
396 toolboxes: Vec::new(),
397 max_turns: None,
398 allowed_tools: None,
399 model: None,
400 session_id: Some("sess-abc".into()),
401 agent_name: None,
402 hook_runner: None,
403 permission_rules: None,
404 extended_thinking: false,
405 };
406 let prompt = ctx.to_prompt_string();
407 assert!(!prompt.contains("[System]"));
408 assert!(prompt.contains("[User]\nCurrent time: 2026-03-06\n\nHow are you?"));
409 }
410
411 #[test]
412 fn test_to_prompt_string_with_agent_name() {
413 let ctx = Context {
414 system_prompt: "You are a build analyst...".into(),
415 history: vec![ContextEntry {
416 role: "user".into(),
417 content: "prev".into(),
418 }],
419 current_message: "Build me a task tracker.".into(),
420 mcp_servers: Vec::new(),
421 toolboxes: Vec::new(),
422 max_turns: None,
423 allowed_tools: None,
424 model: None,
425 session_id: None,
426 agent_name: Some("build-analyst".into()),
427 hook_runner: None,
428 permission_rules: None,
429 extended_thinking: false,
430 };
431 let prompt = ctx.to_prompt_string();
432 assert_eq!(prompt, "Build me a task tracker.");
433 }
434
435 #[test]
436 fn test_agent_name_takes_precedence_over_session_id() {
437 let ctx = Context {
438 system_prompt: "system".into(),
439 history: Vec::new(),
440 current_message: "Build something.".into(),
441 mcp_servers: Vec::new(),
442 toolboxes: Vec::new(),
443 max_turns: None,
444 allowed_tools: None,
445 model: None,
446 session_id: Some("sess-456".into()),
447 agent_name: Some("build-architect".into()),
448 hook_runner: None,
449 permission_rules: None,
450 extended_thinking: false,
451 };
452 assert_eq!(ctx.to_prompt_string(), "Build something.");
453 }
454
455 #[test]
456 fn test_session_id_serde_round_trip() {
457 let ctx = Context {
458 system_prompt: "test".into(),
459 history: Vec::new(),
460 current_message: "hi".into(),
461 mcp_servers: Vec::new(),
462 toolboxes: Vec::new(),
463 max_turns: None,
464 allowed_tools: None,
465 model: None,
466 session_id: Some("sess-123".into()),
467 agent_name: None,
468 hook_runner: None,
469 permission_rules: None,
470 extended_thinking: false,
471 };
472 let json = serde_json::to_string(&ctx).unwrap();
473 let deserialized: Context = serde_json::from_str(&json).unwrap();
474 assert_eq!(deserialized.session_id, Some("sess-123".into()));
475 }
476
477 #[test]
478 fn test_optional_fields_skipped_in_serialization() {
479 let ctx = Context::new("hello");
480 let json = serde_json::to_string(&ctx).unwrap();
481 assert!(!json.contains("session_id"));
482 assert!(!json.contains("agent_name"));
483 assert!(!json.contains("max_turns"));
484 assert!(!json.contains("toolboxes"));
485 }
486
487 #[test]
488 fn test_toolbox_serde_round_trip() {
489 let tb = Toolbox {
490 name: "lint".into(),
491 description: "Run linter on a file.".into(),
492 parameters: serde_json::json!({
493 "type": "object",
494 "properties": {"file": {"type": "string"}},
495 "required": ["file"]
496 }),
497 command: "bash".into(),
498 args: vec!["scripts/lint.sh".into()],
499 env: HashMap::new(),
500 search_hints: Vec::new(),
501 };
502 let json = serde_json::to_string(&tb).unwrap();
503 let deserialized: Toolbox = serde_json::from_str(&json).unwrap();
504 assert_eq!(deserialized.name, "lint");
505 assert_eq!(deserialized.command, "bash");
506 assert_eq!(deserialized.args, vec!["scripts/lint.sh"]);
507 }
508
509 #[test]
510 fn test_toolbox_default_parameters() {
511 let json = r#"{"name":"test","description":"Test tool.","command":"echo"}"#;
512 let tb: Toolbox = serde_json::from_str(json).unwrap();
513 assert_eq!(tb.parameters, serde_json::json!({"type": "object"}));
514 assert!(tb.args.is_empty());
515 assert!(tb.env.is_empty());
516 }
517
518 #[test]
519 fn test_context_serde_with_toolboxes() {
520 let mut ctx = Context::new("run lint");
521 ctx.toolboxes.push(Toolbox {
522 name: "lint".into(),
523 description: "Lint a file.".into(),
524 parameters: serde_json::json!({"type": "object"}),
525 command: "bash".into(),
526 args: vec!["lint.sh".into()],
527 env: HashMap::new(),
528 search_hints: Vec::new(),
529 });
530 let json = serde_json::to_string(&ctx).unwrap();
531 assert!(json.contains("toolboxes"));
532 let deserialized: Context = serde_json::from_str(&json).unwrap();
533 assert_eq!(deserialized.toolboxes.len(), 1);
534 assert_eq!(deserialized.toolboxes[0].name, "lint");
535 }
536}