Skip to main content

kvlar_proxy/
handler.rs

1//! Transport-agnostic proxy message handler.
2//!
3//! Contains the core bidirectional proxy loop that reads MCP messages,
4//! evaluates tool calls against the policy engine, and forwards or blocks
5//! them. This module is used by both TCP and stdio transports.
6
7use std::sync::Arc;
8
9use kvlar_audit::AuditLogger;
10use kvlar_audit::event::{AuditEvent, EventOutcome};
11use kvlar_core::{Action, ApprovalRequest, Decision, Engine};
12use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWrite, AsyncWriteExt};
13use tokio::sync::{Mutex, RwLock};
14
15use crate::approval::ApprovalBackend;
16use crate::mcp::{self, McpMessage};
17
18/// Runs the bidirectional proxy loop.
19///
20/// Reads MCP JSON-RPC messages from `client_reader`, evaluates tool calls
21/// against the policy engine, forwards allowed messages to `upstream_writer`,
22/// and sends deny/approval responses back through `client_writer`. Server
23/// responses from `upstream_reader` are forwarded back to `client_writer`.
24///
25/// If `approval_backend` is provided, `RequireApproval` decisions will be
26/// sent to the backend for human review. If not provided, they are denied
27/// by default (fail-closed).
28#[allow(clippy::too_many_arguments)]
29pub async fn run_proxy_loop<CR, CW, UR, UW>(
30    client_reader: CR,
31    client_writer: Arc<Mutex<CW>>,
32    upstream_reader: UR,
33    upstream_writer: Arc<Mutex<UW>>,
34    engine: Arc<RwLock<Engine>>,
35    audit: Arc<Mutex<AuditLogger>>,
36    _fail_open: bool,
37    approval_backend: Option<Arc<dyn ApprovalBackend>>,
38) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
39where
40    CR: AsyncBufRead + Unpin + Send + 'static,
41    CW: AsyncWrite + Unpin + Send + 'static,
42    UR: AsyncBufRead + Unpin + Send + 'static,
43    UW: AsyncWrite + Unpin + Send + 'static,
44{
45    // Client → Upstream (with policy enforcement)
46    let engine_clone = engine.clone();
47    let audit_clone = audit.clone();
48    let client_writer_clone = client_writer.clone();
49    let client_to_upstream = tokio::spawn(async move {
50        if let Err(e) = proxy_client_to_upstream(
51            client_reader,
52            client_writer_clone,
53            upstream_writer,
54            engine_clone,
55            audit_clone,
56            approval_backend,
57        )
58        .await
59        {
60            tracing::error!(error = %e, "client-to-upstream error");
61        }
62    });
63
64    // Upstream → Client (pass-through)
65    let upstream_to_client = tokio::spawn(async move {
66        if let Err(e) = proxy_upstream_to_client(upstream_reader, client_writer).await {
67            tracing::error!(error = %e, "upstream-to-client error");
68        }
69    });
70
71    let _ = tokio::join!(client_to_upstream, upstream_to_client);
72    Ok(())
73}
74
75/// Reads messages from the client, evaluates tool calls, and forwards or denies.
76#[allow(clippy::too_many_arguments)]
77async fn proxy_client_to_upstream<CR, CW, UW>(
78    mut client_reader: CR,
79    client_writer: Arc<Mutex<CW>>,
80    upstream_writer: Arc<Mutex<UW>>,
81    engine: Arc<RwLock<Engine>>,
82    audit: Arc<Mutex<AuditLogger>>,
83    approval_backend: Option<Arc<dyn ApprovalBackend>>,
84) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
85where
86    CR: AsyncBufRead + Unpin,
87    CW: AsyncWrite + Unpin,
88    UW: AsyncWrite + Unpin,
89{
90    let mut line = String::new();
91    loop {
92        line.clear();
93        match client_reader.read_line(&mut line).await {
94            Ok(0) => break, // EOF
95            Ok(_) => {
96                let trimmed = line.trim();
97                if trimmed.is_empty() {
98                    continue;
99                }
100
101                match McpMessage::parse(trimmed) {
102                    Ok(msg) => {
103                        if let Some(req) = msg.as_request()
104                            && let Some(tool_call) = req.extract_tool_call()
105                        {
106                            // Build action with tool arguments bridged to parameters
107                            let mut action =
108                                Action::new("tool_call", &tool_call.tool_name, "mcp-agent");
109                            if let Some(obj) = tool_call.arguments.as_object() {
110                                for (key, value) in obj {
111                                    action.parameters.insert(key.clone(), value.clone());
112                                }
113                            }
114
115                            // Evaluate against policy
116                            let eng = engine.read().await;
117                            let decision = eng.evaluate(&action);
118                            drop(eng);
119
120                            // Record audit event
121                            let (outcome, reason) = match &decision {
122                                Decision::Allow { .. } => (EventOutcome::Allowed, None),
123                                Decision::Deny { reason, .. } => {
124                                    (EventOutcome::Denied, Some(reason.clone()))
125                                }
126                                Decision::RequireApproval { reason, .. } => {
127                                    (EventOutcome::PendingApproval, Some(reason.clone()))
128                                }
129                            };
130
131                            let matched_rule = match &decision {
132                                Decision::Allow { matched_rule }
133                                | Decision::Deny { matched_rule, .. }
134                                | Decision::RequireApproval { matched_rule, .. } => {
135                                    matched_rule.clone()
136                                }
137                            };
138
139                            let mut event = AuditEvent::new(
140                                "tool_call",
141                                &tool_call.tool_name,
142                                "mcp-agent",
143                                outcome,
144                                &matched_rule,
145                            );
146                            if let Some(r) = &reason {
147                                event = event.with_reason(r);
148                            }
149                            event = event.with_parameters(tool_call.arguments.clone());
150                            let mut aud = audit.lock().await;
151                            aud.record(event);
152                            drop(aud);
153
154                            // Route based on decision
155                            match decision {
156                                Decision::Allow { .. } => {
157                                    tracing::info!(
158                                        tool = %tool_call.tool_name,
159                                        rule = %matched_rule,
160                                        "ALLOW"
161                                    );
162                                    let mut writer = upstream_writer.lock().await;
163                                    let _ = writer.write_all(line.as_bytes()).await;
164                                    let _ = writer.flush().await;
165                                }
166                                Decision::Deny { reason, .. } => {
167                                    tracing::warn!(
168                                        tool = %tool_call.tool_name,
169                                        rule = %matched_rule,
170                                        reason = %reason,
171                                        "DENY"
172                                    );
173                                    let request_id =
174                                        req.id.clone().unwrap_or(serde_json::json!(null));
175                                    let resp = mcp::deny_response(
176                                        request_id,
177                                        &reason,
178                                        &tool_call.tool_name,
179                                        &matched_rule,
180                                    );
181                                    if let Ok(json) = serde_json::to_string(&resp) {
182                                        let mut writer = client_writer.lock().await;
183                                        let _ = writer
184                                            .write_all(format!("{}\n", json).as_bytes())
185                                            .await;
186                                        let _ = writer.flush().await;
187                                    }
188                                }
189                                Decision::RequireApproval { reason, .. } => {
190                                    tracing::warn!(
191                                        tool = %tool_call.tool_name,
192                                        rule = %matched_rule,
193                                        reason = %reason,
194                                        "REQUIRE_APPROVAL"
195                                    );
196                                    let request_id =
197                                        req.id.clone().unwrap_or(serde_json::json!(null));
198
199                                    // If an approval backend is configured, request approval
200                                    if let Some(ref backend) = approval_backend {
201                                        let approval_req = ApprovalRequest::new(
202                                            &tool_call.tool_name,
203                                            tool_call.arguments.clone(),
204                                            &matched_rule,
205                                            &reason,
206                                            "mcp-agent",
207                                        );
208
209                                        match backend.request_approval(&approval_req).await {
210                                            Ok(kvlar_core::ApprovalResponse::Approved) => {
211                                                tracing::info!(
212                                                    tool = %tool_call.tool_name,
213                                                    rule = %matched_rule,
214                                                    "APPROVED (via webhook)"
215                                                );
216                                                let mut writer = upstream_writer.lock().await;
217                                                let _ = writer.write_all(line.as_bytes()).await;
218                                                let _ = writer.flush().await;
219                                                continue;
220                                            }
221                                            Ok(kvlar_core::ApprovalResponse::Denied {
222                                                reason: deny_reason,
223                                            }) => {
224                                                let final_reason =
225                                                    deny_reason.unwrap_or_else(|| {
226                                                        "denied by human reviewer".into()
227                                                    });
228                                                tracing::warn!(
229                                                    tool = %tool_call.tool_name,
230                                                    reason = %final_reason,
231                                                    "DENIED (by human reviewer)"
232                                                );
233                                                let resp = mcp::deny_response(
234                                                    request_id,
235                                                    &final_reason,
236                                                    &tool_call.tool_name,
237                                                    &matched_rule,
238                                                );
239                                                if let Ok(json) = serde_json::to_string(&resp) {
240                                                    let mut writer = client_writer.lock().await;
241                                                    let _ = writer
242                                                        .write_all(format!("{}\n", json).as_bytes())
243                                                        .await;
244                                                    let _ = writer.flush().await;
245                                                }
246                                                continue;
247                                            }
248                                            Err(e) => {
249                                                tracing::error!(
250                                                    tool = %tool_call.tool_name,
251                                                    error = %e,
252                                                    "approval backend error, denying"
253                                                );
254                                                // Fall through to default behavior below
255                                            }
256                                        }
257                                    }
258
259                                    // No approval backend or backend error — send approval-required response
260                                    let resp = mcp::approval_required_response(
261                                        request_id,
262                                        &reason,
263                                        &tool_call.tool_name,
264                                        &matched_rule,
265                                    );
266                                    if let Ok(json) = serde_json::to_string(&resp) {
267                                        let mut writer = client_writer.lock().await;
268                                        let _ = writer
269                                            .write_all(format!("{}\n", json).as_bytes())
270                                            .await;
271                                        let _ = writer.flush().await;
272                                    }
273                                }
274                            }
275                            continue;
276                        }
277                        // Non-tool-call requests: pass through
278                        let mut writer = upstream_writer.lock().await;
279                        let _ = writer.write_all(line.as_bytes()).await;
280                        let _ = writer.flush().await;
281                    }
282                    Err(e) => {
283                        // Malformed JSON-RPC — send parse error back to client
284                        tracing::warn!(error = %e, "malformed JSON-RPC message from client");
285                        let resp = mcp::parse_error_response(&e.to_string());
286                        if let Ok(json) = serde_json::to_string(&resp) {
287                            let mut writer = client_writer.lock().await;
288                            let _ = writer.write_all(format!("{}\n", json).as_bytes()).await;
289                            let _ = writer.flush().await;
290                        }
291                    }
292                }
293            }
294            Err(e) => {
295                tracing::debug!(error = %e, "client read error");
296                break;
297            }
298        }
299    }
300    Ok(())
301}
302
303/// Forwards all messages from upstream back to the client.
304///
305/// If the upstream disconnects (EOF) or errors, logs the event
306/// and exits gracefully without crashing the proxy.
307async fn proxy_upstream_to_client<UR, CW>(
308    mut upstream_reader: UR,
309    client_writer: Arc<Mutex<CW>>,
310) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
311where
312    UR: AsyncBufRead + Unpin,
313    CW: AsyncWrite + Unpin,
314{
315    let mut line = String::new();
316    loop {
317        line.clear();
318        match upstream_reader.read_line(&mut line).await {
319            Ok(0) => {
320                tracing::warn!("upstream server disconnected (EOF)");
321                break;
322            }
323            Ok(_) => {
324                let trimmed = line.trim();
325                if trimmed.is_empty() {
326                    continue;
327                }
328                let mut writer = client_writer.lock().await;
329                let _ = writer.write_all(line.as_bytes()).await;
330                let _ = writer.flush().await;
331            }
332            Err(e) => {
333                tracing::error!(error = %e, "upstream read error — connection may be broken");
334                break;
335            }
336        }
337    }
338    Ok(())
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344    use std::io::Cursor;
345    use tokio::io::BufReader;
346
347    /// Helper: create engine loaded with the default policy inline.
348    fn engine_with_default_policy() -> Engine {
349        let mut engine = Engine::new();
350        engine
351            .load_policy_yaml(
352                r#"
353name: test-default
354description: Test policy
355version: "1"
356rules:
357  - id: deny-shell
358    description: Block shell execution
359    match_on:
360      resources: ["bash", "shell"]
361    effect:
362      type: deny
363      reason: "Shell execution denied"
364  - id: approve-email
365    description: Require approval for email
366    match_on:
367      resources: ["send_email"]
368    effect:
369      type: require_approval
370      reason: "Email requires approval"
371  - id: allow-read
372    description: Allow file reads
373    match_on:
374      resources: ["read_file"]
375    effect:
376      type: allow
377"#,
378            )
379            .unwrap();
380        engine
381    }
382
383    /// Helper: run the proxy loop with in-memory buffers and return outputs.
384    async fn run_with_buffers(
385        client_input: &str,
386        upstream_input: &str,
387        engine: Engine,
388    ) -> (Vec<u8>, Vec<u8>) {
389        let client_reader = BufReader::new(Cursor::new(client_input.as_bytes().to_vec()));
390        let client_output = Arc::new(Mutex::new(Vec::<u8>::new()));
391        let upstream_reader = BufReader::new(Cursor::new(upstream_input.as_bytes().to_vec()));
392        let upstream_output = Arc::new(Mutex::new(Vec::<u8>::new()));
393        let audit = AuditLogger::default();
394
395        run_proxy_loop(
396            client_reader,
397            client_output.clone(),
398            upstream_reader,
399            upstream_output.clone(),
400            Arc::new(RwLock::new(engine)),
401            Arc::new(Mutex::new(audit)),
402            false,
403            None,
404        )
405        .await
406        .unwrap();
407
408        let client_out = client_output.lock().await.clone();
409        let upstream_out = upstream_output.lock().await.clone();
410        (client_out, upstream_out)
411    }
412
413    #[tokio::test]
414    async fn test_allowed_tool_call_forwarded_to_upstream() {
415        let msg = r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/tmp/test.txt"}}}"#;
416        let client_input = format!("{}\n", msg);
417
418        let (client_out, upstream_out) =
419            run_with_buffers(&client_input, "", engine_with_default_policy()).await;
420
421        // Allowed tool call should be forwarded to upstream
422        let upstream_str = String::from_utf8(upstream_out).unwrap();
423        assert!(
424            upstream_str.contains("read_file"),
425            "allowed request should be forwarded to upstream"
426        );
427
428        // No deny response should be sent to client
429        let client_str = String::from_utf8(client_out).unwrap();
430        assert!(
431            !client_str.contains("denied"),
432            "allowed request should not produce a deny response"
433        );
434    }
435
436    #[tokio::test]
437    async fn test_denied_tool_call_blocked() {
438        let msg = r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"bash","arguments":{"command":"rm -rf /"}}}"#;
439        let client_input = format!("{}\n", msg);
440
441        let (client_out, upstream_out) =
442            run_with_buffers(&client_input, "", engine_with_default_policy()).await;
443
444        // Denied tool call should NOT be forwarded to upstream
445        let upstream_str = String::from_utf8(upstream_out).unwrap();
446        assert!(
447            !upstream_str.contains("bash"),
448            "denied request should not be forwarded"
449        );
450
451        // Deny response should be sent back to client as tool result with isError
452        let client_str = String::from_utf8(client_out).unwrap();
453        assert!(
454            client_str.contains("BLOCKED BY KVLAR"),
455            "client should get Kvlar deny response"
456        );
457        assert!(
458            client_str.contains("Shell execution denied"),
459            "deny response should contain the reason"
460        );
461
462        // Response should include the correct request ID and isError flag
463        let resp: serde_json::Value = serde_json::from_str(client_str.trim()).unwrap();
464        assert_eq!(resp["id"], 2);
465        assert_eq!(resp["result"]["isError"], true);
466    }
467
468    #[tokio::test]
469    async fn test_approval_required_tool_call_blocked() {
470        let msg = r#"{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"send_email","arguments":{"to":"user@example.com"}}}"#;
471        let client_input = format!("{}\n", msg);
472
473        let (client_out, upstream_out) =
474            run_with_buffers(&client_input, "", engine_with_default_policy()).await;
475
476        // Should NOT be forwarded to upstream
477        let upstream_str = String::from_utf8(upstream_out).unwrap();
478        assert!(upstream_str.is_empty());
479
480        // Client should get approval-required response as tool result
481        let client_str = String::from_utf8(client_out).unwrap();
482        assert!(client_str.contains("APPROVAL REQUIRED"));
483        assert!(client_str.contains("Email requires approval"));
484    }
485
486    #[tokio::test]
487    async fn test_non_tool_call_request_passthrough() {
488        let msg = r#"{"jsonrpc":"2.0","id":4,"method":"resources/read","params":{"uri":"file:///tmp/test.txt"}}"#;
489        let client_input = format!("{}\n", msg);
490
491        let (client_out, upstream_out) =
492            run_with_buffers(&client_input, "", engine_with_default_policy()).await;
493
494        // Non-tool-call should pass through to upstream
495        let upstream_str = String::from_utf8(upstream_out).unwrap();
496        assert!(
497            upstream_str.contains("resources/read"),
498            "non-tool-call requests should pass through"
499        );
500
501        // No response sent to client
502        let client_str = String::from_utf8(client_out).unwrap();
503        assert!(client_str.is_empty());
504    }
505
506    #[tokio::test]
507    async fn test_upstream_response_forwarded_to_client() {
508        let upstream_resp =
509            r#"{"jsonrpc":"2.0","id":1,"result":{"content":[{"type":"text","text":"hello"}]}}"#;
510        let upstream_input = format!("{}\n", upstream_resp);
511
512        let (client_out, _upstream_out) =
513            run_with_buffers("", &upstream_input, engine_with_default_policy()).await;
514
515        // Upstream response should be forwarded to client
516        let client_str = String::from_utf8(client_out).unwrap();
517        assert!(
518            client_str.contains("hello"),
519            "upstream response should be forwarded to client"
520        );
521    }
522
523    #[tokio::test]
524    async fn test_tool_args_bridged_to_action_parameters() {
525        // Use a policy that matches on conditions (parameters)
526        let mut engine = Engine::new();
527        engine
528            .load_policy_yaml(
529                r#"
530name: param-test
531description: Test parameter bridging
532version: "1"
533rules:
534  - id: deny-dangerous-path
535    description: Deny access to /etc
536    match_on:
537      resources: ["read_file"]
538      conditions:
539        - field: path
540          operator: starts_with
541          value: "/etc"
542    effect:
543      type: deny
544      reason: "Access to /etc is denied"
545  - id: allow-read
546    description: Allow other reads
547    match_on:
548      resources: ["read_file"]
549    effect:
550      type: allow
551"#,
552            )
553            .unwrap();
554
555        // Request with path=/etc/passwd should be DENIED
556        let msg_denied = r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/etc/passwd"}}}"#;
557        let (client_out, upstream_out) =
558            run_with_buffers(&format!("{}\n", msg_denied), "", engine).await;
559        let client_str = String::from_utf8(client_out).unwrap();
560        let upstream_str = String::from_utf8(upstream_out).unwrap();
561        assert!(
562            client_str.contains("BLOCKED BY KVLAR"),
563            "should deny /etc access"
564        );
565        assert!(upstream_str.is_empty(), "should not forward denied request");
566
567        // Request with path=/tmp/file.txt should be ALLOWED
568        let mut engine2 = Engine::new();
569        engine2
570            .load_policy_yaml(
571                r#"
572name: param-test
573description: Test parameter bridging
574version: "1"
575rules:
576  - id: deny-dangerous-path
577    description: Deny access to /etc
578    match_on:
579      resources: ["read_file"]
580      conditions:
581        - field: path
582          operator: starts_with
583          value: "/etc"
584    effect:
585      type: deny
586      reason: "Access to /etc is denied"
587  - id: allow-read
588    description: Allow other reads
589    match_on:
590      resources: ["read_file"]
591    effect:
592      type: allow
593"#,
594            )
595            .unwrap();
596
597        let msg_allowed = r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/tmp/file.txt"}}}"#;
598        let (_client_out2, upstream_out2) =
599            run_with_buffers(&format!("{}\n", msg_allowed), "", engine2).await;
600        let upstream_str2 = String::from_utf8(upstream_out2).unwrap();
601        assert!(
602            upstream_str2.contains("read_file"),
603            "should forward allowed request"
604        );
605    }
606
607    #[tokio::test]
608    async fn test_default_deny_unmatched_tool() {
609        let msg = r#"{"jsonrpc":"2.0","id":5,"method":"tools/call","params":{"name":"unknown_tool","arguments":{}}}"#;
610        let client_input = format!("{}\n", msg);
611
612        let (client_out, upstream_out) =
613            run_with_buffers(&client_input, "", engine_with_default_policy()).await;
614
615        // Unmatched tool should be DENIED (fail-closed)
616        let upstream_str = String::from_utf8(upstream_out).unwrap();
617        assert!(
618            upstream_str.is_empty(),
619            "unmatched tool should not be forwarded"
620        );
621
622        let client_str = String::from_utf8(client_out).unwrap();
623        assert!(
624            client_str.contains("BLOCKED BY KVLAR"),
625            "unmatched tool should be denied by Kvlar"
626        );
627    }
628
629    #[tokio::test]
630    async fn test_audit_records_created() {
631        let msg = r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"bash","arguments":{"command":"ls"}}}"#;
632        let client_input = format!("{}\n", msg);
633
634        let client_reader = BufReader::new(Cursor::new(client_input.as_bytes().to_vec()));
635        let client_output = Arc::new(Mutex::new(Vec::<u8>::new()));
636        let upstream_reader = BufReader::new(Cursor::new(Vec::<u8>::new()));
637        let upstream_output = Arc::new(Mutex::new(Vec::<u8>::new()));
638        let audit = Arc::new(Mutex::new(AuditLogger::default()));
639
640        run_proxy_loop(
641            client_reader,
642            client_output,
643            upstream_reader,
644            upstream_output,
645            Arc::new(RwLock::new(engine_with_default_policy())),
646            audit.clone(),
647            false,
648            None,
649        )
650        .await
651        .unwrap();
652
653        let aud = audit.lock().await;
654        let events = aud.events();
655        assert_eq!(events.len(), 1, "should record one audit event");
656        assert_eq!(events[0].resource, "bash");
657        assert_eq!(events[0].outcome, kvlar_audit::event::EventOutcome::Denied);
658        assert!(
659            events[0].parameters.is_some(),
660            "audit event should include parameters"
661        );
662    }
663}