1#![allow(clippy::collapsible_if, clippy::too_many_arguments)]
2
3use std::sync::Arc;
4
5use dome_core::{DomeError, McpMessage};
6use dome_ledger::{AuditEntry, Direction, Ledger};
7use dome_policy::{Identity as PolicyIdentity, SharedPolicyEngine};
8use dome_sentinel::{
9 AnonymousAuthenticator, ApiKeyAuthenticator, Authenticator, IdentityResolver, PskAuthenticator,
10 ResolverConfig,
11};
12use dome_throttle::{BudgetTracker, BudgetTrackerConfig, RateLimiter, RateLimiterConfig};
13use dome_transport::stdio::StdioTransport;
14use dome_ward::schema_pin::DriftSeverity;
15use dome_ward::{InjectionScanner, SchemaPinStore};
16
17use chrono::Utc;
18use serde_json::Value;
19use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
20use tokio::sync::Mutex;
21use tokio::time::{Duration, timeout};
22use tracing::{debug, error, info, warn};
23
24const MAX_LINE_SIZE: usize = 10 * 1024 * 1024; const CLIENT_READ_TIMEOUT: Duration = Duration::from_secs(300); use uuid::Uuid;
27
28enum InboundAction {
34 Forward(McpMessage),
36 Deny(McpMessage),
38}
39
40enum OutboundAction {
42 Forward(McpMessage),
44 Block(McpMessage),
46}
47
48struct OutboundContext {
50 first_tools_list: bool,
51 last_good_tools_result: Option<Value>,
52}
53
54impl OutboundContext {
55 fn new() -> Self {
56 Self {
57 first_tools_list: true,
58 last_good_tools_result: None,
59 }
60 }
61}
62
63#[derive(Debug, Clone)]
69pub struct GateConfig {
70 pub enforce_policy: bool,
72 pub enable_ward: bool,
74 pub enable_schema_pin: bool,
76 pub enable_rate_limit: bool,
78 pub enable_budget: bool,
80 pub allow_anonymous: bool,
82 pub block_outbound_injection: bool,
85}
86
87impl Default for GateConfig {
88 fn default() -> Self {
89 Self {
90 enforce_policy: false,
91 enable_ward: false,
92 enable_schema_pin: false,
93 enable_rate_limit: false,
94 enable_budget: false,
95 allow_anonymous: true,
96 block_outbound_injection: false,
97 }
98 }
99}
100
101pub struct Gate {
119 config: GateConfig,
120 resolver: IdentityResolver,
121 policy_engine: Option<SharedPolicyEngine>,
122 rate_limiter: Arc<RateLimiter>,
123 budget_tracker: Arc<BudgetTracker>,
124 injection_scanner: Arc<InjectionScanner>,
125 schema_store: Arc<Mutex<SchemaPinStore>>,
126 ledger: Arc<Mutex<Ledger>>,
127}
128
129impl std::fmt::Debug for Gate {
130 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131 f.debug_struct("Gate")
132 .field("config", &self.config)
133 .field("has_policy_engine", &self.policy_engine.is_some())
134 .finish_non_exhaustive()
135 }
136}
137
138impl Gate {
139 pub fn new(
146 config: GateConfig,
147 authenticators: Vec<Box<dyn Authenticator>>,
148 policy_engine: Option<SharedPolicyEngine>,
149 rate_limiter_config: RateLimiterConfig,
150 budget_config: BudgetTrackerConfig,
151 ledger: Ledger,
152 ) -> Self {
153 Self {
154 resolver: IdentityResolver::new(
155 authenticators,
156 ResolverConfig {
157 allow_anonymous: config.allow_anonymous,
158 },
159 ),
160 policy_engine,
161 rate_limiter: Arc::new(RateLimiter::new(rate_limiter_config)),
162 budget_tracker: Arc::new(BudgetTracker::new(budget_config)),
163 injection_scanner: Arc::new(InjectionScanner::new()),
164 schema_store: Arc::new(Mutex::new(SchemaPinStore::new())),
165 ledger: Arc::new(Mutex::new(ledger)),
166 config,
167 }
168 }
169
170 pub fn transparent(ledger: Ledger) -> Self {
172 Self::new(
173 GateConfig::default(),
174 vec![Box::new(AnonymousAuthenticator)],
175 None,
176 RateLimiterConfig::default(),
177 BudgetTrackerConfig::default(),
178 ledger,
179 )
180 }
181
182 fn into_proxy_state(self) -> Arc<ProxyState> {
184 Arc::new(ProxyState {
185 identity: Mutex::new(None),
186 resolver: self.resolver,
187 config: self.config,
188 policy: self.policy_engine,
189 rate_limiter: self.rate_limiter,
190 budget: self.budget_tracker,
191 scanner: self.injection_scanner,
192 schema_store: self.schema_store,
193 ledger: self.ledger,
194 })
195 }
196
197 pub async fn run_stdio(self, command: &str, args: &[&str]) -> Result<(), DomeError> {
199 let state = self.into_proxy_state();
200
201 let transport = StdioTransport::spawn(command, args).await?;
202 let (mut server_reader, mut server_writer, mut child) = transport.split();
203
204 let client_stdin = tokio::io::stdin();
205 let client_stdout = tokio::io::stdout();
206 let mut client_reader = BufReader::new(client_stdin);
207 let client_writer: Arc<Mutex<tokio::io::Stdout>> = Arc::new(Mutex::new(client_stdout));
208
209 info!("MCPDome proxy active -- interceptor chain armed");
210
211 let inbound_state = Arc::clone(&state);
212 let inbound_writer = Arc::clone(&client_writer);
213
214 let mut client_to_server = tokio::spawn(async move {
216 let mut line = String::new();
217 loop {
218 line.clear();
219 let read_result =
220 timeout(CLIENT_READ_TIMEOUT, client_reader.read_line(&mut line)).await;
221 let read_result = match read_result {
222 Ok(inner) => inner,
223 Err(_) => {
224 warn!("client read timed out");
225 break;
226 }
227 };
228 match read_result {
229 Ok(0) => {
230 info!("client closed stdin -- shutting down");
231 break;
232 }
233 Ok(_) => {
234 if line.len() > MAX_LINE_SIZE {
235 warn!(
236 size = line.len(),
237 max = MAX_LINE_SIZE,
238 "client message exceeds size limit, dropping"
239 );
240 let err_resp = McpMessage::error_response(
241 Value::Null,
242 -32600,
243 "Message too large",
244 );
245 if let Err(we) = write_to_client(&inbound_writer, &err_resp).await {
246 error!(%we, "failed to send size error to client");
247 break;
248 }
249 continue;
250 }
251 let trimmed = line.trim();
252 if trimmed.is_empty() {
253 continue;
254 }
255
256 match McpMessage::parse(trimmed) {
257 Ok(msg) => match inbound_state.process_inbound(msg).await {
258 InboundAction::Forward(msg) => {
259 if let Err(e) = server_writer.send(&msg).await {
260 error!(%e, "failed to forward to server");
261 break;
262 }
263 }
264 InboundAction::Deny(err_resp) => {
265 if let Err(we) =
266 write_to_client(&inbound_writer, &err_resp).await
267 {
268 error!(%we, "failed to send error to client");
269 break;
270 }
271 }
272 },
273 Err(e) => {
274 warn!(%e, raw = trimmed, "invalid JSON from client, dropping");
275 let err_resp = McpMessage::error_response(
276 Value::Null,
277 -32700,
278 "Parse error: invalid JSON",
279 );
280 if let Err(we) = write_to_client(&inbound_writer, &err_resp).await {
281 error!(%we, "failed to send parse error to client");
282 break;
283 }
284 }
285 }
286 }
287 Err(e) => {
288 error!(%e, "error reading from client");
289 break;
290 }
291 }
292 }
293 });
294
295 let outbound_state = Arc::clone(&state);
296 let outbound_writer = Arc::clone(&client_writer);
297
298 let mut server_to_client = tokio::spawn(async move {
300 let mut ctx = OutboundContext::new();
301 loop {
302 match server_reader.recv().await {
303 Ok(msg) => match outbound_state.process_outbound(msg, &mut ctx).await {
304 OutboundAction::Forward(msg) => {
305 if let Err(e) = write_to_client(&outbound_writer, &msg).await {
306 error!(%e, "failed to write to client");
307 break;
308 }
309 }
310 OutboundAction::Block(err_resp) => {
311 if let Err(e) = write_to_client(&outbound_writer, &err_resp).await {
312 error!(%e, "failed to send outbound error to client");
313 break;
314 }
315 }
316 },
317 Err(DomeError::Transport(ref e))
318 if e.kind() == std::io::ErrorKind::UnexpectedEof =>
319 {
320 info!("server closed stdout -- shutting down");
321 break;
322 }
323 Err(e) => {
324 error!(%e, "error reading from server");
325 break;
326 }
327 }
328 }
329 });
330
331 tokio::select! {
333 r = &mut client_to_server => {
334 if let Err(e) = r {
335 error!(%e, "client->server task panicked");
336 }
337 server_to_client.abort();
338 }
339 r = &mut server_to_client => {
340 if let Err(e) = r {
341 error!(%e, "server->client task panicked");
342 }
343 client_to_server.abort();
344 }
345 }
346
347 state.ledger.lock().await.flush();
349
350 match tokio::time::timeout(Duration::from_secs(5), child.wait()).await {
354 Ok(Ok(status)) => info!(%status, "upstream server exited"),
355 Ok(Err(e)) => warn!(%e, "error waiting for upstream server"),
356 Err(_) => {
357 warn!("upstream server did not exit within 5s, forcing termination");
358 let _ = child.kill().await;
359 }
360 }
361
362 info!("MCPDome proxy shut down");
363 Ok(())
364 }
365
366 #[cfg(feature = "http")]
369 pub async fn run_http(
370 self,
371 command: &str,
372 args: &[&str],
373 http_config: dome_transport::http::HttpTransportConfig,
374 ) -> Result<(), DomeError> {
375 let state = self.into_proxy_state();
376
377 let transport = StdioTransport::spawn(command, args).await?;
378 let (mut server_reader, mut server_writer, mut child) = transport.split();
379
380 let http = dome_transport::http::HttpTransport::start(http_config).await?;
381 let (mut http_reader, http_writer, http_handle) = http.split();
382 let http_writer = Arc::new(http_writer);
383
384 info!("MCPDome HTTP+SSE proxy active -- interceptor chain armed");
385
386 let inbound_state = Arc::clone(&state);
387 let inbound_http_writer = Arc::clone(&http_writer);
388
389 let mut client_to_server = tokio::spawn(async move {
391 loop {
392 match http_reader.recv().await {
393 Ok(msg) => match inbound_state.process_inbound(msg).await {
394 InboundAction::Forward(msg) => {
395 if let Err(e) = server_writer.send(&msg).await {
396 error!(%e, "failed to forward to server");
397 break;
398 }
399 }
400 InboundAction::Deny(err_resp) => {
401 if let Err(e) = inbound_http_writer.send(&err_resp).await {
402 warn!(%e, "failed to send error to HTTP client");
403 }
404 }
405 },
406 Err(e) => {
407 info!(%e, "HTTP client transport closed");
408 break;
409 }
410 }
411 }
412 });
413
414 let outbound_state = Arc::clone(&state);
415 let outbound_http_writer = Arc::clone(&http_writer);
416
417 let mut server_to_client = tokio::spawn(async move {
419 let mut ctx = OutboundContext::new();
420 loop {
421 match server_reader.recv().await {
422 Ok(msg) => match outbound_state.process_outbound(msg, &mut ctx).await {
423 OutboundAction::Forward(msg) => {
424 if let Err(e) = outbound_http_writer.send(&msg).await {
425 warn!(%e, "failed to send to HTTP client");
426 break;
427 }
428 }
429 OutboundAction::Block(err_resp) => {
430 if let Err(e) = outbound_http_writer.send(&err_resp).await {
431 warn!(%e, "failed to send outbound error to HTTP client");
432 break;
433 }
434 }
435 },
436 Err(DomeError::Transport(ref e))
437 if e.kind() == std::io::ErrorKind::UnexpectedEof =>
438 {
439 info!("server closed stdout -- shutting down");
440 break;
441 }
442 Err(e) => {
443 error!(%e, "error reading from server");
444 break;
445 }
446 }
447 }
448 });
449
450 tokio::select! {
452 r = &mut client_to_server => {
453 if let Err(e) = r {
454 error!(%e, "client->server task panicked");
455 }
456 server_to_client.abort();
457 }
458 r = &mut server_to_client => {
459 if let Err(e) = r {
460 error!(%e, "server->client task panicked");
461 }
462 client_to_server.abort();
463 }
464 }
465
466 state.ledger.lock().await.flush();
468
469 http_handle.shutdown().await;
471
472 match tokio::time::timeout(Duration::from_secs(5), child.wait()).await {
474 Ok(Ok(status)) => info!(%status, "upstream server exited"),
475 Ok(Err(e)) => warn!(%e, "error waiting for upstream server"),
476 Err(_) => {
477 warn!("upstream server did not exit within 5s, forcing termination");
478 let _ = child.kill().await;
479 }
480 }
481
482 info!("MCPDome HTTP+SSE proxy shut down");
483 Ok(())
484 }
485}
486
487struct ProxyState {
496 identity: Mutex<Option<dome_sentinel::Identity>>,
497 resolver: IdentityResolver,
498 config: GateConfig,
499 policy: Option<SharedPolicyEngine>,
500 rate_limiter: Arc<RateLimiter>,
501 budget: Arc<BudgetTracker>,
502 scanner: Arc<InjectionScanner>,
503 schema_store: Arc<Mutex<SchemaPinStore>>,
504 ledger: Arc<Mutex<Ledger>>,
505}
506
507impl ProxyState {
508 async fn process_inbound(&self, msg: McpMessage) -> InboundAction {
511 let start = std::time::Instant::now();
512 let method = msg.method.as_deref().unwrap_or("-").to_string();
513 let tool = msg.tool_name().map(String::from);
514 let request_id = Uuid::new_v4();
515
516 debug!(
517 method = method.as_str(),
518 id = ?msg.id,
519 tool = tool.as_deref().unwrap_or("-"),
520 "client -> server"
521 );
522
523 let mut msg = msg;
525 if method == "initialize" {
526 match self.resolver.resolve(&msg).await {
527 Ok(id) => {
528 info!(
529 principal = %id.principal,
530 method = %id.auth_method,
531 "identity resolved"
532 );
533 *self.identity.lock().await = Some(id);
534
535 msg = PskAuthenticator::strip_psk(&msg);
537 msg = ApiKeyAuthenticator::strip_api_key(&msg);
538 }
539 Err(e) => {
540 warn!(%e, "authentication failed");
541 let err_id = msg.id.clone().unwrap_or(Value::Null);
542 return InboundAction::Deny(McpMessage::error_response(
543 err_id,
544 -32600,
545 "Authentication failed",
546 ));
547 }
548 }
549 }
550
551 if method != "initialize" {
554 let identity_lock = self.identity.lock().await;
555 if identity_lock.is_none() {
556 drop(identity_lock);
557 warn!(method = %method, "request before initialize");
558 let err_id = msg.id.clone().unwrap_or(Value::Null);
559 return InboundAction::Deny(McpMessage::error_response(
560 err_id,
561 -32600,
562 "Session not initialized",
563 ));
564 }
565 drop(identity_lock);
566 }
567
568 let identity_lock = self.identity.lock().await;
569 let principal = identity_lock
570 .as_ref()
571 .map(|i| i.principal.clone())
572 .unwrap_or_else(|| "anonymous".to_string());
573 let labels = identity_lock
574 .as_ref()
575 .map(|i| i.labels.clone())
576 .unwrap_or_default();
577 drop(identity_lock);
578
579 let resource_name = msg.method_resource_name().unwrap_or("-").to_string();
581 let tool_name = resource_name.as_str();
582
583 let args = msg
584 .params
585 .as_ref()
586 .and_then(|p| p.get("arguments"))
587 .cloned()
588 .unwrap_or(Value::Null);
589
590 if self.config.enable_rate_limit {
592 let rl_tool = if tool_name != "-" {
593 Some(tool_name)
594 } else {
595 None
596 };
597 if let Err(e) = self.rate_limiter.check_rate_limit(&principal, rl_tool) {
598 warn!(%e, principal = %principal, method = %method, "rate limited");
599 record_audit(
600 &self.ledger,
601 request_id,
602 &principal,
603 Direction::Inbound,
604 &method,
605 tool.as_deref(),
606 "deny:rate_limit",
607 None,
608 start.elapsed().as_micros() as u64,
609 )
610 .await;
611 let err_id = msg.id.clone().unwrap_or(Value::Null);
612 return InboundAction::Deny(McpMessage::error_response(
613 err_id,
614 -32000,
615 "Rate limit exceeded",
616 ));
617 }
618 }
619
620 if self.config.enable_budget {
622 if let Err(e) = self.budget.try_spend(&principal, 1.0) {
623 warn!(%e, principal = %principal, "budget exhausted");
624 record_audit(
625 &self.ledger,
626 request_id,
627 &principal,
628 Direction::Inbound,
629 &method,
630 tool.as_deref(),
631 "deny:budget",
632 None,
633 start.elapsed().as_micros() as u64,
634 )
635 .await;
636 let err_id = msg.id.clone().unwrap_or(Value::Null);
637 return InboundAction::Deny(McpMessage::error_response(
638 err_id,
639 -32000,
640 "Budget exhausted",
641 ));
642 }
643 }
644
645 if self.config.enable_ward {
649 let scan_text = if method == "tools/call" {
650 serde_json::to_string(&args).unwrap_or_default()
651 } else if let Some(ref params) = msg.params {
652 serde_json::to_string(params).unwrap_or_default()
653 } else {
654 String::new()
655 };
656
657 if !scan_text.is_empty() {
658 let matches = self.scanner.scan_text(&scan_text);
659 if !matches.is_empty() {
660 let pattern_names: Vec<&str> =
661 matches.iter().map(|m| m.pattern_name.as_str()).collect();
662 warn!(
663 patterns = ?pattern_names,
664 method = %method,
665 tool = tool_name,
666 principal = %principal,
667 "injection detected"
668 );
669 record_audit(
670 &self.ledger,
671 request_id,
672 &principal,
673 Direction::Inbound,
674 &method,
675 tool.as_deref(),
676 &format!("deny:injection:{}", pattern_names.join(",")),
677 None,
678 start.elapsed().as_micros() as u64,
679 )
680 .await;
681 let err_id = msg.id.clone().unwrap_or(Value::Null);
682 return InboundAction::Deny(McpMessage::error_response(
683 err_id,
684 -32003,
685 "Request blocked: injection pattern detected",
686 ));
687 }
688 }
689 }
690
691 if self.config.enforce_policy {
693 if let Some(ref shared_engine) = self.policy {
694 let engine = shared_engine.load();
697
698 let policy_resource = if method == "tools/call" {
699 tool_name
700 } else {
701 method.as_str()
702 };
703 let policy_id = PolicyIdentity::new(principal.clone(), labels.iter().cloned());
704 let decision = engine.evaluate(&policy_id, policy_resource, &args);
705
706 if !decision.is_allowed() {
707 warn!(
708 rule_id = %decision.rule_id,
709 method = %method,
710 resource = policy_resource,
711 principal = %principal,
712 "policy denied"
713 );
714 record_audit(
715 &self.ledger,
716 request_id,
717 &principal,
718 Direction::Inbound,
719 &method,
720 tool.as_deref(),
721 &format!("deny:policy:{}", decision.rule_id),
722 Some(&decision.rule_id),
723 start.elapsed().as_micros() as u64,
724 )
725 .await;
726 let err_id = msg.id.clone().unwrap_or(Value::Null);
727 return InboundAction::Deny(McpMessage::error_response(
728 err_id,
729 -32003,
730 format!("Denied by policy: {}", decision.rule_id),
731 ));
732 }
733 }
734 }
735
736 record_audit(
738 &self.ledger,
739 request_id,
740 &principal,
741 Direction::Inbound,
742 &method,
743 tool.as_deref(),
744 "allow",
745 None,
746 start.elapsed().as_micros() as u64,
747 )
748 .await;
749
750 InboundAction::Forward(msg)
751 }
752
753 async fn process_outbound(&self, msg: McpMessage, ctx: &mut OutboundContext) -> OutboundAction {
756 let start = std::time::Instant::now();
757 let method = msg.method.as_deref().unwrap_or("-").to_string();
758 let outbound_request_id = Uuid::new_v4();
759
760 debug!(
761 method = method.as_str(),
762 id = ?msg.id,
763 "server -> client"
764 );
765
766 let mut forward_msg = msg;
767
768 if self.config.enable_schema_pin {
770 if let Some(result) = &forward_msg.result {
771 if result.get("tools").is_some() {
772 let mut store = self.schema_store.lock().await;
773 if ctx.first_tools_list {
774 store.pin_tools(result);
775 info!(pinned = store.len(), "schema pins established");
776 ctx.last_good_tools_result = Some(result.clone());
777 ctx.first_tools_list = false;
778 } else {
779 let drifts = store.verify_tools(result);
780 if !drifts.is_empty() {
781 let mut has_critical_or_high = false;
782 for drift in &drifts {
783 warn!(
784 tool = %drift.tool_name,
785 drift_type = ?drift.drift_type,
786 severity = ?drift.severity,
787 "schema drift detected"
788 );
789 if matches!(
790 drift.severity,
791 DriftSeverity::Critical | DriftSeverity::High
792 ) {
793 has_critical_or_high = true;
794 }
795 }
796
797 if has_critical_or_high {
798 warn!(
799 "critical/high schema drift detected -- blocking drifted tools/list"
800 );
801 record_audit(
802 &self.ledger,
803 outbound_request_id,
804 "server",
805 Direction::Outbound,
806 "tools/list",
807 None,
808 "deny:schema_drift",
809 None,
810 start.elapsed().as_micros() as u64,
811 )
812 .await;
813
814 if let Some(ref good_result) = ctx.last_good_tools_result {
815 forward_msg.result = Some(good_result.clone());
816 } else {
817 let err_id = forward_msg.id.clone().unwrap_or(Value::Null);
818 return OutboundAction::Block(McpMessage::error_response(
819 err_id,
820 -32003,
821 "Schema drift detected: tool definitions have been tampered with",
822 ));
823 }
824 }
825 }
826 }
827 }
828 }
829 }
830
831 if self.config.enable_ward {
833 if let Some(ref result) = forward_msg.result {
834 let scan_target = if let Some(content) = result.get("content") {
835 serde_json::to_string(content).unwrap_or_default()
836 } else {
837 serde_json::to_string(result).unwrap_or_default()
838 };
839
840 if !scan_target.is_empty() {
841 let matches = self.scanner.scan_text(&scan_target);
842 if !matches.is_empty() {
843 let pattern_names: Vec<&str> =
844 matches.iter().map(|m| m.pattern_name.as_str()).collect();
845 let decision = if self.config.block_outbound_injection {
846 "deny:outbound_injection"
847 } else {
848 "warn:outbound_injection"
849 };
850 warn!(
851 patterns = ?pattern_names,
852 direction = "outbound",
853 blocked = self.config.block_outbound_injection,
854 "injection detected in server response"
855 );
856 record_audit(
857 &self.ledger,
858 outbound_request_id,
859 "server",
860 Direction::Outbound,
861 &method,
862 None,
863 &format!("{}:{}", decision, pattern_names.join(",")),
864 None,
865 start.elapsed().as_micros() as u64,
866 )
867 .await;
868
869 if self.config.block_outbound_injection {
870 let err_id = forward_msg.id.clone().unwrap_or(Value::Null);
871 return OutboundAction::Block(McpMessage::error_response(
872 err_id,
873 -32005,
874 "Response blocked: injection pattern detected in server output",
875 ));
876 }
877 }
878 }
879 }
880 }
881
882 record_audit(
884 &self.ledger,
885 outbound_request_id,
886 "server",
887 Direction::Outbound,
888 &method,
889 None,
890 "forward",
891 None,
892 start.elapsed().as_micros() as u64,
893 )
894 .await;
895
896 OutboundAction::Forward(forward_msg)
897 }
898}
899
900async fn write_to_client(
906 writer: &Arc<Mutex<tokio::io::Stdout>>,
907 msg: &McpMessage,
908) -> Result<(), std::io::Error> {
909 match msg.to_json() {
910 Ok(json) => {
911 let mut out = json.into_bytes();
912 out.push(b'\n');
913 let mut w = writer.lock().await;
914 w.write_all(&out).await?;
915 w.flush().await?;
916 Ok(())
917 }
918 Err(e) => {
919 error!(%e, "failed to serialize message for client");
920 Err(std::io::Error::new(std::io::ErrorKind::InvalidData, e))
921 }
922 }
923}
924
925async fn record_audit(
927 ledger: &Arc<Mutex<Ledger>>,
928 request_id: Uuid,
929 identity: &str,
930 direction: Direction,
931 method: &str,
932 tool: Option<&str>,
933 decision: &str,
934 rule_id: Option<&str>,
935 latency_us: u64,
936) {
937 let entry = AuditEntry {
938 seq: 0, timestamp: Utc::now(),
940 request_id,
941 identity: identity.to_string(),
942 direction,
943 method: method.to_string(),
944 tool: tool.map(String::from),
945 decision: decision.to_string(),
946 rule_id: rule_id.map(String::from),
947 latency_us,
948 prev_hash: String::new(), annotations: std::collections::HashMap::new(),
950 };
951
952 if let Err(e) = ledger.lock().await.record(entry) {
953 error!(%e, "failed to record audit entry");
954 }
955}
956
957#[cfg(test)]
958mod tests {
959 use super::*;
960 use dome_ledger::MemorySink;
961
962 fn test_ledger() -> Ledger {
968 Ledger::new(vec![Box::new(MemorySink::new())])
969 }
970
971 fn empty_ledger() -> Ledger {
974 Ledger::new(vec![])
975 }
976
977 #[test]
982 fn gate_config_defaults_all_security_disabled() {
983 let config = GateConfig::default();
984
985 assert!(
986 !config.enforce_policy,
987 "enforce_policy should default to false"
988 );
989 assert!(!config.enable_ward, "enable_ward should default to false");
990 assert!(
991 !config.enable_schema_pin,
992 "enable_schema_pin should default to false"
993 );
994 assert!(
995 !config.enable_rate_limit,
996 "enable_rate_limit should default to false"
997 );
998 assert!(
999 !config.enable_budget,
1000 "enable_budget should default to false"
1001 );
1002 assert!(
1003 config.allow_anonymous,
1004 "allow_anonymous should default to true"
1005 );
1006 assert!(
1007 !config.block_outbound_injection,
1008 "block_outbound_injection should default to false"
1009 );
1010 }
1011
1012 #[test]
1013 fn gate_config_with_all_enabled() {
1014 let config = GateConfig {
1015 enforce_policy: true,
1016 enable_ward: true,
1017 enable_schema_pin: true,
1018 enable_rate_limit: true,
1019 enable_budget: true,
1020 allow_anonymous: false,
1021 block_outbound_injection: true,
1022 };
1023
1024 assert!(config.enforce_policy);
1025 assert!(config.enable_ward);
1026 assert!(config.enable_schema_pin);
1027 assert!(config.enable_rate_limit);
1028 assert!(config.enable_budget);
1029 assert!(!config.allow_anonymous);
1030 assert!(config.block_outbound_injection);
1031 }
1032
1033 #[test]
1034 fn gate_config_is_cloneable() {
1035 let original = GateConfig {
1036 enforce_policy: true,
1037 enable_ward: true,
1038 enable_schema_pin: false,
1039 enable_rate_limit: true,
1040 enable_budget: false,
1041 allow_anonymous: false,
1042 block_outbound_injection: true,
1043 };
1044 let cloned = original.clone();
1045
1046 assert_eq!(cloned.enforce_policy, original.enforce_policy);
1047 assert_eq!(cloned.enable_ward, original.enable_ward);
1048 assert_eq!(cloned.enable_schema_pin, original.enable_schema_pin);
1049 assert_eq!(cloned.enable_rate_limit, original.enable_rate_limit);
1050 assert_eq!(cloned.enable_budget, original.enable_budget);
1051 assert_eq!(cloned.allow_anonymous, original.allow_anonymous);
1052 assert_eq!(
1053 cloned.block_outbound_injection,
1054 original.block_outbound_injection
1055 );
1056 }
1057
1058 #[test]
1059 fn gate_config_is_debug_printable() {
1060 let config = GateConfig::default();
1061 let debug_output = format!("{:?}", config);
1062
1063 assert!(debug_output.contains("GateConfig"));
1064 assert!(debug_output.contains("enforce_policy"));
1065 assert!(debug_output.contains("enable_ward"));
1066 }
1067
1068 #[test]
1073 fn transparent_gate_has_correct_config_defaults() {
1074 let gate = Gate::transparent(empty_ledger());
1075
1076 assert!(
1077 !gate.config.enforce_policy,
1078 "transparent gate should not enforce policy"
1079 );
1080 assert!(
1081 !gate.config.enable_ward,
1082 "transparent gate should not enable ward"
1083 );
1084 assert!(
1085 !gate.config.enable_schema_pin,
1086 "transparent gate should not enable schema pinning"
1087 );
1088 assert!(
1089 !gate.config.enable_rate_limit,
1090 "transparent gate should not enable rate limiting"
1091 );
1092 assert!(
1093 !gate.config.enable_budget,
1094 "transparent gate should not enable budget tracking"
1095 );
1096 assert!(
1097 gate.config.allow_anonymous,
1098 "transparent gate should allow anonymous access"
1099 );
1100 assert!(
1101 !gate.config.block_outbound_injection,
1102 "transparent gate should not block outbound injection"
1103 );
1104 }
1105
1106 #[test]
1107 fn transparent_gate_has_no_policy_engine() {
1108 let gate = Gate::transparent(empty_ledger());
1109 assert!(
1110 gate.policy_engine.is_none(),
1111 "transparent gate should have no policy engine"
1112 );
1113 }
1114
1115 #[test]
1120 fn gate_new_with_custom_config_preserves_flags() {
1121 let config = GateConfig {
1122 enforce_policy: true,
1123 enable_ward: true,
1124 enable_schema_pin: true,
1125 enable_rate_limit: true,
1126 enable_budget: true,
1127 allow_anonymous: false,
1128 block_outbound_injection: true,
1129 };
1130
1131 let gate = Gate::new(
1132 config,
1133 vec![Box::new(AnonymousAuthenticator)],
1134 None,
1135 RateLimiterConfig::default(),
1136 BudgetTrackerConfig::default(),
1137 empty_ledger(),
1138 );
1139
1140 assert!(gate.config.enforce_policy);
1141 assert!(gate.config.enable_ward);
1142 assert!(gate.config.enable_schema_pin);
1143 assert!(gate.config.enable_rate_limit);
1144 assert!(gate.config.enable_budget);
1145 assert!(!gate.config.allow_anonymous);
1146 assert!(gate.config.block_outbound_injection);
1147 }
1148
1149 #[test]
1150 fn gate_new_without_policy_engine_stores_none() {
1151 let gate = Gate::new(
1152 GateConfig::default(),
1153 vec![Box::new(AnonymousAuthenticator)],
1154 None,
1155 RateLimiterConfig::default(),
1156 BudgetTrackerConfig::default(),
1157 empty_ledger(),
1158 );
1159
1160 assert!(gate.policy_engine.is_none());
1161 }
1162
1163 #[test]
1168 fn gate_is_debug_printable() {
1169 let gate = Gate::transparent(empty_ledger());
1170 let debug_output = format!("{:?}", gate);
1171
1172 assert!(debug_output.contains("Gate"));
1173 assert!(debug_output.contains("config"));
1174 assert!(debug_output.contains("has_policy_engine"));
1175 }
1176
1177 #[tokio::test]
1182 async fn record_audit_creates_entry_with_correct_fields() {
1183 let ledger = Arc::new(Mutex::new(test_ledger()));
1184 let request_id = Uuid::new_v4();
1185
1186 record_audit(
1187 &ledger,
1188 request_id,
1189 "test-user",
1190 Direction::Inbound,
1191 "tools/call",
1192 Some("read_file"),
1193 "allow",
1194 Some("rule-1"),
1195 42,
1196 )
1197 .await;
1198
1199 let ledger_guard = ledger.lock().await;
1200 assert_eq!(
1201 ledger_guard.entry_count(),
1202 1,
1203 "should have recorded exactly one entry"
1204 );
1205 }
1206
1207 #[tokio::test]
1208 async fn record_audit_with_no_tool_and_no_rule() {
1209 let ledger = Arc::new(Mutex::new(test_ledger()));
1210 let request_id = Uuid::new_v4();
1211
1212 record_audit(
1213 &ledger,
1214 request_id,
1215 "anonymous",
1216 Direction::Outbound,
1217 "initialize",
1218 None,
1219 "forward",
1220 None,
1221 0,
1222 )
1223 .await;
1224
1225 let ledger_guard = ledger.lock().await;
1226 assert_eq!(ledger_guard.entry_count(), 1);
1227 }
1228
1229 #[tokio::test]
1230 async fn record_audit_multiple_entries_increment_count() {
1231 let ledger = Arc::new(Mutex::new(test_ledger()));
1232
1233 for i in 0..5 {
1234 record_audit(
1235 &ledger,
1236 Uuid::new_v4(),
1237 &format!("user-{i}"),
1238 Direction::Inbound,
1239 "tools/call",
1240 Some("test_tool"),
1241 "allow",
1242 None,
1243 i * 10,
1244 )
1245 .await;
1246 }
1247
1248 let ledger_guard = ledger.lock().await;
1249 assert_eq!(ledger_guard.entry_count(), 5);
1250 }
1251
1252 #[tokio::test]
1253 async fn record_audit_deny_decisions_are_recorded() {
1254 let ledger = Arc::new(Mutex::new(test_ledger()));
1255
1256 record_audit(
1257 &ledger,
1258 Uuid::new_v4(),
1259 "malicious-user",
1260 Direction::Inbound,
1261 "tools/call",
1262 Some("exec_command"),
1263 "deny:policy:no-exec",
1264 Some("no-exec"),
1265 150,
1266 )
1267 .await;
1268
1269 record_audit(
1270 &ledger,
1271 Uuid::new_v4(),
1272 "spammer",
1273 Direction::Inbound,
1274 "tools/call",
1275 Some("spam_tool"),
1276 "deny:rate_limit",
1277 None,
1278 5,
1279 )
1280 .await;
1281
1282 record_audit(
1283 &ledger,
1284 Uuid::new_v4(),
1285 "attacker",
1286 Direction::Inbound,
1287 "tools/call",
1288 Some("read_file"),
1289 "deny:injection:prompt_injection",
1290 None,
1291 200,
1292 )
1293 .await;
1294
1295 let ledger_guard = ledger.lock().await;
1296 assert_eq!(ledger_guard.entry_count(), 3);
1297 }
1298
1299 #[tokio::test]
1300 async fn record_audit_outbound_schema_drift() {
1301 let ledger = Arc::new(Mutex::new(test_ledger()));
1302
1303 record_audit(
1304 &ledger,
1305 Uuid::new_v4(),
1306 "server",
1307 Direction::Outbound,
1308 "tools/list",
1309 None,
1310 "deny:schema_drift",
1311 None,
1312 75,
1313 )
1314 .await;
1315
1316 let ledger_guard = ledger.lock().await;
1317 assert_eq!(ledger_guard.entry_count(), 1);
1318 }
1319
1320 #[test]
1325 fn max_line_size_is_ten_megabytes() {
1326 assert_eq!(MAX_LINE_SIZE, 10 * 1024 * 1024);
1327 }
1328
1329 #[test]
1330 fn client_read_timeout_is_five_minutes() {
1331 assert_eq!(CLIENT_READ_TIMEOUT, Duration::from_secs(300));
1332 }
1333
1334 #[test]
1339 fn gate_config_ward_only_mode() {
1340 let config = GateConfig {
1341 enable_ward: true,
1342 ..GateConfig::default()
1343 };
1344
1345 assert!(config.enable_ward);
1346 assert!(!config.enforce_policy, "policy should remain off");
1347 assert!(!config.enable_rate_limit, "rate limit should remain off");
1348 assert!(!config.enable_budget, "budget should remain off");
1349 assert!(!config.enable_schema_pin, "schema pin should remain off");
1350 assert!(config.allow_anonymous, "anonymous should remain on");
1351 }
1352
1353 #[test]
1354 fn gate_config_full_security_mode() {
1355 let config = GateConfig {
1356 enforce_policy: true,
1357 enable_ward: true,
1358 enable_schema_pin: true,
1359 enable_rate_limit: true,
1360 enable_budget: true,
1361 allow_anonymous: false,
1362 block_outbound_injection: true,
1363 };
1364
1365 assert!(config.enforce_policy);
1367 assert!(config.enable_ward);
1368 assert!(config.enable_schema_pin);
1369 assert!(config.enable_rate_limit);
1370 assert!(config.enable_budget);
1371 assert!(!config.allow_anonymous);
1372 assert!(config.block_outbound_injection);
1373 }
1374
1375 #[test]
1380 fn gate_new_with_multiple_authenticators_does_not_panic() {
1381 let _gate = Gate::new(
1382 GateConfig::default(),
1383 vec![
1384 Box::new(AnonymousAuthenticator),
1385 Box::new(AnonymousAuthenticator),
1386 ],
1387 None,
1388 RateLimiterConfig::default(),
1389 BudgetTrackerConfig::default(),
1390 empty_ledger(),
1391 );
1392 }
1393
1394 #[test]
1395 fn gate_new_with_empty_authenticators_does_not_panic() {
1396 let _gate = Gate::new(
1397 GateConfig::default(),
1398 vec![],
1399 None,
1400 RateLimiterConfig::default(),
1401 BudgetTrackerConfig::default(),
1402 empty_ledger(),
1403 );
1404 }
1405
1406 #[tokio::test]
1411 async fn record_audit_entry_fields_populated_correctly() {
1412 let ledger = Arc::new(Mutex::new(Ledger::new(vec![Box::new(MemorySink::new())])));
1414 let request_id = Uuid::new_v4();
1415
1416 record_audit(
1417 &ledger,
1418 request_id,
1419 "psk:dev-key-1",
1420 Direction::Inbound,
1421 "tools/call",
1422 Some("read_file"),
1423 "deny:policy:no-read",
1424 Some("no-read"),
1425 999,
1426 )
1427 .await;
1428
1429 let guard = ledger.lock().await;
1431 assert_eq!(guard.entry_count(), 1);
1432 }
1433
1434 #[tokio::test]
1435 async fn record_audit_handles_empty_identity() {
1436 let ledger = Arc::new(Mutex::new(test_ledger()));
1437
1438 record_audit(
1439 &ledger,
1440 Uuid::new_v4(),
1441 "",
1442 Direction::Inbound,
1443 "initialize",
1444 None,
1445 "allow",
1446 None,
1447 0,
1448 )
1449 .await;
1450
1451 let guard = ledger.lock().await;
1452 assert_eq!(
1453 guard.entry_count(),
1454 1,
1455 "empty identity should not prevent recording"
1456 );
1457 }
1458
1459 #[tokio::test]
1460 async fn record_audit_handles_long_decision_string() {
1461 let ledger = Arc::new(Mutex::new(test_ledger()));
1462 let long_decision = format!("deny:injection:{}", "pattern_name,".repeat(50));
1463
1464 record_audit(
1465 &ledger,
1466 Uuid::new_v4(),
1467 "test-user",
1468 Direction::Inbound,
1469 "tools/call",
1470 Some("risky_tool"),
1471 &long_decision,
1472 None,
1473 500,
1474 )
1475 .await;
1476
1477 let guard = ledger.lock().await;
1478 assert_eq!(
1479 guard.entry_count(),
1480 1,
1481 "long decision strings should be accepted"
1482 );
1483 }
1484}