use jiff::Timestamp;
use miette::{IntoDiagnostic, Result};
use redb::{ReadableDatabase, ReadableTable};
use serde::{Deserialize, Serialize};
use tracing::trace;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuditEntry {
pub query: String,
pub db_user: String,
pub sys_user: String,
pub writemode: bool,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
pub tailscale: Vec<super::tailscale::TailscalePeer>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub ots: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub instance_id: Option<Uuid>,
#[serde(default = "default_recall")]
pub recall: bool,
}
fn default_recall() -> bool {
true
}
#[derive(Debug, Clone, Serialize)]
pub struct AuditEntryWithTimestamp {
pub ts: String,
#[serde(flatten)]
pub entry: AuditEntry,
}
impl AuditEntryWithTimestamp {
pub fn from_entry_and_timestamp(entry: AuditEntry, timestamp_micros: u64) -> Self {
let timestamp = Timestamp::from_microsecond(timestamp_micros as i64).unwrap();
Self {
ts: timestamp.to_string(),
entry,
}
}
}
impl super::Audit {
pub(crate) fn get_entry(&self, timestamp: u64) -> Result<Option<AuditEntry>> {
let read_txn = self.db.begin_read().into_diagnostic()?;
let table = match read_txn.open_table(super::HISTORY_TABLE) {
Ok(table) => table,
Err(_) => return Ok(None),
};
let json = match table.get(timestamp).into_diagnostic()? {
Some(json) => json,
None => return Ok(None),
};
let entry = serde_json::from_str(json.value()).into_diagnostic()?;
Ok(Some(entry))
}
pub fn add_entry(&mut self, query: String) -> Result<()> {
let state = self.repl_state.lock().unwrap();
let recall = !state.from_snippet_or_include;
drop(state);
self.add_entry_with_recall(query, recall)
}
pub fn add_entry_with_recall(&mut self, query: String, recall: bool) -> Result<()> {
trace!("adding audit entry");
let tailscale = super::tailscale::get_active_peers()
.ok()
.unwrap_or_default();
let state = self.repl_state.lock().unwrap();
let instance_id = self.working_info.as_ref().map(|info| info.uuid);
let entry = AuditEntry {
query,
db_user: state.db_user.clone(),
sys_user: state.sys_user.clone(),
writemode: state.write_mode,
tailscale,
ots: state.ots.clone(),
instance_id,
recall,
};
drop(state);
let json = serde_json::to_string(&entry).into_diagnostic()?;
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.into_diagnostic()?
.as_micros() as u64;
let write_txn = self.db.begin_write().into_diagnostic()?;
{
let mut history_table = write_txn
.open_table(super::HISTORY_TABLE)
.into_diagnostic()?;
history_table
.insert(timestamp, json.as_str())
.into_diagnostic()?;
}
write_txn.commit().into_diagnostic()?;
self.hist_index_push(timestamp)?;
Ok(())
}
pub fn list(&self) -> Result<Vec<(u64, AuditEntry)>> {
let read_txn = self.db.begin_read().into_diagnostic()?;
let table = read_txn
.open_table(super::HISTORY_TABLE)
.into_diagnostic()?;
let mut entries = Vec::new();
for item in table.iter().into_diagnostic()? {
let (timestamp, json) = item.into_diagnostic()?;
let entry: AuditEntry = serde_json::from_str(json.value()).into_diagnostic()?;
entries.push((timestamp.value(), entry));
}
Ok(entries)
}
}
#[cfg(test)]
mod tests {
use crate::audit::*;
#[test]
fn test_audit_roundtrip() {
let temp_dir = tempfile::tempdir().unwrap();
let db_path = temp_dir.path().join("test.redb");
let mut audit = Audit::open_empty(db_path).unwrap();
audit.add_entry("SELECT 1;".to_string()).unwrap();
audit.add_entry("SELECT 2;".to_string()).unwrap();
{
let mut state = audit.repl_state.lock().unwrap();
state.db_user = "dbuser".to_string();
state.sys_user = "testuser".to_string();
state.write_mode = true;
state.ots = Some("John Doe".to_string());
}
audit.add_entry("INSERT INTO foo;".to_string()).unwrap();
let entries = audit.list().unwrap();
assert_eq!(entries.len(), 3);
assert_eq!(entries[0].1.query, "SELECT 1;");
assert_eq!(entries[1].1.query, "SELECT 2;");
assert_eq!(entries[2].1.query, "INSERT INTO foo;");
assert!(entries[2].1.writemode);
assert_eq!(entries[2].1.db_user, "dbuser");
assert_eq!(entries[2].1.sys_user, "testuser");
assert_eq!(entries[2].1.ots, Some("John Doe".to_string()));
assert_eq!(entries[0].1.instance_id, None);
assert_eq!(entries[1].1.instance_id, None);
assert_eq!(entries[2].1.instance_id, None);
}
#[test]
fn test_audit_instance_id() {
let temp_dir = tempfile::tempdir().unwrap();
let db_dir = temp_dir.path().join("audit_dir");
std::fs::create_dir(&db_dir).unwrap();
temp_env::with_var("HOME", Some(db_dir.to_str().unwrap()), || {
let mut audit = Audit::open(
&db_dir,
std::sync::Arc::new(std::sync::Mutex::new(crate::repl::ReplState::new())),
)
.unwrap();
audit.add_entry("SELECT 1;".to_string()).unwrap();
let entries = audit.list().unwrap();
assert_eq!(entries.len(), 1);
assert!(entries[0].1.instance_id.is_some());
});
}
}