1use std::collections::HashMap;
16
17use serde::{Deserialize, Serialize};
18use sqlx::{Row, Sqlite, SqlitePool, Transaction};
19use thiserror::Error;
20use tlog_tiles::{checkpoint::Checkpoint, tlog};
21use tracing::debug;
22
23#[derive(Debug, Error)]
24pub enum AuditError {
25 #[error("database error: {0}")]
26 Db(#[from] sqlx::Error),
27 #[error("proof error: {0}")]
28 Proof(#[from] tlog::Error),
29 #[error("malformed checkpoint")]
30 Checkpoint,
31 #[error("hash not found at index {0}")]
32 HashNotFound(u64),
33 #[error("invalid hex hash: {0}")]
34 InvalidHex(String),
35 #[error("no checkpoint exists yet")]
36 NoCheckpoint,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct AuditCheckpoint {
42 pub tree_size: i64,
43 pub root_hash: String,
44 pub checkpoint_text: String,
46 pub created_at: String,
47}
48
49struct InMemHashReader(HashMap<u64, tlog::Hash>);
51
52impl tlog::HashReader for InMemHashReader {
53 fn read_hashes(&self, indexes: &[u64]) -> Result<Vec<tlog::Hash>, tlog::Error> {
54 indexes
55 .iter()
56 .map(|&i| self.0.get(&i).copied().ok_or(tlog::Error::IndexesNotInTree))
57 .collect()
58 }
59}
60
61#[derive(Clone)]
64pub struct AuditLog {
65 pool: SqlitePool,
66 origin: String,
67}
68
69impl AuditLog {
70 pub fn new(pool: SqlitePool, origin: impl Into<String>) -> Self {
71 Self {
72 pool,
73 origin: origin.into(),
74 }
75 }
76
77 pub async fn append_leaf(&self, log_index: u64, leaf_hash_hex: &str) -> Result<(), AuditError> {
81 let mut tx = self.pool.begin().await?;
82 self.append_leaf_in_tx(&mut tx, log_index, leaf_hash_hex)
83 .await?;
84 tx.commit().await?;
85 Ok(())
86 }
87
88 pub async fn append_leaf_in_tx(
90 &self,
91 tx: &mut Transaction<'_, Sqlite>,
92 log_index: u64,
93 leaf_hash_hex: &str,
94 ) -> Result<(), AuditError> {
95 let leaf = hex_to_hash(leaf_hash_hex)?;
96 let n = log_index;
97
98 let mut to_store: Vec<(u64, tlog::Hash)> = vec![(tlog::stored_hash_index(0, n), leaf)];
99 let mut current = leaf;
100 let mut level = 0u8;
101
102 while (n >> level) & 1 == 1 {
103 let n_at_level = n >> level;
104 let left_idx = tlog::stored_hash_index(level, n_at_level - 1);
105 let left = self.read_hash_in_tx(&mut *tx, left_idx).await?;
106 let parent = tlog::node_hash(left, current);
107 level += 1;
108 let parent_idx = tlog::stored_hash_index(level, n_at_level >> 1);
109 to_store.push((parent_idx, parent));
110 current = parent;
111 }
112
113 for (idx, hash) in to_store {
114 sqlx::query("INSERT OR REPLACE INTO audit_hashes (hash_index, hash) VALUES (?1, ?2)")
115 .bind(idx as i64)
116 .bind(hash.0.as_slice())
117 .execute(&mut **tx)
118 .await?;
119 }
120
121 debug!(log_index, "audit leaf appended");
122 Ok(())
123 }
124
125 pub async fn make_checkpoint(&self, tree_size: u64) -> Result<AuditCheckpoint, AuditError> {
128 let mut tx = self.pool.begin().await?;
129 let cp = self.make_checkpoint_in_tx(&mut tx, tree_size).await?;
130 tx.commit().await?;
131 Ok(cp)
132 }
133
134 pub async fn make_checkpoint_in_tx(
136 &self,
137 tx: &mut Transaction<'_, Sqlite>,
138 tree_size: u64,
139 ) -> Result<AuditCheckpoint, AuditError> {
140 let reader = self.load_all_hashes_in_tx(&mut *tx).await?;
141 let root = tlog::tree_hash(tree_size, &reader)?;
142 let root_hex = hash_to_hex(&root);
143
144 let cp = Checkpoint::new(&self.origin, tree_size, root, "")
145 .map_err(|_| AuditError::Checkpoint)?;
146 let cp_text = String::from_utf8(cp.to_bytes()).unwrap_or_default();
147
148 let now = now_millis_string();
149 sqlx::query(
150 "INSERT OR REPLACE INTO audit_checkpoints
151 (tree_size, root_hash, checkpoint_text, created_at)
152 VALUES (?1, ?2, ?3, ?4)",
153 )
154 .bind(tree_size as i64)
155 .bind(&root_hex)
156 .bind(&cp_text)
157 .bind(&now)
158 .execute(&mut **tx)
159 .await?;
160
161 debug!(tree_size, root_hash = %root_hex, "checkpoint created");
162 Ok(AuditCheckpoint {
163 tree_size: tree_size as i64,
164 root_hash: root_hex,
165 checkpoint_text: cp_text,
166 created_at: now,
167 })
168 }
169
170 pub async fn latest_checkpoint(&self) -> Result<AuditCheckpoint, AuditError> {
172 let row = sqlx::query(
173 r#"
174 SELECT tree_size, root_hash, checkpoint_text, created_at
175 FROM audit_checkpoints
176 ORDER BY tree_size DESC
177 LIMIT 1
178 "#,
179 )
180 .fetch_optional(&self.pool)
181 .await?;
182
183 row.map(|r| AuditCheckpoint {
184 tree_size: r.get("tree_size"),
185 root_hash: r.get("root_hash"),
186 checkpoint_text: r.get("checkpoint_text"),
187 created_at: r.get("created_at"),
188 })
189 .ok_or(AuditError::NoCheckpoint)
190 }
191
192 pub async fn inclusion_proof(
197 &self,
198 log_index: u64,
199 tree_size: u64,
200 ) -> Result<Vec<String>, AuditError> {
201 let reader = self.load_all_hashes().await?;
202 let proof = tlog::prove_record(tree_size, log_index, &reader)?;
203 Ok(proof.iter().map(hash_to_hex).collect())
204 }
205
206 pub async fn consistency_proof(
211 &self,
212 old_size: u64,
213 new_size: u64,
214 ) -> Result<Vec<String>, AuditError> {
215 let reader = self.load_all_hashes().await?;
216 let proof = tlog::prove_tree(new_size, old_size, &reader)?;
217 Ok(proof.iter().map(hash_to_hex).collect())
218 }
219
220 pub async fn tree_size(&self) -> Result<u64, AuditError> {
222 let row =
224 sqlx::query("SELECT tree_size FROM audit_checkpoints ORDER BY tree_size DESC LIMIT 1")
225 .fetch_optional(&self.pool)
226 .await?;
227 Ok(row
228 .map(|r| r.get::<i64, _>("tree_size") as u64)
229 .unwrap_or(0))
230 }
231
232 async fn read_hash_in_tx(
233 &self,
234 tx: &mut Transaction<'_, Sqlite>,
235 idx: u64,
236 ) -> Result<tlog::Hash, AuditError> {
237 let row = sqlx::query("SELECT hash FROM audit_hashes WHERE hash_index = ?1")
238 .bind(idx as i64)
239 .fetch_optional(&mut **tx)
240 .await?;
241 Self::row_to_hash(row, idx)
242 }
243
244 fn row_to_hash(
245 row: Option<sqlx::sqlite::SqliteRow>,
246 idx: u64,
247 ) -> Result<tlog::Hash, AuditError> {
248 match row {
249 Some(r) => {
250 let bytes: Vec<u8> = r.get("hash");
251 if bytes.len() != 32 {
252 return Err(AuditError::HashNotFound(idx));
253 }
254 let mut h = [0u8; 32];
255 h.copy_from_slice(&bytes);
256 Ok(tlog::Hash(h))
257 }
258 None => Err(AuditError::HashNotFound(idx)),
259 }
260 }
261
262 async fn load_all_hashes(&self) -> Result<InMemHashReader, AuditError> {
263 let rows = sqlx::query("SELECT hash_index, hash FROM audit_hashes ORDER BY hash_index ASC")
264 .fetch_all(&self.pool)
265 .await?;
266 Ok(Self::rows_to_reader(rows))
267 }
268
269 async fn load_all_hashes_in_tx(
270 &self,
271 tx: &mut Transaction<'_, Sqlite>,
272 ) -> Result<InMemHashReader, AuditError> {
273 let rows = sqlx::query("SELECT hash_index, hash FROM audit_hashes ORDER BY hash_index ASC")
274 .fetch_all(&mut **tx)
275 .await?;
276 Ok(Self::rows_to_reader(rows))
277 }
278
279 fn rows_to_reader(rows: Vec<sqlx::sqlite::SqliteRow>) -> InMemHashReader {
280 let map = rows
281 .into_iter()
282 .map(|r| {
283 let idx: i64 = r.get("hash_index");
284 let bytes: Vec<u8> = r.get("hash");
285 let mut h = [0u8; 32];
286 h.copy_from_slice(&bytes);
287 (idx as u64, tlog::Hash(h))
288 })
289 .collect();
290 InMemHashReader(map)
291 }
292}
293
294fn hex_to_hash(hex: &str) -> Result<tlog::Hash, AuditError> {
295 if hex.len() != 64 {
296 return Err(AuditError::InvalidHex(hex.to_string()));
297 }
298 let mut bytes = [0u8; 32];
299 for i in 0..32 {
300 bytes[i] = u8::from_str_radix(&hex[i * 2..i * 2 + 2], 16)
301 .map_err(|_| AuditError::InvalidHex(hex.to_string()))?;
302 }
303 Ok(tlog::Hash(bytes))
304}
305
306fn hash_to_hex(h: &tlog::Hash) -> String {
307 const LUT: &[u8; 16] = b"0123456789abcdef";
308 let mut out = String::with_capacity(64);
309 for &b in h.0.iter() {
310 out.push(LUT[(b >> 4) as usize] as char);
311 out.push(LUT[(b & 0x0f) as usize] as char);
312 }
313 out
314}
315
316fn now_millis_string() -> String {
317 let now = std::time::SystemTime::now()
318 .duration_since(std::time::UNIX_EPOCH)
319 .unwrap_or_default();
320 now.as_millis().to_string()
321}