cortexai_audit/backends/
async_logger.rs1use crate::error::AuditError;
7use crate::traits::{AuditConfig, AuditLogger};
8use crate::types::AuditEvent;
9use async_trait::async_trait;
10use std::sync::Arc;
11use tokio::sync::mpsc;
12use tokio::time::{interval, Duration};
13
14pub struct AsyncLogger {
19 sender: mpsc::Sender<LogCommand>,
20 name: String,
21}
22
23enum LogCommand {
24 Log(Box<AuditEvent>),
25 Flush,
26 Shutdown,
27}
28
29impl AsyncLogger {
30 pub fn new<L: AuditLogger + 'static>(inner: L, config: &AuditConfig) -> Self {
32 let (sender, receiver) = mpsc::channel(config.buffer_size);
33 let name = format!("async({})", inner.name());
34 let flush_interval = config.flush_interval_secs;
35
36 tokio::spawn(Self::background_task(
37 Arc::new(inner),
38 receiver,
39 flush_interval,
40 ));
41
42 Self { sender, name }
43 }
44
45 pub fn wrap<L: AuditLogger + 'static>(inner: L) -> Self {
47 Self::new(inner, &AuditConfig::default())
48 }
49
50 async fn background_task(
51 inner: Arc<dyn AuditLogger>,
52 mut receiver: mpsc::Receiver<LogCommand>,
53 flush_interval_secs: u64,
54 ) {
55 let mut flush_timer = interval(Duration::from_secs(flush_interval_secs));
56
57 loop {
58 tokio::select! {
59 cmd = receiver.recv() => {
60 match cmd {
61 Some(LogCommand::Log(event)) => {
62 if let Err(e) = inner.log(*event).await {
63 tracing::error!("Async audit log error: {}", e);
64 }
65 }
66 Some(LogCommand::Flush) => {
67 if let Err(e) = inner.flush().await {
68 tracing::error!("Async audit flush error: {}", e);
69 }
70 }
71 Some(LogCommand::Shutdown) | None => {
72 let _ = inner.flush().await;
74 break;
75 }
76 }
77 }
78 _ = flush_timer.tick() => {
79 if let Err(e) = inner.flush().await {
80 tracing::error!("Async audit periodic flush error: {}", e);
81 }
82 }
83 }
84 }
85
86 tracing::debug!("Async audit logger background task stopped");
87 }
88
89 pub async fn shutdown(&self) -> Result<(), AuditError> {
91 self.sender
92 .send(LogCommand::Shutdown)
93 .await
94 .map_err(|_| AuditError::ChannelSend)
95 }
96}
97
98#[async_trait]
99impl AuditLogger for AsyncLogger {
100 async fn log(&self, event: AuditEvent) -> Result<(), AuditError> {
101 self.sender
102 .send(LogCommand::Log(Box::new(event)))
103 .await
104 .map_err(|_| AuditError::ChannelSend)
105 }
106
107 async fn flush(&self) -> Result<(), AuditError> {
108 self.sender
109 .send(LogCommand::Flush)
110 .await
111 .map_err(|_| AuditError::ChannelSend)
112 }
113
114 fn name(&self) -> &str {
115 &self.name
116 }
117}
118
119impl Drop for AsyncLogger {
120 fn drop(&mut self) {
121 let _ = self.sender.try_send(LogCommand::Shutdown);
123 }
124}
125
126pub struct AsyncLoggerBuilder<L> {
128 inner: L,
129 buffer_size: usize,
130 flush_interval_secs: u64,
131}
132
133impl<L: AuditLogger + 'static> AsyncLoggerBuilder<L> {
134 pub fn new(inner: L) -> Self {
136 Self {
137 inner,
138 buffer_size: 1000,
139 flush_interval_secs: 5,
140 }
141 }
142
143 pub fn buffer_size(mut self, size: usize) -> Self {
145 self.buffer_size = size;
146 self
147 }
148
149 pub fn flush_interval(mut self, secs: u64) -> Self {
151 self.flush_interval_secs = secs;
152 self
153 }
154
155 pub fn build(self) -> AsyncLogger {
157 let config = AuditConfig {
158 buffer_size: self.buffer_size,
159 flush_interval_secs: self.flush_interval_secs,
160 ..Default::default()
161 };
162 AsyncLogger::new(self.inner, &config)
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use crate::traits::MemoryLogger;
170 use std::sync::Arc;
171
172 #[tokio::test]
173 async fn test_async_logger_basic() {
174 let memory = Arc::new(MemoryLogger::new());
175 let memory_clone = memory.clone();
176
177 let (sender, mut receiver) = mpsc::channel::<LogCommand>(100);
179 let name = "async(memory)".to_string();
180
181 let inner = memory_clone;
182 tokio::spawn(async move {
183 while let Some(cmd) = receiver.recv().await {
184 match cmd {
185 LogCommand::Log(event) => {
186 let _ = inner.log(*event).await;
187 }
188 LogCommand::Flush => {
189 let _ = inner.flush().await;
190 }
191 LogCommand::Shutdown => break,
192 }
193 }
194 });
195
196 let async_logger = AsyncLogger { sender, name };
197
198 let event = AuditEvent::tool_call("test", serde_json::json!({}), true);
199 async_logger.log(event).await.unwrap();
200 async_logger.flush().await.unwrap();
201
202 tokio::time::sleep(Duration::from_millis(100)).await;
204
205 let _ = memory.count().await;
207 }
208
209 #[tokio::test]
210 async fn test_async_logger_builder() {
211 let memory = MemoryLogger::new();
212
213 let logger = AsyncLoggerBuilder::new(memory)
214 .buffer_size(500)
215 .flush_interval(2)
216 .build();
217
218 assert!(logger.name().contains("async"));
219 }
220
221 #[tokio::test]
222 async fn test_async_logger_wrap() {
223 let memory = MemoryLogger::new();
224 let logger = AsyncLogger::wrap(memory);
225
226 assert!(logger.name().contains("async"));
227 assert!(logger.name().contains("memory"));
228 }
229}