use std::path::Path;
use std::str::FromStr;
use async_trait::async_trait;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
use sqlx::SqlitePool;
use thiserror::Error;
use super::types::{ClickEvent, LinkId, MsgId, OpenEvent};
#[derive(Debug, Error)]
pub enum TrackingStoreError {
#[error("tracking sqlx: {0}")]
Sqlx(#[from] sqlx::Error),
#[error("tracking migration: {0}")]
Migration(String),
}
#[async_trait]
pub trait TrackingStore: Send + Sync {
async fn register_link(
&self,
tenant_id: &str,
msg_id: &MsgId,
link_id: &LinkId,
original_url: &str,
created_at_ms: i64,
) -> Result<(), TrackingStoreError>;
async fn lookup_link(
&self,
tenant_id: &str,
msg_id: &MsgId,
link_id: &LinkId,
) -> Result<Option<String>, TrackingStoreError>;
async fn record_open(&self, event: &OpenEvent) -> Result<(), TrackingStoreError>;
async fn record_click(&self, event: &ClickEvent) -> Result<(), TrackingStoreError>;
async fn count_opens(&self, tenant_id: &str, msg_id: &MsgId)
-> Result<u64, TrackingStoreError>;
async fn list_clicks(
&self,
tenant_id: &str,
msg_id: &MsgId,
) -> Result<Vec<ClickEvent>, TrackingStoreError>;
async fn count_clicks_by_link(
&self,
tenant_id: &str,
msg_id: &MsgId,
) -> Result<Vec<(LinkId, u64)>, TrackingStoreError>;
async fn delete_by_tenant(&self, tenant_id: &str) -> Result<u64, TrackingStoreError>;
}
const MIGRATION_SQL: &str = r#"
CREATE TABLE IF NOT EXISTS tracking_link (
tenant_id TEXT NOT NULL,
msg_id TEXT NOT NULL,
link_id TEXT NOT NULL,
original_url TEXT NOT NULL,
created_at_ms INTEGER NOT NULL,
PRIMARY KEY (tenant_id, msg_id, link_id)
);
CREATE TABLE IF NOT EXISTS tracking_open (
tenant_id TEXT NOT NULL,
msg_id TEXT NOT NULL,
opened_at_ms INTEGER NOT NULL,
ip_hash TEXT,
ua_hash TEXT
);
CREATE INDEX IF NOT EXISTS idx_tracking_open_tenant_msg
ON tracking_open(tenant_id, msg_id);
CREATE TABLE IF NOT EXISTS tracking_click (
tenant_id TEXT NOT NULL,
msg_id TEXT NOT NULL,
link_id TEXT NOT NULL,
clicked_at_ms INTEGER NOT NULL,
ip_hash TEXT,
ua_hash TEXT
);
CREATE INDEX IF NOT EXISTS idx_tracking_click_tenant_msg
ON tracking_click(tenant_id, msg_id);
CREATE INDEX IF NOT EXISTS idx_tracking_click_tenant_msg_link
ON tracking_click(tenant_id, msg_id, link_id);
"#;
pub async fn open_pool(path: impl AsRef<Path>) -> Result<SqlitePool, TrackingStoreError> {
let p = path.as_ref().to_string_lossy().to_string();
let conn_str = if p == ":memory:" {
"sqlite::memory:".to_string()
} else {
if let Some(parent) = path.as_ref().parent() {
let _ = std::fs::create_dir_all(parent);
}
format!("sqlite://{p}")
};
let opts = SqliteConnectOptions::from_str(&conn_str)
.map_err(|e| TrackingStoreError::Migration(e.to_string()))?
.create_if_missing(true);
let pool = SqlitePoolOptions::new()
.max_connections(2)
.connect_with(opts)
.await?;
sqlx::query("PRAGMA journal_mode=WAL")
.execute(&pool)
.await
.ok();
sqlx::query(MIGRATION_SQL).execute(&pool).await?;
Ok(pool)
}
pub struct SqliteTrackingStore {
pool: SqlitePool,
}
impl SqliteTrackingStore {
pub fn new(pool: SqlitePool) -> Self {
Self { pool }
}
pub fn pool(&self) -> &SqlitePool {
&self.pool
}
}
#[async_trait]
impl TrackingStore for SqliteTrackingStore {
async fn register_link(
&self,
tenant_id: &str,
msg_id: &MsgId,
link_id: &LinkId,
original_url: &str,
created_at_ms: i64,
) -> Result<(), TrackingStoreError> {
sqlx::query(
"INSERT OR REPLACE INTO tracking_link
(tenant_id, msg_id, link_id, original_url, created_at_ms)
VALUES (?, ?, ?, ?, ?)",
)
.bind(tenant_id)
.bind(msg_id.as_str())
.bind(link_id.as_str())
.bind(original_url)
.bind(created_at_ms)
.execute(&self.pool)
.await?;
Ok(())
}
async fn lookup_link(
&self,
tenant_id: &str,
msg_id: &MsgId,
link_id: &LinkId,
) -> Result<Option<String>, TrackingStoreError> {
let row: Option<(String,)> = sqlx::query_as(
"SELECT original_url FROM tracking_link
WHERE tenant_id = ? AND msg_id = ? AND link_id = ?",
)
.bind(tenant_id)
.bind(msg_id.as_str())
.bind(link_id.as_str())
.fetch_optional(&self.pool)
.await?;
Ok(row.map(|(url,)| url))
}
async fn record_open(&self, event: &OpenEvent) -> Result<(), TrackingStoreError> {
sqlx::query(
"INSERT INTO tracking_open
(tenant_id, msg_id, opened_at_ms, ip_hash, ua_hash)
VALUES (?, ?, ?, ?, ?)",
)
.bind(&event.tenant_id)
.bind(event.msg_id.as_str())
.bind(event.opened_at_ms)
.bind(event.ip_hash.as_deref())
.bind(event.ua_hash.as_deref())
.execute(&self.pool)
.await?;
Ok(())
}
async fn record_click(&self, event: &ClickEvent) -> Result<(), TrackingStoreError> {
sqlx::query(
"INSERT INTO tracking_click
(tenant_id, msg_id, link_id, clicked_at_ms, ip_hash, ua_hash)
VALUES (?, ?, ?, ?, ?, ?)",
)
.bind(&event.tenant_id)
.bind(event.msg_id.as_str())
.bind(event.link_id.as_str())
.bind(event.clicked_at_ms)
.bind(event.ip_hash.as_deref())
.bind(event.ua_hash.as_deref())
.execute(&self.pool)
.await?;
Ok(())
}
async fn count_opens(
&self,
tenant_id: &str,
msg_id: &MsgId,
) -> Result<u64, TrackingStoreError> {
let (n,): (i64,) = sqlx::query_as(
"SELECT COUNT(*) FROM tracking_open
WHERE tenant_id = ? AND msg_id = ?",
)
.bind(tenant_id)
.bind(msg_id.as_str())
.fetch_one(&self.pool)
.await?;
Ok(n.max(0) as u64)
}
async fn list_clicks(
&self,
tenant_id: &str,
msg_id: &MsgId,
) -> Result<Vec<ClickEvent>, TrackingStoreError> {
let rows: Vec<(String, String, i64, Option<String>, Option<String>)> = sqlx::query_as(
"SELECT msg_id, link_id, clicked_at_ms, ip_hash, ua_hash
FROM tracking_click
WHERE tenant_id = ? AND msg_id = ?
ORDER BY clicked_at_ms ASC
LIMIT 1000",
)
.bind(tenant_id)
.bind(msg_id.as_str())
.fetch_all(&self.pool)
.await?;
Ok(rows
.into_iter()
.map(|(m, l, t, ip, ua)| ClickEvent {
tenant_id: tenant_id.to_string(),
msg_id: MsgId(m),
link_id: LinkId(l),
clicked_at_ms: t,
ip_hash: ip,
ua_hash: ua,
})
.collect())
}
async fn count_clicks_by_link(
&self,
tenant_id: &str,
msg_id: &MsgId,
) -> Result<Vec<(LinkId, u64)>, TrackingStoreError> {
let rows: Vec<(String, i64)> = sqlx::query_as(
"SELECT link_id, COUNT(*) FROM tracking_click
WHERE tenant_id = ? AND msg_id = ?
GROUP BY link_id
ORDER BY 2 DESC",
)
.bind(tenant_id)
.bind(msg_id.as_str())
.fetch_all(&self.pool)
.await?;
Ok(rows
.into_iter()
.map(|(l, n)| (LinkId(l), n.max(0) as u64))
.collect())
}
async fn delete_by_tenant(&self, tenant_id: &str) -> Result<u64, TrackingStoreError> {
let mut total: u64 = 0;
for table in ["tracking_link", "tracking_open", "tracking_click"] {
let r = sqlx::query(&format!("DELETE FROM {table} WHERE tenant_id = ?"))
.bind(tenant_id)
.execute(&self.pool)
.await?;
total += r.rows_affected();
}
Ok(total)
}
}
#[cfg(test)]
mod tests {
use super::*;
async fn fresh() -> SqliteTrackingStore {
let pool = open_pool(":memory:").await.unwrap();
SqliteTrackingStore::new(pool)
}
#[tokio::test]
async fn register_and_lookup_link() {
let s = fresh().await;
let m = MsgId::new("m1");
let l = LinkId::new("L0");
s.register_link("acme", &m, &l, "https://acme.com/x", 1)
.await
.unwrap();
let got = s.lookup_link("acme", &m, &l).await.unwrap();
assert_eq!(got.as_deref(), Some("https://acme.com/x"));
}
#[tokio::test]
async fn lookup_misses_cross_tenant() {
let s = fresh().await;
let m = MsgId::new("m1");
let l = LinkId::new("L0");
s.register_link("acme", &m, &l, "https://acme.com/x", 1)
.await
.unwrap();
let got = s.lookup_link("globex", &m, &l).await.unwrap();
assert!(got.is_none());
}
#[tokio::test]
async fn register_link_is_idempotent() {
let s = fresh().await;
let m = MsgId::new("m1");
let l = LinkId::new("L0");
s.register_link("acme", &m, &l, "https://a.com/", 1)
.await
.unwrap();
s.register_link("acme", &m, &l, "https://b.com/", 2)
.await
.unwrap();
let got = s.lookup_link("acme", &m, &l).await.unwrap();
assert_eq!(got.as_deref(), Some("https://b.com/"));
}
#[tokio::test]
async fn record_open_and_count() {
let s = fresh().await;
let m = MsgId::new("m1");
let ev = OpenEvent {
tenant_id: "acme".into(),
msg_id: m.clone(),
opened_at_ms: 1,
ip_hash: Some("ipa".into()),
ua_hash: Some("uaa".into()),
};
s.record_open(&ev).await.unwrap();
s.record_open(&ev).await.unwrap();
s.record_open(&ev).await.unwrap();
assert_eq!(s.count_opens("acme", &m).await.unwrap(), 3);
assert_eq!(s.count_opens("globex", &m).await.unwrap(), 0);
}
#[tokio::test]
async fn record_click_and_list() {
let s = fresh().await;
let m = MsgId::new("m1");
let l0 = LinkId::new("L0");
let l1 = LinkId::new("L1");
for (l, t) in [(&l0, 10), (&l1, 20), (&l0, 30)] {
let ev = ClickEvent {
tenant_id: "acme".into(),
msg_id: m.clone(),
link_id: l.clone(),
clicked_at_ms: t,
ip_hash: None,
ua_hash: None,
};
s.record_click(&ev).await.unwrap();
}
let all = s.list_clicks("acme", &m).await.unwrap();
assert_eq!(all.len(), 3);
assert_eq!(all[0].clicked_at_ms, 10);
assert_eq!(all[2].clicked_at_ms, 30);
let by_link = s.count_clicks_by_link("acme", &m).await.unwrap();
assert_eq!(by_link[0], (LinkId::new("L0"), 2));
assert_eq!(by_link[1], (LinkId::new("L1"), 1));
}
#[tokio::test]
async fn delete_by_tenant_cascades() {
let s = fresh().await;
let m = MsgId::new("m1");
let l = LinkId::new("L0");
s.register_link("acme", &m, &l, "https://a.com/", 1)
.await
.unwrap();
s.record_open(&OpenEvent {
tenant_id: "acme".into(),
msg_id: m.clone(),
opened_at_ms: 1,
ip_hash: None,
ua_hash: None,
})
.await
.unwrap();
s.record_click(&ClickEvent {
tenant_id: "acme".into(),
msg_id: m.clone(),
link_id: l.clone(),
clicked_at_ms: 1,
ip_hash: None,
ua_hash: None,
})
.await
.unwrap();
s.register_link("globex", &m, &l, "https://g.com/", 1)
.await
.unwrap();
let n = s.delete_by_tenant("acme").await.unwrap();
assert_eq!(n, 3, "1 link + 1 open + 1 click");
let g = s.lookup_link("globex", &m, &l).await.unwrap();
assert_eq!(g.as_deref(), Some("https://g.com/"));
}
#[tokio::test]
async fn migrations_idempotent() {
let pool1 = open_pool(":memory:").await.unwrap();
let pool2 = open_pool(":memory:").await.unwrap();
sqlx::query("SELECT 1").execute(&pool1).await.unwrap();
sqlx::query("SELECT 1").execute(&pool2).await.unwrap();
}
}