1use std::sync::Arc;
8
9use kvlar_audit::AuditLogger;
10use kvlar_audit::event::{AuditEvent, EventOutcome};
11use kvlar_core::{Action, ApprovalRequest, Decision, Engine};
12use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWrite, AsyncWriteExt};
13use tokio::sync::{Mutex, RwLock};
14
15use crate::approval::ApprovalBackend;
16use crate::mcp::{self, McpMessage};
17
18#[allow(clippy::too_many_arguments)]
29pub async fn run_proxy_loop<CR, CW, UR, UW>(
30 client_reader: CR,
31 client_writer: Arc<Mutex<CW>>,
32 upstream_reader: UR,
33 upstream_writer: Arc<Mutex<UW>>,
34 engine: Arc<RwLock<Engine>>,
35 audit: Arc<Mutex<AuditLogger>>,
36 _fail_open: bool,
37 approval_backend: Option<Arc<dyn ApprovalBackend>>,
38) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
39where
40 CR: AsyncBufRead + Unpin + Send + 'static,
41 CW: AsyncWrite + Unpin + Send + 'static,
42 UR: AsyncBufRead + Unpin + Send + 'static,
43 UW: AsyncWrite + Unpin + Send + 'static,
44{
45 let engine_clone = engine.clone();
47 let audit_clone = audit.clone();
48 let client_writer_clone = client_writer.clone();
49 let client_to_upstream = tokio::spawn(async move {
50 if let Err(e) = proxy_client_to_upstream(
51 client_reader,
52 client_writer_clone,
53 upstream_writer,
54 engine_clone,
55 audit_clone,
56 approval_backend,
57 )
58 .await
59 {
60 tracing::error!(error = %e, "client-to-upstream error");
61 }
62 });
63
64 let upstream_to_client = tokio::spawn(async move {
66 if let Err(e) = proxy_upstream_to_client(upstream_reader, client_writer).await {
67 tracing::error!(error = %e, "upstream-to-client error");
68 }
69 });
70
71 let _ = tokio::join!(client_to_upstream, upstream_to_client);
72 Ok(())
73}
74
75#[allow(clippy::too_many_arguments)]
77async fn proxy_client_to_upstream<CR, CW, UW>(
78 mut client_reader: CR,
79 client_writer: Arc<Mutex<CW>>,
80 upstream_writer: Arc<Mutex<UW>>,
81 engine: Arc<RwLock<Engine>>,
82 audit: Arc<Mutex<AuditLogger>>,
83 approval_backend: Option<Arc<dyn ApprovalBackend>>,
84) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
85where
86 CR: AsyncBufRead + Unpin,
87 CW: AsyncWrite + Unpin,
88 UW: AsyncWrite + Unpin,
89{
90 let mut line = String::new();
91 loop {
92 line.clear();
93 match client_reader.read_line(&mut line).await {
94 Ok(0) => break, Ok(_) => {
96 let trimmed = line.trim();
97 if trimmed.is_empty() {
98 continue;
99 }
100
101 match McpMessage::parse(trimmed) {
102 Ok(msg) => {
103 if let Some(req) = msg.as_request()
104 && let Some(tool_call) = req.extract_tool_call()
105 {
106 let mut action =
108 Action::new("tool_call", &tool_call.tool_name, "mcp-agent");
109 if let Some(obj) = tool_call.arguments.as_object() {
110 for (key, value) in obj {
111 action.parameters.insert(key.clone(), value.clone());
112 }
113 }
114
115 let eng = engine.read().await;
117 let decision = eng.evaluate(&action);
118 drop(eng);
119
120 let (outcome, reason) = match &decision {
122 Decision::Allow { .. } => (EventOutcome::Allowed, None),
123 Decision::Deny { reason, .. } => {
124 (EventOutcome::Denied, Some(reason.clone()))
125 }
126 Decision::RequireApproval { reason, .. } => {
127 (EventOutcome::PendingApproval, Some(reason.clone()))
128 }
129 };
130
131 let matched_rule = match &decision {
132 Decision::Allow { matched_rule }
133 | Decision::Deny { matched_rule, .. }
134 | Decision::RequireApproval { matched_rule, .. } => {
135 matched_rule.clone()
136 }
137 };
138
139 let mut event = AuditEvent::new(
140 "tool_call",
141 &tool_call.tool_name,
142 "mcp-agent",
143 outcome,
144 &matched_rule,
145 );
146 if let Some(r) = &reason {
147 event = event.with_reason(r);
148 }
149 event = event.with_parameters(tool_call.arguments.clone());
150 let mut aud = audit.lock().await;
151 aud.record(event);
152 drop(aud);
153
154 match decision {
156 Decision::Allow { .. } => {
157 tracing::info!(
158 tool = %tool_call.tool_name,
159 rule = %matched_rule,
160 "ALLOW"
161 );
162 let mut writer = upstream_writer.lock().await;
163 let _ = writer.write_all(line.as_bytes()).await;
164 let _ = writer.flush().await;
165 }
166 Decision::Deny { reason, .. } => {
167 tracing::warn!(
168 tool = %tool_call.tool_name,
169 rule = %matched_rule,
170 reason = %reason,
171 "DENY"
172 );
173 let request_id =
174 req.id.clone().unwrap_or(serde_json::json!(null));
175 let resp = mcp::deny_response(
176 request_id,
177 &reason,
178 &tool_call.tool_name,
179 &matched_rule,
180 );
181 if let Ok(json) = serde_json::to_string(&resp) {
182 let mut writer = client_writer.lock().await;
183 let _ = writer
184 .write_all(format!("{}\n", json).as_bytes())
185 .await;
186 let _ = writer.flush().await;
187 }
188 }
189 Decision::RequireApproval { reason, .. } => {
190 tracing::warn!(
191 tool = %tool_call.tool_name,
192 rule = %matched_rule,
193 reason = %reason,
194 "REQUIRE_APPROVAL"
195 );
196 let request_id =
197 req.id.clone().unwrap_or(serde_json::json!(null));
198
199 if let Some(ref backend) = approval_backend {
201 let approval_req = ApprovalRequest::new(
202 &tool_call.tool_name,
203 tool_call.arguments.clone(),
204 &matched_rule,
205 &reason,
206 "mcp-agent",
207 );
208
209 match backend.request_approval(&approval_req).await {
210 Ok(kvlar_core::ApprovalResponse::Approved) => {
211 tracing::info!(
212 tool = %tool_call.tool_name,
213 rule = %matched_rule,
214 "APPROVED (via webhook)"
215 );
216 let mut writer = upstream_writer.lock().await;
217 let _ = writer.write_all(line.as_bytes()).await;
218 let _ = writer.flush().await;
219 continue;
220 }
221 Ok(kvlar_core::ApprovalResponse::Denied {
222 reason: deny_reason,
223 }) => {
224 let final_reason =
225 deny_reason.unwrap_or_else(|| {
226 "denied by human reviewer".into()
227 });
228 tracing::warn!(
229 tool = %tool_call.tool_name,
230 reason = %final_reason,
231 "DENIED (by human reviewer)"
232 );
233 let resp = mcp::deny_response(
234 request_id,
235 &final_reason,
236 &tool_call.tool_name,
237 &matched_rule,
238 );
239 if let Ok(json) = serde_json::to_string(&resp) {
240 let mut writer = client_writer.lock().await;
241 let _ = writer
242 .write_all(format!("{}\n", json).as_bytes())
243 .await;
244 let _ = writer.flush().await;
245 }
246 continue;
247 }
248 Err(e) => {
249 tracing::error!(
250 tool = %tool_call.tool_name,
251 error = %e,
252 "approval backend error, denying"
253 );
254 }
256 }
257 }
258
259 let resp = mcp::approval_required_response(
261 request_id,
262 &reason,
263 &tool_call.tool_name,
264 &matched_rule,
265 );
266 if let Ok(json) = serde_json::to_string(&resp) {
267 let mut writer = client_writer.lock().await;
268 let _ = writer
269 .write_all(format!("{}\n", json).as_bytes())
270 .await;
271 let _ = writer.flush().await;
272 }
273 }
274 }
275 continue;
276 }
277 let mut writer = upstream_writer.lock().await;
279 let _ = writer.write_all(line.as_bytes()).await;
280 let _ = writer.flush().await;
281 }
282 Err(e) => {
283 tracing::warn!(error = %e, "malformed JSON-RPC message from client");
285 let resp = mcp::parse_error_response(&e.to_string());
286 if let Ok(json) = serde_json::to_string(&resp) {
287 let mut writer = client_writer.lock().await;
288 let _ = writer.write_all(format!("{}\n", json).as_bytes()).await;
289 let _ = writer.flush().await;
290 }
291 }
292 }
293 }
294 Err(e) => {
295 tracing::debug!(error = %e, "client read error");
296 break;
297 }
298 }
299 }
300 Ok(())
301}
302
303async fn proxy_upstream_to_client<UR, CW>(
308 mut upstream_reader: UR,
309 client_writer: Arc<Mutex<CW>>,
310) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
311where
312 UR: AsyncBufRead + Unpin,
313 CW: AsyncWrite + Unpin,
314{
315 let mut line = String::new();
316 loop {
317 line.clear();
318 match upstream_reader.read_line(&mut line).await {
319 Ok(0) => {
320 tracing::warn!("upstream server disconnected (EOF)");
321 break;
322 }
323 Ok(_) => {
324 let trimmed = line.trim();
325 if trimmed.is_empty() {
326 continue;
327 }
328 let mut writer = client_writer.lock().await;
329 let _ = writer.write_all(line.as_bytes()).await;
330 let _ = writer.flush().await;
331 }
332 Err(e) => {
333 tracing::error!(error = %e, "upstream read error — connection may be broken");
334 break;
335 }
336 }
337 }
338 Ok(())
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344 use std::io::Cursor;
345 use tokio::io::BufReader;
346
347 fn engine_with_default_policy() -> Engine {
349 let mut engine = Engine::new();
350 engine
351 .load_policy_yaml(
352 r#"
353name: test-default
354description: Test policy
355version: "1"
356rules:
357 - id: deny-shell
358 description: Block shell execution
359 match_on:
360 resources: ["bash", "shell"]
361 effect:
362 type: deny
363 reason: "Shell execution denied"
364 - id: approve-email
365 description: Require approval for email
366 match_on:
367 resources: ["send_email"]
368 effect:
369 type: require_approval
370 reason: "Email requires approval"
371 - id: allow-read
372 description: Allow file reads
373 match_on:
374 resources: ["read_file"]
375 effect:
376 type: allow
377"#,
378 )
379 .unwrap();
380 engine
381 }
382
383 async fn run_with_buffers(
385 client_input: &str,
386 upstream_input: &str,
387 engine: Engine,
388 ) -> (Vec<u8>, Vec<u8>) {
389 let client_reader = BufReader::new(Cursor::new(client_input.as_bytes().to_vec()));
390 let client_output = Arc::new(Mutex::new(Vec::<u8>::new()));
391 let upstream_reader = BufReader::new(Cursor::new(upstream_input.as_bytes().to_vec()));
392 let upstream_output = Arc::new(Mutex::new(Vec::<u8>::new()));
393 let audit = AuditLogger::default();
394
395 run_proxy_loop(
396 client_reader,
397 client_output.clone(),
398 upstream_reader,
399 upstream_output.clone(),
400 Arc::new(RwLock::new(engine)),
401 Arc::new(Mutex::new(audit)),
402 false,
403 None,
404 )
405 .await
406 .unwrap();
407
408 let client_out = client_output.lock().await.clone();
409 let upstream_out = upstream_output.lock().await.clone();
410 (client_out, upstream_out)
411 }
412
413 #[tokio::test]
414 async fn test_allowed_tool_call_forwarded_to_upstream() {
415 let msg = r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/tmp/test.txt"}}}"#;
416 let client_input = format!("{}\n", msg);
417
418 let (client_out, upstream_out) =
419 run_with_buffers(&client_input, "", engine_with_default_policy()).await;
420
421 let upstream_str = String::from_utf8(upstream_out).unwrap();
423 assert!(
424 upstream_str.contains("read_file"),
425 "allowed request should be forwarded to upstream"
426 );
427
428 let client_str = String::from_utf8(client_out).unwrap();
430 assert!(
431 !client_str.contains("denied"),
432 "allowed request should not produce a deny response"
433 );
434 }
435
436 #[tokio::test]
437 async fn test_denied_tool_call_blocked() {
438 let msg = r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"bash","arguments":{"command":"rm -rf /"}}}"#;
439 let client_input = format!("{}\n", msg);
440
441 let (client_out, upstream_out) =
442 run_with_buffers(&client_input, "", engine_with_default_policy()).await;
443
444 let upstream_str = String::from_utf8(upstream_out).unwrap();
446 assert!(
447 !upstream_str.contains("bash"),
448 "denied request should not be forwarded"
449 );
450
451 let client_str = String::from_utf8(client_out).unwrap();
453 assert!(
454 client_str.contains("BLOCKED BY KVLAR"),
455 "client should get Kvlar deny response"
456 );
457 assert!(
458 client_str.contains("Shell execution denied"),
459 "deny response should contain the reason"
460 );
461
462 let resp: serde_json::Value = serde_json::from_str(client_str.trim()).unwrap();
464 assert_eq!(resp["id"], 2);
465 assert_eq!(resp["result"]["isError"], true);
466 }
467
468 #[tokio::test]
469 async fn test_approval_required_tool_call_blocked() {
470 let msg = r#"{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"send_email","arguments":{"to":"user@example.com"}}}"#;
471 let client_input = format!("{}\n", msg);
472
473 let (client_out, upstream_out) =
474 run_with_buffers(&client_input, "", engine_with_default_policy()).await;
475
476 let upstream_str = String::from_utf8(upstream_out).unwrap();
478 assert!(upstream_str.is_empty());
479
480 let client_str = String::from_utf8(client_out).unwrap();
482 assert!(client_str.contains("APPROVAL REQUIRED"));
483 assert!(client_str.contains("Email requires approval"));
484 }
485
486 #[tokio::test]
487 async fn test_non_tool_call_request_passthrough() {
488 let msg = r#"{"jsonrpc":"2.0","id":4,"method":"resources/read","params":{"uri":"file:///tmp/test.txt"}}"#;
489 let client_input = format!("{}\n", msg);
490
491 let (client_out, upstream_out) =
492 run_with_buffers(&client_input, "", engine_with_default_policy()).await;
493
494 let upstream_str = String::from_utf8(upstream_out).unwrap();
496 assert!(
497 upstream_str.contains("resources/read"),
498 "non-tool-call requests should pass through"
499 );
500
501 let client_str = String::from_utf8(client_out).unwrap();
503 assert!(client_str.is_empty());
504 }
505
506 #[tokio::test]
507 async fn test_upstream_response_forwarded_to_client() {
508 let upstream_resp =
509 r#"{"jsonrpc":"2.0","id":1,"result":{"content":[{"type":"text","text":"hello"}]}}"#;
510 let upstream_input = format!("{}\n", upstream_resp);
511
512 let (client_out, _upstream_out) =
513 run_with_buffers("", &upstream_input, engine_with_default_policy()).await;
514
515 let client_str = String::from_utf8(client_out).unwrap();
517 assert!(
518 client_str.contains("hello"),
519 "upstream response should be forwarded to client"
520 );
521 }
522
523 #[tokio::test]
524 async fn test_tool_args_bridged_to_action_parameters() {
525 let mut engine = Engine::new();
527 engine
528 .load_policy_yaml(
529 r#"
530name: param-test
531description: Test parameter bridging
532version: "1"
533rules:
534 - id: deny-dangerous-path
535 description: Deny access to /etc
536 match_on:
537 resources: ["read_file"]
538 conditions:
539 - field: path
540 operator: starts_with
541 value: "/etc"
542 effect:
543 type: deny
544 reason: "Access to /etc is denied"
545 - id: allow-read
546 description: Allow other reads
547 match_on:
548 resources: ["read_file"]
549 effect:
550 type: allow
551"#,
552 )
553 .unwrap();
554
555 let msg_denied = r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/etc/passwd"}}}"#;
557 let (client_out, upstream_out) =
558 run_with_buffers(&format!("{}\n", msg_denied), "", engine).await;
559 let client_str = String::from_utf8(client_out).unwrap();
560 let upstream_str = String::from_utf8(upstream_out).unwrap();
561 assert!(
562 client_str.contains("BLOCKED BY KVLAR"),
563 "should deny /etc access"
564 );
565 assert!(upstream_str.is_empty(), "should not forward denied request");
566
567 let mut engine2 = Engine::new();
569 engine2
570 .load_policy_yaml(
571 r#"
572name: param-test
573description: Test parameter bridging
574version: "1"
575rules:
576 - id: deny-dangerous-path
577 description: Deny access to /etc
578 match_on:
579 resources: ["read_file"]
580 conditions:
581 - field: path
582 operator: starts_with
583 value: "/etc"
584 effect:
585 type: deny
586 reason: "Access to /etc is denied"
587 - id: allow-read
588 description: Allow other reads
589 match_on:
590 resources: ["read_file"]
591 effect:
592 type: allow
593"#,
594 )
595 .unwrap();
596
597 let msg_allowed = r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/tmp/file.txt"}}}"#;
598 let (_client_out2, upstream_out2) =
599 run_with_buffers(&format!("{}\n", msg_allowed), "", engine2).await;
600 let upstream_str2 = String::from_utf8(upstream_out2).unwrap();
601 assert!(
602 upstream_str2.contains("read_file"),
603 "should forward allowed request"
604 );
605 }
606
607 #[tokio::test]
608 async fn test_default_deny_unmatched_tool() {
609 let msg = r#"{"jsonrpc":"2.0","id":5,"method":"tools/call","params":{"name":"unknown_tool","arguments":{}}}"#;
610 let client_input = format!("{}\n", msg);
611
612 let (client_out, upstream_out) =
613 run_with_buffers(&client_input, "", engine_with_default_policy()).await;
614
615 let upstream_str = String::from_utf8(upstream_out).unwrap();
617 assert!(
618 upstream_str.is_empty(),
619 "unmatched tool should not be forwarded"
620 );
621
622 let client_str = String::from_utf8(client_out).unwrap();
623 assert!(
624 client_str.contains("BLOCKED BY KVLAR"),
625 "unmatched tool should be denied by Kvlar"
626 );
627 }
628
629 #[tokio::test]
630 async fn test_audit_records_created() {
631 let msg = r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"bash","arguments":{"command":"ls"}}}"#;
632 let client_input = format!("{}\n", msg);
633
634 let client_reader = BufReader::new(Cursor::new(client_input.as_bytes().to_vec()));
635 let client_output = Arc::new(Mutex::new(Vec::<u8>::new()));
636 let upstream_reader = BufReader::new(Cursor::new(Vec::<u8>::new()));
637 let upstream_output = Arc::new(Mutex::new(Vec::<u8>::new()));
638 let audit = Arc::new(Mutex::new(AuditLogger::default()));
639
640 run_proxy_loop(
641 client_reader,
642 client_output,
643 upstream_reader,
644 upstream_output,
645 Arc::new(RwLock::new(engine_with_default_policy())),
646 audit.clone(),
647 false,
648 None,
649 )
650 .await
651 .unwrap();
652
653 let aud = audit.lock().await;
654 let events = aud.events();
655 assert_eq!(events.len(), 1, "should record one audit event");
656 assert_eq!(events[0].resource, "bash");
657 assert_eq!(events[0].outcome, kvlar_audit::event::EventOutcome::Denied);
658 assert!(
659 events[0].parameters.is_some(),
660 "audit event should include parameters"
661 );
662 }
663}