dome_gate/lib.rs
1#![allow(clippy::collapsible_if, clippy::too_many_arguments)]
2
3use std::sync::Arc;
4
5use dome_core::{McpMessage, DomeError};
6use dome_ledger::{AuditEntry, Direction, Ledger};
7use dome_policy::{Identity as PolicyIdentity, SharedPolicyEngine};
8use dome_sentinel::{AnonymousAuthenticator, Authenticator, IdentityResolver, PskAuthenticator, ResolverConfig};
9use dome_throttle::{BudgetTracker, BudgetTrackerConfig, RateLimiter, RateLimiterConfig};
10use dome_transport::stdio::StdioTransport;
11use dome_ward::{InjectionScanner, SchemaPinStore};
12use dome_ward::schema_pin::DriftSeverity;
13
14use chrono::Utc;
15use serde_json::Value;
16use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
17use tokio::sync::Mutex;
18use tracing::{debug, error, info, warn};
19use uuid::Uuid;
20
21/// Configuration for the Gate proxy.
22#[derive(Debug, Clone)]
23pub struct GateConfig {
24 /// Whether to enforce policy (false = transparent pass-through mode).
25 pub enforce_policy: bool,
26 /// Whether to enable injection scanning.
27 pub enable_ward: bool,
28 /// Whether to enable schema pinning.
29 pub enable_schema_pin: bool,
30 /// Whether to enable rate limiting.
31 pub enable_rate_limit: bool,
32 /// Whether to enable budget tracking.
33 pub enable_budget: bool,
34 /// Whether to allow anonymous access.
35 pub allow_anonymous: bool,
36}
37
38impl Default for GateConfig {
39 fn default() -> Self {
40 Self {
41 enforce_policy: false,
42 enable_ward: false,
43 enable_schema_pin: false,
44 enable_rate_limit: false,
45 enable_budget: false,
46 allow_anonymous: true,
47 }
48 }
49}
50
51/// The Gate -- MCPDome's core proxy loop with full interceptor chain.
52///
53/// Interceptor order (inbound, client -> server):
54/// 1. Sentinel -- authenticate on `initialize`, resolve identity
55/// 2. Throttle -- check rate limits and budget
56/// 3. Ward -- scan for injection patterns in tool arguments
57/// 4. Policy -- evaluate authorization rules
58/// 5. Ledger -- record the decision in the audit chain
59///
60/// Outbound (server -> client):
61/// 1. Schema Pin -- verify tools/list responses for drift (block Critical/High)
62/// 2. Ward -- scan outbound tool results for injection patterns
63/// 3. Ledger -- record outbound audit entry
64pub struct Gate {
65 config: GateConfig,
66 resolver: IdentityResolver,
67 policy_engine: Option<SharedPolicyEngine>,
68 rate_limiter: Arc<RateLimiter>,
69 budget_tracker: Arc<BudgetTracker>,
70 injection_scanner: Arc<InjectionScanner>,
71 schema_store: Arc<Mutex<SchemaPinStore>>,
72 ledger: Arc<Mutex<Ledger>>,
73}
74
75impl Gate {
76 /// Create a new Gate with full interceptor chain.
77 ///
78 /// The `policy_engine` parameter accepts an optional `SharedPolicyEngine`
79 /// (`Arc<ArcSwap<PolicyEngine>>`), which enables hot-reload. The gate reads
80 /// the current policy atomically on every request via `load()`, so swaps
81 /// performed by a [`PolicyWatcher`] are immediately visible without restart.
82 pub fn new(
83 config: GateConfig,
84 authenticators: Vec<Box<dyn Authenticator>>,
85 policy_engine: Option<SharedPolicyEngine>,
86 rate_limiter_config: RateLimiterConfig,
87 budget_config: BudgetTrackerConfig,
88 ledger: Ledger,
89 ) -> Self {
90 Self {
91 resolver: IdentityResolver::new(
92 authenticators,
93 ResolverConfig {
94 allow_anonymous: config.allow_anonymous,
95 },
96 ),
97 policy_engine,
98 rate_limiter: Arc::new(RateLimiter::new(rate_limiter_config)),
99 budget_tracker: Arc::new(BudgetTracker::new(budget_config)),
100 injection_scanner: Arc::new(InjectionScanner::new()),
101 schema_store: Arc::new(Mutex::new(SchemaPinStore::new())),
102 ledger: Arc::new(Mutex::new(ledger)),
103 config,
104 }
105 }
106
107 /// Create a transparent pass-through Gate (no security enforcement).
108 pub fn transparent(ledger: Ledger) -> Self {
109 Self::new(
110 GateConfig::default(),
111 vec![Box::new(AnonymousAuthenticator)],
112 None,
113 RateLimiterConfig::default(),
114 BudgetTrackerConfig::default(),
115 ledger,
116 )
117 }
118
119 /// Run the proxy.
120 pub async fn run_stdio(self, command: &str, args: &[&str]) -> Result<(), DomeError> {
121 let transport = StdioTransport::spawn(command, args).await?;
122 let (mut server_reader, mut server_writer, mut child) = transport.split();
123
124 let client_stdin = tokio::io::stdin();
125 let client_stdout = tokio::io::stdout();
126 let mut client_reader = BufReader::new(client_stdin);
127 let client_writer: Arc<Mutex<tokio::io::Stdout>> = Arc::new(Mutex::new(client_stdout));
128
129 info!("MCPDome proxy active -- interceptor chain armed");
130
131 // Shared state for the two tasks
132 let identity: Arc<Mutex<Option<dome_sentinel::Identity>>> = Arc::new(Mutex::new(None));
133
134 let gate_identity = Arc::clone(&identity);
135 let gate_resolver = self.resolver;
136 let gate_policy = self.policy_engine.clone();
137 let gate_rate_limiter = Arc::clone(&self.rate_limiter);
138 let gate_budget = Arc::clone(&self.budget_tracker);
139 let gate_scanner = Arc::clone(&self.injection_scanner);
140 let gate_ledger = Arc::clone(&self.ledger);
141 let gate_config = self.config.clone();
142 let gate_client_writer = Arc::clone(&client_writer);
143
144 // Client -> Server task (inbound interceptor chain)
145 let client_to_server = tokio::spawn(async move {
146 let mut line = String::new();
147 loop {
148 line.clear();
149 match client_reader.read_line(&mut line).await {
150 Ok(0) => {
151 info!("client closed stdin -- shutting down");
152 break;
153 }
154 Ok(_) => {
155 let trimmed = line.trim();
156 if trimmed.is_empty() {
157 continue;
158 }
159
160 match McpMessage::parse(trimmed) {
161 Ok(msg) => {
162 let start = std::time::Instant::now();
163 let method = msg.method.as_deref().unwrap_or("-").to_string();
164 let tool = msg.tool_name().map(String::from);
165 let request_id = Uuid::new_v4();
166
167 debug!(
168 method = method.as_str(),
169 id = ?msg.id,
170 tool = tool.as_deref().unwrap_or("-"),
171 "client -> server"
172 );
173
174 // ── 1. Sentinel: Authenticate on initialize ──
175 let mut msg = msg;
176 if method == "initialize" {
177 match gate_resolver.resolve(&msg).await {
178 Ok(id) => {
179 info!(
180 principal = %id.principal,
181 method = %id.auth_method,
182 "identity resolved"
183 );
184 *gate_identity.lock().await = Some(id);
185
186 // Fix 4: Strip PSK before forwarding so the
187 // downstream server never sees credentials.
188 msg = PskAuthenticator::strip_psk(&msg);
189 }
190 Err(e) => {
191 // Fix 1: On auth failure during initialize,
192 // send error response and do NOT forward.
193 warn!(%e, "authentication failed");
194 let err_id = msg.id.clone().unwrap_or(Value::Null);
195 let err_resp = McpMessage::error_response(
196 err_id,
197 -32600,
198 "Authentication failed",
199 );
200 if let Err(we) = write_to_client(&gate_client_writer, &err_resp).await {
201 error!(%we, "failed to send auth error to client");
202 break;
203 }
204 continue;
205 }
206 }
207 }
208
209 // Fix 1: Block all non-initialize requests before
210 // the session has been initialized (identity resolved).
211 if method != "initialize" {
212 let identity_lock = gate_identity.lock().await;
213 if identity_lock.is_none() {
214 drop(identity_lock);
215 warn!(method = %method, "request before initialize");
216 let err_id = msg.id.clone().unwrap_or(Value::Null);
217 let err_resp = McpMessage::error_response(
218 err_id,
219 -32600,
220 "Session not initialized",
221 );
222 if let Err(we) = write_to_client(&gate_client_writer, &err_resp).await {
223 error!(%we, "failed to send not-initialized error to client");
224 break;
225 }
226 continue;
227 }
228 drop(identity_lock);
229 }
230
231 let identity_lock = gate_identity.lock().await;
232 let principal = identity_lock
233 .as_ref()
234 .map(|i| i.principal.clone())
235 .unwrap_or_else(|| "anonymous".to_string());
236 let labels = identity_lock
237 .as_ref()
238 .map(|i| i.labels.clone())
239 .unwrap_or_default();
240 drop(identity_lock);
241
242 // Fix 2: Apply interceptor chain to ALL methods,
243 // not just tools/call.
244
245 // Extract tool-specific fields conditionally.
246 let tool_name = if method == "tools/call" {
247 tool.as_deref().unwrap_or("unknown")
248 } else {
249 "-"
250 };
251
252 let args = msg
253 .params
254 .as_ref()
255 .and_then(|p| p.get("arguments"))
256 .cloned()
257 .unwrap_or(Value::Null);
258
259 // ── 2. Throttle: Rate limit check ──
260 if gate_config.enable_rate_limit {
261 let rl_tool = if method == "tools/call" { Some(tool_name) } else { None };
262 if let Err(e) = gate_rate_limiter
263 .check_rate_limit(&principal, rl_tool)
264 {
265 warn!(%e, principal = %principal, method = %method, "rate limited");
266 record_audit(
267 &gate_ledger,
268 request_id,
269 &principal,
270 Direction::Inbound,
271 &method,
272 tool.as_deref(),
273 "deny:rate_limit",
274 None,
275 start.elapsed().as_micros() as u64,
276 )
277 .await;
278 // Fix 3: Send JSON-RPC error on rate limit deny.
279 let err_id = msg.id.clone().unwrap_or(Value::Null);
280 let err_resp = McpMessage::error_response(
281 err_id,
282 -32000,
283 "Rate limit exceeded",
284 );
285 if let Err(we) = write_to_client(&gate_client_writer, &err_resp).await {
286 error!(%we, "failed to send rate limit error to client");
287 break;
288 }
289 continue;
290 }
291 }
292
293 // ── 2b. Throttle: Budget check ──
294 if gate_config.enable_budget {
295 if let Err(e) = gate_budget.try_spend(&principal, 1.0) {
296 warn!(%e, principal = %principal, "budget exhausted");
297 record_audit(
298 &gate_ledger,
299 request_id,
300 &principal,
301 Direction::Inbound,
302 &method,
303 tool.as_deref(),
304 "deny:budget",
305 None,
306 start.elapsed().as_micros() as u64,
307 )
308 .await;
309 // Fix 3: Send JSON-RPC error on budget deny.
310 let err_id = msg.id.clone().unwrap_or(Value::Null);
311 let err_resp = McpMessage::error_response(
312 err_id,
313 -32000,
314 "Budget exhausted",
315 );
316 if let Err(we) = write_to_client(&gate_client_writer, &err_resp).await {
317 error!(%we, "failed to send budget error to client");
318 break;
319 }
320 continue;
321 }
322 }
323
324 // ── 3. Ward: Injection scanning ──
325 // Fix 8: Ward runs BEFORE policy so injection detection
326 // is applied regardless of authorization level.
327 if gate_config.enable_ward {
328 // Scan params/arguments if present.
329 let scan_text = if method == "tools/call" {
330 serde_json::to_string(&args).unwrap_or_default()
331 } else if let Some(ref params) = msg.params {
332 serde_json::to_string(params).unwrap_or_default()
333 } else {
334 String::new()
335 };
336
337 if !scan_text.is_empty() {
338 let matches = gate_scanner.scan_text(&scan_text);
339 if !matches.is_empty() {
340 let pattern_names: Vec<&str> = matches
341 .iter()
342 .map(|m| m.pattern_name.as_str())
343 .collect();
344 warn!(
345 patterns = ?pattern_names,
346 method = %method,
347 tool = tool_name,
348 principal = %principal,
349 "injection detected"
350 );
351 record_audit(
352 &gate_ledger,
353 request_id,
354 &principal,
355 Direction::Inbound,
356 &method,
357 tool.as_deref(),
358 &format!(
359 "deny:injection:{}",
360 pattern_names.join(",")
361 ),
362 None,
363 start.elapsed().as_micros() as u64,
364 )
365 .await;
366 // Fix 3: Send JSON-RPC error on injection deny.
367 let err_id = msg.id.clone().unwrap_or(Value::Null);
368 let err_resp = McpMessage::error_response(
369 err_id,
370 -32003,
371 "Request blocked: injection pattern detected",
372 );
373 if let Err(we) = write_to_client(&gate_client_writer, &err_resp).await {
374 error!(%we, "failed to send injection error to client");
375 break;
376 }
377 continue;
378 }
379 }
380 }
381
382 // ── 4. Policy: Authorization ──
383 if gate_config.enforce_policy {
384 if let Some(ref shared_engine) = gate_policy {
385 // Load the current policy atomically. This is
386 // lock-free and picks up hot-reloaded changes
387 // immediately.
388 let engine = shared_engine.load();
389
390 // For tools/call, evaluate with the tool name;
391 // for other methods, evaluate with the method itself.
392 let policy_resource = if method == "tools/call" {
393 tool_name
394 } else {
395 method.as_str()
396 };
397 let policy_id = PolicyIdentity::new(
398 principal.clone(),
399 labels.iter().cloned(),
400 );
401 let decision =
402 engine.evaluate(&policy_id, policy_resource, &args);
403
404 if !decision.is_allowed() {
405 warn!(
406 rule_id = %decision.rule_id,
407 method = %method,
408 resource = policy_resource,
409 principal = %principal,
410 "policy denied"
411 );
412 record_audit(
413 &gate_ledger,
414 request_id,
415 &principal,
416 Direction::Inbound,
417 &method,
418 tool.as_deref(),
419 &format!("deny:policy:{}", decision.rule_id),
420 Some(&decision.rule_id),
421 start.elapsed().as_micros() as u64,
422 )
423 .await;
424 // Fix 3: Send JSON-RPC error on policy deny.
425 let err_id = msg.id.clone().unwrap_or(Value::Null);
426 let err_resp = McpMessage::error_response(
427 err_id,
428 -32003,
429 format!("Denied by policy: {}", decision.rule_id),
430 );
431 if let Err(we) = write_to_client(&gate_client_writer, &err_resp).await {
432 error!(%we, "failed to send policy error to client");
433 break;
434 }
435 continue;
436 }
437 }
438 }
439
440 // ── 5. Ledger: Record allowed request ──
441 record_audit(
442 &gate_ledger,
443 request_id,
444 &principal,
445 Direction::Inbound,
446 &method,
447 tool.as_deref(),
448 "allow",
449 None,
450 start.elapsed().as_micros() as u64,
451 )
452 .await;
453
454 // Forward to server
455 if let Err(e) = server_writer.send(&msg).await {
456 error!(%e, "failed to forward to server");
457 break;
458 }
459 }
460 Err(e) => {
461 // Fix 7: Drop invalid JSON. Send a JSON-RPC parse
462 // error back to the client instead of forwarding.
463 warn!(%e, raw = trimmed, "invalid JSON from client, dropping");
464 let err_resp = McpMessage::error_response(
465 Value::Null,
466 -32700,
467 "Parse error: invalid JSON",
468 );
469 if let Err(we) = write_to_client(&gate_client_writer, &err_resp).await {
470 error!(%we, "failed to send parse error to client");
471 break;
472 }
473 }
474 }
475 }
476 Err(e) => {
477 error!(%e, "error reading from client");
478 break;
479 }
480 }
481 }
482 });
483
484 let schema_store = Arc::clone(&self.schema_store);
485 let outbound_ledger = Arc::clone(&self.ledger);
486 let outbound_scanner = Arc::clone(&self.injection_scanner);
487 let outbound_config = self.config.clone();
488 let outbound_client_writer = Arc::clone(&client_writer);
489
490 // Server -> Client task (outbound interceptor chain)
491 let server_to_client = tokio::spawn(async move {
492 let mut first_tools_list = true;
493 // Cache the last known-good tools/list response for fallback
494 // when critical schema drift is detected.
495 let mut last_good_tools_result: Option<Value> = None;
496
497 loop {
498 match server_reader.recv().await {
499 Ok(msg) => {
500 let start = std::time::Instant::now();
501 let method = msg.method.as_deref().unwrap_or("-").to_string();
502 let outbound_request_id = Uuid::new_v4();
503
504 debug!(
505 method = method.as_str(),
506 id = ?msg.id,
507 "server -> client"
508 );
509
510 let mut forward_msg = msg;
511
512 // ── Schema Pin: Verify tools/list responses ──
513 if outbound_config.enable_schema_pin {
514 if let Some(result) = &forward_msg.result {
515 if result.get("tools").is_some() {
516 let mut store = schema_store.lock().await;
517 if first_tools_list {
518 store.pin_tools(result);
519 info!(pinned = store.len(), "schema pins established");
520 // Cache the first (known-good) tools result.
521 last_good_tools_result = Some(result.clone());
522 first_tools_list = false;
523 } else {
524 let drifts = store.verify_tools(result);
525 if !drifts.is_empty() {
526 let mut has_critical_or_high = false;
527 for drift in &drifts {
528 warn!(
529 tool = %drift.tool_name,
530 drift_type = ?drift.drift_type,
531 severity = ?drift.severity,
532 "schema drift detected"
533 );
534 if matches!(drift.severity, DriftSeverity::Critical | DriftSeverity::High) {
535 has_critical_or_high = true;
536 }
537 }
538
539 // Fix 5: Block Critical/High schema drift.
540 // Send the previously pinned (known-good)
541 // schema instead of the drifted response.
542 if has_critical_or_high {
543 warn!("critical/high schema drift detected -- blocking drifted tools/list");
544 // Fix 9: Record outbound drift block.
545 record_audit(
546 &outbound_ledger,
547 outbound_request_id,
548 "server",
549 Direction::Outbound,
550 "tools/list",
551 None,
552 "deny:schema_drift",
553 None,
554 start.elapsed().as_micros() as u64,
555 )
556 .await;
557
558 if let Some(ref good_result) = last_good_tools_result {
559 // Replace with the known-good schema.
560 forward_msg.result = Some(good_result.clone());
561 } else {
562 // No known-good schema available; send error.
563 let err_id = forward_msg.id.clone().unwrap_or(Value::Null);
564 let err_resp = McpMessage::error_response(
565 err_id,
566 -32003,
567 "Schema drift detected: tool definitions have been tampered with",
568 );
569 if let Err(we) = write_to_client(&outbound_client_writer, &err_resp).await {
570 error!(%we, "failed to send schema drift error to client");
571 break;
572 }
573 continue;
574 }
575 }
576 }
577 }
578 }
579 }
580 }
581
582 // ── Fix 6: Outbound response scanning ──
583 // Scan tool call results for injection patterns (log only).
584 if outbound_config.enable_ward {
585 if let Some(ref result) = forward_msg.result {
586 // Scan content arrays in tool call responses.
587 let scan_target = if let Some(content) = result.get("content") {
588 serde_json::to_string(content).unwrap_or_default()
589 } else {
590 serde_json::to_string(result).unwrap_or_default()
591 };
592
593 if !scan_target.is_empty() {
594 let matches = outbound_scanner.scan_text(&scan_target);
595 if !matches.is_empty() {
596 let pattern_names: Vec<&str> = matches
597 .iter()
598 .map(|m| m.pattern_name.as_str())
599 .collect();
600 warn!(
601 patterns = ?pattern_names,
602 direction = "outbound",
603 "suspicious content in server response"
604 );
605 // Fix 9: Record outbound ward warning.
606 record_audit(
607 &outbound_ledger,
608 outbound_request_id,
609 "server",
610 Direction::Outbound,
611 &method,
612 None,
613 &format!("warn:outbound_injection:{}", pattern_names.join(",")),
614 None,
615 start.elapsed().as_micros() as u64,
616 )
617 .await;
618 // Don't block by default, just log.
619 }
620 }
621 }
622 }
623
624 // ── Fix 9: Record outbound audit entry ──
625 record_audit(
626 &outbound_ledger,
627 outbound_request_id,
628 "server",
629 Direction::Outbound,
630 &method,
631 None,
632 "forward",
633 None,
634 start.elapsed().as_micros() as u64,
635 )
636 .await;
637
638 // Forward to client
639 if let Err(e) = write_to_client(&outbound_client_writer, &forward_msg).await {
640 error!(%e, "failed to write to client");
641 break;
642 }
643 }
644 Err(DomeError::Transport(ref e))
645 if e.kind() == std::io::ErrorKind::UnexpectedEof =>
646 {
647 info!("server closed stdout -- shutting down");
648 break;
649 }
650 Err(e) => {
651 error!(%e, "error reading from server");
652 break;
653 }
654 }
655 }
656 });
657
658 // Wait for either side to finish
659 tokio::select! {
660 r = client_to_server => {
661 if let Err(e) = r {
662 error!(%e, "client->server task panicked");
663 }
664 }
665 r = server_to_client => {
666 if let Err(e) = r {
667 error!(%e, "server->client task panicked");
668 }
669 }
670 }
671
672 // Flush audit log and clean up
673 self.ledger.lock().await.flush();
674 let _ = child.kill().await;
675 info!("MCPDome proxy shut down");
676
677 Ok(())
678 }
679}
680
681/// Write a McpMessage to the client's stdout, with newline and flush.
682async fn write_to_client(
683 writer: &Arc<Mutex<tokio::io::Stdout>>,
684 msg: &McpMessage,
685) -> Result<(), std::io::Error> {
686 match msg.to_json() {
687 Ok(json) => {
688 let mut out = json.into_bytes();
689 out.push(b'\n');
690 let mut w = writer.lock().await;
691 w.write_all(&out).await?;
692 w.flush().await?;
693 Ok(())
694 }
695 Err(e) => {
696 error!(%e, "failed to serialize message for client");
697 Err(std::io::Error::new(std::io::ErrorKind::InvalidData, e))
698 }
699 }
700}
701
702/// Helper to record an audit entry.
703async fn record_audit(
704 ledger: &Arc<Mutex<Ledger>>,
705 request_id: Uuid,
706 identity: &str,
707 direction: Direction,
708 method: &str,
709 tool: Option<&str>,
710 decision: &str,
711 rule_id: Option<&str>,
712 latency_us: u64,
713) {
714 let entry = AuditEntry {
715 seq: 0, // set by ledger
716 timestamp: Utc::now(),
717 request_id,
718 identity: identity.to_string(),
719 direction,
720 method: method.to_string(),
721 tool: tool.map(String::from),
722 decision: decision.to_string(),
723 rule_id: rule_id.map(String::from),
724 latency_us,
725 prev_hash: String::new(), // set by ledger
726 annotations: std::collections::HashMap::new(),
727 };
728
729 if let Err(e) = ledger.lock().await.record(entry) {
730 error!(%e, "failed to record audit entry");
731 }
732}