use std::sync::Arc;
use kvlar_audit::AuditLogger;
use kvlar_audit::event::{AuditEvent, EventOutcome};
use kvlar_core::{Action, Decision, Engine};
use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::Mutex;
use crate::mcp::{self, McpMessage};
pub async fn run_proxy_loop<CR, CW, UR, UW>(
client_reader: CR,
client_writer: Arc<Mutex<CW>>,
upstream_reader: UR,
upstream_writer: Arc<Mutex<UW>>,
engine: Arc<Mutex<Engine>>,
audit: Arc<Mutex<AuditLogger>>,
_fail_open: bool,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
where
CR: AsyncBufRead + Unpin + Send + 'static,
CW: AsyncWrite + Unpin + Send + 'static,
UR: AsyncBufRead + Unpin + Send + 'static,
UW: AsyncWrite + Unpin + Send + 'static,
{
let engine_clone = engine.clone();
let audit_clone = audit.clone();
let client_writer_clone = client_writer.clone();
let client_to_upstream = tokio::spawn(async move {
if let Err(e) = proxy_client_to_upstream(
client_reader,
client_writer_clone,
upstream_writer,
engine_clone,
audit_clone,
)
.await
{
tracing::error!(error = %e, "client-to-upstream error");
}
});
let upstream_to_client = tokio::spawn(async move {
if let Err(e) = proxy_upstream_to_client(upstream_reader, client_writer).await {
tracing::error!(error = %e, "upstream-to-client error");
}
});
let _ = tokio::join!(client_to_upstream, upstream_to_client);
Ok(())
}
async fn proxy_client_to_upstream<CR, CW, UW>(
mut client_reader: CR,
client_writer: Arc<Mutex<CW>>,
upstream_writer: Arc<Mutex<UW>>,
engine: Arc<Mutex<Engine>>,
audit: Arc<Mutex<AuditLogger>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
where
CR: AsyncBufRead + Unpin,
CW: AsyncWrite + Unpin,
UW: AsyncWrite + Unpin,
{
let mut line = String::new();
loop {
line.clear();
match client_reader.read_line(&mut line).await {
Ok(0) => break, Ok(_) => {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
match McpMessage::parse(trimmed) {
Ok(msg) => {
if let Some(req) = msg.as_request()
&& let Some(tool_call) = req.extract_tool_call()
{
let mut action =
Action::new("tool_call", &tool_call.tool_name, "mcp-agent");
if let Some(obj) = tool_call.arguments.as_object() {
for (key, value) in obj {
action.parameters.insert(key.clone(), value.clone());
}
}
let eng = engine.lock().await;
let decision = eng.evaluate(&action);
drop(eng);
let (outcome, reason) = match &decision {
Decision::Allow { .. } => (EventOutcome::Allowed, None),
Decision::Deny { reason, .. } => {
(EventOutcome::Denied, Some(reason.clone()))
}
Decision::RequireApproval { reason, .. } => {
(EventOutcome::PendingApproval, Some(reason.clone()))
}
};
let matched_rule = match &decision {
Decision::Allow { matched_rule }
| Decision::Deny { matched_rule, .. }
| Decision::RequireApproval { matched_rule, .. } => {
matched_rule.clone()
}
};
let mut event = AuditEvent::new(
"tool_call",
&tool_call.tool_name,
"mcp-agent",
outcome,
&matched_rule,
);
if let Some(r) = &reason {
event = event.with_reason(r);
}
event = event.with_parameters(tool_call.arguments.clone());
let mut aud = audit.lock().await;
aud.record(event);
drop(aud);
match decision {
Decision::Allow { .. } => {
tracing::info!(
tool = %tool_call.tool_name,
rule = %matched_rule,
"ALLOW"
);
let mut writer = upstream_writer.lock().await;
let _ = writer.write_all(line.as_bytes()).await;
let _ = writer.flush().await;
}
Decision::Deny { reason, .. } => {
tracing::warn!(
tool = %tool_call.tool_name,
rule = %matched_rule,
reason = %reason,
"DENY"
);
let request_id =
req.id.clone().unwrap_or(serde_json::json!(null));
let resp = mcp::deny_response(
request_id,
&reason,
&tool_call.tool_name,
&matched_rule,
);
if let Ok(json) = serde_json::to_string(&resp) {
let mut writer = client_writer.lock().await;
let _ = writer
.write_all(format!("{}\n", json).as_bytes())
.await;
let _ = writer.flush().await;
}
}
Decision::RequireApproval { reason, .. } => {
tracing::warn!(
tool = %tool_call.tool_name,
rule = %matched_rule,
reason = %reason,
"REQUIRE_APPROVAL"
);
let request_id =
req.id.clone().unwrap_or(serde_json::json!(null));
let resp = mcp::approval_required_response(
request_id,
&reason,
&tool_call.tool_name,
&matched_rule,
);
if let Ok(json) = serde_json::to_string(&resp) {
let mut writer = client_writer.lock().await;
let _ = writer
.write_all(format!("{}\n", json).as_bytes())
.await;
let _ = writer.flush().await;
}
}
}
continue;
}
let mut writer = upstream_writer.lock().await;
let _ = writer.write_all(line.as_bytes()).await;
let _ = writer.flush().await;
}
Err(_) => {
let mut writer = upstream_writer.lock().await;
let _ = writer.write_all(line.as_bytes()).await;
let _ = writer.flush().await;
}
}
}
Err(e) => {
tracing::debug!(error = %e, "client read error");
break;
}
}
}
Ok(())
}
async fn proxy_upstream_to_client<UR, CW>(
mut upstream_reader: UR,
client_writer: Arc<Mutex<CW>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
where
UR: AsyncBufRead + Unpin,
CW: AsyncWrite + Unpin,
{
let mut line = String::new();
loop {
line.clear();
match upstream_reader.read_line(&mut line).await {
Ok(0) => break, Ok(_) => {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let mut writer = client_writer.lock().await;
let _ = writer.write_all(line.as_bytes()).await;
let _ = writer.flush().await;
}
Err(e) => {
tracing::debug!(error = %e, "upstream read error");
break;
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
use tokio::io::BufReader;
fn engine_with_default_policy() -> Engine {
let mut engine = Engine::new();
engine
.load_policy_yaml(
r#"
name: test-default
description: Test policy
version: "1"
rules:
- id: deny-shell
description: Block shell execution
match_on:
resources: ["bash", "shell"]
effect:
type: deny
reason: "Shell execution denied"
- id: approve-email
description: Require approval for email
match_on:
resources: ["send_email"]
effect:
type: require_approval
reason: "Email requires approval"
- id: allow-read
description: Allow file reads
match_on:
resources: ["read_file"]
effect:
type: allow
"#,
)
.unwrap();
engine
}
async fn run_with_buffers(
client_input: &str,
upstream_input: &str,
engine: Engine,
) -> (Vec<u8>, Vec<u8>) {
let client_reader = BufReader::new(Cursor::new(client_input.as_bytes().to_vec()));
let client_output = Arc::new(Mutex::new(Vec::<u8>::new()));
let upstream_reader = BufReader::new(Cursor::new(upstream_input.as_bytes().to_vec()));
let upstream_output = Arc::new(Mutex::new(Vec::<u8>::new()));
let audit = AuditLogger::default();
run_proxy_loop(
client_reader,
client_output.clone(),
upstream_reader,
upstream_output.clone(),
Arc::new(Mutex::new(engine)),
Arc::new(Mutex::new(audit)),
false,
)
.await
.unwrap();
let client_out = client_output.lock().await.clone();
let upstream_out = upstream_output.lock().await.clone();
(client_out, upstream_out)
}
#[tokio::test]
async fn test_allowed_tool_call_forwarded_to_upstream() {
let msg = r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/tmp/test.txt"}}}"#;
let client_input = format!("{}\n", msg);
let (client_out, upstream_out) =
run_with_buffers(&client_input, "", engine_with_default_policy()).await;
let upstream_str = String::from_utf8(upstream_out).unwrap();
assert!(
upstream_str.contains("read_file"),
"allowed request should be forwarded to upstream"
);
let client_str = String::from_utf8(client_out).unwrap();
assert!(
!client_str.contains("denied"),
"allowed request should not produce a deny response"
);
}
#[tokio::test]
async fn test_denied_tool_call_blocked() {
let msg = r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"bash","arguments":{"command":"rm -rf /"}}}"#;
let client_input = format!("{}\n", msg);
let (client_out, upstream_out) =
run_with_buffers(&client_input, "", engine_with_default_policy()).await;
let upstream_str = String::from_utf8(upstream_out).unwrap();
assert!(
!upstream_str.contains("bash"),
"denied request should not be forwarded"
);
let client_str = String::from_utf8(client_out).unwrap();
assert!(
client_str.contains("BLOCKED BY KVLAR"),
"client should get Kvlar deny response"
);
assert!(
client_str.contains("Shell execution denied"),
"deny response should contain the reason"
);
let resp: serde_json::Value = serde_json::from_str(client_str.trim()).unwrap();
assert_eq!(resp["id"], 2);
assert_eq!(resp["result"]["isError"], true);
}
#[tokio::test]
async fn test_approval_required_tool_call_blocked() {
let msg = r#"{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"send_email","arguments":{"to":"user@example.com"}}}"#;
let client_input = format!("{}\n", msg);
let (client_out, upstream_out) =
run_with_buffers(&client_input, "", engine_with_default_policy()).await;
let upstream_str = String::from_utf8(upstream_out).unwrap();
assert!(upstream_str.is_empty());
let client_str = String::from_utf8(client_out).unwrap();
assert!(client_str.contains("APPROVAL REQUIRED"));
assert!(client_str.contains("Email requires approval"));
}
#[tokio::test]
async fn test_non_tool_call_request_passthrough() {
let msg = r#"{"jsonrpc":"2.0","id":4,"method":"resources/read","params":{"uri":"file:///tmp/test.txt"}}"#;
let client_input = format!("{}\n", msg);
let (client_out, upstream_out) =
run_with_buffers(&client_input, "", engine_with_default_policy()).await;
let upstream_str = String::from_utf8(upstream_out).unwrap();
assert!(
upstream_str.contains("resources/read"),
"non-tool-call requests should pass through"
);
let client_str = String::from_utf8(client_out).unwrap();
assert!(client_str.is_empty());
}
#[tokio::test]
async fn test_upstream_response_forwarded_to_client() {
let upstream_resp =
r#"{"jsonrpc":"2.0","id":1,"result":{"content":[{"type":"text","text":"hello"}]}}"#;
let upstream_input = format!("{}\n", upstream_resp);
let (client_out, _upstream_out) =
run_with_buffers("", &upstream_input, engine_with_default_policy()).await;
let client_str = String::from_utf8(client_out).unwrap();
assert!(
client_str.contains("hello"),
"upstream response should be forwarded to client"
);
}
#[tokio::test]
async fn test_tool_args_bridged_to_action_parameters() {
let mut engine = Engine::new();
engine
.load_policy_yaml(
r#"
name: param-test
description: Test parameter bridging
version: "1"
rules:
- id: deny-dangerous-path
description: Deny access to /etc
match_on:
resources: ["read_file"]
conditions:
- field: path
operator: starts_with
value: "/etc"
effect:
type: deny
reason: "Access to /etc is denied"
- id: allow-read
description: Allow other reads
match_on:
resources: ["read_file"]
effect:
type: allow
"#,
)
.unwrap();
let msg_denied = r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/etc/passwd"}}}"#;
let (client_out, upstream_out) =
run_with_buffers(&format!("{}\n", msg_denied), "", engine).await;
let client_str = String::from_utf8(client_out).unwrap();
let upstream_str = String::from_utf8(upstream_out).unwrap();
assert!(
client_str.contains("BLOCKED BY KVLAR"),
"should deny /etc access"
);
assert!(upstream_str.is_empty(), "should not forward denied request");
let mut engine2 = Engine::new();
engine2
.load_policy_yaml(
r#"
name: param-test
description: Test parameter bridging
version: "1"
rules:
- id: deny-dangerous-path
description: Deny access to /etc
match_on:
resources: ["read_file"]
conditions:
- field: path
operator: starts_with
value: "/etc"
effect:
type: deny
reason: "Access to /etc is denied"
- id: allow-read
description: Allow other reads
match_on:
resources: ["read_file"]
effect:
type: allow
"#,
)
.unwrap();
let msg_allowed = r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/tmp/file.txt"}}}"#;
let (_client_out2, upstream_out2) =
run_with_buffers(&format!("{}\n", msg_allowed), "", engine2).await;
let upstream_str2 = String::from_utf8(upstream_out2).unwrap();
assert!(
upstream_str2.contains("read_file"),
"should forward allowed request"
);
}
#[tokio::test]
async fn test_default_deny_unmatched_tool() {
let msg = r#"{"jsonrpc":"2.0","id":5,"method":"tools/call","params":{"name":"unknown_tool","arguments":{}}}"#;
let client_input = format!("{}\n", msg);
let (client_out, upstream_out) =
run_with_buffers(&client_input, "", engine_with_default_policy()).await;
let upstream_str = String::from_utf8(upstream_out).unwrap();
assert!(
upstream_str.is_empty(),
"unmatched tool should not be forwarded"
);
let client_str = String::from_utf8(client_out).unwrap();
assert!(
client_str.contains("BLOCKED BY KVLAR"),
"unmatched tool should be denied by Kvlar"
);
}
#[tokio::test]
async fn test_audit_records_created() {
let msg = r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"bash","arguments":{"command":"ls"}}}"#;
let client_input = format!("{}\n", msg);
let client_reader = BufReader::new(Cursor::new(client_input.as_bytes().to_vec()));
let client_output = Arc::new(Mutex::new(Vec::<u8>::new()));
let upstream_reader = BufReader::new(Cursor::new(Vec::<u8>::new()));
let upstream_output = Arc::new(Mutex::new(Vec::<u8>::new()));
let audit = Arc::new(Mutex::new(AuditLogger::default()));
run_proxy_loop(
client_reader,
client_output,
upstream_reader,
upstream_output,
Arc::new(Mutex::new(engine_with_default_policy())),
audit.clone(),
false,
)
.await
.unwrap();
let aud = audit.lock().await;
let events = aud.events();
assert_eq!(events.len(), 1, "should record one audit event");
assert_eq!(events[0].resource, "bash");
assert_eq!(events[0].outcome, kvlar_audit::event::EventOutcome::Denied);
assert!(
events[0].parameters.is_some(),
"audit event should include parameters"
);
}
}