Skip to main content

dome_gate/
lib.rs

1use std::sync::Arc;
2
3use dome_core::{DomeError, McpMessage};
4use dome_ledger::{AuditEntry, Direction, Ledger};
5use dome_policy::{Identity as PolicyIdentity, SharedPolicyEngine};
6use dome_sentinel::{
7    AnonymousAuthenticator, ApiKeyAuthenticator, Authenticator, IdentityResolver, PskAuthenticator,
8    ResolverConfig,
9};
10use dome_throttle::{BudgetTracker, BudgetTrackerConfig, RateLimiter, RateLimiterConfig};
11use dome_transport::stdio::StdioTransport;
12use dome_ward::schema_pin::DriftSeverity;
13use dome_ward::{InjectionScanner, SchemaPinStore};
14
15use chrono::Utc;
16use serde_json::Value;
17use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
18use tokio::sync::Mutex;
19use tokio::time::{Duration, timeout};
20use tracing::{debug, error, info, warn};
21
22const MAX_LINE_SIZE: usize = 10 * 1024 * 1024; // 10 MB
23const CLIENT_READ_TIMEOUT: Duration = Duration::from_secs(300); // 5 minutes
24use uuid::Uuid;
25
26// ---------------------------------------------------------------------------
27// Interceptor chain action types
28// ---------------------------------------------------------------------------
29
30/// Result of processing an inbound message through the interceptor chain.
31enum InboundAction {
32    /// Forward the (possibly modified) message to the upstream server.
33    Forward(McpMessage),
34    /// Send this error response back to the client (do NOT forward).
35    Deny(McpMessage),
36}
37
38/// Result of processing an outbound message through the interceptor chain.
39enum OutboundAction {
40    /// Forward the (possibly modified) message to the client.
41    Forward(McpMessage),
42    /// Block the message; send this error response to the client instead.
43    Block(McpMessage),
44}
45
46/// Per-session state for the outbound interceptor chain.
47struct OutboundContext {
48    first_tools_list: bool,
49    last_good_tools_result: Option<Value>,
50}
51
52impl OutboundContext {
53    fn new() -> Self {
54        Self {
55            first_tools_list: true,
56            last_good_tools_result: None,
57        }
58    }
59}
60
61// ---------------------------------------------------------------------------
62// Gate configuration
63// ---------------------------------------------------------------------------
64
65/// Configuration for the Gate proxy.
66#[derive(Debug, Clone)]
67pub struct GateConfig {
68    /// Whether to enforce policy (false = transparent pass-through mode).
69    pub enforce_policy: bool,
70    /// Whether to enable injection scanning.
71    pub enable_ward: bool,
72    /// Whether to enable schema pinning.
73    pub enable_schema_pin: bool,
74    /// Whether to enable rate limiting.
75    pub enable_rate_limit: bool,
76    /// Whether to enable budget tracking.
77    pub enable_budget: bool,
78    /// Whether to allow anonymous access.
79    pub allow_anonymous: bool,
80    /// Whether to block outbound responses that contain injection patterns.
81    /// When false (default), outbound injection is logged but not blocked.
82    pub block_outbound_injection: bool,
83}
84
85impl Default for GateConfig {
86    fn default() -> Self {
87        Self {
88            enforce_policy: false,
89            enable_ward: false,
90            enable_schema_pin: false,
91            enable_rate_limit: false,
92            enable_budget: false,
93            allow_anonymous: true,
94            block_outbound_injection: false,
95        }
96    }
97}
98
99// ---------------------------------------------------------------------------
100// Gate — public API
101// ---------------------------------------------------------------------------
102
103/// The Gate -- Thunder Dome's core proxy loop with full interceptor chain.
104///
105/// Interceptor order (inbound, client -> server):
106///   1. Sentinel -- authenticate on `initialize`, resolve identity
107///   2. Throttle -- check rate limits and budget
108///   3. Ward    -- scan for injection patterns in tool arguments
109///   4. Policy  -- evaluate authorization rules
110///   5. Ledger  -- record the decision in the audit chain
111///
112/// Outbound (server -> client):
113///   1. Schema Pin -- verify tools/list responses for drift (block Critical/High)
114///   2. Ward       -- scan outbound tool results for injection patterns
115///   3. Ledger     -- record outbound audit entry
116pub struct Gate {
117    config: GateConfig,
118    resolver: IdentityResolver,
119    policy_engine: Option<SharedPolicyEngine>,
120    rate_limiter: Arc<RateLimiter>,
121    budget_tracker: Arc<BudgetTracker>,
122    injection_scanner: Arc<InjectionScanner>,
123    schema_store: Arc<Mutex<SchemaPinStore>>,
124    ledger: Arc<Mutex<Ledger>>,
125}
126
127impl std::fmt::Debug for Gate {
128    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129        f.debug_struct("Gate")
130            .field("config", &self.config)
131            .field("has_policy_engine", &self.policy_engine.is_some())
132            .finish_non_exhaustive()
133    }
134}
135
136impl Gate {
137    /// Create a new Gate with full interceptor chain.
138    ///
139    /// The `policy_engine` parameter accepts an optional `SharedPolicyEngine`
140    /// (`Arc<ArcSwap<PolicyEngine>>`), which enables hot-reload. The gate reads
141    /// the current policy atomically on every request via `load()`, so swaps
142    /// performed by a [`PolicyWatcher`] are immediately visible without restart.
143    pub fn new(
144        config: GateConfig,
145        authenticators: Vec<Box<dyn Authenticator>>,
146        policy_engine: Option<SharedPolicyEngine>,
147        rate_limiter_config: RateLimiterConfig,
148        budget_config: BudgetTrackerConfig,
149        ledger: Ledger,
150    ) -> Self {
151        Self {
152            resolver: IdentityResolver::new(
153                authenticators,
154                ResolverConfig {
155                    allow_anonymous: config.allow_anonymous,
156                },
157            ),
158            policy_engine,
159            rate_limiter: Arc::new(RateLimiter::new(rate_limiter_config)),
160            budget_tracker: Arc::new(BudgetTracker::new(budget_config)),
161            injection_scanner: Arc::new(InjectionScanner::new()),
162            schema_store: Arc::new(Mutex::new(SchemaPinStore::new())),
163            ledger: Arc::new(Mutex::new(ledger)),
164            config,
165        }
166    }
167
168    /// Create a transparent pass-through Gate (no security enforcement).
169    pub fn transparent(ledger: Ledger) -> Self {
170        Self::new(
171            GateConfig::default(),
172            vec![Box::new(AnonymousAuthenticator)],
173            None,
174            RateLimiterConfig::default(),
175            BudgetTrackerConfig::default(),
176            ledger,
177        )
178    }
179
180    /// Convert the Gate into a shared ProxyState for the proxy loop.
181    fn into_proxy_state(self) -> Arc<ProxyState> {
182        Arc::new(ProxyState {
183            identity: Mutex::new(None),
184            resolver: self.resolver,
185            config: self.config,
186            policy: self.policy_engine,
187            rate_limiter: self.rate_limiter,
188            budget: self.budget_tracker,
189            scanner: self.injection_scanner,
190            schema_store: self.schema_store,
191            ledger: self.ledger,
192        })
193    }
194
195    /// Run the proxy over stdio (stdin/stdout for the client, child process for the server).
196    pub async fn run_stdio(self, command: &str, args: &[&str]) -> Result<(), DomeError> {
197        let state = self.into_proxy_state();
198
199        let transport = StdioTransport::spawn(command, args).await?;
200        let (mut server_reader, mut server_writer, child) = transport.split();
201
202        let client_stdin = tokio::io::stdin();
203        let client_stdout = tokio::io::stdout();
204        let mut client_reader = BufReader::new(client_stdin);
205        let client_writer: Arc<Mutex<tokio::io::Stdout>> = Arc::new(Mutex::new(client_stdout));
206
207        info!("Thunder Dome proxy active -- interceptor chain armed");
208
209        let inbound_state = Arc::clone(&state);
210        let inbound_writer = Arc::clone(&client_writer);
211
212        // Client -> Server task (inbound interceptor chain)
213        let mut client_to_server = tokio::spawn(async move {
214            let mut line = String::new();
215            loop {
216                line.clear();
217                let read_result =
218                    timeout(CLIENT_READ_TIMEOUT, client_reader.read_line(&mut line)).await;
219                let read_result = match read_result {
220                    Ok(inner) => inner,
221                    Err(_) => {
222                        warn!("client read timed out");
223                        break;
224                    }
225                };
226                match read_result {
227                    Ok(0) => {
228                        info!("client closed stdin -- shutting down");
229                        break;
230                    }
231                    Ok(_) => {
232                        if line.len() > MAX_LINE_SIZE {
233                            warn!(
234                                size = line.len(),
235                                max = MAX_LINE_SIZE,
236                                "client message exceeds size limit, dropping"
237                            );
238                            let err_resp = McpMessage::error_response(
239                                Value::Null,
240                                -32600,
241                                "Message too large",
242                            );
243                            if let Err(we) = write_to_client(&inbound_writer, &err_resp).await {
244                                error!(%we, "failed to send size error to client");
245                                break;
246                            }
247                            continue;
248                        }
249                        let trimmed = line.trim();
250                        if trimmed.is_empty() {
251                            continue;
252                        }
253
254                        match McpMessage::parse(trimmed) {
255                            Ok(msg) => match inbound_state.process_inbound(msg).await {
256                                InboundAction::Forward(msg) => {
257                                    if let Err(e) = server_writer.send(&msg).await {
258                                        error!(%e, "failed to forward to server");
259                                        break;
260                                    }
261                                }
262                                InboundAction::Deny(err_resp) => {
263                                    if let Err(we) =
264                                        write_to_client(&inbound_writer, &err_resp).await
265                                    {
266                                        error!(%we, "failed to send error to client");
267                                        break;
268                                    }
269                                }
270                            },
271                            Err(e) => {
272                                warn!(%e, raw = trimmed, "invalid JSON from client, dropping");
273                                let err_resp = McpMessage::error_response(
274                                    Value::Null,
275                                    -32700,
276                                    "Parse error: invalid JSON",
277                                );
278                                if let Err(we) = write_to_client(&inbound_writer, &err_resp).await {
279                                    error!(%we, "failed to send parse error to client");
280                                    break;
281                                }
282                            }
283                        }
284                    }
285                    Err(e) => {
286                        error!(%e, "error reading from client");
287                        break;
288                    }
289                }
290            }
291        });
292
293        let outbound_state = Arc::clone(&state);
294        let outbound_writer = Arc::clone(&client_writer);
295
296        // Server -> Client task (outbound interceptor chain)
297        let mut server_to_client = tokio::spawn(async move {
298            let mut ctx = OutboundContext::new();
299            loop {
300                match server_reader.recv().await {
301                    Ok(msg) => match outbound_state.process_outbound(msg, &mut ctx).await {
302                        OutboundAction::Forward(msg) => {
303                            if let Err(e) = write_to_client(&outbound_writer, &msg).await {
304                                error!(%e, "failed to write to client");
305                                break;
306                            }
307                        }
308                        OutboundAction::Block(err_resp) => {
309                            if let Err(e) = write_to_client(&outbound_writer, &err_resp).await {
310                                error!(%e, "failed to send outbound error to client");
311                                break;
312                            }
313                        }
314                    },
315                    Err(DomeError::Transport(ref e))
316                        if e.kind() == std::io::ErrorKind::UnexpectedEof =>
317                    {
318                        info!("server closed stdout -- shutting down");
319                        break;
320                    }
321                    Err(e) => {
322                        error!(%e, "error reading from server");
323                        break;
324                    }
325                }
326            }
327        });
328
329        // Wait for either side to finish, then abort the other.
330        select_and_abort(&mut client_to_server, &mut server_to_client).await;
331
332        // Flush audit log.
333        state.ledger.lock().await.flush();
334
335        // Graceful child termination: the child's stdin pipe was closed when
336        // the writer task ended/was aborted, so the child should see EOF and
337        // exit. Wait up to 5 seconds, then force-kill as a last resort.
338        shutdown_child(child).await;
339
340        info!("Thunder Dome proxy shut down");
341        Ok(())
342    }
343
344    /// Run the proxy over HTTP+SSE (HTTP server for the client, child process
345    /// for the upstream MCP server).
346    #[cfg(feature = "http")]
347    pub async fn run_http(
348        self,
349        command: &str,
350        args: &[&str],
351        http_config: dome_transport::http::HttpTransportConfig,
352    ) -> Result<(), DomeError> {
353        let state = self.into_proxy_state();
354
355        let transport = StdioTransport::spawn(command, args).await?;
356        let (mut server_reader, mut server_writer, child) = transport.split();
357
358        let http = dome_transport::http::HttpTransport::start(http_config).await?;
359        let (mut http_reader, http_writer, http_handle) = http.split();
360        let http_writer = Arc::new(http_writer);
361
362        info!("Thunder Dome HTTP+SSE proxy active -- interceptor chain armed");
363
364        let inbound_state = Arc::clone(&state);
365        let inbound_http_writer = Arc::clone(&http_writer);
366
367        // Client -> Server task (inbound interceptor chain via HTTP)
368        let mut client_to_server = tokio::spawn(async move {
369            loop {
370                match http_reader.recv().await {
371                    Ok(msg) => match inbound_state.process_inbound(msg).await {
372                        InboundAction::Forward(msg) => {
373                            if let Err(e) = server_writer.send(&msg).await {
374                                error!(%e, "failed to forward to server");
375                                break;
376                            }
377                        }
378                        InboundAction::Deny(err_resp) => {
379                            if let Err(e) = inbound_http_writer.send(&err_resp).await {
380                                warn!(%e, "failed to send error to HTTP client");
381                            }
382                        }
383                    },
384                    Err(e) => {
385                        info!(%e, "HTTP client transport closed");
386                        break;
387                    }
388                }
389            }
390        });
391
392        let outbound_state = Arc::clone(&state);
393        let outbound_http_writer = Arc::clone(&http_writer);
394
395        // Server -> Client task (outbound interceptor chain via HTTP)
396        let mut server_to_client = tokio::spawn(async move {
397            let mut ctx = OutboundContext::new();
398            loop {
399                match server_reader.recv().await {
400                    Ok(msg) => match outbound_state.process_outbound(msg, &mut ctx).await {
401                        OutboundAction::Forward(msg) => {
402                            if let Err(e) = outbound_http_writer.send(&msg).await {
403                                warn!(%e, "failed to send to HTTP client");
404                                break;
405                            }
406                        }
407                        OutboundAction::Block(err_resp) => {
408                            if let Err(e) = outbound_http_writer.send(&err_resp).await {
409                                warn!(%e, "failed to send outbound error to HTTP client");
410                                break;
411                            }
412                        }
413                    },
414                    Err(DomeError::Transport(ref e))
415                        if e.kind() == std::io::ErrorKind::UnexpectedEof =>
416                    {
417                        info!("server closed stdout -- shutting down");
418                        break;
419                    }
420                    Err(e) => {
421                        error!(%e, "error reading from server");
422                        break;
423                    }
424                }
425            }
426        });
427
428        // Wait for either side to finish, then abort the other.
429        select_and_abort(&mut client_to_server, &mut server_to_client).await;
430
431        // Flush audit log.
432        state.ledger.lock().await.flush();
433
434        // Shut down HTTP server.
435        http_handle.shutdown().await;
436
437        // Graceful child termination.
438        shutdown_child(child).await;
439
440        info!("Thunder Dome HTTP+SSE proxy shut down");
441        Ok(())
442    }
443}
444
445// ---------------------------------------------------------------------------
446// ProxyState — shared interceptor chain state
447// ---------------------------------------------------------------------------
448
449/// Internal shared state for the proxy loop. Created from a [`Gate`] at the
450/// start of a proxy session. Holds all interceptor chain components behind
451/// appropriate synchronization primitives so both the inbound (client → server)
452/// and outbound (server → client) tasks can share it via `Arc`.
453struct ProxyState {
454    identity: Mutex<Option<dome_sentinel::Identity>>,
455    resolver: IdentityResolver,
456    config: GateConfig,
457    policy: Option<SharedPolicyEngine>,
458    rate_limiter: Arc<RateLimiter>,
459    budget: Arc<BudgetTracker>,
460    scanner: Arc<InjectionScanner>,
461    schema_store: Arc<Mutex<SchemaPinStore>>,
462    ledger: Arc<Mutex<Ledger>>,
463}
464
465impl ProxyState {
466    /// Process an inbound (client → server) message through the full
467    /// interceptor chain: Sentinel → Throttle → Ward → Policy → Ledger.
468    async fn process_inbound(&self, msg: McpMessage) -> InboundAction {
469        let start = std::time::Instant::now();
470        let method = msg.method.as_deref().unwrap_or("-").to_string();
471        let tool = msg.tool_name().map(String::from);
472        let request_id = Uuid::new_v4();
473
474        debug!(
475            method = method.as_str(),
476            id = ?msg.id,
477            tool = tool.as_deref().unwrap_or("-"),
478            "client -> server"
479        );
480
481        // ── 1. Sentinel: Authenticate on initialize ──
482        let mut msg = msg;
483        if method == "initialize" {
484            match self.resolver.resolve(&msg).await {
485                Ok(id) => {
486                    info!(
487                        principal = %id.principal,
488                        method = %id.auth_method,
489                        "identity resolved"
490                    );
491                    *self.identity.lock().await = Some(id);
492
493                    // Strip all credential fields before forwarding.
494                    msg = PskAuthenticator::strip_psk(&msg);
495                    msg = ApiKeyAuthenticator::strip_api_key(&msg);
496                }
497                Err(e) => {
498                    warn!(%e, "authentication failed");
499                    let err_id = msg.id.clone().unwrap_or(Value::Null);
500                    return InboundAction::Deny(McpMessage::error_response(
501                        err_id,
502                        -32600,
503                        "Authentication failed",
504                    ));
505                }
506            }
507        }
508
509        // Block all non-initialize requests before the session has been
510        // initialized (identity resolved).
511        if method != "initialize" {
512            let identity_lock = self.identity.lock().await;
513            if identity_lock.is_none() {
514                drop(identity_lock);
515                warn!(method = %method, "request before initialize");
516                let err_id = msg.id.clone().unwrap_or(Value::Null);
517                return InboundAction::Deny(McpMessage::error_response(
518                    err_id,
519                    -32600,
520                    "Session not initialized",
521                ));
522            }
523            drop(identity_lock);
524        }
525
526        let identity_lock = self.identity.lock().await;
527        let principal = identity_lock
528            .as_ref()
529            .map(|i| i.principal.clone())
530            .unwrap_or_else(|| "anonymous".to_string());
531        let labels = identity_lock
532            .as_ref()
533            .map(|i| i.labels.clone())
534            .unwrap_or_default();
535        drop(identity_lock);
536
537        // Extract the method-specific resource name for policy evaluation.
538        let resource_name = msg.method_resource_name().unwrap_or("-").to_string();
539        let tool_name = resource_name.as_str();
540
541        let args = msg
542            .params
543            .as_ref()
544            .and_then(|p| p.get("arguments"))
545            .cloned()
546            .unwrap_or(Value::Null);
547
548        // ── 2. Throttle: Rate limit check ──
549        if self.config.enable_rate_limit {
550            let rl_tool = if tool_name != "-" {
551                Some(tool_name)
552            } else {
553                None
554            };
555            if let Err(e) = self.rate_limiter.check_rate_limit(&principal, rl_tool) {
556                warn!(%e, principal = %principal, method = %method, "rate limited");
557                record_audit(
558                    &self.ledger,
559                    AuditParams {
560                        request_id,
561                        identity: &principal,
562                        direction: Direction::Inbound,
563                        method: &method,
564                        tool: tool.as_deref(),
565                        decision: "deny:rate_limit",
566                        rule_id: None,
567                        latency_us: start.elapsed().as_micros() as u64,
568                    },
569                )
570                .await;
571                let err_id = msg.id.clone().unwrap_or(Value::Null);
572                return InboundAction::Deny(McpMessage::error_response(
573                    err_id,
574                    -32000,
575                    "Rate limit exceeded",
576                ));
577            }
578        }
579
580        // ── 2b. Throttle: Budget check ──
581        if self.config.enable_budget
582            && let Err(e) = self.budget.try_spend(&principal, 1.0)
583        {
584            warn!(%e, principal = %principal, "budget exhausted");
585            record_audit(
586                &self.ledger,
587                AuditParams {
588                    request_id,
589                    identity: &principal,
590                    direction: Direction::Inbound,
591                    method: &method,
592                    tool: tool.as_deref(),
593                    decision: "deny:budget",
594                    rule_id: None,
595                    latency_us: start.elapsed().as_micros() as u64,
596                },
597            )
598            .await;
599            let err_id = msg.id.clone().unwrap_or(Value::Null);
600            return InboundAction::Deny(McpMessage::error_response(
601                err_id,
602                -32000,
603                "Budget exhausted",
604            ));
605        }
606
607        // ── 3. Ward: Injection scanning ──
608        // Ward runs BEFORE policy so injection detection is applied regardless
609        // of authorization level. Uses scan_json_value to extract and scan
610        // each leaf string individually, preventing evasion via JSON encoding
611        // or injection text split across JSON structural boundaries.
612        if self.config.enable_ward {
613            let scan_target = if method == "tools/call" {
614                args.clone()
615            } else if let Some(ref params) = msg.params {
616                params.clone()
617            } else {
618                Value::Null
619            };
620
621            if !scan_target.is_null() {
622                let scan_result = dome_ward::scan_json_value(&self.scanner, &scan_target);
623                if !scan_result.pattern_matches.is_empty() {
624                    let pattern_names: Vec<&str> = scan_result
625                        .pattern_matches
626                        .iter()
627                        .map(|m| m.pattern_name.as_str())
628                        .collect();
629                    warn!(
630                        patterns = ?pattern_names,
631                        method = %method,
632                        tool = tool_name,
633                        principal = %principal,
634                        "injection detected"
635                    );
636                    record_audit(
637                        &self.ledger,
638                        AuditParams {
639                            request_id,
640                            identity: &principal,
641                            direction: Direction::Inbound,
642                            method: &method,
643                            tool: tool.as_deref(),
644                            decision: &format!("deny:injection:{}", pattern_names.join(",")),
645                            rule_id: None,
646                            latency_us: start.elapsed().as_micros() as u64,
647                        },
648                    )
649                    .await;
650                    let err_id = msg.id.clone().unwrap_or(Value::Null);
651                    return InboundAction::Deny(McpMessage::error_response(
652                        err_id,
653                        -32003,
654                        "Request blocked: injection pattern detected",
655                    ));
656                }
657            }
658        }
659
660        // ── 4. Policy: Authorization ──
661        if self.config.enforce_policy
662            && let Some(ref shared_engine) = self.policy
663        {
664            // Load the current policy atomically. This is lock-free and
665            // picks up hot-reloaded changes immediately.
666            let engine = shared_engine.load();
667
668            let policy_resource = if method == "tools/call" {
669                tool_name
670            } else {
671                method.as_str()
672            };
673            let policy_id = PolicyIdentity::new(principal.clone(), labels.iter().cloned());
674            let decision = engine.evaluate(&policy_id, policy_resource, &args);
675
676            if !decision.is_allowed() {
677                warn!(
678                    rule_id = %decision.rule_id,
679                    method = %method,
680                    resource = policy_resource,
681                    principal = %principal,
682                    "policy denied"
683                );
684                record_audit(
685                    &self.ledger,
686                    AuditParams {
687                        request_id,
688                        identity: &principal,
689                        direction: Direction::Inbound,
690                        method: &method,
691                        tool: tool.as_deref(),
692                        decision: &format!("deny:policy:{}", decision.rule_id),
693                        rule_id: Some(&decision.rule_id),
694                        latency_us: start.elapsed().as_micros() as u64,
695                    },
696                )
697                .await;
698                let err_id = msg.id.clone().unwrap_or(Value::Null);
699                return InboundAction::Deny(McpMessage::error_response(
700                    err_id,
701                    -32003,
702                    format!("Denied by policy: {}", decision.rule_id),
703                ));
704            }
705        }
706
707        // ── 5. Ledger: Record allowed request ──
708        record_audit(
709            &self.ledger,
710            AuditParams {
711                request_id,
712                identity: &principal,
713                direction: Direction::Inbound,
714                method: &method,
715                tool: tool.as_deref(),
716                decision: "allow",
717                rule_id: None,
718                latency_us: start.elapsed().as_micros() as u64,
719            },
720        )
721        .await;
722
723        InboundAction::Forward(msg)
724    }
725
726    /// Process an outbound (server → client) message through the outbound
727    /// interceptor chain: Schema Pin → Ward → Ledger.
728    async fn process_outbound(&self, msg: McpMessage, ctx: &mut OutboundContext) -> OutboundAction {
729        let start = std::time::Instant::now();
730        let method = msg.method.as_deref().unwrap_or("-").to_string();
731        let outbound_request_id = Uuid::new_v4();
732
733        debug!(
734            method = method.as_str(),
735            id = ?msg.id,
736            "server -> client"
737        );
738
739        let mut forward_msg = msg;
740
741        // ── Schema Pin: Verify tools/list responses ──
742        if self.config.enable_schema_pin
743            && let Some(result) = &forward_msg.result
744            && result.get("tools").is_some()
745        {
746            let mut store = self.schema_store.lock().await;
747            if ctx.first_tools_list {
748                store.pin_tools(result);
749                info!(pinned = store.len(), "schema pins established");
750                ctx.last_good_tools_result = Some(result.clone());
751                ctx.first_tools_list = false;
752            } else {
753                let drifts = store.verify_tools(result);
754                if !drifts.is_empty() {
755                    let has_critical_or_high = drifts.iter().any(|drift| {
756                        warn!(
757                            tool = %drift.tool_name,
758                            drift_type = ?drift.drift_type,
759                            severity = ?drift.severity,
760                            "schema drift detected"
761                        );
762                        matches!(
763                            drift.severity,
764                            DriftSeverity::Critical | DriftSeverity::High
765                        )
766                    });
767
768                    if has_critical_or_high {
769                        warn!("critical/high schema drift detected -- blocking drifted tools/list");
770                        record_audit(
771                            &self.ledger,
772                            AuditParams {
773                                request_id: outbound_request_id,
774                                identity: "server",
775                                direction: Direction::Outbound,
776                                method: "tools/list",
777                                tool: None,
778                                decision: "deny:schema_drift",
779                                rule_id: None,
780                                latency_us: start.elapsed().as_micros() as u64,
781                            },
782                        )
783                        .await;
784
785                        if let Some(ref good_result) = ctx.last_good_tools_result {
786                            forward_msg.result = Some(good_result.clone());
787                        } else {
788                            let err_id = forward_msg.id.clone().unwrap_or(Value::Null);
789                            return OutboundAction::Block(McpMessage::error_response(
790                                err_id,
791                                -32003,
792                                "Schema drift detected: tool definitions have been tampered with",
793                            ));
794                        }
795                    }
796                }
797            }
798        }
799
800        // ── Outbound response scanning ──
801        // Uses scan_json_value to extract and scan each leaf string
802        // individually, matching the inbound scanning approach.
803        if self.config.enable_ward
804            && let Some(ref result) = forward_msg.result
805        {
806            let scan_value = if let Some(content) = result.get("content") {
807                content.clone()
808            } else {
809                result.clone()
810            };
811
812            let scan_result = dome_ward::scan_json_value(&self.scanner, &scan_value);
813            if !scan_result.pattern_matches.is_empty() {
814                let pattern_names: Vec<&str> = scan_result
815                    .pattern_matches
816                    .iter()
817                    .map(|m| m.pattern_name.as_str())
818                    .collect();
819                let decision = if self.config.block_outbound_injection {
820                    "deny:outbound_injection"
821                } else {
822                    "warn:outbound_injection"
823                };
824                warn!(
825                    patterns = ?pattern_names,
826                    direction = "outbound",
827                    blocked = self.config.block_outbound_injection,
828                    "injection detected in server response"
829                );
830                record_audit(
831                    &self.ledger,
832                    AuditParams {
833                        request_id: outbound_request_id,
834                        identity: "server",
835                        direction: Direction::Outbound,
836                        method: &method,
837                        tool: None,
838                        decision: &format!("{}:{}", decision, pattern_names.join(",")),
839                        rule_id: None,
840                        latency_us: start.elapsed().as_micros() as u64,
841                    },
842                )
843                .await;
844
845                if self.config.block_outbound_injection {
846                    let err_id = forward_msg.id.clone().unwrap_or(Value::Null);
847                    return OutboundAction::Block(McpMessage::error_response(
848                        err_id,
849                        -32005,
850                        "Response blocked: injection pattern detected in server output",
851                    ));
852                }
853            }
854        }
855
856        // Record outbound audit entry.
857        record_audit(
858            &self.ledger,
859            AuditParams {
860                request_id: outbound_request_id,
861                identity: "server",
862                direction: Direction::Outbound,
863                method: &method,
864                tool: None,
865                decision: "forward",
866                rule_id: None,
867                latency_us: start.elapsed().as_micros() as u64,
868            },
869        )
870        .await;
871
872        OutboundAction::Forward(forward_msg)
873    }
874}
875
876// ---------------------------------------------------------------------------
877// Helpers
878// ---------------------------------------------------------------------------
879
880/// Race two proxy tasks; abort the survivor when one finishes.
881///
882/// Both `run_stdio` and `run_http` need the same select-and-abort
883/// pattern to ensure that when one direction closes, the other is
884/// cleaned up deterministically. Extracted to avoid duplication.
885async fn select_and_abort(
886    client_to_server: &mut tokio::task::JoinHandle<()>,
887    server_to_client: &mut tokio::task::JoinHandle<()>,
888) {
889    tokio::select! {
890        r = &mut *client_to_server => {
891            if let Err(e) = r {
892                error!(%e, "client->server task panicked");
893            }
894            server_to_client.abort();
895        }
896        r = &mut *server_to_client => {
897            if let Err(e) = r {
898                error!(%e, "server->client task panicked");
899            }
900            client_to_server.abort();
901        }
902    }
903}
904
905/// Wait up to 5 seconds for the upstream child process to exit gracefully,
906/// then force-kill it if it does not.
907async fn shutdown_child(mut child: tokio::process::Child) {
908    match tokio::time::timeout(Duration::from_secs(5), child.wait()).await {
909        Ok(Ok(status)) => info!(%status, "upstream server exited"),
910        Ok(Err(e)) => warn!(%e, "error waiting for upstream server"),
911        Err(_) => {
912            warn!("upstream server did not exit within 5s, forcing termination");
913            let _ = child.kill().await;
914        }
915    }
916}
917
918/// Write a McpMessage to the client's stdout, with newline and flush.
919async fn write_to_client(
920    writer: &Arc<Mutex<tokio::io::Stdout>>,
921    msg: &McpMessage,
922) -> Result<(), std::io::Error> {
923    match msg.to_json() {
924        Ok(json) => {
925            let mut out = json.into_bytes();
926            out.push(b'\n');
927            let mut w = writer.lock().await;
928            w.write_all(&out).await?;
929            w.flush().await?;
930            Ok(())
931        }
932        Err(e) => {
933            error!(%e, "failed to serialize message for client");
934            Err(std::io::Error::new(std::io::ErrorKind::InvalidData, e))
935        }
936    }
937}
938
939/// Parameters for recording a single audit entry.
940///
941/// Groups the 9 fields that were previously passed individually to avoid
942/// long argument lists and make call sites self-documenting.
943struct AuditParams<'a> {
944    request_id: Uuid,
945    identity: &'a str,
946    direction: Direction,
947    method: &'a str,
948    tool: Option<&'a str>,
949    decision: &'a str,
950    rule_id: Option<&'a str>,
951    latency_us: u64,
952}
953
954/// Record a single audit entry in the ledger.
955async fn record_audit(ledger: &Arc<Mutex<Ledger>>, params: AuditParams<'_>) {
956    let entry = AuditEntry {
957        seq: 0, // assigned by ledger
958        timestamp: Utc::now(),
959        request_id: params.request_id,
960        identity: params.identity.to_string(),
961        direction: params.direction,
962        method: params.method.to_string(),
963        tool: params.tool.map(String::from),
964        decision: params.decision.to_string(),
965        rule_id: params.rule_id.map(String::from),
966        latency_us: params.latency_us,
967        prev_hash: String::new(), // assigned by ledger
968        annotations: std::collections::HashMap::new(),
969    };
970
971    if let Err(e) = ledger.lock().await.record(entry) {
972        error!(%e, "failed to record audit entry");
973    }
974}
975
976#[cfg(test)]
977mod tests {
978    use super::*;
979    use dome_ledger::MemorySink;
980
981    // -----------------------------------------------------------------------
982    // Helpers
983    // -----------------------------------------------------------------------
984
985    /// Create a Ledger backed by a MemorySink for test inspection.
986    fn test_ledger() -> Ledger {
987        Ledger::new(vec![Box::new(MemorySink::new())])
988    }
989
990    /// Create a Ledger with no sinks (sufficient when we only care about
991    /// entry count, not sink contents).
992    fn empty_ledger() -> Ledger {
993        Ledger::new(vec![])
994    }
995
996    // -----------------------------------------------------------------------
997    // GateConfig defaults
998    // -----------------------------------------------------------------------
999
1000    #[test]
1001    fn gate_config_defaults_all_security_disabled() {
1002        let config = GateConfig::default();
1003
1004        assert!(
1005            !config.enforce_policy,
1006            "enforce_policy should default to false"
1007        );
1008        assert!(!config.enable_ward, "enable_ward should default to false");
1009        assert!(
1010            !config.enable_schema_pin,
1011            "enable_schema_pin should default to false"
1012        );
1013        assert!(
1014            !config.enable_rate_limit,
1015            "enable_rate_limit should default to false"
1016        );
1017        assert!(
1018            !config.enable_budget,
1019            "enable_budget should default to false"
1020        );
1021        assert!(
1022            config.allow_anonymous,
1023            "allow_anonymous should default to true"
1024        );
1025        assert!(
1026            !config.block_outbound_injection,
1027            "block_outbound_injection should default to false"
1028        );
1029    }
1030
1031    #[test]
1032    fn gate_config_with_all_enabled() {
1033        let config = GateConfig {
1034            enforce_policy: true,
1035            enable_ward: true,
1036            enable_schema_pin: true,
1037            enable_rate_limit: true,
1038            enable_budget: true,
1039            allow_anonymous: false,
1040            block_outbound_injection: true,
1041        };
1042
1043        assert!(config.enforce_policy);
1044        assert!(config.enable_ward);
1045        assert!(config.enable_schema_pin);
1046        assert!(config.enable_rate_limit);
1047        assert!(config.enable_budget);
1048        assert!(!config.allow_anonymous);
1049        assert!(config.block_outbound_injection);
1050    }
1051
1052    #[test]
1053    fn gate_config_is_cloneable() {
1054        let original = GateConfig {
1055            enforce_policy: true,
1056            enable_ward: true,
1057            enable_schema_pin: false,
1058            enable_rate_limit: true,
1059            enable_budget: false,
1060            allow_anonymous: false,
1061            block_outbound_injection: true,
1062        };
1063        let cloned = original.clone();
1064
1065        assert_eq!(cloned.enforce_policy, original.enforce_policy);
1066        assert_eq!(cloned.enable_ward, original.enable_ward);
1067        assert_eq!(cloned.enable_schema_pin, original.enable_schema_pin);
1068        assert_eq!(cloned.enable_rate_limit, original.enable_rate_limit);
1069        assert_eq!(cloned.enable_budget, original.enable_budget);
1070        assert_eq!(cloned.allow_anonymous, original.allow_anonymous);
1071        assert_eq!(
1072            cloned.block_outbound_injection,
1073            original.block_outbound_injection
1074        );
1075    }
1076
1077    #[test]
1078    fn gate_config_is_debug_printable() {
1079        let config = GateConfig::default();
1080        let debug_output = format!("{:?}", config);
1081
1082        assert!(debug_output.contains("GateConfig"));
1083        assert!(debug_output.contains("enforce_policy"));
1084        assert!(debug_output.contains("enable_ward"));
1085    }
1086
1087    // -----------------------------------------------------------------------
1088    // Gate::transparent
1089    // -----------------------------------------------------------------------
1090
1091    #[test]
1092    fn transparent_gate_has_correct_config_defaults() {
1093        let gate = Gate::transparent(empty_ledger());
1094
1095        assert!(
1096            !gate.config.enforce_policy,
1097            "transparent gate should not enforce policy"
1098        );
1099        assert!(
1100            !gate.config.enable_ward,
1101            "transparent gate should not enable ward"
1102        );
1103        assert!(
1104            !gate.config.enable_schema_pin,
1105            "transparent gate should not enable schema pinning"
1106        );
1107        assert!(
1108            !gate.config.enable_rate_limit,
1109            "transparent gate should not enable rate limiting"
1110        );
1111        assert!(
1112            !gate.config.enable_budget,
1113            "transparent gate should not enable budget tracking"
1114        );
1115        assert!(
1116            gate.config.allow_anonymous,
1117            "transparent gate should allow anonymous access"
1118        );
1119        assert!(
1120            !gate.config.block_outbound_injection,
1121            "transparent gate should not block outbound injection"
1122        );
1123    }
1124
1125    #[test]
1126    fn transparent_gate_has_no_policy_engine() {
1127        let gate = Gate::transparent(empty_ledger());
1128        assert!(
1129            gate.policy_engine.is_none(),
1130            "transparent gate should have no policy engine"
1131        );
1132    }
1133
1134    // -----------------------------------------------------------------------
1135    // Gate::new with custom config
1136    // -----------------------------------------------------------------------
1137
1138    #[test]
1139    fn gate_new_with_custom_config_preserves_flags() {
1140        let config = GateConfig {
1141            enforce_policy: true,
1142            enable_ward: true,
1143            enable_schema_pin: true,
1144            enable_rate_limit: true,
1145            enable_budget: true,
1146            allow_anonymous: false,
1147            block_outbound_injection: true,
1148        };
1149
1150        let gate = Gate::new(
1151            config,
1152            vec![Box::new(AnonymousAuthenticator)],
1153            None,
1154            RateLimiterConfig::default(),
1155            BudgetTrackerConfig::default(),
1156            empty_ledger(),
1157        );
1158
1159        assert!(gate.config.enforce_policy);
1160        assert!(gate.config.enable_ward);
1161        assert!(gate.config.enable_schema_pin);
1162        assert!(gate.config.enable_rate_limit);
1163        assert!(gate.config.enable_budget);
1164        assert!(!gate.config.allow_anonymous);
1165        assert!(gate.config.block_outbound_injection);
1166    }
1167
1168    #[test]
1169    fn gate_new_without_policy_engine_stores_none() {
1170        let gate = Gate::new(
1171            GateConfig::default(),
1172            vec![Box::new(AnonymousAuthenticator)],
1173            None,
1174            RateLimiterConfig::default(),
1175            BudgetTrackerConfig::default(),
1176            empty_ledger(),
1177        );
1178
1179        assert!(gate.policy_engine.is_none());
1180    }
1181
1182    // -----------------------------------------------------------------------
1183    // Gate Debug impl
1184    // -----------------------------------------------------------------------
1185
1186    #[test]
1187    fn gate_is_debug_printable() {
1188        let gate = Gate::transparent(empty_ledger());
1189        let debug_output = format!("{:?}", gate);
1190
1191        assert!(debug_output.contains("Gate"));
1192        assert!(debug_output.contains("config"));
1193        assert!(debug_output.contains("has_policy_engine"));
1194    }
1195
1196    // -----------------------------------------------------------------------
1197    // record_audit
1198    // -----------------------------------------------------------------------
1199
1200    #[tokio::test]
1201    async fn record_audit_creates_entry_with_correct_fields() {
1202        let ledger = Arc::new(Mutex::new(test_ledger()));
1203        let request_id = Uuid::new_v4();
1204
1205        record_audit(
1206            &ledger,
1207            AuditParams {
1208                request_id,
1209                identity: "test-user",
1210                direction: Direction::Inbound,
1211                method: "tools/call",
1212                tool: Some("read_file"),
1213                decision: "allow",
1214                rule_id: Some("rule-1"),
1215                latency_us: 42,
1216            },
1217        )
1218        .await;
1219
1220        let ledger_guard = ledger.lock().await;
1221        assert_eq!(
1222            ledger_guard.entry_count(),
1223            1,
1224            "should have recorded exactly one entry"
1225        );
1226    }
1227
1228    #[tokio::test]
1229    async fn record_audit_with_no_tool_and_no_rule() {
1230        let ledger = Arc::new(Mutex::new(test_ledger()));
1231        let request_id = Uuid::new_v4();
1232
1233        record_audit(
1234            &ledger,
1235            AuditParams {
1236                request_id,
1237                identity: "anonymous",
1238                direction: Direction::Outbound,
1239                method: "initialize",
1240                tool: None,
1241                decision: "forward",
1242                rule_id: None,
1243                latency_us: 0,
1244            },
1245        )
1246        .await;
1247
1248        let ledger_guard = ledger.lock().await;
1249        assert_eq!(ledger_guard.entry_count(), 1);
1250    }
1251
1252    #[tokio::test]
1253    async fn record_audit_multiple_entries_increment_count() {
1254        let ledger = Arc::new(Mutex::new(test_ledger()));
1255
1256        for i in 0..5 {
1257            let identity = format!("user-{i}");
1258            record_audit(
1259                &ledger,
1260                AuditParams {
1261                    request_id: Uuid::new_v4(),
1262                    identity: &identity,
1263                    direction: Direction::Inbound,
1264                    method: "tools/call",
1265                    tool: Some("test_tool"),
1266                    decision: "allow",
1267                    rule_id: None,
1268                    latency_us: i * 10,
1269                },
1270            )
1271            .await;
1272        }
1273
1274        let ledger_guard = ledger.lock().await;
1275        assert_eq!(ledger_guard.entry_count(), 5);
1276    }
1277
1278    #[tokio::test]
1279    async fn record_audit_deny_decisions_are_recorded() {
1280        let ledger = Arc::new(Mutex::new(test_ledger()));
1281
1282        record_audit(
1283            &ledger,
1284            AuditParams {
1285                request_id: Uuid::new_v4(),
1286                identity: "malicious-user",
1287                direction: Direction::Inbound,
1288                method: "tools/call",
1289                tool: Some("exec_command"),
1290                decision: "deny:policy:no-exec",
1291                rule_id: Some("no-exec"),
1292                latency_us: 150,
1293            },
1294        )
1295        .await;
1296
1297        record_audit(
1298            &ledger,
1299            AuditParams {
1300                request_id: Uuid::new_v4(),
1301                identity: "spammer",
1302                direction: Direction::Inbound,
1303                method: "tools/call",
1304                tool: Some("spam_tool"),
1305                decision: "deny:rate_limit",
1306                rule_id: None,
1307                latency_us: 5,
1308            },
1309        )
1310        .await;
1311
1312        record_audit(
1313            &ledger,
1314            AuditParams {
1315                request_id: Uuid::new_v4(),
1316                identity: "attacker",
1317                direction: Direction::Inbound,
1318                method: "tools/call",
1319                tool: Some("read_file"),
1320                decision: "deny:injection:prompt_injection",
1321                rule_id: None,
1322                latency_us: 200,
1323            },
1324        )
1325        .await;
1326
1327        let ledger_guard = ledger.lock().await;
1328        assert_eq!(ledger_guard.entry_count(), 3);
1329    }
1330
1331    #[tokio::test]
1332    async fn record_audit_outbound_schema_drift() {
1333        let ledger = Arc::new(Mutex::new(test_ledger()));
1334
1335        record_audit(
1336            &ledger,
1337            AuditParams {
1338                request_id: Uuid::new_v4(),
1339                identity: "server",
1340                direction: Direction::Outbound,
1341                method: "tools/list",
1342                tool: None,
1343                decision: "deny:schema_drift",
1344                rule_id: None,
1345                latency_us: 75,
1346            },
1347        )
1348        .await;
1349
1350        let ledger_guard = ledger.lock().await;
1351        assert_eq!(ledger_guard.entry_count(), 1);
1352    }
1353
1354    // -----------------------------------------------------------------------
1355    // Constants
1356    // -----------------------------------------------------------------------
1357
1358    #[test]
1359    fn max_line_size_is_ten_megabytes() {
1360        assert_eq!(MAX_LINE_SIZE, 10 * 1024 * 1024);
1361    }
1362
1363    #[test]
1364    fn client_read_timeout_is_five_minutes() {
1365        assert_eq!(CLIENT_READ_TIMEOUT, Duration::from_secs(300));
1366    }
1367
1368    // -----------------------------------------------------------------------
1369    // GateConfig partial overrides (common patterns)
1370    // -----------------------------------------------------------------------
1371
1372    #[test]
1373    fn gate_config_ward_only_mode() {
1374        let config = GateConfig {
1375            enable_ward: true,
1376            ..GateConfig::default()
1377        };
1378
1379        assert!(config.enable_ward);
1380        assert!(!config.enforce_policy, "policy should remain off");
1381        assert!(!config.enable_rate_limit, "rate limit should remain off");
1382        assert!(!config.enable_budget, "budget should remain off");
1383        assert!(!config.enable_schema_pin, "schema pin should remain off");
1384        assert!(config.allow_anonymous, "anonymous should remain on");
1385    }
1386
1387    #[test]
1388    fn gate_config_full_security_mode() {
1389        let config = GateConfig {
1390            enforce_policy: true,
1391            enable_ward: true,
1392            enable_schema_pin: true,
1393            enable_rate_limit: true,
1394            enable_budget: true,
1395            allow_anonymous: false,
1396            block_outbound_injection: true,
1397        };
1398
1399        // Every security feature is active
1400        assert!(config.enforce_policy);
1401        assert!(config.enable_ward);
1402        assert!(config.enable_schema_pin);
1403        assert!(config.enable_rate_limit);
1404        assert!(config.enable_budget);
1405        assert!(!config.allow_anonymous);
1406        assert!(config.block_outbound_injection);
1407    }
1408
1409    // -----------------------------------------------------------------------
1410    // Gate construction does not panic
1411    // -----------------------------------------------------------------------
1412
1413    #[test]
1414    fn gate_new_with_multiple_authenticators_does_not_panic() {
1415        let _gate = Gate::new(
1416            GateConfig::default(),
1417            vec![
1418                Box::new(AnonymousAuthenticator),
1419                Box::new(AnonymousAuthenticator),
1420            ],
1421            None,
1422            RateLimiterConfig::default(),
1423            BudgetTrackerConfig::default(),
1424            empty_ledger(),
1425        );
1426    }
1427
1428    #[test]
1429    fn gate_new_with_empty_authenticators_does_not_panic() {
1430        let _gate = Gate::new(
1431            GateConfig::default(),
1432            vec![],
1433            None,
1434            RateLimiterConfig::default(),
1435            BudgetTrackerConfig::default(),
1436            empty_ledger(),
1437        );
1438    }
1439
1440    // -----------------------------------------------------------------------
1441    // Audit entry construction (verifying the shape record_audit produces)
1442    // -----------------------------------------------------------------------
1443
1444    #[tokio::test]
1445    async fn record_audit_entry_fields_populated_correctly() {
1446        // Use a MemorySink so entries are recorded and inspectable.
1447        let ledger = Arc::new(Mutex::new(Ledger::new(vec![Box::new(MemorySink::new())])));
1448        let request_id = Uuid::new_v4();
1449
1450        record_audit(
1451            &ledger,
1452            AuditParams {
1453                request_id,
1454                identity: "psk:dev-key-1",
1455                direction: Direction::Inbound,
1456                method: "tools/call",
1457                tool: Some("read_file"),
1458                decision: "deny:policy:no-read",
1459                rule_id: Some("no-read"),
1460                latency_us: 999,
1461            },
1462        )
1463        .await;
1464
1465        // Verify the ledger recorded it
1466        let guard = ledger.lock().await;
1467        assert_eq!(guard.entry_count(), 1);
1468    }
1469
1470    #[tokio::test]
1471    async fn record_audit_handles_empty_identity() {
1472        let ledger = Arc::new(Mutex::new(test_ledger()));
1473
1474        record_audit(
1475            &ledger,
1476            AuditParams {
1477                request_id: Uuid::new_v4(),
1478                identity: "",
1479                direction: Direction::Inbound,
1480                method: "initialize",
1481                tool: None,
1482                decision: "allow",
1483                rule_id: None,
1484                latency_us: 0,
1485            },
1486        )
1487        .await;
1488
1489        let guard = ledger.lock().await;
1490        assert_eq!(
1491            guard.entry_count(),
1492            1,
1493            "empty identity should not prevent recording"
1494        );
1495    }
1496
1497    #[tokio::test]
1498    async fn record_audit_handles_long_decision_string() {
1499        let ledger = Arc::new(Mutex::new(test_ledger()));
1500        let long_decision = format!("deny:injection:{}", "pattern_name,".repeat(50));
1501
1502        record_audit(
1503            &ledger,
1504            AuditParams {
1505                request_id: Uuid::new_v4(),
1506                identity: "test-user",
1507                direction: Direction::Inbound,
1508                method: "tools/call",
1509                tool: Some("risky_tool"),
1510                decision: &long_decision,
1511                rule_id: None,
1512                latency_us: 500,
1513            },
1514        )
1515        .await;
1516
1517        let guard = ledger.lock().await;
1518        assert_eq!(
1519            guard.entry_count(),
1520            1,
1521            "long decision strings should be accepted"
1522        );
1523    }
1524}