1use kernex_core::error::KernexError;
4use sqlx::SqlitePool;
5use tracing::debug;
6use uuid::Uuid;
7
8pub struct AuditEntry {
10 pub channel: String,
11 pub sender_id: String,
12 pub sender_name: Option<String>,
13 pub input_text: String,
14 pub output_text: Option<String>,
15 pub provider_used: Option<String>,
16 pub model: Option<String>,
17 pub processing_ms: Option<i64>,
18 pub status: AuditStatus,
19 pub denial_reason: Option<String>,
20}
21
22pub enum AuditStatus {
24 Ok,
25 Error,
26 Denied,
27}
28
29impl AuditStatus {
30 fn as_str(&self) -> &'static str {
31 match self {
32 Self::Ok => "ok",
33 Self::Error => "error",
34 Self::Denied => "denied",
35 }
36 }
37}
38
39#[derive(Clone)]
41pub struct AuditLogger {
42 pool: SqlitePool,
43}
44
45impl AuditLogger {
46 pub fn new(pool: SqlitePool) -> Self {
48 Self { pool }
49 }
50
51 pub async fn log(&self, entry: &AuditEntry) -> Result<(), KernexError> {
53 let id = Uuid::new_v4().to_string();
54
55 sqlx::query(
56 "INSERT INTO audit_log \
57 (id, channel, sender_id, sender_name, input_text, output_text, \
58 provider_used, model, processing_ms, status, denial_reason) \
59 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
60 )
61 .bind(&id)
62 .bind(&entry.channel)
63 .bind(&entry.sender_id)
64 .bind(&entry.sender_name)
65 .bind(&entry.input_text)
66 .bind(&entry.output_text)
67 .bind(&entry.provider_used)
68 .bind(&entry.model)
69 .bind(entry.processing_ms)
70 .bind(entry.status.as_str())
71 .bind(&entry.denial_reason)
72 .execute(&self.pool)
73 .await
74 .map_err(|e| KernexError::Store(format!("audit log write failed: {e}")))?;
75
76 debug!(
77 "audit: {} {} [{}] {}",
78 entry.channel,
79 entry.sender_id,
80 entry.status.as_str(),
81 truncate(&entry.input_text, 80)
82 );
83
84 Ok(())
85 }
86}
87
88fn truncate(s: &str, max: usize) -> &str {
89 if s.len() <= max {
90 s
91 } else {
92 &s[..s.floor_char_boundary(max)]
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99 use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
100 use sqlx::Row;
101 use std::str::FromStr;
102
103 async fn test_pool() -> SqlitePool {
104 let opts = SqliteConnectOptions::from_str("sqlite::memory:")
105 .unwrap()
106 .create_if_missing(true);
107 let pool = SqlitePoolOptions::new()
108 .max_connections(1)
109 .connect_with(opts)
110 .await
111 .unwrap();
112 sqlx::raw_sql(include_str!("../migrations/002_audit_log.sql"))
113 .execute(&pool)
114 .await
115 .unwrap();
116 pool
117 }
118
119 #[tokio::test]
120 async fn test_audit_logger_log_inserts_entry() {
121 let pool = test_pool().await;
122 let logger = AuditLogger::new(pool.clone());
123
124 let entry = AuditEntry {
125 channel: "api".to_string(),
126 sender_id: "user42".to_string(),
127 sender_name: Some("Alice".to_string()),
128 input_text: "hello kernex".to_string(),
129 output_text: Some("hi there".to_string()),
130 provider_used: Some("claude-code".to_string()),
131 model: Some("sonnet".to_string()),
132 processing_ms: Some(123),
133 status: AuditStatus::Ok,
134 denial_reason: None,
135 };
136
137 logger.log(&entry).await.unwrap();
138
139 let row = sqlx::query("SELECT channel, sender_id, sender_name, input_text, output_text, provider_used, model, processing_ms, status, denial_reason FROM audit_log LIMIT 1")
140 .fetch_one(&pool)
141 .await
142 .unwrap();
143
144 assert_eq!(row.get::<String, _>("channel"), "api");
145 assert_eq!(row.get::<String, _>("sender_id"), "user42");
146 assert_eq!(
147 row.get::<Option<String>, _>("sender_name"),
148 Some("Alice".to_string())
149 );
150 assert_eq!(row.get::<String, _>("input_text"), "hello kernex");
151 assert_eq!(
152 row.get::<Option<String>, _>("output_text"),
153 Some("hi there".to_string())
154 );
155 assert_eq!(
156 row.get::<Option<String>, _>("provider_used"),
157 Some("claude-code".to_string())
158 );
159 assert_eq!(
160 row.get::<Option<String>, _>("model"),
161 Some("sonnet".to_string())
162 );
163 assert_eq!(row.get::<Option<i64>, _>("processing_ms"), Some(123));
164 assert_eq!(row.get::<String, _>("status"), "ok");
165 assert_eq!(
166 row.get::<Option<String>, _>("denial_reason"),
167 None::<String>
168 );
169 }
170
171 #[test]
172 fn test_truncate_ascii() {
173 assert_eq!(truncate("hello", 10), "hello");
174 assert_eq!(truncate("hello world", 5), "hello");
175 }
176
177 #[test]
178 fn test_truncate_multibyte() {
179 let s = "\u{041f}\u{0440}\u{0438}\u{0432}\u{0435}\u{0442} \u{043c}\u{0438}\u{0440}!";
180 let result = truncate(s, 5);
181 assert!(!result.is_empty());
182 }
183
184 #[test]
185 fn test_truncate_emoji() {
186 let s = "Hi \u{1f389} there";
187 let result = truncate(s, 4);
188 assert!(!result.is_empty());
189 }
190}