use std::sync::Arc;
use kvlar_audit::AuditLogger;
use kvlar_core::Engine;
use tokio::io::BufReader;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Mutex;
use crate::config::ProxyConfig;
use crate::handler;
pub struct McpProxy {
engine: Arc<Mutex<Engine>>,
config: ProxyConfig,
audit: Arc<Mutex<AuditLogger>>,
}
impl McpProxy {
pub fn new(engine: Engine, config: ProxyConfig) -> Self {
let audit = AuditLogger::default();
Self {
engine: Arc::new(Mutex::new(engine)),
config,
audit: Arc::new(Mutex::new(audit)),
}
}
pub fn with_audit(engine: Engine, config: ProxyConfig, audit: AuditLogger) -> Self {
Self {
engine: Arc::new(Mutex::new(engine)),
config,
audit: Arc::new(Mutex::new(audit)),
}
}
pub fn engine(&self) -> &Arc<Mutex<Engine>> {
&self.engine
}
pub fn config(&self) -> &ProxyConfig {
&self.config
}
pub async fn replace_engine(&self, new_engine: Engine) {
let mut engine = self.engine.lock().await;
*engine = new_engine;
}
pub async fn run(&self) -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::bind(&self.config.listen_addr).await?;
tracing::info!(addr = %self.config.listen_addr, "kvlar proxy listening");
loop {
let (client_stream, client_addr) = listener.accept().await?;
tracing::info!(client = %client_addr, "new connection");
let engine = self.engine.clone();
let upstream_addr = self.config.upstream_addr.clone();
let audit = self.audit.clone();
let fail_open = self.config.fail_open;
tokio::spawn(async move {
if let Err(e) =
Self::handle_connection(client_stream, &upstream_addr, engine, audit, fail_open)
.await
{
tracing::error!(client = %client_addr, error = %e, "connection error");
}
});
}
}
async fn handle_connection(
client_stream: TcpStream,
upstream_addr: &str,
engine: Arc<Mutex<Engine>>,
audit: Arc<Mutex<AuditLogger>>,
fail_open: bool,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let upstream_stream = TcpStream::connect(upstream_addr).await?;
let (client_read, client_write) = client_stream.into_split();
let (upstream_read, upstream_write) = upstream_stream.into_split();
let client_reader = BufReader::new(client_read);
let upstream_reader = BufReader::new(upstream_read);
handler::run_proxy_loop(
client_reader,
Arc::new(Mutex::new(client_write)),
upstream_reader,
Arc::new(Mutex::new(upstream_write)),
engine,
audit,
fail_open,
)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_proxy_creation() {
let engine = Engine::new();
let config = ProxyConfig::default();
let proxy = McpProxy::new(engine, config);
assert_eq!(proxy.config().listen_addr, "127.0.0.1:9100");
}
#[tokio::test]
async fn test_proxy_replace_engine() {
let engine = Engine::new();
let config = ProxyConfig::default();
let proxy = McpProxy::new(engine, config);
{
let engine = proxy.engine().lock().await;
assert_eq!(engine.policy_count(), 0);
}
let mut new_engine = Engine::new();
new_engine
.load_policy_yaml(
r#"
name: test
description: test
version: "1"
rules:
- id: deny-all
description: deny everything
match_on: {}
effect:
type: deny
reason: "denied"
"#,
)
.unwrap();
proxy.replace_engine(new_engine).await;
{
let engine = proxy.engine().lock().await;
assert_eq!(engine.policy_count(), 1);
}
}
}