Skip to main content

dome_gate/
lib.rs

1#![allow(clippy::collapsible_if, clippy::too_many_arguments)]
2
3use std::sync::Arc;
4
5use dome_core::{DomeError, McpMessage};
6use dome_ledger::{AuditEntry, Direction, Ledger};
7use dome_policy::{Identity as PolicyIdentity, SharedPolicyEngine};
8use dome_sentinel::{
9    AnonymousAuthenticator, ApiKeyAuthenticator, Authenticator, IdentityResolver, PskAuthenticator,
10    ResolverConfig,
11};
12use dome_throttle::{BudgetTracker, BudgetTrackerConfig, RateLimiter, RateLimiterConfig};
13use dome_transport::stdio::StdioTransport;
14use dome_ward::schema_pin::DriftSeverity;
15use dome_ward::{InjectionScanner, SchemaPinStore};
16
17use chrono::Utc;
18use serde_json::Value;
19use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
20use tokio::sync::Mutex;
21use tokio::time::{Duration, timeout};
22use tracing::{debug, error, info, warn};
23
24const MAX_LINE_SIZE: usize = 10 * 1024 * 1024; // 10 MB
25const CLIENT_READ_TIMEOUT: Duration = Duration::from_secs(300); // 5 minutes
26use uuid::Uuid;
27
28// ---------------------------------------------------------------------------
29// Interceptor chain action types
30// ---------------------------------------------------------------------------
31
32/// Result of processing an inbound message through the interceptor chain.
33enum InboundAction {
34    /// Forward the (possibly modified) message to the upstream server.
35    Forward(McpMessage),
36    /// Send this error response back to the client (do NOT forward).
37    Deny(McpMessage),
38}
39
40/// Result of processing an outbound message through the interceptor chain.
41enum OutboundAction {
42    /// Forward the (possibly modified) message to the client.
43    Forward(McpMessage),
44    /// Block the message; send this error response to the client instead.
45    Block(McpMessage),
46}
47
48/// Per-session state for the outbound interceptor chain.
49struct OutboundContext {
50    first_tools_list: bool,
51    last_good_tools_result: Option<Value>,
52}
53
54impl OutboundContext {
55    fn new() -> Self {
56        Self {
57            first_tools_list: true,
58            last_good_tools_result: None,
59        }
60    }
61}
62
63// ---------------------------------------------------------------------------
64// Gate configuration
65// ---------------------------------------------------------------------------
66
67/// Configuration for the Gate proxy.
68#[derive(Debug, Clone)]
69pub struct GateConfig {
70    /// Whether to enforce policy (false = transparent pass-through mode).
71    pub enforce_policy: bool,
72    /// Whether to enable injection scanning.
73    pub enable_ward: bool,
74    /// Whether to enable schema pinning.
75    pub enable_schema_pin: bool,
76    /// Whether to enable rate limiting.
77    pub enable_rate_limit: bool,
78    /// Whether to enable budget tracking.
79    pub enable_budget: bool,
80    /// Whether to allow anonymous access.
81    pub allow_anonymous: bool,
82    /// Whether to block outbound responses that contain injection patterns.
83    /// When false (default), outbound injection is logged but not blocked.
84    pub block_outbound_injection: bool,
85}
86
87impl Default for GateConfig {
88    fn default() -> Self {
89        Self {
90            enforce_policy: false,
91            enable_ward: false,
92            enable_schema_pin: false,
93            enable_rate_limit: false,
94            enable_budget: false,
95            allow_anonymous: true,
96            block_outbound_injection: false,
97        }
98    }
99}
100
101// ---------------------------------------------------------------------------
102// Gate — public API
103// ---------------------------------------------------------------------------
104
105/// The Gate -- MCPDome's core proxy loop with full interceptor chain.
106///
107/// Interceptor order (inbound, client -> server):
108///   1. Sentinel -- authenticate on `initialize`, resolve identity
109///   2. Throttle -- check rate limits and budget
110///   3. Ward    -- scan for injection patterns in tool arguments
111///   4. Policy  -- evaluate authorization rules
112///   5. Ledger  -- record the decision in the audit chain
113///
114/// Outbound (server -> client):
115///   1. Schema Pin -- verify tools/list responses for drift (block Critical/High)
116///   2. Ward       -- scan outbound tool results for injection patterns
117///   3. Ledger     -- record outbound audit entry
118pub struct Gate {
119    config: GateConfig,
120    resolver: IdentityResolver,
121    policy_engine: Option<SharedPolicyEngine>,
122    rate_limiter: Arc<RateLimiter>,
123    budget_tracker: Arc<BudgetTracker>,
124    injection_scanner: Arc<InjectionScanner>,
125    schema_store: Arc<Mutex<SchemaPinStore>>,
126    ledger: Arc<Mutex<Ledger>>,
127}
128
129impl std::fmt::Debug for Gate {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        f.debug_struct("Gate")
132            .field("config", &self.config)
133            .field("has_policy_engine", &self.policy_engine.is_some())
134            .finish_non_exhaustive()
135    }
136}
137
138impl Gate {
139    /// Create a new Gate with full interceptor chain.
140    ///
141    /// The `policy_engine` parameter accepts an optional `SharedPolicyEngine`
142    /// (`Arc<ArcSwap<PolicyEngine>>`), which enables hot-reload. The gate reads
143    /// the current policy atomically on every request via `load()`, so swaps
144    /// performed by a [`PolicyWatcher`] are immediately visible without restart.
145    pub fn new(
146        config: GateConfig,
147        authenticators: Vec<Box<dyn Authenticator>>,
148        policy_engine: Option<SharedPolicyEngine>,
149        rate_limiter_config: RateLimiterConfig,
150        budget_config: BudgetTrackerConfig,
151        ledger: Ledger,
152    ) -> Self {
153        Self {
154            resolver: IdentityResolver::new(
155                authenticators,
156                ResolverConfig {
157                    allow_anonymous: config.allow_anonymous,
158                },
159            ),
160            policy_engine,
161            rate_limiter: Arc::new(RateLimiter::new(rate_limiter_config)),
162            budget_tracker: Arc::new(BudgetTracker::new(budget_config)),
163            injection_scanner: Arc::new(InjectionScanner::new()),
164            schema_store: Arc::new(Mutex::new(SchemaPinStore::new())),
165            ledger: Arc::new(Mutex::new(ledger)),
166            config,
167        }
168    }
169
170    /// Create a transparent pass-through Gate (no security enforcement).
171    pub fn transparent(ledger: Ledger) -> Self {
172        Self::new(
173            GateConfig::default(),
174            vec![Box::new(AnonymousAuthenticator)],
175            None,
176            RateLimiterConfig::default(),
177            BudgetTrackerConfig::default(),
178            ledger,
179        )
180    }
181
182    /// Convert the Gate into a shared ProxyState for the proxy loop.
183    fn into_proxy_state(self) -> Arc<ProxyState> {
184        Arc::new(ProxyState {
185            identity: Mutex::new(None),
186            resolver: self.resolver,
187            config: self.config,
188            policy: self.policy_engine,
189            rate_limiter: self.rate_limiter,
190            budget: self.budget_tracker,
191            scanner: self.injection_scanner,
192            schema_store: self.schema_store,
193            ledger: self.ledger,
194        })
195    }
196
197    /// Run the proxy over stdio (stdin/stdout for the client, child process for the server).
198    pub async fn run_stdio(self, command: &str, args: &[&str]) -> Result<(), DomeError> {
199        let state = self.into_proxy_state();
200
201        let transport = StdioTransport::spawn(command, args).await?;
202        let (mut server_reader, mut server_writer, mut child) = transport.split();
203
204        let client_stdin = tokio::io::stdin();
205        let client_stdout = tokio::io::stdout();
206        let mut client_reader = BufReader::new(client_stdin);
207        let client_writer: Arc<Mutex<tokio::io::Stdout>> = Arc::new(Mutex::new(client_stdout));
208
209        info!("MCPDome proxy active -- interceptor chain armed");
210
211        let inbound_state = Arc::clone(&state);
212        let inbound_writer = Arc::clone(&client_writer);
213
214        // Client -> Server task (inbound interceptor chain)
215        let mut client_to_server = tokio::spawn(async move {
216            let mut line = String::new();
217            loop {
218                line.clear();
219                let read_result =
220                    timeout(CLIENT_READ_TIMEOUT, client_reader.read_line(&mut line)).await;
221                let read_result = match read_result {
222                    Ok(inner) => inner,
223                    Err(_) => {
224                        warn!("client read timed out");
225                        break;
226                    }
227                };
228                match read_result {
229                    Ok(0) => {
230                        info!("client closed stdin -- shutting down");
231                        break;
232                    }
233                    Ok(_) => {
234                        if line.len() > MAX_LINE_SIZE {
235                            warn!(
236                                size = line.len(),
237                                max = MAX_LINE_SIZE,
238                                "client message exceeds size limit, dropping"
239                            );
240                            let err_resp = McpMessage::error_response(
241                                Value::Null,
242                                -32600,
243                                "Message too large",
244                            );
245                            if let Err(we) = write_to_client(&inbound_writer, &err_resp).await {
246                                error!(%we, "failed to send size error to client");
247                                break;
248                            }
249                            continue;
250                        }
251                        let trimmed = line.trim();
252                        if trimmed.is_empty() {
253                            continue;
254                        }
255
256                        match McpMessage::parse(trimmed) {
257                            Ok(msg) => match inbound_state.process_inbound(msg).await {
258                                InboundAction::Forward(msg) => {
259                                    if let Err(e) = server_writer.send(&msg).await {
260                                        error!(%e, "failed to forward to server");
261                                        break;
262                                    }
263                                }
264                                InboundAction::Deny(err_resp) => {
265                                    if let Err(we) =
266                                        write_to_client(&inbound_writer, &err_resp).await
267                                    {
268                                        error!(%we, "failed to send error to client");
269                                        break;
270                                    }
271                                }
272                            },
273                            Err(e) => {
274                                warn!(%e, raw = trimmed, "invalid JSON from client, dropping");
275                                let err_resp = McpMessage::error_response(
276                                    Value::Null,
277                                    -32700,
278                                    "Parse error: invalid JSON",
279                                );
280                                if let Err(we) = write_to_client(&inbound_writer, &err_resp).await {
281                                    error!(%we, "failed to send parse error to client");
282                                    break;
283                                }
284                            }
285                        }
286                    }
287                    Err(e) => {
288                        error!(%e, "error reading from client");
289                        break;
290                    }
291                }
292            }
293        });
294
295        let outbound_state = Arc::clone(&state);
296        let outbound_writer = Arc::clone(&client_writer);
297
298        // Server -> Client task (outbound interceptor chain)
299        let mut server_to_client = tokio::spawn(async move {
300            let mut ctx = OutboundContext::new();
301            loop {
302                match server_reader.recv().await {
303                    Ok(msg) => match outbound_state.process_outbound(msg, &mut ctx).await {
304                        OutboundAction::Forward(msg) => {
305                            if let Err(e) = write_to_client(&outbound_writer, &msg).await {
306                                error!(%e, "failed to write to client");
307                                break;
308                            }
309                        }
310                        OutboundAction::Block(err_resp) => {
311                            if let Err(e) = write_to_client(&outbound_writer, &err_resp).await {
312                                error!(%e, "failed to send outbound error to client");
313                                break;
314                            }
315                        }
316                    },
317                    Err(DomeError::Transport(ref e))
318                        if e.kind() == std::io::ErrorKind::UnexpectedEof =>
319                    {
320                        info!("server closed stdout -- shutting down");
321                        break;
322                    }
323                    Err(e) => {
324                        error!(%e, "error reading from server");
325                        break;
326                    }
327                }
328            }
329        });
330
331        // Wait for either side to finish, then abort the other.
332        tokio::select! {
333            r = &mut client_to_server => {
334                if let Err(e) = r {
335                    error!(%e, "client->server task panicked");
336                }
337                server_to_client.abort();
338            }
339            r = &mut server_to_client => {
340                if let Err(e) = r {
341                    error!(%e, "server->client task panicked");
342                }
343                client_to_server.abort();
344            }
345        }
346
347        // Flush audit log.
348        state.ledger.lock().await.flush();
349
350        // Graceful child termination: the child's stdin pipe was closed when
351        // the writer task ended/was aborted, so the child should see EOF and
352        // exit. Wait up to 5 seconds, then force-kill as a last resort.
353        match tokio::time::timeout(Duration::from_secs(5), child.wait()).await {
354            Ok(Ok(status)) => info!(%status, "upstream server exited"),
355            Ok(Err(e)) => warn!(%e, "error waiting for upstream server"),
356            Err(_) => {
357                warn!("upstream server did not exit within 5s, forcing termination");
358                let _ = child.kill().await;
359            }
360        }
361
362        info!("MCPDome proxy shut down");
363        Ok(())
364    }
365
366    /// Run the proxy over HTTP+SSE (HTTP server for the client, child process
367    /// for the upstream MCP server).
368    #[cfg(feature = "http")]
369    pub async fn run_http(
370        self,
371        command: &str,
372        args: &[&str],
373        http_config: dome_transport::http::HttpTransportConfig,
374    ) -> Result<(), DomeError> {
375        let state = self.into_proxy_state();
376
377        let transport = StdioTransport::spawn(command, args).await?;
378        let (mut server_reader, mut server_writer, mut child) = transport.split();
379
380        let http = dome_transport::http::HttpTransport::start(http_config).await?;
381        let (mut http_reader, http_writer, http_handle) = http.split();
382        let http_writer = Arc::new(http_writer);
383
384        info!("MCPDome HTTP+SSE proxy active -- interceptor chain armed");
385
386        let inbound_state = Arc::clone(&state);
387        let inbound_http_writer = Arc::clone(&http_writer);
388
389        // Client -> Server task (inbound interceptor chain via HTTP)
390        let mut client_to_server = tokio::spawn(async move {
391            loop {
392                match http_reader.recv().await {
393                    Ok(msg) => match inbound_state.process_inbound(msg).await {
394                        InboundAction::Forward(msg) => {
395                            if let Err(e) = server_writer.send(&msg).await {
396                                error!(%e, "failed to forward to server");
397                                break;
398                            }
399                        }
400                        InboundAction::Deny(err_resp) => {
401                            if let Err(e) = inbound_http_writer.send(&err_resp).await {
402                                warn!(%e, "failed to send error to HTTP client");
403                            }
404                        }
405                    },
406                    Err(e) => {
407                        info!(%e, "HTTP client transport closed");
408                        break;
409                    }
410                }
411            }
412        });
413
414        let outbound_state = Arc::clone(&state);
415        let outbound_http_writer = Arc::clone(&http_writer);
416
417        // Server -> Client task (outbound interceptor chain via HTTP)
418        let mut server_to_client = tokio::spawn(async move {
419            let mut ctx = OutboundContext::new();
420            loop {
421                match server_reader.recv().await {
422                    Ok(msg) => match outbound_state.process_outbound(msg, &mut ctx).await {
423                        OutboundAction::Forward(msg) => {
424                            if let Err(e) = outbound_http_writer.send(&msg).await {
425                                warn!(%e, "failed to send to HTTP client");
426                                break;
427                            }
428                        }
429                        OutboundAction::Block(err_resp) => {
430                            if let Err(e) = outbound_http_writer.send(&err_resp).await {
431                                warn!(%e, "failed to send outbound error to HTTP client");
432                                break;
433                            }
434                        }
435                    },
436                    Err(DomeError::Transport(ref e))
437                        if e.kind() == std::io::ErrorKind::UnexpectedEof =>
438                    {
439                        info!("server closed stdout -- shutting down");
440                        break;
441                    }
442                    Err(e) => {
443                        error!(%e, "error reading from server");
444                        break;
445                    }
446                }
447            }
448        });
449
450        // Wait for either side to finish, then abort the other.
451        tokio::select! {
452            r = &mut client_to_server => {
453                if let Err(e) = r {
454                    error!(%e, "client->server task panicked");
455                }
456                server_to_client.abort();
457            }
458            r = &mut server_to_client => {
459                if let Err(e) = r {
460                    error!(%e, "server->client task panicked");
461                }
462                client_to_server.abort();
463            }
464        }
465
466        // Flush audit log.
467        state.ledger.lock().await.flush();
468
469        // Shut down HTTP server.
470        http_handle.shutdown().await;
471
472        // Graceful child termination.
473        match tokio::time::timeout(Duration::from_secs(5), child.wait()).await {
474            Ok(Ok(status)) => info!(%status, "upstream server exited"),
475            Ok(Err(e)) => warn!(%e, "error waiting for upstream server"),
476            Err(_) => {
477                warn!("upstream server did not exit within 5s, forcing termination");
478                let _ = child.kill().await;
479            }
480        }
481
482        info!("MCPDome HTTP+SSE proxy shut down");
483        Ok(())
484    }
485}
486
487// ---------------------------------------------------------------------------
488// ProxyState — shared interceptor chain state
489// ---------------------------------------------------------------------------
490
491/// Internal shared state for the proxy loop. Created from a [`Gate`] at the
492/// start of a proxy session. Holds all interceptor chain components behind
493/// appropriate synchronization primitives so both the inbound (client → server)
494/// and outbound (server → client) tasks can share it via `Arc`.
495struct ProxyState {
496    identity: Mutex<Option<dome_sentinel::Identity>>,
497    resolver: IdentityResolver,
498    config: GateConfig,
499    policy: Option<SharedPolicyEngine>,
500    rate_limiter: Arc<RateLimiter>,
501    budget: Arc<BudgetTracker>,
502    scanner: Arc<InjectionScanner>,
503    schema_store: Arc<Mutex<SchemaPinStore>>,
504    ledger: Arc<Mutex<Ledger>>,
505}
506
507impl ProxyState {
508    /// Process an inbound (client → server) message through the full
509    /// interceptor chain: Sentinel → Throttle → Ward → Policy → Ledger.
510    async fn process_inbound(&self, msg: McpMessage) -> InboundAction {
511        let start = std::time::Instant::now();
512        let method = msg.method.as_deref().unwrap_or("-").to_string();
513        let tool = msg.tool_name().map(String::from);
514        let request_id = Uuid::new_v4();
515
516        debug!(
517            method = method.as_str(),
518            id = ?msg.id,
519            tool = tool.as_deref().unwrap_or("-"),
520            "client -> server"
521        );
522
523        // ── 1. Sentinel: Authenticate on initialize ──
524        let mut msg = msg;
525        if method == "initialize" {
526            match self.resolver.resolve(&msg).await {
527                Ok(id) => {
528                    info!(
529                        principal = %id.principal,
530                        method = %id.auth_method,
531                        "identity resolved"
532                    );
533                    *self.identity.lock().await = Some(id);
534
535                    // Strip all credential fields before forwarding.
536                    msg = PskAuthenticator::strip_psk(&msg);
537                    msg = ApiKeyAuthenticator::strip_api_key(&msg);
538                }
539                Err(e) => {
540                    warn!(%e, "authentication failed");
541                    let err_id = msg.id.clone().unwrap_or(Value::Null);
542                    return InboundAction::Deny(McpMessage::error_response(
543                        err_id,
544                        -32600,
545                        "Authentication failed",
546                    ));
547                }
548            }
549        }
550
551        // Block all non-initialize requests before the session has been
552        // initialized (identity resolved).
553        if method != "initialize" {
554            let identity_lock = self.identity.lock().await;
555            if identity_lock.is_none() {
556                drop(identity_lock);
557                warn!(method = %method, "request before initialize");
558                let err_id = msg.id.clone().unwrap_or(Value::Null);
559                return InboundAction::Deny(McpMessage::error_response(
560                    err_id,
561                    -32600,
562                    "Session not initialized",
563                ));
564            }
565            drop(identity_lock);
566        }
567
568        let identity_lock = self.identity.lock().await;
569        let principal = identity_lock
570            .as_ref()
571            .map(|i| i.principal.clone())
572            .unwrap_or_else(|| "anonymous".to_string());
573        let labels = identity_lock
574            .as_ref()
575            .map(|i| i.labels.clone())
576            .unwrap_or_default();
577        drop(identity_lock);
578
579        // Extract the method-specific resource name for policy evaluation.
580        let resource_name = msg.method_resource_name().unwrap_or("-").to_string();
581        let tool_name = resource_name.as_str();
582
583        let args = msg
584            .params
585            .as_ref()
586            .and_then(|p| p.get("arguments"))
587            .cloned()
588            .unwrap_or(Value::Null);
589
590        // ── 2. Throttle: Rate limit check ──
591        if self.config.enable_rate_limit {
592            let rl_tool = if tool_name != "-" {
593                Some(tool_name)
594            } else {
595                None
596            };
597            if let Err(e) = self.rate_limiter.check_rate_limit(&principal, rl_tool) {
598                warn!(%e, principal = %principal, method = %method, "rate limited");
599                record_audit(
600                    &self.ledger,
601                    request_id,
602                    &principal,
603                    Direction::Inbound,
604                    &method,
605                    tool.as_deref(),
606                    "deny:rate_limit",
607                    None,
608                    start.elapsed().as_micros() as u64,
609                )
610                .await;
611                let err_id = msg.id.clone().unwrap_or(Value::Null);
612                return InboundAction::Deny(McpMessage::error_response(
613                    err_id,
614                    -32000,
615                    "Rate limit exceeded",
616                ));
617            }
618        }
619
620        // ── 2b. Throttle: Budget check ──
621        if self.config.enable_budget {
622            if let Err(e) = self.budget.try_spend(&principal, 1.0) {
623                warn!(%e, principal = %principal, "budget exhausted");
624                record_audit(
625                    &self.ledger,
626                    request_id,
627                    &principal,
628                    Direction::Inbound,
629                    &method,
630                    tool.as_deref(),
631                    "deny:budget",
632                    None,
633                    start.elapsed().as_micros() as u64,
634                )
635                .await;
636                let err_id = msg.id.clone().unwrap_or(Value::Null);
637                return InboundAction::Deny(McpMessage::error_response(
638                    err_id,
639                    -32000,
640                    "Budget exhausted",
641                ));
642            }
643        }
644
645        // ── 3. Ward: Injection scanning ──
646        // Ward runs BEFORE policy so injection detection is applied regardless
647        // of authorization level.
648        if self.config.enable_ward {
649            let scan_text = if method == "tools/call" {
650                serde_json::to_string(&args).unwrap_or_default()
651            } else if let Some(ref params) = msg.params {
652                serde_json::to_string(params).unwrap_or_default()
653            } else {
654                String::new()
655            };
656
657            if !scan_text.is_empty() {
658                let matches = self.scanner.scan_text(&scan_text);
659                if !matches.is_empty() {
660                    let pattern_names: Vec<&str> =
661                        matches.iter().map(|m| m.pattern_name.as_str()).collect();
662                    warn!(
663                        patterns = ?pattern_names,
664                        method = %method,
665                        tool = tool_name,
666                        principal = %principal,
667                        "injection detected"
668                    );
669                    record_audit(
670                        &self.ledger,
671                        request_id,
672                        &principal,
673                        Direction::Inbound,
674                        &method,
675                        tool.as_deref(),
676                        &format!("deny:injection:{}", pattern_names.join(",")),
677                        None,
678                        start.elapsed().as_micros() as u64,
679                    )
680                    .await;
681                    let err_id = msg.id.clone().unwrap_or(Value::Null);
682                    return InboundAction::Deny(McpMessage::error_response(
683                        err_id,
684                        -32003,
685                        "Request blocked: injection pattern detected",
686                    ));
687                }
688            }
689        }
690
691        // ── 4. Policy: Authorization ──
692        if self.config.enforce_policy {
693            if let Some(ref shared_engine) = self.policy {
694                // Load the current policy atomically. This is lock-free and
695                // picks up hot-reloaded changes immediately.
696                let engine = shared_engine.load();
697
698                let policy_resource = if method == "tools/call" {
699                    tool_name
700                } else {
701                    method.as_str()
702                };
703                let policy_id = PolicyIdentity::new(principal.clone(), labels.iter().cloned());
704                let decision = engine.evaluate(&policy_id, policy_resource, &args);
705
706                if !decision.is_allowed() {
707                    warn!(
708                        rule_id = %decision.rule_id,
709                        method = %method,
710                        resource = policy_resource,
711                        principal = %principal,
712                        "policy denied"
713                    );
714                    record_audit(
715                        &self.ledger,
716                        request_id,
717                        &principal,
718                        Direction::Inbound,
719                        &method,
720                        tool.as_deref(),
721                        &format!("deny:policy:{}", decision.rule_id),
722                        Some(&decision.rule_id),
723                        start.elapsed().as_micros() as u64,
724                    )
725                    .await;
726                    let err_id = msg.id.clone().unwrap_or(Value::Null);
727                    return InboundAction::Deny(McpMessage::error_response(
728                        err_id,
729                        -32003,
730                        format!("Denied by policy: {}", decision.rule_id),
731                    ));
732                }
733            }
734        }
735
736        // ── 5. Ledger: Record allowed request ──
737        record_audit(
738            &self.ledger,
739            request_id,
740            &principal,
741            Direction::Inbound,
742            &method,
743            tool.as_deref(),
744            "allow",
745            None,
746            start.elapsed().as_micros() as u64,
747        )
748        .await;
749
750        InboundAction::Forward(msg)
751    }
752
753    /// Process an outbound (server → client) message through the outbound
754    /// interceptor chain: Schema Pin → Ward → Ledger.
755    async fn process_outbound(&self, msg: McpMessage, ctx: &mut OutboundContext) -> OutboundAction {
756        let start = std::time::Instant::now();
757        let method = msg.method.as_deref().unwrap_or("-").to_string();
758        let outbound_request_id = Uuid::new_v4();
759
760        debug!(
761            method = method.as_str(),
762            id = ?msg.id,
763            "server -> client"
764        );
765
766        let mut forward_msg = msg;
767
768        // ── Schema Pin: Verify tools/list responses ──
769        if self.config.enable_schema_pin {
770            if let Some(result) = &forward_msg.result {
771                if result.get("tools").is_some() {
772                    let mut store = self.schema_store.lock().await;
773                    if ctx.first_tools_list {
774                        store.pin_tools(result);
775                        info!(pinned = store.len(), "schema pins established");
776                        ctx.last_good_tools_result = Some(result.clone());
777                        ctx.first_tools_list = false;
778                    } else {
779                        let drifts = store.verify_tools(result);
780                        if !drifts.is_empty() {
781                            let mut has_critical_or_high = false;
782                            for drift in &drifts {
783                                warn!(
784                                    tool = %drift.tool_name,
785                                    drift_type = ?drift.drift_type,
786                                    severity = ?drift.severity,
787                                    "schema drift detected"
788                                );
789                                if matches!(
790                                    drift.severity,
791                                    DriftSeverity::Critical | DriftSeverity::High
792                                ) {
793                                    has_critical_or_high = true;
794                                }
795                            }
796
797                            if has_critical_or_high {
798                                warn!(
799                                    "critical/high schema drift detected -- blocking drifted tools/list"
800                                );
801                                record_audit(
802                                    &self.ledger,
803                                    outbound_request_id,
804                                    "server",
805                                    Direction::Outbound,
806                                    "tools/list",
807                                    None,
808                                    "deny:schema_drift",
809                                    None,
810                                    start.elapsed().as_micros() as u64,
811                                )
812                                .await;
813
814                                if let Some(ref good_result) = ctx.last_good_tools_result {
815                                    forward_msg.result = Some(good_result.clone());
816                                } else {
817                                    let err_id = forward_msg.id.clone().unwrap_or(Value::Null);
818                                    return OutboundAction::Block(McpMessage::error_response(
819                                        err_id,
820                                        -32003,
821                                        "Schema drift detected: tool definitions have been tampered with",
822                                    ));
823                                }
824                            }
825                        }
826                    }
827                }
828            }
829        }
830
831        // ── Outbound response scanning ──
832        if self.config.enable_ward {
833            if let Some(ref result) = forward_msg.result {
834                let scan_target = if let Some(content) = result.get("content") {
835                    serde_json::to_string(content).unwrap_or_default()
836                } else {
837                    serde_json::to_string(result).unwrap_or_default()
838                };
839
840                if !scan_target.is_empty() {
841                    let matches = self.scanner.scan_text(&scan_target);
842                    if !matches.is_empty() {
843                        let pattern_names: Vec<&str> =
844                            matches.iter().map(|m| m.pattern_name.as_str()).collect();
845                        let decision = if self.config.block_outbound_injection {
846                            "deny:outbound_injection"
847                        } else {
848                            "warn:outbound_injection"
849                        };
850                        warn!(
851                            patterns = ?pattern_names,
852                            direction = "outbound",
853                            blocked = self.config.block_outbound_injection,
854                            "injection detected in server response"
855                        );
856                        record_audit(
857                            &self.ledger,
858                            outbound_request_id,
859                            "server",
860                            Direction::Outbound,
861                            &method,
862                            None,
863                            &format!("{}:{}", decision, pattern_names.join(",")),
864                            None,
865                            start.elapsed().as_micros() as u64,
866                        )
867                        .await;
868
869                        if self.config.block_outbound_injection {
870                            let err_id = forward_msg.id.clone().unwrap_or(Value::Null);
871                            return OutboundAction::Block(McpMessage::error_response(
872                                err_id,
873                                -32005,
874                                "Response blocked: injection pattern detected in server output",
875                            ));
876                        }
877                    }
878                }
879            }
880        }
881
882        // Record outbound audit entry.
883        record_audit(
884            &self.ledger,
885            outbound_request_id,
886            "server",
887            Direction::Outbound,
888            &method,
889            None,
890            "forward",
891            None,
892            start.elapsed().as_micros() as u64,
893        )
894        .await;
895
896        OutboundAction::Forward(forward_msg)
897    }
898}
899
900// ---------------------------------------------------------------------------
901// Helpers
902// ---------------------------------------------------------------------------
903
904/// Write a McpMessage to the client's stdout, with newline and flush.
905async fn write_to_client(
906    writer: &Arc<Mutex<tokio::io::Stdout>>,
907    msg: &McpMessage,
908) -> Result<(), std::io::Error> {
909    match msg.to_json() {
910        Ok(json) => {
911            let mut out = json.into_bytes();
912            out.push(b'\n');
913            let mut w = writer.lock().await;
914            w.write_all(&out).await?;
915            w.flush().await?;
916            Ok(())
917        }
918        Err(e) => {
919            error!(%e, "failed to serialize message for client");
920            Err(std::io::Error::new(std::io::ErrorKind::InvalidData, e))
921        }
922    }
923}
924
925/// Helper to record an audit entry.
926async fn record_audit(
927    ledger: &Arc<Mutex<Ledger>>,
928    request_id: Uuid,
929    identity: &str,
930    direction: Direction,
931    method: &str,
932    tool: Option<&str>,
933    decision: &str,
934    rule_id: Option<&str>,
935    latency_us: u64,
936) {
937    let entry = AuditEntry {
938        seq: 0, // set by ledger
939        timestamp: Utc::now(),
940        request_id,
941        identity: identity.to_string(),
942        direction,
943        method: method.to_string(),
944        tool: tool.map(String::from),
945        decision: decision.to_string(),
946        rule_id: rule_id.map(String::from),
947        latency_us,
948        prev_hash: String::new(), // set by ledger
949        annotations: std::collections::HashMap::new(),
950    };
951
952    if let Err(e) = ledger.lock().await.record(entry) {
953        error!(%e, "failed to record audit entry");
954    }
955}
956
957#[cfg(test)]
958mod tests {
959    use super::*;
960    use dome_ledger::MemorySink;
961
962    // -----------------------------------------------------------------------
963    // Helpers
964    // -----------------------------------------------------------------------
965
966    /// Create a Ledger backed by a MemorySink for test inspection.
967    fn test_ledger() -> Ledger {
968        Ledger::new(vec![Box::new(MemorySink::new())])
969    }
970
971    /// Create a Ledger with no sinks (sufficient when we only care about
972    /// entry count, not sink contents).
973    fn empty_ledger() -> Ledger {
974        Ledger::new(vec![])
975    }
976
977    // -----------------------------------------------------------------------
978    // GateConfig defaults
979    // -----------------------------------------------------------------------
980
981    #[test]
982    fn gate_config_defaults_all_security_disabled() {
983        let config = GateConfig::default();
984
985        assert!(
986            !config.enforce_policy,
987            "enforce_policy should default to false"
988        );
989        assert!(!config.enable_ward, "enable_ward should default to false");
990        assert!(
991            !config.enable_schema_pin,
992            "enable_schema_pin should default to false"
993        );
994        assert!(
995            !config.enable_rate_limit,
996            "enable_rate_limit should default to false"
997        );
998        assert!(
999            !config.enable_budget,
1000            "enable_budget should default to false"
1001        );
1002        assert!(
1003            config.allow_anonymous,
1004            "allow_anonymous should default to true"
1005        );
1006        assert!(
1007            !config.block_outbound_injection,
1008            "block_outbound_injection should default to false"
1009        );
1010    }
1011
1012    #[test]
1013    fn gate_config_with_all_enabled() {
1014        let config = GateConfig {
1015            enforce_policy: true,
1016            enable_ward: true,
1017            enable_schema_pin: true,
1018            enable_rate_limit: true,
1019            enable_budget: true,
1020            allow_anonymous: false,
1021            block_outbound_injection: true,
1022        };
1023
1024        assert!(config.enforce_policy);
1025        assert!(config.enable_ward);
1026        assert!(config.enable_schema_pin);
1027        assert!(config.enable_rate_limit);
1028        assert!(config.enable_budget);
1029        assert!(!config.allow_anonymous);
1030        assert!(config.block_outbound_injection);
1031    }
1032
1033    #[test]
1034    fn gate_config_is_cloneable() {
1035        let original = GateConfig {
1036            enforce_policy: true,
1037            enable_ward: true,
1038            enable_schema_pin: false,
1039            enable_rate_limit: true,
1040            enable_budget: false,
1041            allow_anonymous: false,
1042            block_outbound_injection: true,
1043        };
1044        let cloned = original.clone();
1045
1046        assert_eq!(cloned.enforce_policy, original.enforce_policy);
1047        assert_eq!(cloned.enable_ward, original.enable_ward);
1048        assert_eq!(cloned.enable_schema_pin, original.enable_schema_pin);
1049        assert_eq!(cloned.enable_rate_limit, original.enable_rate_limit);
1050        assert_eq!(cloned.enable_budget, original.enable_budget);
1051        assert_eq!(cloned.allow_anonymous, original.allow_anonymous);
1052        assert_eq!(
1053            cloned.block_outbound_injection,
1054            original.block_outbound_injection
1055        );
1056    }
1057
1058    #[test]
1059    fn gate_config_is_debug_printable() {
1060        let config = GateConfig::default();
1061        let debug_output = format!("{:?}", config);
1062
1063        assert!(debug_output.contains("GateConfig"));
1064        assert!(debug_output.contains("enforce_policy"));
1065        assert!(debug_output.contains("enable_ward"));
1066    }
1067
1068    // -----------------------------------------------------------------------
1069    // Gate::transparent
1070    // -----------------------------------------------------------------------
1071
1072    #[test]
1073    fn transparent_gate_has_correct_config_defaults() {
1074        let gate = Gate::transparent(empty_ledger());
1075
1076        assert!(
1077            !gate.config.enforce_policy,
1078            "transparent gate should not enforce policy"
1079        );
1080        assert!(
1081            !gate.config.enable_ward,
1082            "transparent gate should not enable ward"
1083        );
1084        assert!(
1085            !gate.config.enable_schema_pin,
1086            "transparent gate should not enable schema pinning"
1087        );
1088        assert!(
1089            !gate.config.enable_rate_limit,
1090            "transparent gate should not enable rate limiting"
1091        );
1092        assert!(
1093            !gate.config.enable_budget,
1094            "transparent gate should not enable budget tracking"
1095        );
1096        assert!(
1097            gate.config.allow_anonymous,
1098            "transparent gate should allow anonymous access"
1099        );
1100        assert!(
1101            !gate.config.block_outbound_injection,
1102            "transparent gate should not block outbound injection"
1103        );
1104    }
1105
1106    #[test]
1107    fn transparent_gate_has_no_policy_engine() {
1108        let gate = Gate::transparent(empty_ledger());
1109        assert!(
1110            gate.policy_engine.is_none(),
1111            "transparent gate should have no policy engine"
1112        );
1113    }
1114
1115    // -----------------------------------------------------------------------
1116    // Gate::new with custom config
1117    // -----------------------------------------------------------------------
1118
1119    #[test]
1120    fn gate_new_with_custom_config_preserves_flags() {
1121        let config = GateConfig {
1122            enforce_policy: true,
1123            enable_ward: true,
1124            enable_schema_pin: true,
1125            enable_rate_limit: true,
1126            enable_budget: true,
1127            allow_anonymous: false,
1128            block_outbound_injection: true,
1129        };
1130
1131        let gate = Gate::new(
1132            config,
1133            vec![Box::new(AnonymousAuthenticator)],
1134            None,
1135            RateLimiterConfig::default(),
1136            BudgetTrackerConfig::default(),
1137            empty_ledger(),
1138        );
1139
1140        assert!(gate.config.enforce_policy);
1141        assert!(gate.config.enable_ward);
1142        assert!(gate.config.enable_schema_pin);
1143        assert!(gate.config.enable_rate_limit);
1144        assert!(gate.config.enable_budget);
1145        assert!(!gate.config.allow_anonymous);
1146        assert!(gate.config.block_outbound_injection);
1147    }
1148
1149    #[test]
1150    fn gate_new_without_policy_engine_stores_none() {
1151        let gate = Gate::new(
1152            GateConfig::default(),
1153            vec![Box::new(AnonymousAuthenticator)],
1154            None,
1155            RateLimiterConfig::default(),
1156            BudgetTrackerConfig::default(),
1157            empty_ledger(),
1158        );
1159
1160        assert!(gate.policy_engine.is_none());
1161    }
1162
1163    // -----------------------------------------------------------------------
1164    // Gate Debug impl
1165    // -----------------------------------------------------------------------
1166
1167    #[test]
1168    fn gate_is_debug_printable() {
1169        let gate = Gate::transparent(empty_ledger());
1170        let debug_output = format!("{:?}", gate);
1171
1172        assert!(debug_output.contains("Gate"));
1173        assert!(debug_output.contains("config"));
1174        assert!(debug_output.contains("has_policy_engine"));
1175    }
1176
1177    // -----------------------------------------------------------------------
1178    // record_audit
1179    // -----------------------------------------------------------------------
1180
1181    #[tokio::test]
1182    async fn record_audit_creates_entry_with_correct_fields() {
1183        let ledger = Arc::new(Mutex::new(test_ledger()));
1184        let request_id = Uuid::new_v4();
1185
1186        record_audit(
1187            &ledger,
1188            request_id,
1189            "test-user",
1190            Direction::Inbound,
1191            "tools/call",
1192            Some("read_file"),
1193            "allow",
1194            Some("rule-1"),
1195            42,
1196        )
1197        .await;
1198
1199        let ledger_guard = ledger.lock().await;
1200        assert_eq!(
1201            ledger_guard.entry_count(),
1202            1,
1203            "should have recorded exactly one entry"
1204        );
1205    }
1206
1207    #[tokio::test]
1208    async fn record_audit_with_no_tool_and_no_rule() {
1209        let ledger = Arc::new(Mutex::new(test_ledger()));
1210        let request_id = Uuid::new_v4();
1211
1212        record_audit(
1213            &ledger,
1214            request_id,
1215            "anonymous",
1216            Direction::Outbound,
1217            "initialize",
1218            None,
1219            "forward",
1220            None,
1221            0,
1222        )
1223        .await;
1224
1225        let ledger_guard = ledger.lock().await;
1226        assert_eq!(ledger_guard.entry_count(), 1);
1227    }
1228
1229    #[tokio::test]
1230    async fn record_audit_multiple_entries_increment_count() {
1231        let ledger = Arc::new(Mutex::new(test_ledger()));
1232
1233        for i in 0..5 {
1234            record_audit(
1235                &ledger,
1236                Uuid::new_v4(),
1237                &format!("user-{i}"),
1238                Direction::Inbound,
1239                "tools/call",
1240                Some("test_tool"),
1241                "allow",
1242                None,
1243                i * 10,
1244            )
1245            .await;
1246        }
1247
1248        let ledger_guard = ledger.lock().await;
1249        assert_eq!(ledger_guard.entry_count(), 5);
1250    }
1251
1252    #[tokio::test]
1253    async fn record_audit_deny_decisions_are_recorded() {
1254        let ledger = Arc::new(Mutex::new(test_ledger()));
1255
1256        record_audit(
1257            &ledger,
1258            Uuid::new_v4(),
1259            "malicious-user",
1260            Direction::Inbound,
1261            "tools/call",
1262            Some("exec_command"),
1263            "deny:policy:no-exec",
1264            Some("no-exec"),
1265            150,
1266        )
1267        .await;
1268
1269        record_audit(
1270            &ledger,
1271            Uuid::new_v4(),
1272            "spammer",
1273            Direction::Inbound,
1274            "tools/call",
1275            Some("spam_tool"),
1276            "deny:rate_limit",
1277            None,
1278            5,
1279        )
1280        .await;
1281
1282        record_audit(
1283            &ledger,
1284            Uuid::new_v4(),
1285            "attacker",
1286            Direction::Inbound,
1287            "tools/call",
1288            Some("read_file"),
1289            "deny:injection:prompt_injection",
1290            None,
1291            200,
1292        )
1293        .await;
1294
1295        let ledger_guard = ledger.lock().await;
1296        assert_eq!(ledger_guard.entry_count(), 3);
1297    }
1298
1299    #[tokio::test]
1300    async fn record_audit_outbound_schema_drift() {
1301        let ledger = Arc::new(Mutex::new(test_ledger()));
1302
1303        record_audit(
1304            &ledger,
1305            Uuid::new_v4(),
1306            "server",
1307            Direction::Outbound,
1308            "tools/list",
1309            None,
1310            "deny:schema_drift",
1311            None,
1312            75,
1313        )
1314        .await;
1315
1316        let ledger_guard = ledger.lock().await;
1317        assert_eq!(ledger_guard.entry_count(), 1);
1318    }
1319
1320    // -----------------------------------------------------------------------
1321    // Constants
1322    // -----------------------------------------------------------------------
1323
1324    #[test]
1325    fn max_line_size_is_ten_megabytes() {
1326        assert_eq!(MAX_LINE_SIZE, 10 * 1024 * 1024);
1327    }
1328
1329    #[test]
1330    fn client_read_timeout_is_five_minutes() {
1331        assert_eq!(CLIENT_READ_TIMEOUT, Duration::from_secs(300));
1332    }
1333
1334    // -----------------------------------------------------------------------
1335    // GateConfig partial overrides (common patterns)
1336    // -----------------------------------------------------------------------
1337
1338    #[test]
1339    fn gate_config_ward_only_mode() {
1340        let config = GateConfig {
1341            enable_ward: true,
1342            ..GateConfig::default()
1343        };
1344
1345        assert!(config.enable_ward);
1346        assert!(!config.enforce_policy, "policy should remain off");
1347        assert!(!config.enable_rate_limit, "rate limit should remain off");
1348        assert!(!config.enable_budget, "budget should remain off");
1349        assert!(!config.enable_schema_pin, "schema pin should remain off");
1350        assert!(config.allow_anonymous, "anonymous should remain on");
1351    }
1352
1353    #[test]
1354    fn gate_config_full_security_mode() {
1355        let config = GateConfig {
1356            enforce_policy: true,
1357            enable_ward: true,
1358            enable_schema_pin: true,
1359            enable_rate_limit: true,
1360            enable_budget: true,
1361            allow_anonymous: false,
1362            block_outbound_injection: true,
1363        };
1364
1365        // Every security feature is active
1366        assert!(config.enforce_policy);
1367        assert!(config.enable_ward);
1368        assert!(config.enable_schema_pin);
1369        assert!(config.enable_rate_limit);
1370        assert!(config.enable_budget);
1371        assert!(!config.allow_anonymous);
1372        assert!(config.block_outbound_injection);
1373    }
1374
1375    // -----------------------------------------------------------------------
1376    // Gate construction does not panic
1377    // -----------------------------------------------------------------------
1378
1379    #[test]
1380    fn gate_new_with_multiple_authenticators_does_not_panic() {
1381        let _gate = Gate::new(
1382            GateConfig::default(),
1383            vec![
1384                Box::new(AnonymousAuthenticator),
1385                Box::new(AnonymousAuthenticator),
1386            ],
1387            None,
1388            RateLimiterConfig::default(),
1389            BudgetTrackerConfig::default(),
1390            empty_ledger(),
1391        );
1392    }
1393
1394    #[test]
1395    fn gate_new_with_empty_authenticators_does_not_panic() {
1396        let _gate = Gate::new(
1397            GateConfig::default(),
1398            vec![],
1399            None,
1400            RateLimiterConfig::default(),
1401            BudgetTrackerConfig::default(),
1402            empty_ledger(),
1403        );
1404    }
1405
1406    // -----------------------------------------------------------------------
1407    // Audit entry construction (verifying the shape record_audit produces)
1408    // -----------------------------------------------------------------------
1409
1410    #[tokio::test]
1411    async fn record_audit_entry_fields_populated_correctly() {
1412        // Use a MemorySink so entries are recorded and inspectable.
1413        let ledger = Arc::new(Mutex::new(Ledger::new(vec![Box::new(MemorySink::new())])));
1414        let request_id = Uuid::new_v4();
1415
1416        record_audit(
1417            &ledger,
1418            request_id,
1419            "psk:dev-key-1",
1420            Direction::Inbound,
1421            "tools/call",
1422            Some("read_file"),
1423            "deny:policy:no-read",
1424            Some("no-read"),
1425            999,
1426        )
1427        .await;
1428
1429        // Verify the ledger recorded it
1430        let guard = ledger.lock().await;
1431        assert_eq!(guard.entry_count(), 1);
1432    }
1433
1434    #[tokio::test]
1435    async fn record_audit_handles_empty_identity() {
1436        let ledger = Arc::new(Mutex::new(test_ledger()));
1437
1438        record_audit(
1439            &ledger,
1440            Uuid::new_v4(),
1441            "",
1442            Direction::Inbound,
1443            "initialize",
1444            None,
1445            "allow",
1446            None,
1447            0,
1448        )
1449        .await;
1450
1451        let guard = ledger.lock().await;
1452        assert_eq!(
1453            guard.entry_count(),
1454            1,
1455            "empty identity should not prevent recording"
1456        );
1457    }
1458
1459    #[tokio::test]
1460    async fn record_audit_handles_long_decision_string() {
1461        let ledger = Arc::new(Mutex::new(test_ledger()));
1462        let long_decision = format!("deny:injection:{}", "pattern_name,".repeat(50));
1463
1464        record_audit(
1465            &ledger,
1466            Uuid::new_v4(),
1467            "test-user",
1468            Direction::Inbound,
1469            "tools/call",
1470            Some("risky_tool"),
1471            &long_decision,
1472            None,
1473            500,
1474        )
1475        .await;
1476
1477        let guard = ledger.lock().await;
1478        assert_eq!(
1479            guard.entry_count(),
1480            1,
1481            "long decision strings should be accepted"
1482        );
1483    }
1484}