use context_compressor::{
estimate_messages_tokens_rough, mask_token, redact_sensitive_text, ContextCompressor,
Message,
};
use serde_json::json;
#[test]
fn test_mask_token() {
assert_eq!(mask_token(""), "***");
assert_eq!(mask_token("short"), "***");
assert_eq!(mask_token("longtoken123456789"), "longto...6789");
}
#[test]
fn test_redact_sensitive_text() {
let bare_text = "my openai key is sk-proj-1234567890abcdefghij";
assert_eq!(
redact_sensitive_text(bare_text, true),
"my openai key is sk-pro...ghij"
);
let env_str = "OPENAI_API_KEY=sk-proj12345678901234567890";
assert_eq!(
redact_sensitive_text(env_str, true),
"OPENAI_API_KEY=***"
);
let auth = "Authorization: Bearer mytoken12345678901234567";
assert_eq!(
redact_sensitive_text(auth, true),
"Authorization: Bearer mytoke...4567"
);
let tg = "bot12345678:ABC-DEF-GHI-JKL-MNO-PQR-STU-VWX-YZ";
assert_eq!(redact_sensitive_text(tg, true), "bot12345678:***");
let db = "postgres://user:password123@host:5432/db";
assert_eq!(
redact_sensitive_text(db, true),
"postgres://user:***@host:5432/db"
);
let discord = "<@!123456789012345678>";
assert_eq!(redact_sensitive_text(discord, true), "<@!***>");
let phone = "+12345678901";
assert_eq!(redact_sensitive_text(phone, true), "+123****8901");
let env_str_double = "OPENAI_API_KEY=\"sk-proj12345678901234567890\"";
assert_eq!(
redact_sensitive_text(env_str_double, true),
"OPENAI_API_KEY=\"***\""
);
let env_str_single = "OPENAI_API_KEY='sk-proj12345678901234567890'";
assert_eq!(
redact_sensitive_text(env_str_single, true),
"OPENAI_API_KEY='***'"
);
}
#[tokio::test]
async fn test_context_compressor_flow() {
let mock_callback = |prompt: String| async move {
assert!(prompt.contains("## Active Task"));
Ok("## Active Task\nUser asked: 'Build the Rust library'\n\n## Goal\nComplete the project".to_string())
};
let mut compressor = ContextCompressor::new(
8000,
mock_callback,
Some(0.50), Some(0), Some(2), Some(0.20), Some(false),
);
let messages = vec![
Message {
role: "system".to_string(),
content: json!("System prompt content"),
tool_calls: None,
tool_call_id: None,
},
Message {
role: "user".to_string(),
content: json!("Hello, please help me with task A"),
tool_calls: None,
tool_call_id: None,
},
Message {
role: "assistant".to_string(),
content: json!(""),
tool_calls: Some(json!([
{
"id": "call_123",
"type": "function",
"function": {
"name": "read_file",
"arguments": "{\"path\": \"src/main.rs\"}"
}
}
])),
tool_call_id: None,
},
Message {
role: "tool".to_string(),
content: json!("fn main() {\n println!(\"hello\");\n}"),
tool_calls: None,
tool_call_id: Some("call_123".to_string()),
},
Message {
role: "assistant".to_string(),
content: json!("I have read the file. It prints hello."),
tool_calls: None,
tool_call_id: None,
},
Message {
role: "user".to_string(),
content: json!("Now write a test for it."),
tool_calls: None,
tool_call_id: None,
},
];
let est = estimate_messages_tokens_rough(&messages);
assert!(est > 0);
let compressed = compressor.compress(&messages, Some(5000), None, true).await;
assert_eq!(compressed[0].role, "system");
let has_summary = compressed.iter().any(|msg| {
if let serde_json::Value::String(s) = &msg.content {
s.contains("[CONTEXT COMPACTION")
} else {
false
}
});
assert!(has_summary);
assert_eq!(
compressed.last().unwrap().content,
json!("Now write a test for it.")
);
}
#[test]
fn test_backtracking_resilience() {
let repeat_count = 10000;
let malicious_input = format!(
"Authorization: Bearer {}",
"a".repeat(repeat_count)
);
let start = std::time::Instant::now();
let redacted = redact_sensitive_text(&malicious_input, true);
let duration = start.elapsed();
assert!(duration < std::time::Duration::from_millis(200));
assert_eq!(redacted, "Authorization: Bearer aaaaaa...aaaa");
}