1use std::sync::Arc;
2
3use dome_core::{DomeError, McpMessage};
4use dome_ledger::{AuditEntry, Direction, Ledger};
5use dome_policy::{Identity as PolicyIdentity, SharedPolicyEngine};
6use dome_sentinel::{
7 AnonymousAuthenticator, ApiKeyAuthenticator, Authenticator, IdentityResolver, PskAuthenticator,
8 ResolverConfig,
9};
10use dome_throttle::{BudgetTracker, BudgetTrackerConfig, RateLimiter, RateLimiterConfig};
11use dome_transport::stdio::StdioTransport;
12use dome_ward::schema_pin::DriftSeverity;
13use dome_ward::{InjectionScanner, SchemaPinStore};
14
15use chrono::Utc;
16use serde_json::Value;
17use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
18use tokio::sync::Mutex;
19use tokio::time::{Duration, timeout};
20use tracing::{debug, error, info, warn};
21
22const MAX_LINE_SIZE: usize = 10 * 1024 * 1024; const CLIENT_READ_TIMEOUT: Duration = Duration::from_secs(300); use uuid::Uuid;
25
26enum InboundAction {
32 Forward(McpMessage),
34 Deny(McpMessage),
36}
37
38enum OutboundAction {
40 Forward(McpMessage),
42 Block(McpMessage),
44}
45
46struct OutboundContext {
48 first_tools_list: bool,
49 last_good_tools_result: Option<Value>,
50}
51
52impl OutboundContext {
53 fn new() -> Self {
54 Self {
55 first_tools_list: true,
56 last_good_tools_result: None,
57 }
58 }
59}
60
61#[derive(Debug, Clone)]
67pub struct GateConfig {
68 pub enforce_policy: bool,
70 pub enable_ward: bool,
72 pub enable_schema_pin: bool,
74 pub enable_rate_limit: bool,
76 pub enable_budget: bool,
78 pub allow_anonymous: bool,
80 pub block_outbound_injection: bool,
83}
84
85impl Default for GateConfig {
86 fn default() -> Self {
87 Self {
88 enforce_policy: false,
89 enable_ward: false,
90 enable_schema_pin: false,
91 enable_rate_limit: false,
92 enable_budget: false,
93 allow_anonymous: true,
94 block_outbound_injection: false,
95 }
96 }
97}
98
99pub struct Gate {
117 config: GateConfig,
118 resolver: IdentityResolver,
119 policy_engine: Option<SharedPolicyEngine>,
120 rate_limiter: Arc<RateLimiter>,
121 budget_tracker: Arc<BudgetTracker>,
122 injection_scanner: Arc<InjectionScanner>,
123 schema_store: Arc<Mutex<SchemaPinStore>>,
124 ledger: Arc<Mutex<Ledger>>,
125}
126
127impl std::fmt::Debug for Gate {
128 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129 f.debug_struct("Gate")
130 .field("config", &self.config)
131 .field("has_policy_engine", &self.policy_engine.is_some())
132 .finish_non_exhaustive()
133 }
134}
135
136impl Gate {
137 pub fn new(
144 config: GateConfig,
145 authenticators: Vec<Box<dyn Authenticator>>,
146 policy_engine: Option<SharedPolicyEngine>,
147 rate_limiter_config: RateLimiterConfig,
148 budget_config: BudgetTrackerConfig,
149 ledger: Ledger,
150 ) -> Self {
151 Self {
152 resolver: IdentityResolver::new(
153 authenticators,
154 ResolverConfig {
155 allow_anonymous: config.allow_anonymous,
156 },
157 ),
158 policy_engine,
159 rate_limiter: Arc::new(RateLimiter::new(rate_limiter_config)),
160 budget_tracker: Arc::new(BudgetTracker::new(budget_config)),
161 injection_scanner: Arc::new(InjectionScanner::new()),
162 schema_store: Arc::new(Mutex::new(SchemaPinStore::new())),
163 ledger: Arc::new(Mutex::new(ledger)),
164 config,
165 }
166 }
167
168 pub fn transparent(ledger: Ledger) -> Self {
170 Self::new(
171 GateConfig::default(),
172 vec![Box::new(AnonymousAuthenticator)],
173 None,
174 RateLimiterConfig::default(),
175 BudgetTrackerConfig::default(),
176 ledger,
177 )
178 }
179
180 fn into_proxy_state(self) -> Arc<ProxyState> {
182 Arc::new(ProxyState {
183 identity: Mutex::new(None),
184 resolver: self.resolver,
185 config: self.config,
186 policy: self.policy_engine,
187 rate_limiter: self.rate_limiter,
188 budget: self.budget_tracker,
189 scanner: self.injection_scanner,
190 schema_store: self.schema_store,
191 ledger: self.ledger,
192 })
193 }
194
195 pub async fn run_stdio(self, command: &str, args: &[&str]) -> Result<(), DomeError> {
197 let state = self.into_proxy_state();
198
199 let transport = StdioTransport::spawn(command, args).await?;
200 let (mut server_reader, mut server_writer, child) = transport.split();
201
202 let client_stdin = tokio::io::stdin();
203 let client_stdout = tokio::io::stdout();
204 let mut client_reader = BufReader::new(client_stdin);
205 let client_writer: Arc<Mutex<tokio::io::Stdout>> = Arc::new(Mutex::new(client_stdout));
206
207 info!("Thunder Dome proxy active -- interceptor chain armed");
208
209 let inbound_state = Arc::clone(&state);
210 let inbound_writer = Arc::clone(&client_writer);
211
212 let mut client_to_server = tokio::spawn(async move {
214 let mut line = String::new();
215 loop {
216 line.clear();
217 let read_result =
218 timeout(CLIENT_READ_TIMEOUT, client_reader.read_line(&mut line)).await;
219 let read_result = match read_result {
220 Ok(inner) => inner,
221 Err(_) => {
222 warn!("client read timed out");
223 break;
224 }
225 };
226 match read_result {
227 Ok(0) => {
228 info!("client closed stdin -- shutting down");
229 break;
230 }
231 Ok(_) => {
232 if line.len() > MAX_LINE_SIZE {
233 warn!(
234 size = line.len(),
235 max = MAX_LINE_SIZE,
236 "client message exceeds size limit, dropping"
237 );
238 let err_resp = McpMessage::error_response(
239 Value::Null,
240 -32600,
241 "Message too large",
242 );
243 if let Err(we) = write_to_client(&inbound_writer, &err_resp).await {
244 error!(%we, "failed to send size error to client");
245 break;
246 }
247 continue;
248 }
249 let trimmed = line.trim();
250 if trimmed.is_empty() {
251 continue;
252 }
253
254 match McpMessage::parse(trimmed) {
255 Ok(msg) => match inbound_state.process_inbound(msg).await {
256 InboundAction::Forward(msg) => {
257 if let Err(e) = server_writer.send(&msg).await {
258 error!(%e, "failed to forward to server");
259 break;
260 }
261 }
262 InboundAction::Deny(err_resp) => {
263 if let Err(we) =
264 write_to_client(&inbound_writer, &err_resp).await
265 {
266 error!(%we, "failed to send error to client");
267 break;
268 }
269 }
270 },
271 Err(e) => {
272 warn!(%e, raw = trimmed, "invalid JSON from client, dropping");
273 let err_resp = McpMessage::error_response(
274 Value::Null,
275 -32700,
276 "Parse error: invalid JSON",
277 );
278 if let Err(we) = write_to_client(&inbound_writer, &err_resp).await {
279 error!(%we, "failed to send parse error to client");
280 break;
281 }
282 }
283 }
284 }
285 Err(e) => {
286 error!(%e, "error reading from client");
287 break;
288 }
289 }
290 }
291 });
292
293 let outbound_state = Arc::clone(&state);
294 let outbound_writer = Arc::clone(&client_writer);
295
296 let mut server_to_client = tokio::spawn(async move {
298 let mut ctx = OutboundContext::new();
299 loop {
300 match server_reader.recv().await {
301 Ok(msg) => match outbound_state.process_outbound(msg, &mut ctx).await {
302 OutboundAction::Forward(msg) => {
303 if let Err(e) = write_to_client(&outbound_writer, &msg).await {
304 error!(%e, "failed to write to client");
305 break;
306 }
307 }
308 OutboundAction::Block(err_resp) => {
309 if let Err(e) = write_to_client(&outbound_writer, &err_resp).await {
310 error!(%e, "failed to send outbound error to client");
311 break;
312 }
313 }
314 },
315 Err(DomeError::Transport(ref e))
316 if e.kind() == std::io::ErrorKind::UnexpectedEof =>
317 {
318 info!("server closed stdout -- shutting down");
319 break;
320 }
321 Err(e) => {
322 error!(%e, "error reading from server");
323 break;
324 }
325 }
326 }
327 });
328
329 select_and_abort(&mut client_to_server, &mut server_to_client).await;
331
332 state.ledger.lock().await.flush();
334
335 shutdown_child(child).await;
339
340 info!("Thunder Dome proxy shut down");
341 Ok(())
342 }
343
344 #[cfg(feature = "http")]
347 pub async fn run_http(
348 self,
349 command: &str,
350 args: &[&str],
351 http_config: dome_transport::http::HttpTransportConfig,
352 ) -> Result<(), DomeError> {
353 let state = self.into_proxy_state();
354
355 let transport = StdioTransport::spawn(command, args).await?;
356 let (mut server_reader, mut server_writer, child) = transport.split();
357
358 let http = dome_transport::http::HttpTransport::start(http_config).await?;
359 let (mut http_reader, http_writer, http_handle) = http.split();
360 let http_writer = Arc::new(http_writer);
361
362 info!("Thunder Dome HTTP+SSE proxy active -- interceptor chain armed");
363
364 let inbound_state = Arc::clone(&state);
365 let inbound_http_writer = Arc::clone(&http_writer);
366
367 let mut client_to_server = tokio::spawn(async move {
369 loop {
370 match http_reader.recv().await {
371 Ok(msg) => match inbound_state.process_inbound(msg).await {
372 InboundAction::Forward(msg) => {
373 if let Err(e) = server_writer.send(&msg).await {
374 error!(%e, "failed to forward to server");
375 break;
376 }
377 }
378 InboundAction::Deny(err_resp) => {
379 if let Err(e) = inbound_http_writer.send(&err_resp).await {
380 warn!(%e, "failed to send error to HTTP client");
381 }
382 }
383 },
384 Err(e) => {
385 info!(%e, "HTTP client transport closed");
386 break;
387 }
388 }
389 }
390 });
391
392 let outbound_state = Arc::clone(&state);
393 let outbound_http_writer = Arc::clone(&http_writer);
394
395 let mut server_to_client = tokio::spawn(async move {
397 let mut ctx = OutboundContext::new();
398 loop {
399 match server_reader.recv().await {
400 Ok(msg) => match outbound_state.process_outbound(msg, &mut ctx).await {
401 OutboundAction::Forward(msg) => {
402 if let Err(e) = outbound_http_writer.send(&msg).await {
403 warn!(%e, "failed to send to HTTP client");
404 break;
405 }
406 }
407 OutboundAction::Block(err_resp) => {
408 if let Err(e) = outbound_http_writer.send(&err_resp).await {
409 warn!(%e, "failed to send outbound error to HTTP client");
410 break;
411 }
412 }
413 },
414 Err(DomeError::Transport(ref e))
415 if e.kind() == std::io::ErrorKind::UnexpectedEof =>
416 {
417 info!("server closed stdout -- shutting down");
418 break;
419 }
420 Err(e) => {
421 error!(%e, "error reading from server");
422 break;
423 }
424 }
425 }
426 });
427
428 select_and_abort(&mut client_to_server, &mut server_to_client).await;
430
431 state.ledger.lock().await.flush();
433
434 http_handle.shutdown().await;
436
437 shutdown_child(child).await;
439
440 info!("Thunder Dome HTTP+SSE proxy shut down");
441 Ok(())
442 }
443}
444
445struct ProxyState {
454 identity: Mutex<Option<dome_sentinel::Identity>>,
455 resolver: IdentityResolver,
456 config: GateConfig,
457 policy: Option<SharedPolicyEngine>,
458 rate_limiter: Arc<RateLimiter>,
459 budget: Arc<BudgetTracker>,
460 scanner: Arc<InjectionScanner>,
461 schema_store: Arc<Mutex<SchemaPinStore>>,
462 ledger: Arc<Mutex<Ledger>>,
463}
464
465impl ProxyState {
466 async fn process_inbound(&self, msg: McpMessage) -> InboundAction {
469 let start = std::time::Instant::now();
470 let method = msg.method.as_deref().unwrap_or("-").to_string();
471 let tool = msg.tool_name().map(String::from);
472 let request_id = Uuid::new_v4();
473
474 debug!(
475 method = method.as_str(),
476 id = ?msg.id,
477 tool = tool.as_deref().unwrap_or("-"),
478 "client -> server"
479 );
480
481 let mut msg = msg;
483 if method == "initialize" {
484 match self.resolver.resolve(&msg).await {
485 Ok(id) => {
486 info!(
487 principal = %id.principal,
488 method = %id.auth_method,
489 "identity resolved"
490 );
491 *self.identity.lock().await = Some(id);
492
493 msg = PskAuthenticator::strip_psk(&msg);
495 msg = ApiKeyAuthenticator::strip_api_key(&msg);
496 }
497 Err(e) => {
498 warn!(%e, "authentication failed");
499 let err_id = msg.id.clone().unwrap_or(Value::Null);
500 return InboundAction::Deny(McpMessage::error_response(
501 err_id,
502 -32600,
503 "Authentication failed",
504 ));
505 }
506 }
507 }
508
509 if method != "initialize" {
512 let identity_lock = self.identity.lock().await;
513 if identity_lock.is_none() {
514 drop(identity_lock);
515 warn!(method = %method, "request before initialize");
516 let err_id = msg.id.clone().unwrap_or(Value::Null);
517 return InboundAction::Deny(McpMessage::error_response(
518 err_id,
519 -32600,
520 "Session not initialized",
521 ));
522 }
523 drop(identity_lock);
524 }
525
526 let identity_lock = self.identity.lock().await;
527 let principal = identity_lock
528 .as_ref()
529 .map(|i| i.principal.clone())
530 .unwrap_or_else(|| "anonymous".to_string());
531 let labels = identity_lock
532 .as_ref()
533 .map(|i| i.labels.clone())
534 .unwrap_or_default();
535 drop(identity_lock);
536
537 let resource_name = msg.method_resource_name().unwrap_or("-").to_string();
539 let tool_name = resource_name.as_str();
540
541 let args = msg
542 .params
543 .as_ref()
544 .and_then(|p| p.get("arguments"))
545 .cloned()
546 .unwrap_or(Value::Null);
547
548 if self.config.enable_rate_limit {
550 let rl_tool = if tool_name != "-" {
551 Some(tool_name)
552 } else {
553 None
554 };
555 if let Err(e) = self.rate_limiter.check_rate_limit(&principal, rl_tool) {
556 warn!(%e, principal = %principal, method = %method, "rate limited");
557 record_audit(
558 &self.ledger,
559 AuditParams {
560 request_id,
561 identity: &principal,
562 direction: Direction::Inbound,
563 method: &method,
564 tool: tool.as_deref(),
565 decision: "deny:rate_limit",
566 rule_id: None,
567 latency_us: start.elapsed().as_micros() as u64,
568 },
569 )
570 .await;
571 let err_id = msg.id.clone().unwrap_or(Value::Null);
572 return InboundAction::Deny(McpMessage::error_response(
573 err_id,
574 -32000,
575 "Rate limit exceeded",
576 ));
577 }
578 }
579
580 if self.config.enable_budget
582 && let Err(e) = self.budget.try_spend(&principal, 1.0)
583 {
584 warn!(%e, principal = %principal, "budget exhausted");
585 record_audit(
586 &self.ledger,
587 AuditParams {
588 request_id,
589 identity: &principal,
590 direction: Direction::Inbound,
591 method: &method,
592 tool: tool.as_deref(),
593 decision: "deny:budget",
594 rule_id: None,
595 latency_us: start.elapsed().as_micros() as u64,
596 },
597 )
598 .await;
599 let err_id = msg.id.clone().unwrap_or(Value::Null);
600 return InboundAction::Deny(McpMessage::error_response(
601 err_id,
602 -32000,
603 "Budget exhausted",
604 ));
605 }
606
607 if self.config.enable_ward {
613 let scan_target = if method == "tools/call" {
614 args.clone()
615 } else if let Some(ref params) = msg.params {
616 params.clone()
617 } else {
618 Value::Null
619 };
620
621 if !scan_target.is_null() {
622 let scan_result = dome_ward::scan_json_value(&self.scanner, &scan_target);
623 if !scan_result.pattern_matches.is_empty() {
624 let pattern_names: Vec<&str> = scan_result
625 .pattern_matches
626 .iter()
627 .map(|m| m.pattern_name.as_str())
628 .collect();
629 warn!(
630 patterns = ?pattern_names,
631 method = %method,
632 tool = tool_name,
633 principal = %principal,
634 "injection detected"
635 );
636 record_audit(
637 &self.ledger,
638 AuditParams {
639 request_id,
640 identity: &principal,
641 direction: Direction::Inbound,
642 method: &method,
643 tool: tool.as_deref(),
644 decision: &format!("deny:injection:{}", pattern_names.join(",")),
645 rule_id: None,
646 latency_us: start.elapsed().as_micros() as u64,
647 },
648 )
649 .await;
650 let err_id = msg.id.clone().unwrap_or(Value::Null);
651 return InboundAction::Deny(McpMessage::error_response(
652 err_id,
653 -32003,
654 "Request blocked: injection pattern detected",
655 ));
656 }
657 }
658 }
659
660 if self.config.enforce_policy
662 && let Some(ref shared_engine) = self.policy
663 {
664 let engine = shared_engine.load();
667
668 let policy_resource = if method == "tools/call" {
669 tool_name
670 } else {
671 method.as_str()
672 };
673 let policy_id = PolicyIdentity::new(principal.clone(), labels.iter().cloned());
674 let decision = engine.evaluate(&policy_id, policy_resource, &args);
675
676 if !decision.is_allowed() {
677 warn!(
678 rule_id = %decision.rule_id,
679 method = %method,
680 resource = policy_resource,
681 principal = %principal,
682 "policy denied"
683 );
684 record_audit(
685 &self.ledger,
686 AuditParams {
687 request_id,
688 identity: &principal,
689 direction: Direction::Inbound,
690 method: &method,
691 tool: tool.as_deref(),
692 decision: &format!("deny:policy:{}", decision.rule_id),
693 rule_id: Some(&decision.rule_id),
694 latency_us: start.elapsed().as_micros() as u64,
695 },
696 )
697 .await;
698 let err_id = msg.id.clone().unwrap_or(Value::Null);
699 return InboundAction::Deny(McpMessage::error_response(
700 err_id,
701 -32003,
702 format!("Denied by policy: {}", decision.rule_id),
703 ));
704 }
705 }
706
707 record_audit(
709 &self.ledger,
710 AuditParams {
711 request_id,
712 identity: &principal,
713 direction: Direction::Inbound,
714 method: &method,
715 tool: tool.as_deref(),
716 decision: "allow",
717 rule_id: None,
718 latency_us: start.elapsed().as_micros() as u64,
719 },
720 )
721 .await;
722
723 InboundAction::Forward(msg)
724 }
725
726 async fn process_outbound(&self, msg: McpMessage, ctx: &mut OutboundContext) -> OutboundAction {
729 let start = std::time::Instant::now();
730 let method = msg.method.as_deref().unwrap_or("-").to_string();
731 let outbound_request_id = Uuid::new_v4();
732
733 debug!(
734 method = method.as_str(),
735 id = ?msg.id,
736 "server -> client"
737 );
738
739 let mut forward_msg = msg;
740
741 if self.config.enable_schema_pin
743 && let Some(result) = &forward_msg.result
744 && result.get("tools").is_some()
745 {
746 let mut store = self.schema_store.lock().await;
747 if ctx.first_tools_list {
748 store.pin_tools(result);
749 info!(pinned = store.len(), "schema pins established");
750 ctx.last_good_tools_result = Some(result.clone());
751 ctx.first_tools_list = false;
752 } else {
753 let drifts = store.verify_tools(result);
754 if !drifts.is_empty() {
755 let has_critical_or_high = drifts.iter().any(|drift| {
756 warn!(
757 tool = %drift.tool_name,
758 drift_type = ?drift.drift_type,
759 severity = ?drift.severity,
760 "schema drift detected"
761 );
762 matches!(
763 drift.severity,
764 DriftSeverity::Critical | DriftSeverity::High
765 )
766 });
767
768 if has_critical_or_high {
769 warn!("critical/high schema drift detected -- blocking drifted tools/list");
770 record_audit(
771 &self.ledger,
772 AuditParams {
773 request_id: outbound_request_id,
774 identity: "server",
775 direction: Direction::Outbound,
776 method: "tools/list",
777 tool: None,
778 decision: "deny:schema_drift",
779 rule_id: None,
780 latency_us: start.elapsed().as_micros() as u64,
781 },
782 )
783 .await;
784
785 if let Some(ref good_result) = ctx.last_good_tools_result {
786 forward_msg.result = Some(good_result.clone());
787 } else {
788 let err_id = forward_msg.id.clone().unwrap_or(Value::Null);
789 return OutboundAction::Block(McpMessage::error_response(
790 err_id,
791 -32003,
792 "Schema drift detected: tool definitions have been tampered with",
793 ));
794 }
795 }
796 }
797 }
798 }
799
800 if self.config.enable_ward
804 && let Some(ref result) = forward_msg.result
805 {
806 let scan_value = if let Some(content) = result.get("content") {
807 content.clone()
808 } else {
809 result.clone()
810 };
811
812 let scan_result = dome_ward::scan_json_value(&self.scanner, &scan_value);
813 if !scan_result.pattern_matches.is_empty() {
814 let pattern_names: Vec<&str> = scan_result
815 .pattern_matches
816 .iter()
817 .map(|m| m.pattern_name.as_str())
818 .collect();
819 let decision = if self.config.block_outbound_injection {
820 "deny:outbound_injection"
821 } else {
822 "warn:outbound_injection"
823 };
824 warn!(
825 patterns = ?pattern_names,
826 direction = "outbound",
827 blocked = self.config.block_outbound_injection,
828 "injection detected in server response"
829 );
830 record_audit(
831 &self.ledger,
832 AuditParams {
833 request_id: outbound_request_id,
834 identity: "server",
835 direction: Direction::Outbound,
836 method: &method,
837 tool: None,
838 decision: &format!("{}:{}", decision, pattern_names.join(",")),
839 rule_id: None,
840 latency_us: start.elapsed().as_micros() as u64,
841 },
842 )
843 .await;
844
845 if self.config.block_outbound_injection {
846 let err_id = forward_msg.id.clone().unwrap_or(Value::Null);
847 return OutboundAction::Block(McpMessage::error_response(
848 err_id,
849 -32005,
850 "Response blocked: injection pattern detected in server output",
851 ));
852 }
853 }
854 }
855
856 record_audit(
858 &self.ledger,
859 AuditParams {
860 request_id: outbound_request_id,
861 identity: "server",
862 direction: Direction::Outbound,
863 method: &method,
864 tool: None,
865 decision: "forward",
866 rule_id: None,
867 latency_us: start.elapsed().as_micros() as u64,
868 },
869 )
870 .await;
871
872 OutboundAction::Forward(forward_msg)
873 }
874}
875
876async fn select_and_abort(
886 client_to_server: &mut tokio::task::JoinHandle<()>,
887 server_to_client: &mut tokio::task::JoinHandle<()>,
888) {
889 tokio::select! {
890 r = &mut *client_to_server => {
891 if let Err(e) = r {
892 error!(%e, "client->server task panicked");
893 }
894 server_to_client.abort();
895 }
896 r = &mut *server_to_client => {
897 if let Err(e) = r {
898 error!(%e, "server->client task panicked");
899 }
900 client_to_server.abort();
901 }
902 }
903}
904
905async fn shutdown_child(mut child: tokio::process::Child) {
908 match tokio::time::timeout(Duration::from_secs(5), child.wait()).await {
909 Ok(Ok(status)) => info!(%status, "upstream server exited"),
910 Ok(Err(e)) => warn!(%e, "error waiting for upstream server"),
911 Err(_) => {
912 warn!("upstream server did not exit within 5s, forcing termination");
913 let _ = child.kill().await;
914 }
915 }
916}
917
918async fn write_to_client(
920 writer: &Arc<Mutex<tokio::io::Stdout>>,
921 msg: &McpMessage,
922) -> Result<(), std::io::Error> {
923 match msg.to_json() {
924 Ok(json) => {
925 let mut out = json.into_bytes();
926 out.push(b'\n');
927 let mut w = writer.lock().await;
928 w.write_all(&out).await?;
929 w.flush().await?;
930 Ok(())
931 }
932 Err(e) => {
933 error!(%e, "failed to serialize message for client");
934 Err(std::io::Error::new(std::io::ErrorKind::InvalidData, e))
935 }
936 }
937}
938
939struct AuditParams<'a> {
944 request_id: Uuid,
945 identity: &'a str,
946 direction: Direction,
947 method: &'a str,
948 tool: Option<&'a str>,
949 decision: &'a str,
950 rule_id: Option<&'a str>,
951 latency_us: u64,
952}
953
954async fn record_audit(ledger: &Arc<Mutex<Ledger>>, params: AuditParams<'_>) {
956 let entry = AuditEntry {
957 seq: 0, timestamp: Utc::now(),
959 request_id: params.request_id,
960 identity: params.identity.to_string(),
961 direction: params.direction,
962 method: params.method.to_string(),
963 tool: params.tool.map(String::from),
964 decision: params.decision.to_string(),
965 rule_id: params.rule_id.map(String::from),
966 latency_us: params.latency_us,
967 prev_hash: String::new(), annotations: std::collections::HashMap::new(),
969 };
970
971 if let Err(e) = ledger.lock().await.record(entry) {
972 error!(%e, "failed to record audit entry");
973 }
974}
975
976#[cfg(test)]
977mod tests {
978 use super::*;
979 use dome_ledger::MemorySink;
980
981 fn test_ledger() -> Ledger {
987 Ledger::new(vec![Box::new(MemorySink::new())])
988 }
989
990 fn empty_ledger() -> Ledger {
993 Ledger::new(vec![])
994 }
995
996 #[test]
1001 fn gate_config_defaults_all_security_disabled() {
1002 let config = GateConfig::default();
1003
1004 assert!(
1005 !config.enforce_policy,
1006 "enforce_policy should default to false"
1007 );
1008 assert!(!config.enable_ward, "enable_ward should default to false");
1009 assert!(
1010 !config.enable_schema_pin,
1011 "enable_schema_pin should default to false"
1012 );
1013 assert!(
1014 !config.enable_rate_limit,
1015 "enable_rate_limit should default to false"
1016 );
1017 assert!(
1018 !config.enable_budget,
1019 "enable_budget should default to false"
1020 );
1021 assert!(
1022 config.allow_anonymous,
1023 "allow_anonymous should default to true"
1024 );
1025 assert!(
1026 !config.block_outbound_injection,
1027 "block_outbound_injection should default to false"
1028 );
1029 }
1030
1031 #[test]
1032 fn gate_config_with_all_enabled() {
1033 let config = GateConfig {
1034 enforce_policy: true,
1035 enable_ward: true,
1036 enable_schema_pin: true,
1037 enable_rate_limit: true,
1038 enable_budget: true,
1039 allow_anonymous: false,
1040 block_outbound_injection: true,
1041 };
1042
1043 assert!(config.enforce_policy);
1044 assert!(config.enable_ward);
1045 assert!(config.enable_schema_pin);
1046 assert!(config.enable_rate_limit);
1047 assert!(config.enable_budget);
1048 assert!(!config.allow_anonymous);
1049 assert!(config.block_outbound_injection);
1050 }
1051
1052 #[test]
1053 fn gate_config_is_cloneable() {
1054 let original = GateConfig {
1055 enforce_policy: true,
1056 enable_ward: true,
1057 enable_schema_pin: false,
1058 enable_rate_limit: true,
1059 enable_budget: false,
1060 allow_anonymous: false,
1061 block_outbound_injection: true,
1062 };
1063 let cloned = original.clone();
1064
1065 assert_eq!(cloned.enforce_policy, original.enforce_policy);
1066 assert_eq!(cloned.enable_ward, original.enable_ward);
1067 assert_eq!(cloned.enable_schema_pin, original.enable_schema_pin);
1068 assert_eq!(cloned.enable_rate_limit, original.enable_rate_limit);
1069 assert_eq!(cloned.enable_budget, original.enable_budget);
1070 assert_eq!(cloned.allow_anonymous, original.allow_anonymous);
1071 assert_eq!(
1072 cloned.block_outbound_injection,
1073 original.block_outbound_injection
1074 );
1075 }
1076
1077 #[test]
1078 fn gate_config_is_debug_printable() {
1079 let config = GateConfig::default();
1080 let debug_output = format!("{:?}", config);
1081
1082 assert!(debug_output.contains("GateConfig"));
1083 assert!(debug_output.contains("enforce_policy"));
1084 assert!(debug_output.contains("enable_ward"));
1085 }
1086
1087 #[test]
1092 fn transparent_gate_has_correct_config_defaults() {
1093 let gate = Gate::transparent(empty_ledger());
1094
1095 assert!(
1096 !gate.config.enforce_policy,
1097 "transparent gate should not enforce policy"
1098 );
1099 assert!(
1100 !gate.config.enable_ward,
1101 "transparent gate should not enable ward"
1102 );
1103 assert!(
1104 !gate.config.enable_schema_pin,
1105 "transparent gate should not enable schema pinning"
1106 );
1107 assert!(
1108 !gate.config.enable_rate_limit,
1109 "transparent gate should not enable rate limiting"
1110 );
1111 assert!(
1112 !gate.config.enable_budget,
1113 "transparent gate should not enable budget tracking"
1114 );
1115 assert!(
1116 gate.config.allow_anonymous,
1117 "transparent gate should allow anonymous access"
1118 );
1119 assert!(
1120 !gate.config.block_outbound_injection,
1121 "transparent gate should not block outbound injection"
1122 );
1123 }
1124
1125 #[test]
1126 fn transparent_gate_has_no_policy_engine() {
1127 let gate = Gate::transparent(empty_ledger());
1128 assert!(
1129 gate.policy_engine.is_none(),
1130 "transparent gate should have no policy engine"
1131 );
1132 }
1133
1134 #[test]
1139 fn gate_new_with_custom_config_preserves_flags() {
1140 let config = GateConfig {
1141 enforce_policy: true,
1142 enable_ward: true,
1143 enable_schema_pin: true,
1144 enable_rate_limit: true,
1145 enable_budget: true,
1146 allow_anonymous: false,
1147 block_outbound_injection: true,
1148 };
1149
1150 let gate = Gate::new(
1151 config,
1152 vec![Box::new(AnonymousAuthenticator)],
1153 None,
1154 RateLimiterConfig::default(),
1155 BudgetTrackerConfig::default(),
1156 empty_ledger(),
1157 );
1158
1159 assert!(gate.config.enforce_policy);
1160 assert!(gate.config.enable_ward);
1161 assert!(gate.config.enable_schema_pin);
1162 assert!(gate.config.enable_rate_limit);
1163 assert!(gate.config.enable_budget);
1164 assert!(!gate.config.allow_anonymous);
1165 assert!(gate.config.block_outbound_injection);
1166 }
1167
1168 #[test]
1169 fn gate_new_without_policy_engine_stores_none() {
1170 let gate = Gate::new(
1171 GateConfig::default(),
1172 vec![Box::new(AnonymousAuthenticator)],
1173 None,
1174 RateLimiterConfig::default(),
1175 BudgetTrackerConfig::default(),
1176 empty_ledger(),
1177 );
1178
1179 assert!(gate.policy_engine.is_none());
1180 }
1181
1182 #[test]
1187 fn gate_is_debug_printable() {
1188 let gate = Gate::transparent(empty_ledger());
1189 let debug_output = format!("{:?}", gate);
1190
1191 assert!(debug_output.contains("Gate"));
1192 assert!(debug_output.contains("config"));
1193 assert!(debug_output.contains("has_policy_engine"));
1194 }
1195
1196 #[tokio::test]
1201 async fn record_audit_creates_entry_with_correct_fields() {
1202 let ledger = Arc::new(Mutex::new(test_ledger()));
1203 let request_id = Uuid::new_v4();
1204
1205 record_audit(
1206 &ledger,
1207 AuditParams {
1208 request_id,
1209 identity: "test-user",
1210 direction: Direction::Inbound,
1211 method: "tools/call",
1212 tool: Some("read_file"),
1213 decision: "allow",
1214 rule_id: Some("rule-1"),
1215 latency_us: 42,
1216 },
1217 )
1218 .await;
1219
1220 let ledger_guard = ledger.lock().await;
1221 assert_eq!(
1222 ledger_guard.entry_count(),
1223 1,
1224 "should have recorded exactly one entry"
1225 );
1226 }
1227
1228 #[tokio::test]
1229 async fn record_audit_with_no_tool_and_no_rule() {
1230 let ledger = Arc::new(Mutex::new(test_ledger()));
1231 let request_id = Uuid::new_v4();
1232
1233 record_audit(
1234 &ledger,
1235 AuditParams {
1236 request_id,
1237 identity: "anonymous",
1238 direction: Direction::Outbound,
1239 method: "initialize",
1240 tool: None,
1241 decision: "forward",
1242 rule_id: None,
1243 latency_us: 0,
1244 },
1245 )
1246 .await;
1247
1248 let ledger_guard = ledger.lock().await;
1249 assert_eq!(ledger_guard.entry_count(), 1);
1250 }
1251
1252 #[tokio::test]
1253 async fn record_audit_multiple_entries_increment_count() {
1254 let ledger = Arc::new(Mutex::new(test_ledger()));
1255
1256 for i in 0..5 {
1257 let identity = format!("user-{i}");
1258 record_audit(
1259 &ledger,
1260 AuditParams {
1261 request_id: Uuid::new_v4(),
1262 identity: &identity,
1263 direction: Direction::Inbound,
1264 method: "tools/call",
1265 tool: Some("test_tool"),
1266 decision: "allow",
1267 rule_id: None,
1268 latency_us: i * 10,
1269 },
1270 )
1271 .await;
1272 }
1273
1274 let ledger_guard = ledger.lock().await;
1275 assert_eq!(ledger_guard.entry_count(), 5);
1276 }
1277
1278 #[tokio::test]
1279 async fn record_audit_deny_decisions_are_recorded() {
1280 let ledger = Arc::new(Mutex::new(test_ledger()));
1281
1282 record_audit(
1283 &ledger,
1284 AuditParams {
1285 request_id: Uuid::new_v4(),
1286 identity: "malicious-user",
1287 direction: Direction::Inbound,
1288 method: "tools/call",
1289 tool: Some("exec_command"),
1290 decision: "deny:policy:no-exec",
1291 rule_id: Some("no-exec"),
1292 latency_us: 150,
1293 },
1294 )
1295 .await;
1296
1297 record_audit(
1298 &ledger,
1299 AuditParams {
1300 request_id: Uuid::new_v4(),
1301 identity: "spammer",
1302 direction: Direction::Inbound,
1303 method: "tools/call",
1304 tool: Some("spam_tool"),
1305 decision: "deny:rate_limit",
1306 rule_id: None,
1307 latency_us: 5,
1308 },
1309 )
1310 .await;
1311
1312 record_audit(
1313 &ledger,
1314 AuditParams {
1315 request_id: Uuid::new_v4(),
1316 identity: "attacker",
1317 direction: Direction::Inbound,
1318 method: "tools/call",
1319 tool: Some("read_file"),
1320 decision: "deny:injection:prompt_injection",
1321 rule_id: None,
1322 latency_us: 200,
1323 },
1324 )
1325 .await;
1326
1327 let ledger_guard = ledger.lock().await;
1328 assert_eq!(ledger_guard.entry_count(), 3);
1329 }
1330
1331 #[tokio::test]
1332 async fn record_audit_outbound_schema_drift() {
1333 let ledger = Arc::new(Mutex::new(test_ledger()));
1334
1335 record_audit(
1336 &ledger,
1337 AuditParams {
1338 request_id: Uuid::new_v4(),
1339 identity: "server",
1340 direction: Direction::Outbound,
1341 method: "tools/list",
1342 tool: None,
1343 decision: "deny:schema_drift",
1344 rule_id: None,
1345 latency_us: 75,
1346 },
1347 )
1348 .await;
1349
1350 let ledger_guard = ledger.lock().await;
1351 assert_eq!(ledger_guard.entry_count(), 1);
1352 }
1353
1354 #[test]
1359 fn max_line_size_is_ten_megabytes() {
1360 assert_eq!(MAX_LINE_SIZE, 10 * 1024 * 1024);
1361 }
1362
1363 #[test]
1364 fn client_read_timeout_is_five_minutes() {
1365 assert_eq!(CLIENT_READ_TIMEOUT, Duration::from_secs(300));
1366 }
1367
1368 #[test]
1373 fn gate_config_ward_only_mode() {
1374 let config = GateConfig {
1375 enable_ward: true,
1376 ..GateConfig::default()
1377 };
1378
1379 assert!(config.enable_ward);
1380 assert!(!config.enforce_policy, "policy should remain off");
1381 assert!(!config.enable_rate_limit, "rate limit should remain off");
1382 assert!(!config.enable_budget, "budget should remain off");
1383 assert!(!config.enable_schema_pin, "schema pin should remain off");
1384 assert!(config.allow_anonymous, "anonymous should remain on");
1385 }
1386
1387 #[test]
1388 fn gate_config_full_security_mode() {
1389 let config = GateConfig {
1390 enforce_policy: true,
1391 enable_ward: true,
1392 enable_schema_pin: true,
1393 enable_rate_limit: true,
1394 enable_budget: true,
1395 allow_anonymous: false,
1396 block_outbound_injection: true,
1397 };
1398
1399 assert!(config.enforce_policy);
1401 assert!(config.enable_ward);
1402 assert!(config.enable_schema_pin);
1403 assert!(config.enable_rate_limit);
1404 assert!(config.enable_budget);
1405 assert!(!config.allow_anonymous);
1406 assert!(config.block_outbound_injection);
1407 }
1408
1409 #[test]
1414 fn gate_new_with_multiple_authenticators_does_not_panic() {
1415 let _gate = Gate::new(
1416 GateConfig::default(),
1417 vec![
1418 Box::new(AnonymousAuthenticator),
1419 Box::new(AnonymousAuthenticator),
1420 ],
1421 None,
1422 RateLimiterConfig::default(),
1423 BudgetTrackerConfig::default(),
1424 empty_ledger(),
1425 );
1426 }
1427
1428 #[test]
1429 fn gate_new_with_empty_authenticators_does_not_panic() {
1430 let _gate = Gate::new(
1431 GateConfig::default(),
1432 vec![],
1433 None,
1434 RateLimiterConfig::default(),
1435 BudgetTrackerConfig::default(),
1436 empty_ledger(),
1437 );
1438 }
1439
1440 #[tokio::test]
1445 async fn record_audit_entry_fields_populated_correctly() {
1446 let ledger = Arc::new(Mutex::new(Ledger::new(vec![Box::new(MemorySink::new())])));
1448 let request_id = Uuid::new_v4();
1449
1450 record_audit(
1451 &ledger,
1452 AuditParams {
1453 request_id,
1454 identity: "psk:dev-key-1",
1455 direction: Direction::Inbound,
1456 method: "tools/call",
1457 tool: Some("read_file"),
1458 decision: "deny:policy:no-read",
1459 rule_id: Some("no-read"),
1460 latency_us: 999,
1461 },
1462 )
1463 .await;
1464
1465 let guard = ledger.lock().await;
1467 assert_eq!(guard.entry_count(), 1);
1468 }
1469
1470 #[tokio::test]
1471 async fn record_audit_handles_empty_identity() {
1472 let ledger = Arc::new(Mutex::new(test_ledger()));
1473
1474 record_audit(
1475 &ledger,
1476 AuditParams {
1477 request_id: Uuid::new_v4(),
1478 identity: "",
1479 direction: Direction::Inbound,
1480 method: "initialize",
1481 tool: None,
1482 decision: "allow",
1483 rule_id: None,
1484 latency_us: 0,
1485 },
1486 )
1487 .await;
1488
1489 let guard = ledger.lock().await;
1490 assert_eq!(
1491 guard.entry_count(),
1492 1,
1493 "empty identity should not prevent recording"
1494 );
1495 }
1496
1497 #[tokio::test]
1498 async fn record_audit_handles_long_decision_string() {
1499 let ledger = Arc::new(Mutex::new(test_ledger()));
1500 let long_decision = format!("deny:injection:{}", "pattern_name,".repeat(50));
1501
1502 record_audit(
1503 &ledger,
1504 AuditParams {
1505 request_id: Uuid::new_v4(),
1506 identity: "test-user",
1507 direction: Direction::Inbound,
1508 method: "tools/call",
1509 tool: Some("risky_tool"),
1510 decision: &long_decision,
1511 rule_id: None,
1512 latency_us: 500,
1513 },
1514 )
1515 .await;
1516
1517 let guard = ledger.lock().await;
1518 assert_eq!(
1519 guard.entry_count(),
1520 1,
1521 "long decision strings should be accepted"
1522 );
1523 }
1524}