dome_gate/lib.rs
1#![allow(clippy::collapsible_if, clippy::too_many_arguments)]
2
3use std::sync::Arc;
4
5use dome_core::{McpMessage, DomeError};
6use dome_ledger::{AuditEntry, Direction, Ledger};
7use dome_policy::{Identity as PolicyIdentity, PolicyEngine};
8use dome_sentinel::{AnonymousAuthenticator, Authenticator, IdentityResolver, PskAuthenticator, ResolverConfig};
9use dome_throttle::{BudgetTracker, BudgetTrackerConfig, RateLimiter, RateLimiterConfig};
10use dome_transport::stdio::StdioTransport;
11use dome_ward::{InjectionScanner, SchemaPinStore};
12use dome_ward::schema_pin::DriftSeverity;
13
14use chrono::Utc;
15use serde_json::Value;
16use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
17use tokio::sync::Mutex;
18use tracing::{debug, error, info, warn};
19use uuid::Uuid;
20
21/// Configuration for the Gate proxy.
22#[derive(Debug, Clone)]
23pub struct GateConfig {
24 /// Whether to enforce policy (false = transparent pass-through mode).
25 pub enforce_policy: bool,
26 /// Whether to enable injection scanning.
27 pub enable_ward: bool,
28 /// Whether to enable schema pinning.
29 pub enable_schema_pin: bool,
30 /// Whether to enable rate limiting.
31 pub enable_rate_limit: bool,
32 /// Whether to enable budget tracking.
33 pub enable_budget: bool,
34 /// Whether to allow anonymous access.
35 pub allow_anonymous: bool,
36}
37
38impl Default for GateConfig {
39 fn default() -> Self {
40 Self {
41 enforce_policy: false,
42 enable_ward: false,
43 enable_schema_pin: false,
44 enable_rate_limit: false,
45 enable_budget: false,
46 allow_anonymous: true,
47 }
48 }
49}
50
51/// The Gate -- MCPDome's core proxy loop with full interceptor chain.
52///
53/// Interceptor order (inbound, client -> server):
54/// 1. Sentinel -- authenticate on `initialize`, resolve identity
55/// 2. Throttle -- check rate limits and budget
56/// 3. Ward -- scan for injection patterns in tool arguments
57/// 4. Policy -- evaluate authorization rules
58/// 5. Ledger -- record the decision in the audit chain
59///
60/// Outbound (server -> client):
61/// 1. Schema Pin -- verify tools/list responses for drift (block Critical/High)
62/// 2. Ward -- scan outbound tool results for injection patterns
63/// 3. Ledger -- record outbound audit entry
64pub struct Gate {
65 config: GateConfig,
66 resolver: IdentityResolver,
67 policy_engine: Option<PolicyEngine>,
68 rate_limiter: Arc<RateLimiter>,
69 budget_tracker: Arc<BudgetTracker>,
70 injection_scanner: Arc<InjectionScanner>,
71 schema_store: Arc<Mutex<SchemaPinStore>>,
72 ledger: Arc<Mutex<Ledger>>,
73}
74
75impl Gate {
76 /// Create a new Gate with full interceptor chain.
77 pub fn new(
78 config: GateConfig,
79 authenticators: Vec<Box<dyn Authenticator>>,
80 policy_engine: Option<PolicyEngine>,
81 rate_limiter_config: RateLimiterConfig,
82 budget_config: BudgetTrackerConfig,
83 ledger: Ledger,
84 ) -> Self {
85 Self {
86 resolver: IdentityResolver::new(
87 authenticators,
88 ResolverConfig {
89 allow_anonymous: config.allow_anonymous,
90 },
91 ),
92 policy_engine,
93 rate_limiter: Arc::new(RateLimiter::new(rate_limiter_config)),
94 budget_tracker: Arc::new(BudgetTracker::new(budget_config)),
95 injection_scanner: Arc::new(InjectionScanner::new()),
96 schema_store: Arc::new(Mutex::new(SchemaPinStore::new())),
97 ledger: Arc::new(Mutex::new(ledger)),
98 config,
99 }
100 }
101
102 /// Create a transparent pass-through Gate (no security enforcement).
103 pub fn transparent(ledger: Ledger) -> Self {
104 Self::new(
105 GateConfig::default(),
106 vec![Box::new(AnonymousAuthenticator)],
107 None,
108 RateLimiterConfig::default(),
109 BudgetTrackerConfig::default(),
110 ledger,
111 )
112 }
113
114 /// Run the proxy.
115 pub async fn run_stdio(self, command: &str, args: &[&str]) -> Result<(), DomeError> {
116 let transport = StdioTransport::spawn(command, args).await?;
117 let (mut server_reader, mut server_writer, mut child) = transport.split();
118
119 let client_stdin = tokio::io::stdin();
120 let client_stdout = tokio::io::stdout();
121 let mut client_reader = BufReader::new(client_stdin);
122 let client_writer: Arc<Mutex<tokio::io::Stdout>> = Arc::new(Mutex::new(client_stdout));
123
124 info!("MCPDome proxy active -- interceptor chain armed");
125
126 // Shared state for the two tasks
127 let identity: Arc<Mutex<Option<dome_sentinel::Identity>>> = Arc::new(Mutex::new(None));
128
129 let gate_identity = Arc::clone(&identity);
130 let gate_resolver = self.resolver;
131 let gate_policy = self.policy_engine;
132 let gate_rate_limiter = Arc::clone(&self.rate_limiter);
133 let gate_budget = Arc::clone(&self.budget_tracker);
134 let gate_scanner = Arc::clone(&self.injection_scanner);
135 let gate_ledger = Arc::clone(&self.ledger);
136 let gate_config = self.config.clone();
137 let gate_client_writer = Arc::clone(&client_writer);
138
139 // Client -> Server task (inbound interceptor chain)
140 let client_to_server = tokio::spawn(async move {
141 let mut line = String::new();
142 loop {
143 line.clear();
144 match client_reader.read_line(&mut line).await {
145 Ok(0) => {
146 info!("client closed stdin -- shutting down");
147 break;
148 }
149 Ok(_) => {
150 let trimmed = line.trim();
151 if trimmed.is_empty() {
152 continue;
153 }
154
155 match McpMessage::parse(trimmed) {
156 Ok(msg) => {
157 let start = std::time::Instant::now();
158 let method = msg.method.as_deref().unwrap_or("-").to_string();
159 let tool = msg.tool_name().map(String::from);
160 let request_id = Uuid::new_v4();
161
162 debug!(
163 method = method.as_str(),
164 id = ?msg.id,
165 tool = tool.as_deref().unwrap_or("-"),
166 "client -> server"
167 );
168
169 // ── 1. Sentinel: Authenticate on initialize ──
170 let mut msg = msg;
171 if method == "initialize" {
172 match gate_resolver.resolve(&msg).await {
173 Ok(id) => {
174 info!(
175 principal = %id.principal,
176 method = %id.auth_method,
177 "identity resolved"
178 );
179 *gate_identity.lock().await = Some(id);
180
181 // Fix 4: Strip PSK before forwarding so the
182 // downstream server never sees credentials.
183 msg = PskAuthenticator::strip_psk(&msg);
184 }
185 Err(e) => {
186 // Fix 1: On auth failure during initialize,
187 // send error response and do NOT forward.
188 warn!(%e, "authentication failed");
189 let err_id = msg.id.clone().unwrap_or(Value::Null);
190 let err_resp = McpMessage::error_response(
191 err_id,
192 -32600,
193 "Authentication failed",
194 );
195 if let Err(we) = write_to_client(&gate_client_writer, &err_resp).await {
196 error!(%we, "failed to send auth error to client");
197 break;
198 }
199 continue;
200 }
201 }
202 }
203
204 // Fix 1: Block all non-initialize requests before
205 // the session has been initialized (identity resolved).
206 if method != "initialize" {
207 let identity_lock = gate_identity.lock().await;
208 if identity_lock.is_none() {
209 drop(identity_lock);
210 warn!(method = %method, "request before initialize");
211 let err_id = msg.id.clone().unwrap_or(Value::Null);
212 let err_resp = McpMessage::error_response(
213 err_id,
214 -32600,
215 "Session not initialized",
216 );
217 if let Err(we) = write_to_client(&gate_client_writer, &err_resp).await {
218 error!(%we, "failed to send not-initialized error to client");
219 break;
220 }
221 continue;
222 }
223 drop(identity_lock);
224 }
225
226 let identity_lock = gate_identity.lock().await;
227 let principal = identity_lock
228 .as_ref()
229 .map(|i| i.principal.clone())
230 .unwrap_or_else(|| "anonymous".to_string());
231 let labels = identity_lock
232 .as_ref()
233 .map(|i| i.labels.clone())
234 .unwrap_or_default();
235 drop(identity_lock);
236
237 // Fix 2: Apply interceptor chain to ALL methods,
238 // not just tools/call.
239
240 // Extract tool-specific fields conditionally.
241 let tool_name = if method == "tools/call" {
242 tool.as_deref().unwrap_or("unknown")
243 } else {
244 "-"
245 };
246
247 let args = msg
248 .params
249 .as_ref()
250 .and_then(|p| p.get("arguments"))
251 .cloned()
252 .unwrap_or(Value::Null);
253
254 // ── 2. Throttle: Rate limit check ──
255 if gate_config.enable_rate_limit {
256 let rl_tool = if method == "tools/call" { Some(tool_name) } else { None };
257 if let Err(e) = gate_rate_limiter
258 .check_rate_limit(&principal, rl_tool)
259 {
260 warn!(%e, principal = %principal, method = %method, "rate limited");
261 record_audit(
262 &gate_ledger,
263 request_id,
264 &principal,
265 Direction::Inbound,
266 &method,
267 tool.as_deref(),
268 "deny:rate_limit",
269 None,
270 start.elapsed().as_micros() as u64,
271 )
272 .await;
273 // Fix 3: Send JSON-RPC error on rate limit deny.
274 let err_id = msg.id.clone().unwrap_or(Value::Null);
275 let err_resp = McpMessage::error_response(
276 err_id,
277 -32000,
278 "Rate limit exceeded",
279 );
280 if let Err(we) = write_to_client(&gate_client_writer, &err_resp).await {
281 error!(%we, "failed to send rate limit error to client");
282 break;
283 }
284 continue;
285 }
286 }
287
288 // ── 2b. Throttle: Budget check ──
289 if gate_config.enable_budget {
290 if let Err(e) = gate_budget.try_spend(&principal, 1.0) {
291 warn!(%e, principal = %principal, "budget exhausted");
292 record_audit(
293 &gate_ledger,
294 request_id,
295 &principal,
296 Direction::Inbound,
297 &method,
298 tool.as_deref(),
299 "deny:budget",
300 None,
301 start.elapsed().as_micros() as u64,
302 )
303 .await;
304 // Fix 3: Send JSON-RPC error on budget deny.
305 let err_id = msg.id.clone().unwrap_or(Value::Null);
306 let err_resp = McpMessage::error_response(
307 err_id,
308 -32000,
309 "Budget exhausted",
310 );
311 if let Err(we) = write_to_client(&gate_client_writer, &err_resp).await {
312 error!(%we, "failed to send budget error to client");
313 break;
314 }
315 continue;
316 }
317 }
318
319 // ── 3. Ward: Injection scanning ──
320 // Fix 8: Ward runs BEFORE policy so injection detection
321 // is applied regardless of authorization level.
322 if gate_config.enable_ward {
323 // Scan params/arguments if present.
324 let scan_text = if method == "tools/call" {
325 serde_json::to_string(&args).unwrap_or_default()
326 } else if let Some(ref params) = msg.params {
327 serde_json::to_string(params).unwrap_or_default()
328 } else {
329 String::new()
330 };
331
332 if !scan_text.is_empty() {
333 let matches = gate_scanner.scan_text(&scan_text);
334 if !matches.is_empty() {
335 let pattern_names: Vec<&str> = matches
336 .iter()
337 .map(|m| m.pattern_name.as_str())
338 .collect();
339 warn!(
340 patterns = ?pattern_names,
341 method = %method,
342 tool = tool_name,
343 principal = %principal,
344 "injection detected"
345 );
346 record_audit(
347 &gate_ledger,
348 request_id,
349 &principal,
350 Direction::Inbound,
351 &method,
352 tool.as_deref(),
353 &format!(
354 "deny:injection:{}",
355 pattern_names.join(",")
356 ),
357 None,
358 start.elapsed().as_micros() as u64,
359 )
360 .await;
361 // Fix 3: Send JSON-RPC error on injection deny.
362 let err_id = msg.id.clone().unwrap_or(Value::Null);
363 let err_resp = McpMessage::error_response(
364 err_id,
365 -32003,
366 "Request blocked: injection pattern detected",
367 );
368 if let Err(we) = write_to_client(&gate_client_writer, &err_resp).await {
369 error!(%we, "failed to send injection error to client");
370 break;
371 }
372 continue;
373 }
374 }
375 }
376
377 // ── 4. Policy: Authorization ──
378 if gate_config.enforce_policy {
379 if let Some(ref engine) = gate_policy {
380 // For tools/call, evaluate with the tool name;
381 // for other methods, evaluate with the method itself.
382 let policy_resource = if method == "tools/call" {
383 tool_name
384 } else {
385 method.as_str()
386 };
387 let policy_id = PolicyIdentity::new(
388 principal.clone(),
389 labels.iter().cloned(),
390 );
391 let decision =
392 engine.evaluate(&policy_id, policy_resource, &args);
393
394 if !decision.is_allowed() {
395 warn!(
396 rule_id = %decision.rule_id,
397 method = %method,
398 resource = policy_resource,
399 principal = %principal,
400 "policy denied"
401 );
402 record_audit(
403 &gate_ledger,
404 request_id,
405 &principal,
406 Direction::Inbound,
407 &method,
408 tool.as_deref(),
409 &format!("deny:policy:{}", decision.rule_id),
410 Some(&decision.rule_id),
411 start.elapsed().as_micros() as u64,
412 )
413 .await;
414 // Fix 3: Send JSON-RPC error on policy deny.
415 let err_id = msg.id.clone().unwrap_or(Value::Null);
416 let err_resp = McpMessage::error_response(
417 err_id,
418 -32003,
419 format!("Denied by policy: {}", decision.rule_id),
420 );
421 if let Err(we) = write_to_client(&gate_client_writer, &err_resp).await {
422 error!(%we, "failed to send policy error to client");
423 break;
424 }
425 continue;
426 }
427 }
428 }
429
430 // ── 5. Ledger: Record allowed request ──
431 record_audit(
432 &gate_ledger,
433 request_id,
434 &principal,
435 Direction::Inbound,
436 &method,
437 tool.as_deref(),
438 "allow",
439 None,
440 start.elapsed().as_micros() as u64,
441 )
442 .await;
443
444 // Forward to server
445 if let Err(e) = server_writer.send(&msg).await {
446 error!(%e, "failed to forward to server");
447 break;
448 }
449 }
450 Err(e) => {
451 // Fix 7: Drop invalid JSON. Send a JSON-RPC parse
452 // error back to the client instead of forwarding.
453 warn!(%e, raw = trimmed, "invalid JSON from client, dropping");
454 let err_resp = McpMessage::error_response(
455 Value::Null,
456 -32700,
457 "Parse error: invalid JSON",
458 );
459 if let Err(we) = write_to_client(&gate_client_writer, &err_resp).await {
460 error!(%we, "failed to send parse error to client");
461 break;
462 }
463 }
464 }
465 }
466 Err(e) => {
467 error!(%e, "error reading from client");
468 break;
469 }
470 }
471 }
472 });
473
474 let schema_store = Arc::clone(&self.schema_store);
475 let outbound_ledger = Arc::clone(&self.ledger);
476 let outbound_scanner = Arc::clone(&self.injection_scanner);
477 let outbound_config = self.config.clone();
478 let outbound_client_writer = Arc::clone(&client_writer);
479
480 // Server -> Client task (outbound interceptor chain)
481 let server_to_client = tokio::spawn(async move {
482 let mut first_tools_list = true;
483 // Cache the last known-good tools/list response for fallback
484 // when critical schema drift is detected.
485 let mut last_good_tools_result: Option<Value> = None;
486
487 loop {
488 match server_reader.recv().await {
489 Ok(msg) => {
490 let start = std::time::Instant::now();
491 let method = msg.method.as_deref().unwrap_or("-").to_string();
492 let outbound_request_id = Uuid::new_v4();
493
494 debug!(
495 method = method.as_str(),
496 id = ?msg.id,
497 "server -> client"
498 );
499
500 let mut forward_msg = msg;
501
502 // ── Schema Pin: Verify tools/list responses ──
503 if outbound_config.enable_schema_pin {
504 if let Some(result) = &forward_msg.result {
505 if result.get("tools").is_some() {
506 let mut store = schema_store.lock().await;
507 if first_tools_list {
508 store.pin_tools(result);
509 info!(pinned = store.len(), "schema pins established");
510 // Cache the first (known-good) tools result.
511 last_good_tools_result = Some(result.clone());
512 first_tools_list = false;
513 } else {
514 let drifts = store.verify_tools(result);
515 if !drifts.is_empty() {
516 let mut has_critical_or_high = false;
517 for drift in &drifts {
518 warn!(
519 tool = %drift.tool_name,
520 drift_type = ?drift.drift_type,
521 severity = ?drift.severity,
522 "schema drift detected"
523 );
524 if matches!(drift.severity, DriftSeverity::Critical | DriftSeverity::High) {
525 has_critical_or_high = true;
526 }
527 }
528
529 // Fix 5: Block Critical/High schema drift.
530 // Send the previously pinned (known-good)
531 // schema instead of the drifted response.
532 if has_critical_or_high {
533 warn!("critical/high schema drift detected -- blocking drifted tools/list");
534 // Fix 9: Record outbound drift block.
535 record_audit(
536 &outbound_ledger,
537 outbound_request_id,
538 "server",
539 Direction::Outbound,
540 "tools/list",
541 None,
542 "deny:schema_drift",
543 None,
544 start.elapsed().as_micros() as u64,
545 )
546 .await;
547
548 if let Some(ref good_result) = last_good_tools_result {
549 // Replace with the known-good schema.
550 forward_msg.result = Some(good_result.clone());
551 } else {
552 // No known-good schema available; send error.
553 let err_id = forward_msg.id.clone().unwrap_or(Value::Null);
554 let err_resp = McpMessage::error_response(
555 err_id,
556 -32003,
557 "Schema drift detected: tool definitions have been tampered with",
558 );
559 if let Err(we) = write_to_client(&outbound_client_writer, &err_resp).await {
560 error!(%we, "failed to send schema drift error to client");
561 break;
562 }
563 continue;
564 }
565 }
566 }
567 }
568 }
569 }
570 }
571
572 // ── Fix 6: Outbound response scanning ──
573 // Scan tool call results for injection patterns (log only).
574 if outbound_config.enable_ward {
575 if let Some(ref result) = forward_msg.result {
576 // Scan content arrays in tool call responses.
577 let scan_target = if let Some(content) = result.get("content") {
578 serde_json::to_string(content).unwrap_or_default()
579 } else {
580 serde_json::to_string(result).unwrap_or_default()
581 };
582
583 if !scan_target.is_empty() {
584 let matches = outbound_scanner.scan_text(&scan_target);
585 if !matches.is_empty() {
586 let pattern_names: Vec<&str> = matches
587 .iter()
588 .map(|m| m.pattern_name.as_str())
589 .collect();
590 warn!(
591 patterns = ?pattern_names,
592 direction = "outbound",
593 "suspicious content in server response"
594 );
595 // Fix 9: Record outbound ward warning.
596 record_audit(
597 &outbound_ledger,
598 outbound_request_id,
599 "server",
600 Direction::Outbound,
601 &method,
602 None,
603 &format!("warn:outbound_injection:{}", pattern_names.join(",")),
604 None,
605 start.elapsed().as_micros() as u64,
606 )
607 .await;
608 // Don't block by default, just log.
609 }
610 }
611 }
612 }
613
614 // ── Fix 9: Record outbound audit entry ──
615 record_audit(
616 &outbound_ledger,
617 outbound_request_id,
618 "server",
619 Direction::Outbound,
620 &method,
621 None,
622 "forward",
623 None,
624 start.elapsed().as_micros() as u64,
625 )
626 .await;
627
628 // Forward to client
629 if let Err(e) = write_to_client(&outbound_client_writer, &forward_msg).await {
630 error!(%e, "failed to write to client");
631 break;
632 }
633 }
634 Err(DomeError::Transport(ref e))
635 if e.kind() == std::io::ErrorKind::UnexpectedEof =>
636 {
637 info!("server closed stdout -- shutting down");
638 break;
639 }
640 Err(e) => {
641 error!(%e, "error reading from server");
642 break;
643 }
644 }
645 }
646 });
647
648 // Wait for either side to finish
649 tokio::select! {
650 r = client_to_server => {
651 if let Err(e) = r {
652 error!(%e, "client->server task panicked");
653 }
654 }
655 r = server_to_client => {
656 if let Err(e) = r {
657 error!(%e, "server->client task panicked");
658 }
659 }
660 }
661
662 // Flush audit log and clean up
663 self.ledger.lock().await.flush();
664 let _ = child.kill().await;
665 info!("MCPDome proxy shut down");
666
667 Ok(())
668 }
669}
670
671/// Write a McpMessage to the client's stdout, with newline and flush.
672async fn write_to_client(
673 writer: &Arc<Mutex<tokio::io::Stdout>>,
674 msg: &McpMessage,
675) -> Result<(), std::io::Error> {
676 match msg.to_json() {
677 Ok(json) => {
678 let mut out = json.into_bytes();
679 out.push(b'\n');
680 let mut w = writer.lock().await;
681 w.write_all(&out).await?;
682 w.flush().await?;
683 Ok(())
684 }
685 Err(e) => {
686 error!(%e, "failed to serialize message for client");
687 Err(std::io::Error::new(std::io::ErrorKind::InvalidData, e))
688 }
689 }
690}
691
692/// Helper to record an audit entry.
693async fn record_audit(
694 ledger: &Arc<Mutex<Ledger>>,
695 request_id: Uuid,
696 identity: &str,
697 direction: Direction,
698 method: &str,
699 tool: Option<&str>,
700 decision: &str,
701 rule_id: Option<&str>,
702 latency_us: u64,
703) {
704 let entry = AuditEntry {
705 seq: 0, // set by ledger
706 timestamp: Utc::now(),
707 request_id,
708 identity: identity.to_string(),
709 direction,
710 method: method.to_string(),
711 tool: tool.map(String::from),
712 decision: decision.to_string(),
713 rule_id: rule_id.map(String::from),
714 latency_us,
715 prev_hash: String::new(), // set by ledger
716 annotations: std::collections::HashMap::new(),
717 };
718
719 if let Err(e) = ledger.lock().await.record(entry) {
720 error!(%e, "failed to record audit entry");
721 }
722}