dome_gate/lib.rs
1use std::sync::Arc;
2
3use dome_core::{DomeError, McpMessage};
4use dome_ledger::{AuditEntry, Direction, Ledger};
5use dome_policy::{Identity as PolicyIdentity, PolicyEngine};
6use dome_sentinel::{AnonymousAuthenticator, Authenticator, IdentityResolver, ResolverConfig};
7use dome_throttle::{BudgetTracker, BudgetTrackerConfig, RateLimiter, RateLimiterConfig};
8use dome_transport::stdio::StdioTransport;
9use dome_ward::{InjectionScanner, SchemaPinStore};
10
11use chrono::Utc;
12use serde_json::Value;
13use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
14use tokio::sync::Mutex;
15use tracing::{debug, error, info, warn};
16use uuid::Uuid;
17
18/// Configuration for the Gate proxy.
19#[derive(Debug, Clone)]
20pub struct GateConfig {
21 /// Whether to enforce policy (false = transparent pass-through mode).
22 pub enforce_policy: bool,
23 /// Whether to enable injection scanning.
24 pub enable_ward: bool,
25 /// Whether to enable schema pinning.
26 pub enable_schema_pin: bool,
27 /// Whether to enable rate limiting.
28 pub enable_rate_limit: bool,
29 /// Whether to enable budget tracking.
30 pub enable_budget: bool,
31 /// Whether to allow anonymous access.
32 pub allow_anonymous: bool,
33}
34
35impl Default for GateConfig {
36 fn default() -> Self {
37 Self {
38 enforce_policy: false,
39 enable_ward: false,
40 enable_schema_pin: false,
41 enable_rate_limit: false,
42 enable_budget: false,
43 allow_anonymous: true,
44 }
45 }
46}
47
48/// The Gate — MCPDome's core proxy loop with full interceptor chain.
49///
50/// Interceptor order (inbound, client → server):
51/// 1. Sentinel — authenticate on `initialize`, resolve identity
52/// 2. Throttle — check rate limits and budget
53/// 3. Policy — evaluate authorization rules
54/// 4. Ward — scan for injection patterns in tool arguments
55/// 5. Ledger — record the decision in the audit chain
56///
57/// Outbound (server → client):
58/// 1. Schema Pin — verify tools/list responses for drift
59/// 2. Ledger — record outbound audit entry
60pub struct Gate {
61 config: GateConfig,
62 resolver: IdentityResolver,
63 policy_engine: Option<PolicyEngine>,
64 rate_limiter: Arc<RateLimiter>,
65 budget_tracker: Arc<BudgetTracker>,
66 injection_scanner: Arc<InjectionScanner>,
67 schema_store: Arc<Mutex<SchemaPinStore>>,
68 ledger: Arc<Mutex<Ledger>>,
69}
70
71impl Gate {
72 /// Create a new Gate with full interceptor chain.
73 pub fn new(
74 config: GateConfig,
75 authenticators: Vec<Box<dyn Authenticator>>,
76 policy_engine: Option<PolicyEngine>,
77 rate_limiter_config: RateLimiterConfig,
78 budget_config: BudgetTrackerConfig,
79 ledger: Ledger,
80 ) -> Self {
81 Self {
82 resolver: IdentityResolver::new(
83 authenticators,
84 ResolverConfig {
85 allow_anonymous: config.allow_anonymous,
86 },
87 ),
88 policy_engine,
89 rate_limiter: Arc::new(RateLimiter::new(rate_limiter_config)),
90 budget_tracker: Arc::new(BudgetTracker::new(budget_config)),
91 injection_scanner: Arc::new(InjectionScanner::new()),
92 schema_store: Arc::new(Mutex::new(SchemaPinStore::new())),
93 ledger: Arc::new(Mutex::new(ledger)),
94 config,
95 }
96 }
97
98 /// Create a transparent pass-through Gate (no security enforcement).
99 pub fn transparent(ledger: Ledger) -> Self {
100 Self::new(
101 GateConfig::default(),
102 vec![Box::new(AnonymousAuthenticator)],
103 None,
104 RateLimiterConfig::default(),
105 BudgetTrackerConfig::default(),
106 ledger,
107 )
108 }
109
110 /// Run the proxy.
111 pub async fn run_stdio(self, command: &str, args: &[&str]) -> Result<(), DomeError> {
112 let transport = StdioTransport::spawn(command, args).await?;
113 let (mut server_reader, mut server_writer, mut child) = transport.split();
114
115 let client_stdin = tokio::io::stdin();
116 let client_stdout = tokio::io::stdout();
117 let mut client_reader = BufReader::new(client_stdin);
118 let mut client_writer = client_stdout;
119
120 info!("MCPDome proxy active — interceptor chain armed");
121
122 // Shared state for the two tasks
123 let identity: Arc<Mutex<Option<dome_sentinel::Identity>>> = Arc::new(Mutex::new(None));
124
125 let gate_identity = Arc::clone(&identity);
126 let gate_resolver = self.resolver;
127 let gate_policy = self.policy_engine;
128 let gate_rate_limiter = Arc::clone(&self.rate_limiter);
129 let gate_budget = Arc::clone(&self.budget_tracker);
130 let gate_scanner = Arc::clone(&self.injection_scanner);
131 let gate_ledger = Arc::clone(&self.ledger);
132 let gate_config = self.config.clone();
133
134 // Client → Server task (inbound interceptor chain)
135 let client_to_server = tokio::spawn(async move {
136 let mut line = String::new();
137 loop {
138 line.clear();
139 match client_reader.read_line(&mut line).await {
140 Ok(0) => {
141 info!("client closed stdin — shutting down");
142 break;
143 }
144 Ok(_) => {
145 let trimmed = line.trim();
146 if trimmed.is_empty() {
147 continue;
148 }
149
150 match McpMessage::parse(trimmed) {
151 Ok(msg) => {
152 let start = std::time::Instant::now();
153 let method = msg.method.as_deref().unwrap_or("-").to_string();
154 let tool = msg.tool_name().map(String::from);
155 let request_id = Uuid::new_v4();
156
157 debug!(
158 method = method.as_str(),
159 id = ?msg.id,
160 tool = tool.as_deref().unwrap_or("-"),
161 "client -> server"
162 );
163
164 // ── 1. Sentinel: Authenticate on initialize ──
165 if method == "initialize" {
166 match gate_resolver.resolve(&msg).await {
167 Ok(id) => {
168 info!(
169 principal = %id.principal,
170 method = %id.auth_method,
171 "identity resolved"
172 );
173 *gate_identity.lock().await = Some(id);
174 }
175 Err(e) => {
176 warn!(%e, "authentication failed");
177 // Auth failure logged; request still forwarded
178 // to let downstream server handle the handshake
179 // Still forward — let the server handle it
180 // but log the auth failure
181 }
182 }
183 }
184
185 let identity_lock = gate_identity.lock().await;
186 let principal = identity_lock
187 .as_ref()
188 .map(|i| i.principal.clone())
189 .unwrap_or_else(|| "anonymous".to_string());
190 let labels = identity_lock
191 .as_ref()
192 .map(|i| i.labels.clone())
193 .unwrap_or_default();
194 drop(identity_lock);
195
196 // Only intercept tools/call
197 if method == "tools/call" {
198 let tool_name = tool.as_deref().unwrap_or("unknown");
199 let args = msg
200 .params
201 .as_ref()
202 .and_then(|p| p.get("arguments"))
203 .cloned()
204 .unwrap_or(Value::Null);
205
206 // ── 2. Throttle: Rate limit check ──
207 if gate_config.enable_rate_limit {
208 if let Err(e) = gate_rate_limiter
209 .check_rate_limit(&principal, Some(tool_name))
210 {
211 warn!(%e, principal = %principal, tool = tool_name, "rate limited");
212 record_audit(
213 &gate_ledger,
214 request_id,
215 &principal,
216 Direction::Inbound,
217 &method,
218 tool.as_deref(),
219 "deny:rate_limit",
220 None,
221 start.elapsed().as_micros() as u64,
222 )
223 .await;
224 // Send error response back to client via server_writer
225 // For now, skip the request
226 continue;
227 }
228 }
229
230 // ── 2b. Throttle: Budget check ──
231 if gate_config.enable_budget {
232 if let Err(e) = gate_budget.try_spend(&principal, 1.0) {
233 warn!(%e, principal = %principal, "budget exhausted");
234 record_audit(
235 &gate_ledger,
236 request_id,
237 &principal,
238 Direction::Inbound,
239 &method,
240 tool.as_deref(),
241 "deny:budget",
242 None,
243 start.elapsed().as_micros() as u64,
244 )
245 .await;
246 continue;
247 }
248 }
249
250 // ── 3. Policy: Authorization ──
251 if gate_config.enforce_policy {
252 if let Some(ref engine) = gate_policy {
253 let policy_id = PolicyIdentity::new(
254 principal.clone(),
255 labels.iter().cloned(),
256 );
257 let decision = engine.evaluate(&policy_id, tool_name, &args);
258
259 if !decision.is_allowed() {
260 warn!(
261 rule_id = %decision.rule_id,
262 tool = tool_name,
263 principal = %principal,
264 "policy denied"
265 );
266 record_audit(
267 &gate_ledger,
268 request_id,
269 &principal,
270 Direction::Inbound,
271 &method,
272 tool.as_deref(),
273 &format!("deny:policy:{}", decision.rule_id),
274 Some(&decision.rule_id),
275 start.elapsed().as_micros() as u64,
276 )
277 .await;
278 continue;
279 }
280 }
281 }
282
283 // ── 4. Ward: Injection scanning ──
284 if gate_config.enable_ward {
285 let args_str = serde_json::to_string(&args).unwrap_or_default();
286 let matches = gate_scanner.scan_text(&args_str);
287 if !matches.is_empty() {
288 let pattern_names: Vec<&str> =
289 matches.iter().map(|m| m.pattern_name.as_str()).collect();
290 warn!(
291 patterns = ?pattern_names,
292 tool = tool_name,
293 principal = %principal,
294 "injection detected"
295 );
296 record_audit(
297 &gate_ledger,
298 request_id,
299 &principal,
300 Direction::Inbound,
301 &method,
302 tool.as_deref(),
303 &format!("deny:injection:{}", pattern_names.join(",")),
304 None,
305 start.elapsed().as_micros() as u64,
306 )
307 .await;
308 continue;
309 }
310 }
311 }
312
313 // ── 5. Ledger: Record allowed request ──
314 record_audit(
315 &gate_ledger,
316 request_id,
317 &principal,
318 Direction::Inbound,
319 &method,
320 tool.as_deref(),
321 "allow",
322 None,
323 start.elapsed().as_micros() as u64,
324 )
325 .await;
326
327 // Forward to server
328 if let Err(e) = server_writer.send(&msg).await {
329 error!(%e, "failed to forward to server");
330 break;
331 }
332 }
333 Err(e) => {
334 warn!(%e, raw = trimmed, "invalid JSON from client, forwarding raw");
335 let _ = server_writer
336 .send(&McpMessage {
337 jsonrpc: "2.0".to_string(),
338 id: None,
339 method: None,
340 params: None,
341 result: None,
342 error: None,
343 })
344 .await;
345 }
346 }
347 }
348 Err(e) => {
349 error!(%e, "error reading from client");
350 break;
351 }
352 }
353 }
354 });
355
356 let schema_store = Arc::clone(&self.schema_store);
357 let _outbound_ledger = Arc::clone(&self.ledger);
358 let outbound_config = self.config.clone();
359
360 // Server → Client task (outbound interceptor chain)
361 let server_to_client = tokio::spawn(async move {
362 let mut first_tools_list = true;
363
364 loop {
365 match server_reader.recv().await {
366 Ok(msg) => {
367 let method = msg.method.as_deref().unwrap_or("-").to_string();
368
369 debug!(
370 method = method.as_str(),
371 id = ?msg.id,
372 "server -> client"
373 );
374
375 // ── Schema Pin: Verify tools/list responses ──
376 if outbound_config.enable_schema_pin {
377 if let Some(result) = &msg.result {
378 if result.get("tools").is_some() {
379 let mut store = schema_store.lock().await;
380 if first_tools_list {
381 store.pin_tools(result);
382 info!(
383 pinned = store.len(),
384 "schema pins established"
385 );
386 first_tools_list = false;
387 } else {
388 let drifts = store.verify_tools(result);
389 if !drifts.is_empty() {
390 for drift in &drifts {
391 warn!(
392 tool = %drift.tool_name,
393 drift_type = ?drift.drift_type,
394 severity = ?drift.severity,
395 "schema drift detected"
396 );
397 }
398 }
399 }
400 }
401 }
402 }
403
404 // Forward to client
405 match msg.to_json() {
406 Ok(json) => {
407 let mut out = json.into_bytes();
408 out.push(b'\n');
409 if let Err(e) = client_writer.write_all(&out).await {
410 error!(%e, "failed to write to client");
411 break;
412 }
413 if let Err(e) = client_writer.flush().await {
414 error!(%e, "failed to flush to client");
415 break;
416 }
417 }
418 Err(e) => {
419 error!(%e, "failed to serialize server response");
420 }
421 }
422 }
423 Err(DomeError::Transport(ref e))
424 if e.kind() == std::io::ErrorKind::UnexpectedEof =>
425 {
426 info!("server closed stdout — shutting down");
427 break;
428 }
429 Err(e) => {
430 error!(%e, "error reading from server");
431 break;
432 }
433 }
434 }
435 });
436
437 // Wait for either side to finish
438 tokio::select! {
439 r = client_to_server => {
440 if let Err(e) = r {
441 error!(%e, "client->server task panicked");
442 }
443 }
444 r = server_to_client => {
445 if let Err(e) = r {
446 error!(%e, "server->client task panicked");
447 }
448 }
449 }
450
451 // Flush audit log and clean up
452 self.ledger.lock().await.flush();
453 let _ = child.kill().await;
454 info!("MCPDome proxy shut down");
455
456 Ok(())
457 }
458}
459
460/// Helper to record an audit entry.
461async fn record_audit(
462 ledger: &Arc<Mutex<Ledger>>,
463 request_id: Uuid,
464 identity: &str,
465 direction: Direction,
466 method: &str,
467 tool: Option<&str>,
468 decision: &str,
469 rule_id: Option<&str>,
470 latency_us: u64,
471) {
472 let entry = AuditEntry {
473 seq: 0, // set by ledger
474 timestamp: Utc::now(),
475 request_id,
476 identity: identity.to_string(),
477 direction,
478 method: method.to_string(),
479 tool: tool.map(String::from),
480 decision: decision.to_string(),
481 rule_id: rule_id.map(String::from),
482 latency_us,
483 prev_hash: String::new(), // set by ledger
484 annotations: std::collections::HashMap::new(),
485 };
486
487 if let Err(e) = ledger.lock().await.record(entry) {
488 error!(%e, "failed to record audit entry");
489 }
490}