Skip to main content

ash_rpc/audit_logging/
processor.rs

1//! `MessageProcessor` wrapper that automatically logs security audit events.
2
3use super::{AuditBackend, AuditEvent, AuditEventType, AuditIntegrity, AuditResult, AuditSeverity};
4use crate::{Message, MessageProcessor, ProcessorCapabilities, Response, auth::ConnectionContext};
5use async_trait::async_trait;
6use std::sync::Arc;
7
8/// Wraps `MessageProcessor` to automatically log requests, responses, and security events
9pub struct AuditProcessor {
10    inner: Arc<dyn MessageProcessor + Send + Sync>,
11    backend: Arc<dyn AuditBackend>,
12    integrity: Arc<dyn AuditIntegrity>,
13    connection_context: Option<Arc<ConnectionContext>>,
14}
15
16impl AuditProcessor {
17    /// Create a new audit processor builder
18    pub fn builder(processor: Arc<dyn MessageProcessor + Send + Sync>) -> AuditProcessorBuilder {
19        AuditProcessorBuilder {
20            processor,
21            backend: Arc::new(super::StdoutAuditBackend),
22            integrity: Arc::new(super::NoIntegrity),
23            connection_context: None,
24        }
25    }
26
27    /// Log an audit event with integrity metadata
28    fn log_event(&self, mut event: AuditEvent) {
29        // Add integrity metadata
30        self.integrity.add_integrity(&mut event);
31
32        // Write to backend
33        self.backend.log_audit(&event);
34    }
35
36    /// Create audit event from request message
37    fn create_request_event(&self, message: &Message) -> Option<AuditEvent> {
38        match message {
39            Message::Request(req) => {
40                let mut event = AuditEvent::builder()
41                    .event_type(AuditEventType::MethodInvocation)
42                    .method(&req.method)
43                    .result(AuditResult::Success) // Will be updated based on response
44                    .severity(AuditSeverity::Info);
45
46                // Add correlation ID if present
47                if let Some(ref id) = req.id {
48                    event = event.correlation_id(id.to_string());
49                }
50
51                // Add connection context if available
52                if let Some(ref ctx) = self.connection_context {
53                    if let Some(addr) = ctx.remote_addr {
54                        event = event.remote_addr(addr);
55                    }
56
57                    // Try to extract principal from context
58                    if let Some(user_id) = ctx.get::<String>("user_id") {
59                        event = event.principal(user_id);
60                    } else if let Some(api_key) = ctx.get::<String>("api_key") {
61                        event = event.principal(format!("api_key:{api_key}"));
62                    }
63                }
64
65                // Sanitize and add parameters (avoid logging sensitive data)
66                if let Some(ref params) = req.params {
67                    // For security, we only log the structure, not the full content
68                    event = event.metadata("params_type", params.clone());
69                }
70
71                Some(event.build())
72            }
73            Message::Notification(notif) => {
74                let mut event = AuditEvent::builder()
75                    .event_type(AuditEventType::MethodInvocation)
76                    .method(&notif.method)
77                    .result(AuditResult::Success)
78                    .severity(AuditSeverity::Info)
79                    .metadata("notification", true);
80
81                // Add connection context if available
82                if let Some(ref ctx) = self.connection_context
83                    && let Some(addr) = ctx.remote_addr
84                {
85                    event = event.remote_addr(addr);
86                }
87
88                Some(event.build())
89            }
90            Message::Response(_) => {
91                // We don't audit raw response messages
92                None
93            }
94        }
95    }
96
97    /// Create audit event from response
98    fn create_response_event(&self, message: &Message, response: Option<&Response>) -> AuditEvent {
99        let method = match message {
100            Message::Request(req) => Some(req.method.as_str()),
101            Message::Notification(notif) => Some(notif.method.as_str()),
102            Message::Response(_) => None,
103        };
104
105        let correlation_id = match message {
106            Message::Request(req) => req.id.as_ref().map(std::string::ToString::to_string),
107            _ => None,
108        };
109
110        let mut event_builder = AuditEvent::builder()
111            .event_type(AuditEventType::MethodInvocation)
112            .correlation_id(correlation_id.unwrap_or_default());
113
114        if let Some(m) = method {
115            event_builder = event_builder.method(m);
116        }
117
118        // Add connection context
119        if let Some(ref ctx) = self.connection_context {
120            if let Some(addr) = ctx.remote_addr {
121                event_builder = event_builder.remote_addr(addr);
122            }
123
124            if let Some(user_id) = ctx.get::<String>("user_id") {
125                event_builder = event_builder.principal(user_id);
126            }
127        }
128
129        // Determine result based on response
130        if let Some(resp) = response {
131            if resp.is_success() {
132                event_builder = event_builder.result(AuditResult::Success);
133            } else {
134                event_builder = event_builder
135                    .result(AuditResult::Failure)
136                    .severity(AuditSeverity::Warning);
137
138                if let Some(ref error) = resp.error {
139                    event_builder = event_builder
140                        .error(&error.message)
141                        .metadata("error_code", error.code);
142                }
143            }
144        } else {
145            // No response (notification or error)
146            event_builder = event_builder.result(AuditResult::Success);
147        }
148
149        event_builder.build()
150    }
151}
152
153#[async_trait]
154impl MessageProcessor for AuditProcessor {
155    async fn process_message(&self, message: Message) -> Option<Response> {
156        // Log incoming request
157        if let Some(request_event) = self.create_request_event(&message) {
158            self.log_event(request_event);
159        }
160
161        // Process the message
162        let response = self.inner.process_message(message.clone()).await;
163
164        // Log response
165        let response_event = self.create_response_event(&message, response.as_ref());
166        self.log_event(response_event);
167
168        response
169    }
170
171    fn get_capabilities(&self) -> ProcessorCapabilities {
172        self.inner.get_capabilities()
173    }
174}
175
176/// Builder for creating audit processors
177pub struct AuditProcessorBuilder {
178    processor: Arc<dyn MessageProcessor + Send + Sync>,
179    backend: Arc<dyn AuditBackend>,
180    integrity: Arc<dyn AuditIntegrity>,
181    connection_context: Option<Arc<ConnectionContext>>,
182}
183
184impl AuditProcessorBuilder {
185    /// Set the audit backend
186    #[must_use]
187    pub fn with_backend(mut self, backend: Arc<dyn AuditBackend>) -> Self {
188        self.backend = backend;
189        self
190    }
191
192    /// Set the integrity mechanism
193    #[must_use]
194    pub fn with_integrity(mut self, integrity: Arc<dyn AuditIntegrity>) -> Self {
195        self.integrity = integrity;
196        self
197    }
198
199    /// Set the connection context for extracting principal and metadata
200    #[must_use]
201    pub fn with_connection_context(mut self, context: Arc<ConnectionContext>) -> Self {
202        self.connection_context = Some(context);
203        self
204    }
205
206    /// Build the audit processor
207    #[must_use]
208    pub fn build(self) -> AuditProcessor {
209        AuditProcessor {
210            inner: self.processor,
211            backend: self.backend,
212            integrity: self.integrity,
213            connection_context: self.connection_context,
214        }
215    }
216}
217
218/// Log authentication/authorization events
219pub fn log_auth_event(
220    backend: &dyn AuditBackend,
221    integrity: &dyn AuditIntegrity,
222    method: &str,
223    ctx: &ConnectionContext,
224    allowed: bool,
225) {
226    let mut event = AuditEvent::builder()
227        .event_type(AuditEventType::AuthorizationCheck)
228        .method(method)
229        .result(if allowed {
230            AuditResult::Success
231        } else {
232            AuditResult::Denied
233        })
234        .severity(if allowed {
235            AuditSeverity::Info
236        } else {
237            AuditSeverity::Critical
238        });
239
240    if let Some(addr) = ctx.remote_addr {
241        event = event.remote_addr(addr);
242    }
243
244    if let Some(user_id) = ctx.get::<String>("user_id") {
245        event = event.principal(user_id);
246    }
247
248    let mut evt = event.build();
249    integrity.add_integrity(&mut evt);
250    backend.log_audit(&evt);
251}
252
253/// Log security policy violations (rate limits, banned IPs, etc.)
254pub fn log_security_violation(
255    backend: &dyn AuditBackend,
256    integrity: &dyn AuditIntegrity,
257    violation_type: &str,
258    remote_addr: Option<std::net::SocketAddr>,
259    principal: Option<&str>,
260) {
261    let mut event = AuditEvent::builder()
262        .event_type(AuditEventType::SecurityViolation)
263        .result(AuditResult::Violation)
264        .severity(AuditSeverity::Critical)
265        .metadata("violation_type", violation_type);
266
267    if let Some(addr) = remote_addr {
268        event = event.remote_addr(addr);
269    }
270
271    if let Some(p) = principal {
272        event = event.principal(p);
273    }
274
275    let mut evt = event.build();
276    integrity.add_integrity(&mut evt);
277    backend.log_audit(&evt);
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use crate::RequestBuilder;
284
285    #[tokio::test]
286    async fn test_audit_processor() {
287        use crate::MethodRegistry;
288
289        let registry = MethodRegistry::new(vec![]);
290        let processor: Arc<dyn MessageProcessor + Send + Sync> = Arc::new(registry);
291
292        let audit = AuditProcessor::builder(processor)
293            .with_backend(Arc::new(super::super::NoopAuditBackend))
294            .with_integrity(Arc::new(super::super::NoIntegrity))
295            .build();
296
297        let request = RequestBuilder::new("test_method")
298            .id(serde_json::json!(1))
299            .build();
300
301        let _ = audit.process_message(Message::Request(request)).await;
302    }
303}