Skip to main content

dome_gate/
lib.rs

1use std::sync::Arc;
2
3use dome_core::{DomeError, McpMessage};
4use dome_ledger::{AuditEntry, Direction, Ledger};
5use dome_policy::{Identity as PolicyIdentity, PolicyEngine};
6use dome_sentinel::{AnonymousAuthenticator, Authenticator, IdentityResolver, ResolverConfig};
7use dome_throttle::{BudgetTracker, BudgetTrackerConfig, RateLimiter, RateLimiterConfig};
8use dome_transport::stdio::StdioTransport;
9use dome_ward::{InjectionScanner, SchemaPinStore};
10
11use chrono::Utc;
12use serde_json::Value;
13use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
14use tokio::sync::Mutex;
15use tracing::{debug, error, info, warn};
16use uuid::Uuid;
17
18/// Configuration for the Gate proxy.
19#[derive(Debug, Clone)]
20pub struct GateConfig {
21    /// Whether to enforce policy (false = transparent pass-through mode).
22    pub enforce_policy: bool,
23    /// Whether to enable injection scanning.
24    pub enable_ward: bool,
25    /// Whether to enable schema pinning.
26    pub enable_schema_pin: bool,
27    /// Whether to enable rate limiting.
28    pub enable_rate_limit: bool,
29    /// Whether to enable budget tracking.
30    pub enable_budget: bool,
31    /// Whether to allow anonymous access.
32    pub allow_anonymous: bool,
33}
34
35impl Default for GateConfig {
36    fn default() -> Self {
37        Self {
38            enforce_policy: false,
39            enable_ward: false,
40            enable_schema_pin: false,
41            enable_rate_limit: false,
42            enable_budget: false,
43            allow_anonymous: true,
44        }
45    }
46}
47
48/// The Gate — MCPDome's core proxy loop with full interceptor chain.
49///
50/// Interceptor order (inbound, client → server):
51///   1. Sentinel — authenticate on `initialize`, resolve identity
52///   2. Throttle — check rate limits and budget
53///   3. Policy  — evaluate authorization rules
54///   4. Ward    — scan for injection patterns in tool arguments
55///   5. Ledger  — record the decision in the audit chain
56///
57/// Outbound (server → client):
58///   1. Schema Pin — verify tools/list responses for drift
59///   2. Ledger     — record outbound audit entry
60pub struct Gate {
61    config: GateConfig,
62    resolver: IdentityResolver,
63    policy_engine: Option<PolicyEngine>,
64    rate_limiter: Arc<RateLimiter>,
65    budget_tracker: Arc<BudgetTracker>,
66    injection_scanner: Arc<InjectionScanner>,
67    schema_store: Arc<Mutex<SchemaPinStore>>,
68    ledger: Arc<Mutex<Ledger>>,
69}
70
71impl Gate {
72    /// Create a new Gate with full interceptor chain.
73    pub fn new(
74        config: GateConfig,
75        authenticators: Vec<Box<dyn Authenticator>>,
76        policy_engine: Option<PolicyEngine>,
77        rate_limiter_config: RateLimiterConfig,
78        budget_config: BudgetTrackerConfig,
79        ledger: Ledger,
80    ) -> Self {
81        Self {
82            resolver: IdentityResolver::new(
83                authenticators,
84                ResolverConfig {
85                    allow_anonymous: config.allow_anonymous,
86                },
87            ),
88            policy_engine,
89            rate_limiter: Arc::new(RateLimiter::new(rate_limiter_config)),
90            budget_tracker: Arc::new(BudgetTracker::new(budget_config)),
91            injection_scanner: Arc::new(InjectionScanner::new()),
92            schema_store: Arc::new(Mutex::new(SchemaPinStore::new())),
93            ledger: Arc::new(Mutex::new(ledger)),
94            config,
95        }
96    }
97
98    /// Create a transparent pass-through Gate (no security enforcement).
99    pub fn transparent(ledger: Ledger) -> Self {
100        Self::new(
101            GateConfig::default(),
102            vec![Box::new(AnonymousAuthenticator)],
103            None,
104            RateLimiterConfig::default(),
105            BudgetTrackerConfig::default(),
106            ledger,
107        )
108    }
109
110    /// Run the proxy.
111    pub async fn run_stdio(self, command: &str, args: &[&str]) -> Result<(), DomeError> {
112        let transport = StdioTransport::spawn(command, args).await?;
113        let (mut server_reader, mut server_writer, mut child) = transport.split();
114
115        let client_stdin = tokio::io::stdin();
116        let client_stdout = tokio::io::stdout();
117        let mut client_reader = BufReader::new(client_stdin);
118        let mut client_writer = client_stdout;
119
120        info!("MCPDome proxy active — interceptor chain armed");
121
122        // Shared state for the two tasks
123        let identity: Arc<Mutex<Option<dome_sentinel::Identity>>> = Arc::new(Mutex::new(None));
124
125        let gate_identity = Arc::clone(&identity);
126        let gate_resolver = self.resolver;
127        let gate_policy = self.policy_engine;
128        let gate_rate_limiter = Arc::clone(&self.rate_limiter);
129        let gate_budget = Arc::clone(&self.budget_tracker);
130        let gate_scanner = Arc::clone(&self.injection_scanner);
131        let gate_ledger = Arc::clone(&self.ledger);
132        let gate_config = self.config.clone();
133
134        // Client → Server task (inbound interceptor chain)
135        let client_to_server = tokio::spawn(async move {
136            let mut line = String::new();
137            loop {
138                line.clear();
139                match client_reader.read_line(&mut line).await {
140                    Ok(0) => {
141                        info!("client closed stdin — shutting down");
142                        break;
143                    }
144                    Ok(_) => {
145                        let trimmed = line.trim();
146                        if trimmed.is_empty() {
147                            continue;
148                        }
149
150                        match McpMessage::parse(trimmed) {
151                            Ok(msg) => {
152                                let start = std::time::Instant::now();
153                                let method = msg.method.as_deref().unwrap_or("-").to_string();
154                                let tool = msg.tool_name().map(String::from);
155                                let request_id = Uuid::new_v4();
156
157                                debug!(
158                                    method = method.as_str(),
159                                    id = ?msg.id,
160                                    tool = tool.as_deref().unwrap_or("-"),
161                                    "client -> server"
162                                );
163
164                                // ── 1. Sentinel: Authenticate on initialize ──
165                                if method == "initialize" {
166                                    match gate_resolver.resolve(&msg).await {
167                                        Ok(id) => {
168                                            info!(
169                                                principal = %id.principal,
170                                                method = %id.auth_method,
171                                                "identity resolved"
172                                            );
173                                            *gate_identity.lock().await = Some(id);
174                                        }
175                                        Err(e) => {
176                                            warn!(%e, "authentication failed");
177                                            // Auth failure logged; request still forwarded
178                                            // to let downstream server handle the handshake
179                                            // Still forward — let the server handle it
180                                            // but log the auth failure
181                                        }
182                                    }
183                                }
184
185                                let identity_lock = gate_identity.lock().await;
186                                let principal = identity_lock
187                                    .as_ref()
188                                    .map(|i| i.principal.clone())
189                                    .unwrap_or_else(|| "anonymous".to_string());
190                                let labels = identity_lock
191                                    .as_ref()
192                                    .map(|i| i.labels.clone())
193                                    .unwrap_or_default();
194                                drop(identity_lock);
195
196                                // Only intercept tools/call
197                                if method == "tools/call" {
198                                    let tool_name = tool.as_deref().unwrap_or("unknown");
199                                    let args = msg
200                                        .params
201                                        .as_ref()
202                                        .and_then(|p| p.get("arguments"))
203                                        .cloned()
204                                        .unwrap_or(Value::Null);
205
206                                    // ── 2. Throttle: Rate limit check ──
207                                    if gate_config.enable_rate_limit {
208                                        if let Err(e) = gate_rate_limiter
209                                            .check_rate_limit(&principal, Some(tool_name))
210                                        {
211                                            warn!(%e, principal = %principal, tool = tool_name, "rate limited");
212                                            record_audit(
213                                                &gate_ledger,
214                                                request_id,
215                                                &principal,
216                                                Direction::Inbound,
217                                                &method,
218                                                tool.as_deref(),
219                                                "deny:rate_limit",
220                                                None,
221                                                start.elapsed().as_micros() as u64,
222                                            )
223                                            .await;
224                                            // Send error response back to client via server_writer
225                                            // For now, skip the request
226                                            continue;
227                                        }
228                                    }
229
230                                    // ── 2b. Throttle: Budget check ──
231                                    if gate_config.enable_budget {
232                                        if let Err(e) = gate_budget.try_spend(&principal, 1.0) {
233                                            warn!(%e, principal = %principal, "budget exhausted");
234                                            record_audit(
235                                                &gate_ledger,
236                                                request_id,
237                                                &principal,
238                                                Direction::Inbound,
239                                                &method,
240                                                tool.as_deref(),
241                                                "deny:budget",
242                                                None,
243                                                start.elapsed().as_micros() as u64,
244                                            )
245                                            .await;
246                                            continue;
247                                        }
248                                    }
249
250                                    // ── 3. Policy: Authorization ──
251                                    if gate_config.enforce_policy {
252                                        if let Some(ref engine) = gate_policy {
253                                            let policy_id = PolicyIdentity::new(
254                                                principal.clone(),
255                                                labels.iter().cloned(),
256                                            );
257                                            let decision = engine.evaluate(&policy_id, tool_name, &args);
258
259                                            if !decision.is_allowed() {
260                                                warn!(
261                                                    rule_id = %decision.rule_id,
262                                                    tool = tool_name,
263                                                    principal = %principal,
264                                                    "policy denied"
265                                                );
266                                                record_audit(
267                                                    &gate_ledger,
268                                                    request_id,
269                                                    &principal,
270                                                    Direction::Inbound,
271                                                    &method,
272                                                    tool.as_deref(),
273                                                    &format!("deny:policy:{}", decision.rule_id),
274                                                    Some(&decision.rule_id),
275                                                    start.elapsed().as_micros() as u64,
276                                                )
277                                                .await;
278                                                continue;
279                                            }
280                                        }
281                                    }
282
283                                    // ── 4. Ward: Injection scanning ──
284                                    if gate_config.enable_ward {
285                                        let args_str = serde_json::to_string(&args).unwrap_or_default();
286                                        let matches = gate_scanner.scan_text(&args_str);
287                                        if !matches.is_empty() {
288                                            let pattern_names: Vec<&str> =
289                                                matches.iter().map(|m| m.pattern_name.as_str()).collect();
290                                            warn!(
291                                                patterns = ?pattern_names,
292                                                tool = tool_name,
293                                                principal = %principal,
294                                                "injection detected"
295                                            );
296                                            record_audit(
297                                                &gate_ledger,
298                                                request_id,
299                                                &principal,
300                                                Direction::Inbound,
301                                                &method,
302                                                tool.as_deref(),
303                                                &format!("deny:injection:{}", pattern_names.join(",")),
304                                                None,
305                                                start.elapsed().as_micros() as u64,
306                                            )
307                                            .await;
308                                            continue;
309                                        }
310                                    }
311                                }
312
313                                // ── 5. Ledger: Record allowed request ──
314                                record_audit(
315                                    &gate_ledger,
316                                    request_id,
317                                    &principal,
318                                    Direction::Inbound,
319                                    &method,
320                                    tool.as_deref(),
321                                    "allow",
322                                    None,
323                                    start.elapsed().as_micros() as u64,
324                                )
325                                .await;
326
327                                // Forward to server
328                                if let Err(e) = server_writer.send(&msg).await {
329                                    error!(%e, "failed to forward to server");
330                                    break;
331                                }
332                            }
333                            Err(e) => {
334                                warn!(%e, raw = trimmed, "invalid JSON from client, forwarding raw");
335                                let _ = server_writer
336                                    .send(&McpMessage {
337                                        jsonrpc: "2.0".to_string(),
338                                        id: None,
339                                        method: None,
340                                        params: None,
341                                        result: None,
342                                        error: None,
343                                    })
344                                    .await;
345                            }
346                        }
347                    }
348                    Err(e) => {
349                        error!(%e, "error reading from client");
350                        break;
351                    }
352                }
353            }
354        });
355
356        let schema_store = Arc::clone(&self.schema_store);
357        let _outbound_ledger = Arc::clone(&self.ledger);
358        let outbound_config = self.config.clone();
359
360        // Server → Client task (outbound interceptor chain)
361        let server_to_client = tokio::spawn(async move {
362            let mut first_tools_list = true;
363
364            loop {
365                match server_reader.recv().await {
366                    Ok(msg) => {
367                        let method = msg.method.as_deref().unwrap_or("-").to_string();
368
369                        debug!(
370                            method = method.as_str(),
371                            id = ?msg.id,
372                            "server -> client"
373                        );
374
375                        // ── Schema Pin: Verify tools/list responses ──
376                        if outbound_config.enable_schema_pin {
377                            if let Some(result) = &msg.result {
378                                if result.get("tools").is_some() {
379                                    let mut store = schema_store.lock().await;
380                                    if first_tools_list {
381                                        store.pin_tools(result);
382                                        info!(
383                                            pinned = store.len(),
384                                            "schema pins established"
385                                        );
386                                        first_tools_list = false;
387                                    } else {
388                                        let drifts = store.verify_tools(result);
389                                        if !drifts.is_empty() {
390                                            for drift in &drifts {
391                                                warn!(
392                                                    tool = %drift.tool_name,
393                                                    drift_type = ?drift.drift_type,
394                                                    severity = ?drift.severity,
395                                                    "schema drift detected"
396                                                );
397                                            }
398                                        }
399                                    }
400                                }
401                            }
402                        }
403
404                        // Forward to client
405                        match msg.to_json() {
406                            Ok(json) => {
407                                let mut out = json.into_bytes();
408                                out.push(b'\n');
409                                if let Err(e) = client_writer.write_all(&out).await {
410                                    error!(%e, "failed to write to client");
411                                    break;
412                                }
413                                if let Err(e) = client_writer.flush().await {
414                                    error!(%e, "failed to flush to client");
415                                    break;
416                                }
417                            }
418                            Err(e) => {
419                                error!(%e, "failed to serialize server response");
420                            }
421                        }
422                    }
423                    Err(DomeError::Transport(ref e))
424                        if e.kind() == std::io::ErrorKind::UnexpectedEof =>
425                    {
426                        info!("server closed stdout — shutting down");
427                        break;
428                    }
429                    Err(e) => {
430                        error!(%e, "error reading from server");
431                        break;
432                    }
433                }
434            }
435        });
436
437        // Wait for either side to finish
438        tokio::select! {
439            r = client_to_server => {
440                if let Err(e) = r {
441                    error!(%e, "client->server task panicked");
442                }
443            }
444            r = server_to_client => {
445                if let Err(e) = r {
446                    error!(%e, "server->client task panicked");
447                }
448            }
449        }
450
451        // Flush audit log and clean up
452        self.ledger.lock().await.flush();
453        let _ = child.kill().await;
454        info!("MCPDome proxy shut down");
455
456        Ok(())
457    }
458}
459
460/// Helper to record an audit entry.
461async fn record_audit(
462    ledger: &Arc<Mutex<Ledger>>,
463    request_id: Uuid,
464    identity: &str,
465    direction: Direction,
466    method: &str,
467    tool: Option<&str>,
468    decision: &str,
469    rule_id: Option<&str>,
470    latency_us: u64,
471) {
472    let entry = AuditEntry {
473        seq: 0, // set by ledger
474        timestamp: Utc::now(),
475        request_id,
476        identity: identity.to_string(),
477        direction,
478        method: method.to_string(),
479        tool: tool.map(String::from),
480        decision: decision.to_string(),
481        rule_id: rule_id.map(String::from),
482        latency_us,
483        prev_hash: String::new(), // set by ledger
484        annotations: std::collections::HashMap::new(),
485    };
486
487    if let Err(e) = ledger.lock().await.record(entry) {
488        error!(%e, "failed to record audit entry");
489    }
490}