use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CompactionStrategy {
#[default]
Drop,
Summarize,
}
#[derive(Debug, Clone)]
pub struct ContextNeeds {
pub recall: bool,
pub pending_tasks: bool,
pub profile: bool,
pub summaries: bool,
pub outcomes: bool,
pub compact: CompactionStrategy,
}
impl Default for ContextNeeds {
fn default() -> Self {
Self {
recall: true,
pending_tasks: true,
profile: true,
summaries: true,
outcomes: true,
compact: CompactionStrategy::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextEntry {
pub role: String,
pub content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct McpServer {
pub name: String,
pub command: String,
pub args: Vec<String>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub env: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Toolbox {
pub name: String,
pub description: String,
#[serde(default = "default_object_schema")]
pub parameters: serde_json::Value,
pub command: String,
#[serde(default)]
pub args: Vec<String>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub env: HashMap<String, String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub search_hints: Vec<String>,
}
fn default_object_schema() -> serde_json::Value {
serde_json::json!({"type": "object"})
}
fn is_false(b: &bool) -> bool {
!b
}
#[derive(Clone, Serialize, Deserialize)]
pub struct Context {
pub system_prompt: String,
pub history: Vec<ContextEntry>,
pub current_message: String,
#[serde(default)]
pub mcp_servers: Vec<McpServer>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub toolboxes: Vec<Toolbox>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_turns: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub allowed_tools: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub agent_name: Option<String>,
#[serde(skip)]
pub hook_runner: Option<std::sync::Arc<dyn crate::hooks::HookRunner>>,
#[serde(skip)]
pub permission_rules: Option<std::sync::Arc<crate::permissions::PermissionRules>>,
#[serde(default, skip_serializing_if = "is_false")]
pub extended_thinking: bool,
}
impl std::fmt::Debug for Context {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Context")
.field("system_prompt", &self.system_prompt)
.field("history", &self.history)
.field("current_message", &self.current_message)
.field("mcp_servers", &self.mcp_servers)
.field("toolboxes", &self.toolboxes)
.field("max_turns", &self.max_turns)
.field("allowed_tools", &self.allowed_tools)
.field("model", &self.model)
.field("session_id", &self.session_id)
.field("agent_name", &self.agent_name)
.field(
"hook_runner",
&self.hook_runner.as_ref().map(|_| "<runner>"),
)
.field(
"permission_rules",
&self.permission_rules.as_ref().map(|_| "<rules>"),
)
.field("extended_thinking", &self.extended_thinking)
.finish()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiMessage {
pub role: String,
pub content: String,
}
impl Context {
pub fn new(message: &str) -> Self {
Self {
system_prompt: String::new(),
history: Vec::new(),
current_message: message.to_string(),
mcp_servers: Vec::new(),
toolboxes: Vec::new(),
max_turns: None,
allowed_tools: None,
model: None,
session_id: None,
agent_name: None,
hook_runner: None,
permission_rules: None,
extended_thinking: false,
}
}
pub fn with_hooks(mut self, runner: std::sync::Arc<dyn crate::hooks::HookRunner>) -> Self {
self.hook_runner = Some(runner);
self
}
pub fn to_prompt_string(&self) -> String {
if self.agent_name.is_some() {
return self.current_message.clone();
}
let mut parts = Vec::new();
if self.session_id.is_none() {
if !self.system_prompt.is_empty() {
parts.push(format!("[System]\n{}", self.system_prompt));
}
for entry in &self.history {
let role = if entry.role == "user" {
"User"
} else {
"Assistant"
};
parts.push(format!("[{}]\n{}", role, entry.content));
}
parts.push(format!("[User]\n{}", self.current_message));
} else {
if !self.system_prompt.is_empty() {
parts.push(format!(
"[User]\n{}\n\n{}",
self.system_prompt, self.current_message
));
} else {
parts.push(format!("[User]\n{}", self.current_message));
}
}
parts.join("\n\n")
}
pub fn to_api_messages(&self) -> (String, Vec<ApiMessage>) {
let mut messages = Vec::with_capacity(self.history.len() + 1);
for entry in &self.history {
messages.push(ApiMessage {
role: entry.role.clone(),
content: entry.content.clone(),
});
}
messages.push(ApiMessage {
role: "user".to_string(),
content: self.current_message.clone(),
});
(self.system_prompt.clone(), messages)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_new_defaults() {
let ctx = Context::new("hello");
assert!(ctx.system_prompt.is_empty());
assert!(ctx.history.is_empty());
assert!(ctx.mcp_servers.is_empty());
assert!(ctx.toolboxes.is_empty());
assert_eq!(ctx.current_message, "hello");
assert!(ctx.session_id.is_none());
assert!(ctx.agent_name.is_none());
}
#[test]
fn test_mcp_server_serde_round_trip() {
let server = McpServer {
name: "playwright".into(),
command: "npx".into(),
args: vec!["@playwright/mcp".into(), "--headless".into()],
env: HashMap::new(),
};
let json = serde_json::to_string(&server).unwrap();
let deserialized: McpServer = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.name, "playwright");
assert_eq!(deserialized.args, vec!["@playwright/mcp", "--headless"]);
}
#[test]
fn test_context_serde_without_optional_fields() {
let json = r#"{"system_prompt":"test","history":[],"current_message":"hi"}"#;
let ctx: Context = serde_json::from_str(json).unwrap();
assert!(ctx.mcp_servers.is_empty());
assert!(ctx.session_id.is_none());
assert!(ctx.agent_name.is_none());
}
#[test]
fn test_to_api_messages_basic() {
let ctx = Context::new("hello");
let (system, messages) = ctx.to_api_messages();
assert!(system.is_empty());
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].role, "user");
assert_eq!(messages[0].content, "hello");
}
#[test]
fn test_to_api_messages_with_history() {
let ctx = Context {
system_prompt: "Be helpful.".into(),
history: vec![
ContextEntry {
role: "user".into(),
content: "Hi".into(),
},
ContextEntry {
role: "assistant".into(),
content: "Hello!".into(),
},
],
current_message: "How are you?".into(),
mcp_servers: Vec::new(),
toolboxes: Vec::new(),
max_turns: None,
allowed_tools: None,
model: None,
session_id: None,
agent_name: None,
hook_runner: None,
permission_rules: None,
extended_thinking: false,
};
let (system, messages) = ctx.to_api_messages();
assert_eq!(system, "Be helpful.");
assert_eq!(messages.len(), 3);
}
#[test]
fn test_to_prompt_string_no_session() {
let ctx = Context {
system_prompt: "Be helpful.".into(),
history: vec![ContextEntry {
role: "user".into(),
content: "Hi".into(),
}],
current_message: "How are you?".into(),
mcp_servers: Vec::new(),
toolboxes: Vec::new(),
max_turns: None,
allowed_tools: None,
model: None,
session_id: None,
agent_name: None,
hook_runner: None,
permission_rules: None,
extended_thinking: false,
};
let prompt = ctx.to_prompt_string();
assert!(prompt.contains("[System]\nBe helpful."));
assert!(prompt.contains("[User]\nHi"));
assert!(prompt.contains("[User]\nHow are you?"));
}
#[test]
fn test_to_prompt_string_with_session() {
let ctx = Context {
system_prompt: "Current time: 2026-03-06".into(),
history: vec![ContextEntry {
role: "user".into(),
content: "Hi".into(),
}],
current_message: "How are you?".into(),
mcp_servers: Vec::new(),
toolboxes: Vec::new(),
max_turns: None,
allowed_tools: None,
model: None,
session_id: Some("sess-abc".into()),
agent_name: None,
hook_runner: None,
permission_rules: None,
extended_thinking: false,
};
let prompt = ctx.to_prompt_string();
assert!(!prompt.contains("[System]"));
assert!(prompt.contains("[User]\nCurrent time: 2026-03-06\n\nHow are you?"));
}
#[test]
fn test_to_prompt_string_with_agent_name() {
let ctx = Context {
system_prompt: "You are a build analyst...".into(),
history: vec![ContextEntry {
role: "user".into(),
content: "prev".into(),
}],
current_message: "Build me a task tracker.".into(),
mcp_servers: Vec::new(),
toolboxes: Vec::new(),
max_turns: None,
allowed_tools: None,
model: None,
session_id: None,
agent_name: Some("build-analyst".into()),
hook_runner: None,
permission_rules: None,
extended_thinking: false,
};
let prompt = ctx.to_prompt_string();
assert_eq!(prompt, "Build me a task tracker.");
}
#[test]
fn test_agent_name_takes_precedence_over_session_id() {
let ctx = Context {
system_prompt: "system".into(),
history: Vec::new(),
current_message: "Build something.".into(),
mcp_servers: Vec::new(),
toolboxes: Vec::new(),
max_turns: None,
allowed_tools: None,
model: None,
session_id: Some("sess-456".into()),
agent_name: Some("build-architect".into()),
hook_runner: None,
permission_rules: None,
extended_thinking: false,
};
assert_eq!(ctx.to_prompt_string(), "Build something.");
}
#[test]
fn test_session_id_serde_round_trip() {
let ctx = Context {
system_prompt: "test".into(),
history: Vec::new(),
current_message: "hi".into(),
mcp_servers: Vec::new(),
toolboxes: Vec::new(),
max_turns: None,
allowed_tools: None,
model: None,
session_id: Some("sess-123".into()),
agent_name: None,
hook_runner: None,
permission_rules: None,
extended_thinking: false,
};
let json = serde_json::to_string(&ctx).unwrap();
let deserialized: Context = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.session_id, Some("sess-123".into()));
}
#[test]
fn test_optional_fields_skipped_in_serialization() {
let ctx = Context::new("hello");
let json = serde_json::to_string(&ctx).unwrap();
assert!(!json.contains("session_id"));
assert!(!json.contains("agent_name"));
assert!(!json.contains("max_turns"));
assert!(!json.contains("toolboxes"));
}
#[test]
fn test_toolbox_serde_round_trip() {
let tb = Toolbox {
name: "lint".into(),
description: "Run linter on a file.".into(),
parameters: serde_json::json!({
"type": "object",
"properties": {"file": {"type": "string"}},
"required": ["file"]
}),
command: "bash".into(),
args: vec!["scripts/lint.sh".into()],
env: HashMap::new(),
search_hints: Vec::new(),
};
let json = serde_json::to_string(&tb).unwrap();
let deserialized: Toolbox = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.name, "lint");
assert_eq!(deserialized.command, "bash");
assert_eq!(deserialized.args, vec!["scripts/lint.sh"]);
}
#[test]
fn test_toolbox_default_parameters() {
let json = r#"{"name":"test","description":"Test tool.","command":"echo"}"#;
let tb: Toolbox = serde_json::from_str(json).unwrap();
assert_eq!(tb.parameters, serde_json::json!({"type": "object"}));
assert!(tb.args.is_empty());
assert!(tb.env.is_empty());
}
#[test]
fn test_context_serde_with_toolboxes() {
let mut ctx = Context::new("run lint");
ctx.toolboxes.push(Toolbox {
name: "lint".into(),
description: "Lint a file.".into(),
parameters: serde_json::json!({"type": "object"}),
command: "bash".into(),
args: vec!["lint.sh".into()],
env: HashMap::new(),
search_hints: Vec::new(),
});
let json = serde_json::to_string(&ctx).unwrap();
assert!(json.contains("toolboxes"));
let deserialized: Context = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.toolboxes.len(), 1);
assert_eq!(deserialized.toolboxes[0].name, "lint");
}
}