Skip to main content

dome_gate/
lib.rs

1#![allow(clippy::collapsible_if, clippy::too_many_arguments)]
2
3use std::sync::Arc;
4
5use dome_core::{McpMessage, DomeError};
6use dome_ledger::{AuditEntry, Direction, Ledger};
7use dome_policy::{Identity as PolicyIdentity, PolicyEngine};
8use dome_sentinel::{AnonymousAuthenticator, Authenticator, IdentityResolver, PskAuthenticator, ResolverConfig};
9use dome_throttle::{BudgetTracker, BudgetTrackerConfig, RateLimiter, RateLimiterConfig};
10use dome_transport::stdio::StdioTransport;
11use dome_ward::{InjectionScanner, SchemaPinStore};
12use dome_ward::schema_pin::DriftSeverity;
13
14use chrono::Utc;
15use serde_json::Value;
16use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
17use tokio::sync::Mutex;
18use tracing::{debug, error, info, warn};
19use uuid::Uuid;
20
21/// Configuration for the Gate proxy.
22#[derive(Debug, Clone)]
23pub struct GateConfig {
24    /// Whether to enforce policy (false = transparent pass-through mode).
25    pub enforce_policy: bool,
26    /// Whether to enable injection scanning.
27    pub enable_ward: bool,
28    /// Whether to enable schema pinning.
29    pub enable_schema_pin: bool,
30    /// Whether to enable rate limiting.
31    pub enable_rate_limit: bool,
32    /// Whether to enable budget tracking.
33    pub enable_budget: bool,
34    /// Whether to allow anonymous access.
35    pub allow_anonymous: bool,
36}
37
38impl Default for GateConfig {
39    fn default() -> Self {
40        Self {
41            enforce_policy: false,
42            enable_ward: false,
43            enable_schema_pin: false,
44            enable_rate_limit: false,
45            enable_budget: false,
46            allow_anonymous: true,
47        }
48    }
49}
50
51/// The Gate -- MCPDome's core proxy loop with full interceptor chain.
52///
53/// Interceptor order (inbound, client -> server):
54///   1. Sentinel -- authenticate on `initialize`, resolve identity
55///   2. Throttle -- check rate limits and budget
56///   3. Ward    -- scan for injection patterns in tool arguments
57///   4. Policy  -- evaluate authorization rules
58///   5. Ledger  -- record the decision in the audit chain
59///
60/// Outbound (server -> client):
61///   1. Schema Pin -- verify tools/list responses for drift (block Critical/High)
62///   2. Ward       -- scan outbound tool results for injection patterns
63///   3. Ledger     -- record outbound audit entry
64pub struct Gate {
65    config: GateConfig,
66    resolver: IdentityResolver,
67    policy_engine: Option<PolicyEngine>,
68    rate_limiter: Arc<RateLimiter>,
69    budget_tracker: Arc<BudgetTracker>,
70    injection_scanner: Arc<InjectionScanner>,
71    schema_store: Arc<Mutex<SchemaPinStore>>,
72    ledger: Arc<Mutex<Ledger>>,
73}
74
75impl Gate {
76    /// Create a new Gate with full interceptor chain.
77    pub fn new(
78        config: GateConfig,
79        authenticators: Vec<Box<dyn Authenticator>>,
80        policy_engine: Option<PolicyEngine>,
81        rate_limiter_config: RateLimiterConfig,
82        budget_config: BudgetTrackerConfig,
83        ledger: Ledger,
84    ) -> Self {
85        Self {
86            resolver: IdentityResolver::new(
87                authenticators,
88                ResolverConfig {
89                    allow_anonymous: config.allow_anonymous,
90                },
91            ),
92            policy_engine,
93            rate_limiter: Arc::new(RateLimiter::new(rate_limiter_config)),
94            budget_tracker: Arc::new(BudgetTracker::new(budget_config)),
95            injection_scanner: Arc::new(InjectionScanner::new()),
96            schema_store: Arc::new(Mutex::new(SchemaPinStore::new())),
97            ledger: Arc::new(Mutex::new(ledger)),
98            config,
99        }
100    }
101
102    /// Create a transparent pass-through Gate (no security enforcement).
103    pub fn transparent(ledger: Ledger) -> Self {
104        Self::new(
105            GateConfig::default(),
106            vec![Box::new(AnonymousAuthenticator)],
107            None,
108            RateLimiterConfig::default(),
109            BudgetTrackerConfig::default(),
110            ledger,
111        )
112    }
113
114    /// Run the proxy.
115    pub async fn run_stdio(self, command: &str, args: &[&str]) -> Result<(), DomeError> {
116        let transport = StdioTransport::spawn(command, args).await?;
117        let (mut server_reader, mut server_writer, mut child) = transport.split();
118
119        let client_stdin = tokio::io::stdin();
120        let client_stdout = tokio::io::stdout();
121        let mut client_reader = BufReader::new(client_stdin);
122        let client_writer: Arc<Mutex<tokio::io::Stdout>> = Arc::new(Mutex::new(client_stdout));
123
124        info!("MCPDome proxy active -- interceptor chain armed");
125
126        // Shared state for the two tasks
127        let identity: Arc<Mutex<Option<dome_sentinel::Identity>>> = Arc::new(Mutex::new(None));
128
129        let gate_identity = Arc::clone(&identity);
130        let gate_resolver = self.resolver;
131        let gate_policy = self.policy_engine;
132        let gate_rate_limiter = Arc::clone(&self.rate_limiter);
133        let gate_budget = Arc::clone(&self.budget_tracker);
134        let gate_scanner = Arc::clone(&self.injection_scanner);
135        let gate_ledger = Arc::clone(&self.ledger);
136        let gate_config = self.config.clone();
137        let gate_client_writer = Arc::clone(&client_writer);
138
139        // Client -> Server task (inbound interceptor chain)
140        let client_to_server = tokio::spawn(async move {
141            let mut line = String::new();
142            loop {
143                line.clear();
144                match client_reader.read_line(&mut line).await {
145                    Ok(0) => {
146                        info!("client closed stdin -- shutting down");
147                        break;
148                    }
149                    Ok(_) => {
150                        let trimmed = line.trim();
151                        if trimmed.is_empty() {
152                            continue;
153                        }
154
155                        match McpMessage::parse(trimmed) {
156                            Ok(msg) => {
157                                let start = std::time::Instant::now();
158                                let method = msg.method.as_deref().unwrap_or("-").to_string();
159                                let tool = msg.tool_name().map(String::from);
160                                let request_id = Uuid::new_v4();
161
162                                debug!(
163                                    method = method.as_str(),
164                                    id = ?msg.id,
165                                    tool = tool.as_deref().unwrap_or("-"),
166                                    "client -> server"
167                                );
168
169                                // ── 1. Sentinel: Authenticate on initialize ──
170                                let mut msg = msg;
171                                if method == "initialize" {
172                                    match gate_resolver.resolve(&msg).await {
173                                        Ok(id) => {
174                                            info!(
175                                                principal = %id.principal,
176                                                method = %id.auth_method,
177                                                "identity resolved"
178                                            );
179                                            *gate_identity.lock().await = Some(id);
180
181                                            // Fix 4: Strip PSK before forwarding so the
182                                            // downstream server never sees credentials.
183                                            msg = PskAuthenticator::strip_psk(&msg);
184                                        }
185                                        Err(e) => {
186                                            // Fix 1: On auth failure during initialize,
187                                            // send error response and do NOT forward.
188                                            warn!(%e, "authentication failed");
189                                            let err_id = msg.id.clone().unwrap_or(Value::Null);
190                                            let err_resp = McpMessage::error_response(
191                                                err_id,
192                                                -32600,
193                                                "Authentication failed",
194                                            );
195                                            if let Err(we) = write_to_client(&gate_client_writer, &err_resp).await {
196                                                error!(%we, "failed to send auth error to client");
197                                                break;
198                                            }
199                                            continue;
200                                        }
201                                    }
202                                }
203
204                                // Fix 1: Block all non-initialize requests before
205                                // the session has been initialized (identity resolved).
206                                if method != "initialize" {
207                                    let identity_lock = gate_identity.lock().await;
208                                    if identity_lock.is_none() {
209                                        drop(identity_lock);
210                                        warn!(method = %method, "request before initialize");
211                                        let err_id = msg.id.clone().unwrap_or(Value::Null);
212                                        let err_resp = McpMessage::error_response(
213                                            err_id,
214                                            -32600,
215                                            "Session not initialized",
216                                        );
217                                        if let Err(we) = write_to_client(&gate_client_writer, &err_resp).await {
218                                            error!(%we, "failed to send not-initialized error to client");
219                                            break;
220                                        }
221                                        continue;
222                                    }
223                                    drop(identity_lock);
224                                }
225
226                                let identity_lock = gate_identity.lock().await;
227                                let principal = identity_lock
228                                    .as_ref()
229                                    .map(|i| i.principal.clone())
230                                    .unwrap_or_else(|| "anonymous".to_string());
231                                let labels = identity_lock
232                                    .as_ref()
233                                    .map(|i| i.labels.clone())
234                                    .unwrap_or_default();
235                                drop(identity_lock);
236
237                                // Fix 2: Apply interceptor chain to ALL methods,
238                                // not just tools/call.
239
240                                // Extract tool-specific fields conditionally.
241                                let tool_name = if method == "tools/call" {
242                                    tool.as_deref().unwrap_or("unknown")
243                                } else {
244                                    "-"
245                                };
246
247                                let args = msg
248                                    .params
249                                    .as_ref()
250                                    .and_then(|p| p.get("arguments"))
251                                    .cloned()
252                                    .unwrap_or(Value::Null);
253
254                                // ── 2. Throttle: Rate limit check ──
255                                if gate_config.enable_rate_limit {
256                                    let rl_tool = if method == "tools/call" { Some(tool_name) } else { None };
257                                    if let Err(e) = gate_rate_limiter
258                                        .check_rate_limit(&principal, rl_tool)
259                                    {
260                                        warn!(%e, principal = %principal, method = %method, "rate limited");
261                                        record_audit(
262                                            &gate_ledger,
263                                            request_id,
264                                            &principal,
265                                            Direction::Inbound,
266                                            &method,
267                                            tool.as_deref(),
268                                            "deny:rate_limit",
269                                            None,
270                                            start.elapsed().as_micros() as u64,
271                                        )
272                                        .await;
273                                        // Fix 3: Send JSON-RPC error on rate limit deny.
274                                        let err_id = msg.id.clone().unwrap_or(Value::Null);
275                                        let err_resp = McpMessage::error_response(
276                                            err_id,
277                                            -32000,
278                                            "Rate limit exceeded",
279                                        );
280                                        if let Err(we) = write_to_client(&gate_client_writer, &err_resp).await {
281                                            error!(%we, "failed to send rate limit error to client");
282                                            break;
283                                        }
284                                        continue;
285                                    }
286                                }
287
288                                // ── 2b. Throttle: Budget check ──
289                                if gate_config.enable_budget {
290                                    if let Err(e) = gate_budget.try_spend(&principal, 1.0) {
291                                        warn!(%e, principal = %principal, "budget exhausted");
292                                        record_audit(
293                                            &gate_ledger,
294                                            request_id,
295                                            &principal,
296                                            Direction::Inbound,
297                                            &method,
298                                            tool.as_deref(),
299                                            "deny:budget",
300                                            None,
301                                            start.elapsed().as_micros() as u64,
302                                        )
303                                        .await;
304                                        // Fix 3: Send JSON-RPC error on budget deny.
305                                        let err_id = msg.id.clone().unwrap_or(Value::Null);
306                                        let err_resp = McpMessage::error_response(
307                                            err_id,
308                                            -32000,
309                                            "Budget exhausted",
310                                        );
311                                        if let Err(we) = write_to_client(&gate_client_writer, &err_resp).await {
312                                            error!(%we, "failed to send budget error to client");
313                                            break;
314                                        }
315                                        continue;
316                                    }
317                                }
318
319                                // ── 3. Ward: Injection scanning ──
320                                // Fix 8: Ward runs BEFORE policy so injection detection
321                                // is applied regardless of authorization level.
322                                if gate_config.enable_ward {
323                                    // Scan params/arguments if present.
324                                    let scan_text = if method == "tools/call" {
325                                        serde_json::to_string(&args).unwrap_or_default()
326                                    } else if let Some(ref params) = msg.params {
327                                        serde_json::to_string(params).unwrap_or_default()
328                                    } else {
329                                        String::new()
330                                    };
331
332                                    if !scan_text.is_empty() {
333                                        let matches = gate_scanner.scan_text(&scan_text);
334                                        if !matches.is_empty() {
335                                            let pattern_names: Vec<&str> = matches
336                                                .iter()
337                                                .map(|m| m.pattern_name.as_str())
338                                                .collect();
339                                            warn!(
340                                                patterns = ?pattern_names,
341                                                method = %method,
342                                                tool = tool_name,
343                                                principal = %principal,
344                                                "injection detected"
345                                            );
346                                            record_audit(
347                                                &gate_ledger,
348                                                request_id,
349                                                &principal,
350                                                Direction::Inbound,
351                                                &method,
352                                                tool.as_deref(),
353                                                &format!(
354                                                    "deny:injection:{}",
355                                                    pattern_names.join(",")
356                                                ),
357                                                None,
358                                                start.elapsed().as_micros() as u64,
359                                            )
360                                            .await;
361                                            // Fix 3: Send JSON-RPC error on injection deny.
362                                            let err_id = msg.id.clone().unwrap_or(Value::Null);
363                                            let err_resp = McpMessage::error_response(
364                                                err_id,
365                                                -32003,
366                                                "Request blocked: injection pattern detected",
367                                            );
368                                            if let Err(we) = write_to_client(&gate_client_writer, &err_resp).await {
369                                                error!(%we, "failed to send injection error to client");
370                                                break;
371                                            }
372                                            continue;
373                                        }
374                                    }
375                                }
376
377                                // ── 4. Policy: Authorization ──
378                                if gate_config.enforce_policy {
379                                    if let Some(ref engine) = gate_policy {
380                                        // For tools/call, evaluate with the tool name;
381                                        // for other methods, evaluate with the method itself.
382                                        let policy_resource = if method == "tools/call" {
383                                            tool_name
384                                        } else {
385                                            method.as_str()
386                                        };
387                                        let policy_id = PolicyIdentity::new(
388                                            principal.clone(),
389                                            labels.iter().cloned(),
390                                        );
391                                        let decision =
392                                            engine.evaluate(&policy_id, policy_resource, &args);
393
394                                        if !decision.is_allowed() {
395                                            warn!(
396                                                rule_id = %decision.rule_id,
397                                                method = %method,
398                                                resource = policy_resource,
399                                                principal = %principal,
400                                                "policy denied"
401                                            );
402                                            record_audit(
403                                                &gate_ledger,
404                                                request_id,
405                                                &principal,
406                                                Direction::Inbound,
407                                                &method,
408                                                tool.as_deref(),
409                                                &format!("deny:policy:{}", decision.rule_id),
410                                                Some(&decision.rule_id),
411                                                start.elapsed().as_micros() as u64,
412                                            )
413                                            .await;
414                                            // Fix 3: Send JSON-RPC error on policy deny.
415                                            let err_id = msg.id.clone().unwrap_or(Value::Null);
416                                            let err_resp = McpMessage::error_response(
417                                                err_id,
418                                                -32003,
419                                                format!("Denied by policy: {}", decision.rule_id),
420                                            );
421                                            if let Err(we) = write_to_client(&gate_client_writer, &err_resp).await {
422                                                error!(%we, "failed to send policy error to client");
423                                                break;
424                                            }
425                                            continue;
426                                        }
427                                    }
428                                }
429
430                                // ── 5. Ledger: Record allowed request ──
431                                record_audit(
432                                    &gate_ledger,
433                                    request_id,
434                                    &principal,
435                                    Direction::Inbound,
436                                    &method,
437                                    tool.as_deref(),
438                                    "allow",
439                                    None,
440                                    start.elapsed().as_micros() as u64,
441                                )
442                                .await;
443
444                                // Forward to server
445                                if let Err(e) = server_writer.send(&msg).await {
446                                    error!(%e, "failed to forward to server");
447                                    break;
448                                }
449                            }
450                            Err(e) => {
451                                // Fix 7: Drop invalid JSON. Send a JSON-RPC parse
452                                // error back to the client instead of forwarding.
453                                warn!(%e, raw = trimmed, "invalid JSON from client, dropping");
454                                let err_resp = McpMessage::error_response(
455                                    Value::Null,
456                                    -32700,
457                                    "Parse error: invalid JSON",
458                                );
459                                if let Err(we) = write_to_client(&gate_client_writer, &err_resp).await {
460                                    error!(%we, "failed to send parse error to client");
461                                    break;
462                                }
463                            }
464                        }
465                    }
466                    Err(e) => {
467                        error!(%e, "error reading from client");
468                        break;
469                    }
470                }
471            }
472        });
473
474        let schema_store = Arc::clone(&self.schema_store);
475        let outbound_ledger = Arc::clone(&self.ledger);
476        let outbound_scanner = Arc::clone(&self.injection_scanner);
477        let outbound_config = self.config.clone();
478        let outbound_client_writer = Arc::clone(&client_writer);
479
480        // Server -> Client task (outbound interceptor chain)
481        let server_to_client = tokio::spawn(async move {
482            let mut first_tools_list = true;
483            // Cache the last known-good tools/list response for fallback
484            // when critical schema drift is detected.
485            let mut last_good_tools_result: Option<Value> = None;
486
487            loop {
488                match server_reader.recv().await {
489                    Ok(msg) => {
490                        let start = std::time::Instant::now();
491                        let method = msg.method.as_deref().unwrap_or("-").to_string();
492                        let outbound_request_id = Uuid::new_v4();
493
494                        debug!(
495                            method = method.as_str(),
496                            id = ?msg.id,
497                            "server -> client"
498                        );
499
500                        let mut forward_msg = msg;
501
502                        // ── Schema Pin: Verify tools/list responses ──
503                        if outbound_config.enable_schema_pin {
504                            if let Some(result) = &forward_msg.result {
505                                if result.get("tools").is_some() {
506                                    let mut store = schema_store.lock().await;
507                                    if first_tools_list {
508                                        store.pin_tools(result);
509                                        info!(pinned = store.len(), "schema pins established");
510                                        // Cache the first (known-good) tools result.
511                                        last_good_tools_result = Some(result.clone());
512                                        first_tools_list = false;
513                                    } else {
514                                        let drifts = store.verify_tools(result);
515                                        if !drifts.is_empty() {
516                                            let mut has_critical_or_high = false;
517                                            for drift in &drifts {
518                                                warn!(
519                                                    tool = %drift.tool_name,
520                                                    drift_type = ?drift.drift_type,
521                                                    severity = ?drift.severity,
522                                                    "schema drift detected"
523                                                );
524                                                if matches!(drift.severity, DriftSeverity::Critical | DriftSeverity::High) {
525                                                    has_critical_or_high = true;
526                                                }
527                                            }
528
529                                            // Fix 5: Block Critical/High schema drift.
530                                            // Send the previously pinned (known-good)
531                                            // schema instead of the drifted response.
532                                            if has_critical_or_high {
533                                                warn!("critical/high schema drift detected -- blocking drifted tools/list");
534                                                // Fix 9: Record outbound drift block.
535                                                record_audit(
536                                                    &outbound_ledger,
537                                                    outbound_request_id,
538                                                    "server",
539                                                    Direction::Outbound,
540                                                    "tools/list",
541                                                    None,
542                                                    "deny:schema_drift",
543                                                    None,
544                                                    start.elapsed().as_micros() as u64,
545                                                )
546                                                .await;
547
548                                                if let Some(ref good_result) = last_good_tools_result {
549                                                    // Replace with the known-good schema.
550                                                    forward_msg.result = Some(good_result.clone());
551                                                } else {
552                                                    // No known-good schema available; send error.
553                                                    let err_id = forward_msg.id.clone().unwrap_or(Value::Null);
554                                                    let err_resp = McpMessage::error_response(
555                                                        err_id,
556                                                        -32003,
557                                                        "Schema drift detected: tool definitions have been tampered with",
558                                                    );
559                                                    if let Err(we) = write_to_client(&outbound_client_writer, &err_resp).await {
560                                                        error!(%we, "failed to send schema drift error to client");
561                                                        break;
562                                                    }
563                                                    continue;
564                                                }
565                                            }
566                                        }
567                                    }
568                                }
569                            }
570                        }
571
572                        // ── Fix 6: Outbound response scanning ──
573                        // Scan tool call results for injection patterns (log only).
574                        if outbound_config.enable_ward {
575                            if let Some(ref result) = forward_msg.result {
576                                // Scan content arrays in tool call responses.
577                                let scan_target = if let Some(content) = result.get("content") {
578                                    serde_json::to_string(content).unwrap_or_default()
579                                } else {
580                                    serde_json::to_string(result).unwrap_or_default()
581                                };
582
583                                if !scan_target.is_empty() {
584                                    let matches = outbound_scanner.scan_text(&scan_target);
585                                    if !matches.is_empty() {
586                                        let pattern_names: Vec<&str> = matches
587                                            .iter()
588                                            .map(|m| m.pattern_name.as_str())
589                                            .collect();
590                                        warn!(
591                                            patterns = ?pattern_names,
592                                            direction = "outbound",
593                                            "suspicious content in server response"
594                                        );
595                                        // Fix 9: Record outbound ward warning.
596                                        record_audit(
597                                            &outbound_ledger,
598                                            outbound_request_id,
599                                            "server",
600                                            Direction::Outbound,
601                                            &method,
602                                            None,
603                                            &format!("warn:outbound_injection:{}", pattern_names.join(",")),
604                                            None,
605                                            start.elapsed().as_micros() as u64,
606                                        )
607                                        .await;
608                                        // Don't block by default, just log.
609                                    }
610                                }
611                            }
612                        }
613
614                        // ── Fix 9: Record outbound audit entry ──
615                        record_audit(
616                            &outbound_ledger,
617                            outbound_request_id,
618                            "server",
619                            Direction::Outbound,
620                            &method,
621                            None,
622                            "forward",
623                            None,
624                            start.elapsed().as_micros() as u64,
625                        )
626                        .await;
627
628                        // Forward to client
629                        if let Err(e) = write_to_client(&outbound_client_writer, &forward_msg).await {
630                            error!(%e, "failed to write to client");
631                            break;
632                        }
633                    }
634                    Err(DomeError::Transport(ref e))
635                        if e.kind() == std::io::ErrorKind::UnexpectedEof =>
636                    {
637                        info!("server closed stdout -- shutting down");
638                        break;
639                    }
640                    Err(e) => {
641                        error!(%e, "error reading from server");
642                        break;
643                    }
644                }
645            }
646        });
647
648        // Wait for either side to finish
649        tokio::select! {
650            r = client_to_server => {
651                if let Err(e) = r {
652                    error!(%e, "client->server task panicked");
653                }
654            }
655            r = server_to_client => {
656                if let Err(e) = r {
657                    error!(%e, "server->client task panicked");
658                }
659            }
660        }
661
662        // Flush audit log and clean up
663        self.ledger.lock().await.flush();
664        let _ = child.kill().await;
665        info!("MCPDome proxy shut down");
666
667        Ok(())
668    }
669}
670
671/// Write a McpMessage to the client's stdout, with newline and flush.
672async fn write_to_client(
673    writer: &Arc<Mutex<tokio::io::Stdout>>,
674    msg: &McpMessage,
675) -> Result<(), std::io::Error> {
676    match msg.to_json() {
677        Ok(json) => {
678            let mut out = json.into_bytes();
679            out.push(b'\n');
680            let mut w = writer.lock().await;
681            w.write_all(&out).await?;
682            w.flush().await?;
683            Ok(())
684        }
685        Err(e) => {
686            error!(%e, "failed to serialize message for client");
687            Err(std::io::Error::new(std::io::ErrorKind::InvalidData, e))
688        }
689    }
690}
691
692/// Helper to record an audit entry.
693async fn record_audit(
694    ledger: &Arc<Mutex<Ledger>>,
695    request_id: Uuid,
696    identity: &str,
697    direction: Direction,
698    method: &str,
699    tool: Option<&str>,
700    decision: &str,
701    rule_id: Option<&str>,
702    latency_us: u64,
703) {
704    let entry = AuditEntry {
705        seq: 0, // set by ledger
706        timestamp: Utc::now(),
707        request_id,
708        identity: identity.to_string(),
709        direction,
710        method: method.to_string(),
711        tool: tool.map(String::from),
712        decision: decision.to_string(),
713        rule_id: rule_id.map(String::from),
714        latency_us,
715        prev_hash: String::new(), // set by ledger
716        annotations: std::collections::HashMap::new(),
717    };
718
719    if let Err(e) = ledger.lock().await.record(entry) {
720        error!(%e, "failed to record audit entry");
721    }
722}