Skip to main content

kvlar_proxy/
proxy.rs

1//! MCP proxy server implementation (TCP transport).
2//!
3//! Implements a TCP proxy that intercepts MCP JSON-RPC messages,
4//! evaluates tool calls against the policy engine, and either
5//! forwards allowed requests or blocks denied ones.
6
7use std::sync::Arc;
8
9use kvlar_audit::AuditLogger;
10use kvlar_core::Engine;
11use tokio::io::BufReader;
12use tokio::net::{TcpListener, TcpStream};
13use tokio::sync::{Mutex, RwLock};
14
15use crate::config::ProxyConfig;
16use crate::handler;
17use crate::health::ProxyStats;
18use crate::shutdown;
19
20/// The MCP security proxy (TCP transport).
21///
22/// Listens for incoming MCP connections, intercepts tool call requests,
23/// runs them through the policy engine, and forwards allowed requests
24/// to the upstream MCP server.
25pub struct McpProxy {
26    /// The policy evaluation engine.
27    engine: Arc<RwLock<Engine>>,
28
29    /// Proxy configuration.
30    config: ProxyConfig,
31
32    /// Audit logger.
33    audit: Arc<Mutex<AuditLogger>>,
34
35    /// Runtime statistics for health endpoint.
36    stats: Arc<ProxyStats>,
37}
38
39impl McpProxy {
40    /// Creates a new proxy with the given engine and configuration.
41    pub fn new(engine: Engine, config: ProxyConfig) -> Self {
42        let audit = AuditLogger::default();
43        Self {
44            engine: Arc::new(RwLock::new(engine)),
45            config,
46            audit: Arc::new(Mutex::new(audit)),
47            stats: Arc::new(ProxyStats::new()),
48        }
49    }
50
51    /// Creates a new proxy with a custom audit logger.
52    pub fn with_audit(engine: Engine, config: ProxyConfig, audit: AuditLogger) -> Self {
53        Self {
54            engine: Arc::new(RwLock::new(engine)),
55            config,
56            audit: Arc::new(Mutex::new(audit)),
57            stats: Arc::new(ProxyStats::new()),
58        }
59    }
60
61    /// Returns a reference to the shared engine.
62    pub fn engine(&self) -> &Arc<RwLock<Engine>> {
63        &self.engine
64    }
65
66    /// Returns a reference to the proxy configuration.
67    pub fn config(&self) -> &ProxyConfig {
68        &self.config
69    }
70
71    /// Returns a reference to the proxy stats.
72    pub fn stats(&self) -> &Arc<ProxyStats> {
73        &self.stats
74    }
75
76    /// Replaces the engine with a new one (for hot-reload).
77    pub async fn replace_engine(&self, new_engine: Engine) {
78        let mut engine = self.engine.write().await;
79        *engine = new_engine;
80    }
81
82    /// Starts the proxy server with graceful shutdown.
83    ///
84    /// Listens for incoming connections and handles them concurrently.
85    /// On SIGTERM/SIGINT, stops accepting new connections and waits
86    /// for active connections to drain (up to 30s timeout).
87    pub async fn run(&self) -> Result<(), Box<dyn std::error::Error>> {
88        let listener = TcpListener::bind(&self.config.listen_addr).await?;
89        tracing::info!(addr = %self.config.listen_addr, "kvlar proxy listening");
90
91        // Start health check server if configured
92        if let Some(ref health_addr) = self.config.health_addr {
93            let stats = self.stats.clone();
94            let addr = health_addr.clone();
95            tokio::spawn(async move {
96                if let Err(e) = crate::health::run_health_server(&addr, stats).await {
97                    tracing::error!(error = %e, "health server error");
98                }
99            });
100        }
101
102        // Update stats with current policy info
103        {
104            let eng = self.engine.read().await;
105            self.stats
106                .set_policy_info(eng.policy_count() > 0, eng.rule_count() as u64);
107        }
108
109        // Install signal handlers
110        let shutdown_token = shutdown::signal_shutdown_token();
111
112        loop {
113            tokio::select! {
114                result = listener.accept() => {
115                    let (client_stream, client_addr) = result?;
116                    tracing::info!(client = %client_addr, "new connection");
117
118                    let engine = self.engine.clone();
119                    let upstream_addr = self.config.upstream_addr.clone();
120                    let audit = self.audit.clone();
121                    let fail_open = self.config.fail_open;
122
123                    tokio::spawn(async move {
124                        if let Err(e) =
125                            Self::handle_connection(client_stream, &upstream_addr, engine, audit, fail_open)
126                                .await
127                        {
128                            tracing::error!(client = %client_addr, error = %e, "connection error");
129                        }
130                    });
131                }
132                _ = shutdown_token.cancelled() => {
133                    tracing::info!("shutdown signal received, stopping accept loop");
134                    break;
135                }
136            }
137        }
138
139        // Flush audit log
140        {
141            let mut audit = self.audit.lock().await;
142            audit.flush();
143            tracing::info!("audit log flushed");
144        }
145
146        // Give active connections time to drain
147        tracing::info!("waiting for active connections to drain...");
148        tokio::time::sleep(std::time::Duration::from_millis(500)).await;
149
150        Ok(())
151    }
152
153    /// Handles a single client connection by delegating to the shared handler.
154    async fn handle_connection(
155        client_stream: TcpStream,
156        upstream_addr: &str,
157        engine: Arc<RwLock<Engine>>,
158        audit: Arc<Mutex<AuditLogger>>,
159        fail_open: bool,
160    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
161        let upstream_stream = TcpStream::connect(upstream_addr).await?;
162
163        let (client_read, client_write) = client_stream.into_split();
164        let (upstream_read, upstream_write) = upstream_stream.into_split();
165
166        let client_reader = BufReader::new(client_read);
167        let upstream_reader = BufReader::new(upstream_read);
168
169        handler::run_proxy_loop(
170            client_reader,
171            Arc::new(Mutex::new(client_write)),
172            upstream_reader,
173            Arc::new(Mutex::new(upstream_write)),
174            engine,
175            audit,
176            fail_open,
177            None,
178        )
179        .await
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186
187    #[test]
188    fn test_proxy_creation() {
189        let engine = Engine::new();
190        let config = ProxyConfig::default();
191        let proxy = McpProxy::new(engine, config);
192        assert_eq!(proxy.config().listen_addr, "127.0.0.1:9100");
193    }
194
195    #[tokio::test]
196    async fn test_proxy_replace_engine() {
197        let engine = Engine::new();
198        let config = ProxyConfig::default();
199        let proxy = McpProxy::new(engine, config);
200
201        {
202            let engine = proxy.engine().read().await;
203            assert_eq!(engine.policy_count(), 0);
204        }
205
206        let mut new_engine = Engine::new();
207        new_engine
208            .load_policy_yaml(
209                r#"
210name: test
211description: test
212version: "1"
213rules:
214  - id: deny-all
215    description: deny everything
216    match_on: {}
217    effect:
218      type: deny
219      reason: "denied"
220"#,
221            )
222            .unwrap();
223
224        proxy.replace_engine(new_engine).await;
225
226        {
227            let engine = proxy.engine().read().await;
228            assert_eq!(engine.policy_count(), 1);
229        }
230    }
231}