1use std::collections::HashMap;
16
17use serde::{Deserialize, Serialize};
18use sqlx::{Row, Sqlite, SqlitePool, Transaction};
19use thiserror::Error;
20use tlog_tiles::{checkpoint::Checkpoint, tlog};
21use tracing::debug;
22
23use crate::signing::SigningKey;
24
25#[derive(Debug, Error)]
26pub enum AuditError {
27 #[error("database error: {0}")]
28 Db(#[from] sqlx::Error),
29 #[error("proof error: {0}")]
30 Proof(#[from] tlog::Error),
31 #[error("malformed checkpoint")]
32 Checkpoint,
33 #[error("hash not found at index {0}")]
34 HashNotFound(u64),
35 #[error("invalid hex hash: {0}")]
36 InvalidHex(String),
37 #[error("no checkpoint exists yet")]
38 NoCheckpoint,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct AuditCheckpoint {
44 pub tree_size: i64,
45 pub root_hash: String,
46 pub checkpoint_text: String,
48 pub created_at: String,
49}
50
51struct InMemHashReader(HashMap<u64, tlog::Hash>);
53
54impl tlog::HashReader for InMemHashReader {
55 fn read_hashes(&self, indexes: &[u64]) -> Result<Vec<tlog::Hash>, tlog::Error> {
56 indexes
57 .iter()
58 .map(|&i| self.0.get(&i).copied().ok_or(tlog::Error::IndexesNotInTree))
59 .collect()
60 }
61}
62
63#[derive(Clone)]
66pub struct AuditLog {
67 pool: SqlitePool,
68 origin: String,
69 signing_key: Option<SigningKey>,
70}
71
72impl AuditLog {
73 pub fn new(
74 pool: SqlitePool,
75 origin: impl Into<String>,
76 signing_key: Option<SigningKey>,
77 ) -> Self {
78 Self {
79 pool,
80 origin: origin.into(),
81 signing_key,
82 }
83 }
84
85 pub async fn append_leaf(&self, log_index: u64, leaf_hash_hex: &str) -> Result<(), AuditError> {
89 let mut tx = self.pool.begin().await?;
90 self.append_leaf_in_tx(&mut tx, log_index, leaf_hash_hex)
91 .await?;
92 tx.commit().await?;
93 Ok(())
94 }
95
96 pub async fn append_leaf_in_tx(
98 &self,
99 tx: &mut Transaction<'_, Sqlite>,
100 log_index: u64,
101 leaf_hash_hex: &str,
102 ) -> Result<(), AuditError> {
103 let leaf = hex_to_hash(leaf_hash_hex)?;
104 let n = log_index;
105
106 let mut to_store: Vec<(u64, tlog::Hash)> = vec![(tlog::stored_hash_index(0, n), leaf)];
107 let mut current = leaf;
108 let mut level = 0u8;
109
110 while (n >> level) & 1 == 1 {
111 let n_at_level = n >> level;
112 let left_idx = tlog::stored_hash_index(level, n_at_level - 1);
113 let left = self.read_hash_in_tx(&mut *tx, left_idx).await?;
114 let parent = tlog::node_hash(left, current);
115 level += 1;
116 let parent_idx = tlog::stored_hash_index(level, n_at_level >> 1);
117 to_store.push((parent_idx, parent));
118 current = parent;
119 }
120
121 for (idx, hash) in to_store {
122 sqlx::query("INSERT OR REPLACE INTO audit_hashes (hash_index, hash) VALUES (?1, ?2)")
123 .bind(idx as i64)
124 .bind(hash.0.as_slice())
125 .execute(&mut **tx)
126 .await?;
127 }
128
129 debug!(log_index, "audit leaf appended");
130 Ok(())
131 }
132
133 pub async fn make_checkpoint(&self, tree_size: u64) -> Result<AuditCheckpoint, AuditError> {
136 let mut tx = self.pool.begin().await?;
137 let cp = self.make_checkpoint_in_tx(&mut tx, tree_size).await?;
138 tx.commit().await?;
139 Ok(cp)
140 }
141
142 pub async fn make_checkpoint_in_tx(
144 &self,
145 tx: &mut Transaction<'_, Sqlite>,
146 tree_size: u64,
147 ) -> Result<AuditCheckpoint, AuditError> {
148 let reader = self.load_all_hashes_in_tx(&mut *tx).await?;
149 let root = tlog::tree_hash(tree_size, &reader)?;
150 let root_hex = hash_to_hex(&root);
151
152 let extension = if let Some(ref sk) = self.signing_key {
154 let body = format!("{}\n{}\n{}\n", self.origin, tree_size, root);
155 sk.sign_checkpoint(body.as_bytes())
156 } else {
157 String::new()
158 };
159 let cp = Checkpoint::new(&self.origin, tree_size, root, &extension)
160 .map_err(|_| AuditError::Checkpoint)?;
161 let cp_text = String::from_utf8(cp.to_bytes()).unwrap_or_default();
162
163 let now = now_millis_string();
164 sqlx::query(
165 "INSERT OR REPLACE INTO audit_checkpoints
166 (tree_size, root_hash, checkpoint_text, created_at)
167 VALUES (?1, ?2, ?3, ?4)",
168 )
169 .bind(tree_size as i64)
170 .bind(&root_hex)
171 .bind(&cp_text)
172 .bind(&now)
173 .execute(&mut **tx)
174 .await?;
175
176 debug!(tree_size, root_hash = %root_hex, "checkpoint created");
177 Ok(AuditCheckpoint {
178 tree_size: tree_size as i64,
179 root_hash: root_hex,
180 checkpoint_text: cp_text,
181 created_at: now,
182 })
183 }
184
185 pub async fn latest_checkpoint(&self) -> Result<AuditCheckpoint, AuditError> {
187 let row = sqlx::query(
188 r#"
189 SELECT tree_size, root_hash, checkpoint_text, created_at
190 FROM audit_checkpoints
191 ORDER BY tree_size DESC
192 LIMIT 1
193 "#,
194 )
195 .fetch_optional(&self.pool)
196 .await?;
197
198 row.map(|r| AuditCheckpoint {
199 tree_size: r.get("tree_size"),
200 root_hash: r.get("root_hash"),
201 checkpoint_text: r.get("checkpoint_text"),
202 created_at: r.get("created_at"),
203 })
204 .ok_or(AuditError::NoCheckpoint)
205 }
206
207 pub async fn inclusion_proof(
212 &self,
213 log_index: u64,
214 tree_size: u64,
215 ) -> Result<Vec<String>, AuditError> {
216 let reader = self.load_all_hashes().await?;
217 let proof = tlog::prove_record(tree_size, log_index, &reader)?;
218 Ok(proof.iter().map(hash_to_hex).collect())
219 }
220
221 pub async fn consistency_proof(
226 &self,
227 old_size: u64,
228 new_size: u64,
229 ) -> Result<Vec<String>, AuditError> {
230 let reader = self.load_all_hashes().await?;
231 let proof = tlog::prove_tree(new_size, old_size, &reader)?;
232 Ok(proof.iter().map(hash_to_hex).collect())
233 }
234
235 pub async fn tree_size(&self) -> Result<u64, AuditError> {
237 let row =
241 sqlx::query("SELECT tree_size FROM audit_checkpoints ORDER BY tree_size DESC LIMIT 1")
242 .fetch_optional(&self.pool)
243 .await?;
244 Ok(row
245 .map(|r| r.get::<i64, _>("tree_size") as u64)
246 .unwrap_or(0))
247 }
248
249 pub async fn ensure_checkpoint(&self, event_count: u64) -> Result<(), AuditError> {
253 if event_count == 0 {
254 return Ok(());
255 }
256 let current = self.tree_size().await?;
257 if current >= event_count {
258 return Ok(()); }
260 self.make_checkpoint(event_count).await?;
261 Ok(())
262 }
263
264 async fn read_hash_in_tx(
265 &self,
266 tx: &mut Transaction<'_, Sqlite>,
267 idx: u64,
268 ) -> Result<tlog::Hash, AuditError> {
269 let row = sqlx::query("SELECT hash FROM audit_hashes WHERE hash_index = ?1")
270 .bind(idx as i64)
271 .fetch_optional(&mut **tx)
272 .await?;
273 Self::row_to_hash(row, idx)
274 }
275
276 fn row_to_hash(
277 row: Option<sqlx::sqlite::SqliteRow>,
278 idx: u64,
279 ) -> Result<tlog::Hash, AuditError> {
280 match row {
281 Some(r) => {
282 let bytes: Vec<u8> = r.get("hash");
283 if bytes.len() != 32 {
284 return Err(AuditError::HashNotFound(idx));
285 }
286 let mut h = [0u8; 32];
287 h.copy_from_slice(&bytes);
288 Ok(tlog::Hash(h))
289 }
290 None => Err(AuditError::HashNotFound(idx)),
291 }
292 }
293
294 async fn load_all_hashes(&self) -> Result<InMemHashReader, AuditError> {
295 let rows = sqlx::query("SELECT hash_index, hash FROM audit_hashes ORDER BY hash_index ASC")
296 .fetch_all(&self.pool)
297 .await?;
298 Ok(Self::rows_to_reader(rows))
299 }
300
301 async fn load_all_hashes_in_tx(
302 &self,
303 tx: &mut Transaction<'_, Sqlite>,
304 ) -> Result<InMemHashReader, AuditError> {
305 let rows = sqlx::query("SELECT hash_index, hash FROM audit_hashes ORDER BY hash_index ASC")
306 .fetch_all(&mut **tx)
307 .await?;
308 Ok(Self::rows_to_reader(rows))
309 }
310
311 fn rows_to_reader(rows: Vec<sqlx::sqlite::SqliteRow>) -> InMemHashReader {
312 let map = rows
313 .into_iter()
314 .map(|r| {
315 let idx: i64 = r.get("hash_index");
316 let bytes: Vec<u8> = r.get("hash");
317 let mut h = [0u8; 32];
318 h.copy_from_slice(&bytes);
319 (idx as u64, tlog::Hash(h))
320 })
321 .collect();
322 InMemHashReader(map)
323 }
324}
325
326fn hex_to_hash(hex: &str) -> Result<tlog::Hash, AuditError> {
327 if hex.len() != 64 {
328 return Err(AuditError::InvalidHex(hex.to_string()));
329 }
330 let mut bytes = [0u8; 32];
331 for i in 0..32 {
332 bytes[i] = u8::from_str_radix(&hex[i * 2..i * 2 + 2], 16)
333 .map_err(|_| AuditError::InvalidHex(hex.to_string()))?;
334 }
335 Ok(tlog::Hash(bytes))
336}
337
338fn hash_to_hex(h: &tlog::Hash) -> String {
339 const LUT: &[u8; 16] = b"0123456789abcdef";
340 let mut out = String::with_capacity(64);
341 for &b in h.0.iter() {
342 out.push(LUT[(b >> 4) as usize] as char);
343 out.push(LUT[(b & 0x0f) as usize] as char);
344 }
345 out
346}
347
348fn now_millis_string() -> String {
349 let now = std::time::SystemTime::now()
350 .duration_since(std::time::UNIX_EPOCH)
351 .unwrap_or_default();
352 now.as_millis().to_string()
353}