1use std::sync::Arc;
8
9use kvlar_audit::AuditLogger;
10use kvlar_core::Engine;
11use tokio::io::BufReader;
12use tokio::net::{TcpListener, TcpStream};
13use tokio::sync::{Mutex, RwLock};
14
15use crate::config::ProxyConfig;
16use crate::handler;
17use crate::health::ProxyStats;
18use crate::shutdown;
19
20pub struct McpProxy {
26 engine: Arc<RwLock<Engine>>,
28
29 config: ProxyConfig,
31
32 audit: Arc<Mutex<AuditLogger>>,
34
35 stats: Arc<ProxyStats>,
37}
38
39impl McpProxy {
40 pub fn new(engine: Engine, config: ProxyConfig) -> Self {
42 let audit = AuditLogger::default();
43 Self {
44 engine: Arc::new(RwLock::new(engine)),
45 config,
46 audit: Arc::new(Mutex::new(audit)),
47 stats: Arc::new(ProxyStats::new()),
48 }
49 }
50
51 pub fn with_audit(engine: Engine, config: ProxyConfig, audit: AuditLogger) -> Self {
53 Self {
54 engine: Arc::new(RwLock::new(engine)),
55 config,
56 audit: Arc::new(Mutex::new(audit)),
57 stats: Arc::new(ProxyStats::new()),
58 }
59 }
60
61 pub fn engine(&self) -> &Arc<RwLock<Engine>> {
63 &self.engine
64 }
65
66 pub fn config(&self) -> &ProxyConfig {
68 &self.config
69 }
70
71 pub fn stats(&self) -> &Arc<ProxyStats> {
73 &self.stats
74 }
75
76 pub async fn replace_engine(&self, new_engine: Engine) {
78 let mut engine = self.engine.write().await;
79 *engine = new_engine;
80 }
81
82 pub async fn run(&self) -> Result<(), Box<dyn std::error::Error>> {
88 let listener = TcpListener::bind(&self.config.listen_addr).await?;
89 tracing::info!(addr = %self.config.listen_addr, "kvlar proxy listening");
90
91 if let Some(ref health_addr) = self.config.health_addr {
93 let stats = self.stats.clone();
94 let addr = health_addr.clone();
95 tokio::spawn(async move {
96 if let Err(e) = crate::health::run_health_server(&addr, stats).await {
97 tracing::error!(error = %e, "health server error");
98 }
99 });
100 }
101
102 {
104 let eng = self.engine.read().await;
105 self.stats
106 .set_policy_info(eng.policy_count() > 0, eng.rule_count() as u64);
107 }
108
109 let shutdown_token = shutdown::signal_shutdown_token();
111
112 loop {
113 tokio::select! {
114 result = listener.accept() => {
115 let (client_stream, client_addr) = result?;
116 tracing::info!(client = %client_addr, "new connection");
117
118 let engine = self.engine.clone();
119 let upstream_addr = self.config.upstream_addr.clone();
120 let audit = self.audit.clone();
121 let fail_open = self.config.fail_open;
122
123 tokio::spawn(async move {
124 if let Err(e) =
125 Self::handle_connection(client_stream, &upstream_addr, engine, audit, fail_open)
126 .await
127 {
128 tracing::error!(client = %client_addr, error = %e, "connection error");
129 }
130 });
131 }
132 _ = shutdown_token.cancelled() => {
133 tracing::info!("shutdown signal received, stopping accept loop");
134 break;
135 }
136 }
137 }
138
139 {
141 let mut audit = self.audit.lock().await;
142 audit.flush();
143 tracing::info!("audit log flushed");
144 }
145
146 tracing::info!("waiting for active connections to drain...");
148 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
149
150 Ok(())
151 }
152
153 async fn handle_connection(
155 client_stream: TcpStream,
156 upstream_addr: &str,
157 engine: Arc<RwLock<Engine>>,
158 audit: Arc<Mutex<AuditLogger>>,
159 fail_open: bool,
160 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
161 let upstream_stream = TcpStream::connect(upstream_addr).await?;
162
163 let (client_read, client_write) = client_stream.into_split();
164 let (upstream_read, upstream_write) = upstream_stream.into_split();
165
166 let client_reader = BufReader::new(client_read);
167 let upstream_reader = BufReader::new(upstream_read);
168
169 handler::run_proxy_loop(
170 client_reader,
171 Arc::new(Mutex::new(client_write)),
172 upstream_reader,
173 Arc::new(Mutex::new(upstream_write)),
174 engine,
175 audit,
176 fail_open,
177 None,
178 )
179 .await
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186
187 #[test]
188 fn test_proxy_creation() {
189 let engine = Engine::new();
190 let config = ProxyConfig::default();
191 let proxy = McpProxy::new(engine, config);
192 assert_eq!(proxy.config().listen_addr, "127.0.0.1:9100");
193 }
194
195 #[tokio::test]
196 async fn test_proxy_replace_engine() {
197 let engine = Engine::new();
198 let config = ProxyConfig::default();
199 let proxy = McpProxy::new(engine, config);
200
201 {
202 let engine = proxy.engine().read().await;
203 assert_eq!(engine.policy_count(), 0);
204 }
205
206 let mut new_engine = Engine::new();
207 new_engine
208 .load_policy_yaml(
209 r#"
210name: test
211description: test
212version: "1"
213rules:
214 - id: deny-all
215 description: deny everything
216 match_on: {}
217 effect:
218 type: deny
219 reason: "denied"
220"#,
221 )
222 .unwrap();
223
224 proxy.replace_engine(new_engine).await;
225
226 {
227 let engine = proxy.engine().read().await;
228 assert_eq!(engine.policy_count(), 1);
229 }
230 }
231}