use r2d2::Pool;
use r2d2_sqlite::SqliteConnectionManager;
use rusqlite::{params, Result};
#[derive(Debug, Clone)]
pub struct DocumentRecord {
pub id: String,
pub slug: String,
pub title: String,
pub raw_content: String,
pub theme: String,
pub password: Option<String>, pub description: Option<String>,
pub created_at: String, pub expires_at: Option<String>, pub updated_at: String,
}
#[derive(Debug, Clone)]
pub struct TokenRecord {
pub id: String,
pub name: String,
pub hash: String, pub created_at: String,
pub last_used: Option<String>,
pub revoked: bool,
pub prefix: Option<String>,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct AuditEntry {
pub id: String,
pub timestamp: String,
pub action: String,
pub slug: String,
pub token_name: String,
pub ip_address: String,
}
#[derive(Clone)]
pub struct Db {
pool: Pool<SqliteConnectionManager>,
}
fn pool_err(e: r2d2::Error) -> rusqlite::Error {
rusqlite::Error::SqliteFailure(
rusqlite::ffi::Error::new(rusqlite::ffi::SQLITE_CANTOPEN),
Some(format!("connection pool error: {e}")),
)
}
impl Db {
pub fn open(path: &str) -> Result<Self> {
let manager = SqliteConnectionManager::file(path).with_init(|conn| {
conn.execute_batch("PRAGMA journal_mode=WAL;")?;
conn.busy_timeout(std::time::Duration::from_secs(5))?;
Ok(())
});
let pool = Pool::builder()
.max_size(8)
.build(manager)
.map_err(pool_err)?;
let db = Db { pool };
db.initialize_schema()?;
db.migrate()?;
Ok(db)
}
fn initialize_schema(&self) -> Result<()> {
let conn = self.pool.get().map_err(pool_err)?;
conn.execute_batch(
"CREATE TABLE IF NOT EXISTS documents (
id TEXT PRIMARY KEY,
slug TEXT UNIQUE NOT NULL,
title TEXT NOT NULL,
raw_content TEXT NOT NULL,
theme TEXT NOT NULL DEFAULT 'clean',
password TEXT,
description TEXT,
created_at TEXT NOT NULL,
expires_at TEXT,
updated_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_documents_slug ON documents(slug);
CREATE TABLE IF NOT EXISTS tokens (
id TEXT PRIMARY KEY,
name TEXT UNIQUE NOT NULL,
hash TEXT NOT NULL,
created_at TEXT NOT NULL,
last_used TEXT,
revoked INTEGER NOT NULL DEFAULT 0,
prefix TEXT
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_tokens_prefix ON tokens(prefix);
CREATE TABLE IF NOT EXISTS audit_log (
id TEXT PRIMARY KEY,
timestamp TEXT NOT NULL,
action TEXT NOT NULL,
slug TEXT NOT NULL,
token_name TEXT NOT NULL,
ip_address TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_audit_log_timestamp ON audit_log(timestamp);
CREATE INDEX IF NOT EXISTS idx_audit_log_slug ON audit_log(slug);
CREATE TABLE IF NOT EXISTS oauth_clients (
client_id TEXT PRIMARY KEY,
client_name TEXT NOT NULL,
redirect_uris TEXT NOT NULL,
grant_types TEXT NOT NULL,
response_types TEXT NOT NULL,
token_endpoint_auth_method TEXT NOT NULL,
created_at TEXT NOT NULL,
provisioned INTEGER NOT NULL DEFAULT 0,
client_secret TEXT
);
CREATE TABLE IF NOT EXISTS oauth_auth_codes (
code TEXT PRIMARY KEY,
client_id TEXT NOT NULL,
redirect_uri TEXT NOT NULL,
expires_at TEXT NOT NULL,
code_challenge TEXT NOT NULL,
resource TEXT,
scope TEXT
);
CREATE INDEX IF NOT EXISTS idx_oauth_auth_codes_expires_at ON oauth_auth_codes(expires_at);
CREATE TABLE IF NOT EXISTS oauth_access_tokens (
token TEXT PRIMARY KEY,
client_id TEXT NOT NULL,
scope TEXT,
expires_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_oauth_access_tokens_expires_at ON oauth_access_tokens(expires_at);
CREATE TABLE IF NOT EXISTS oauth_refresh_tokens (
token TEXT PRIMARY KEY,
client_id TEXT NOT NULL,
access_token TEXT NOT NULL,
scope TEXT,
expires_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_oauth_refresh_tokens_expires_at ON oauth_refresh_tokens(expires_at);",
)?;
Ok(())
}
fn migrate(&self) -> Result<()> {
let conn = self.pool.get().map_err(pool_err)?;
let mut stmt = conn.prepare("PRAGMA table_info(documents)")?;
let columns: Vec<String> = stmt
.query_map([], |row| row.get::<_, String>(1))?
.filter_map(|r| r.ok())
.collect();
drop(stmt);
if !columns.contains(&"theme".to_string()) {
conn.execute_batch(
"ALTER TABLE documents ADD COLUMN theme TEXT NOT NULL DEFAULT 'clean';",
)?;
}
if !columns.contains(&"password".to_string()) {
conn.execute_batch("ALTER TABLE documents ADD COLUMN password TEXT;")?;
}
if !columns.contains(&"description".to_string()) {
conn.execute_batch("ALTER TABLE documents ADD COLUMN description TEXT;")?;
}
if !columns.contains(&"expires_at".to_string()) {
conn.execute_batch("ALTER TABLE documents ADD COLUMN expires_at TEXT;")?;
}
conn.execute_batch(
"CREATE INDEX IF NOT EXISTS idx_documents_expires_at ON documents(expires_at);",
)?;
let mut token_stmt = conn.prepare("PRAGMA table_info(tokens)")?;
let token_columns: Vec<String> = token_stmt
.query_map([], |row| row.get::<_, String>(1))?
.filter_map(|r| r.ok())
.collect();
drop(token_stmt);
if !token_columns.contains(&"prefix".to_string()) {
conn.execute_batch("ALTER TABLE tokens ADD COLUMN prefix TEXT;")?;
}
conn.execute_batch(
"CREATE UNIQUE INDEX IF NOT EXISTS idx_tokens_prefix ON tokens(prefix);",
)?;
conn.execute_batch(
"CREATE TABLE IF NOT EXISTS audit_log (
id TEXT PRIMARY KEY,
timestamp TEXT NOT NULL,
action TEXT NOT NULL,
slug TEXT NOT NULL,
token_name TEXT NOT NULL,
ip_address TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_audit_log_timestamp ON audit_log(timestamp);
CREATE INDEX IF NOT EXISTS idx_audit_log_slug ON audit_log(slug);",
)?;
conn.execute_batch(
"CREATE TABLE IF NOT EXISTS oauth_clients (
client_id TEXT PRIMARY KEY,
client_name TEXT NOT NULL,
redirect_uris TEXT NOT NULL,
grant_types TEXT NOT NULL,
response_types TEXT NOT NULL,
token_endpoint_auth_method TEXT NOT NULL,
created_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS oauth_auth_codes (
code TEXT PRIMARY KEY,
client_id TEXT NOT NULL,
redirect_uri TEXT NOT NULL,
expires_at TEXT NOT NULL,
code_challenge TEXT NOT NULL,
resource TEXT,
scope TEXT
);
CREATE INDEX IF NOT EXISTS idx_oauth_auth_codes_expires_at ON oauth_auth_codes(expires_at);
CREATE TABLE IF NOT EXISTS oauth_access_tokens (
token TEXT PRIMARY KEY,
client_id TEXT NOT NULL,
scope TEXT,
expires_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_oauth_access_tokens_expires_at ON oauth_access_tokens(expires_at);
CREATE TABLE IF NOT EXISTS oauth_refresh_tokens (
token TEXT PRIMARY KEY,
client_id TEXT NOT NULL,
access_token TEXT NOT NULL,
scope TEXT,
expires_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_oauth_refresh_tokens_expires_at ON oauth_refresh_tokens(expires_at);",
)?;
let mut client_stmt = conn.prepare("PRAGMA table_info(oauth_clients)")?;
let client_columns: Vec<String> = client_stmt
.query_map([], |row| row.get::<_, String>(1))?
.filter_map(|r| r.ok())
.collect();
drop(client_stmt);
if !client_columns.contains(&"provisioned".to_string()) {
conn.execute_batch(
"ALTER TABLE oauth_clients ADD COLUMN provisioned INTEGER NOT NULL DEFAULT 0;",
)?;
}
if !client_columns.contains(&"client_secret".to_string()) {
conn.execute_batch("ALTER TABLE oauth_clients ADD COLUMN client_secret TEXT;")?;
}
Ok(())
}
pub fn ping(&self) -> Result<()> {
let conn = self.pool.get().map_err(pool_err)?;
conn.execute_batch("SELECT 1;")?;
Ok(())
}
pub fn insert_document(&self, doc: &DocumentRecord) -> Result<()> {
let conn = self.pool.get().map_err(pool_err)?;
conn.execute(
"INSERT INTO documents (id, slug, title, raw_content, theme, password, description, created_at, expires_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)",
params![
doc.id,
doc.slug,
doc.title,
doc.raw_content,
doc.theme,
doc.password,
doc.description,
doc.created_at,
doc.expires_at,
doc.updated_at,
],
)?;
Ok(())
}
pub fn update_document(&self, slug: &str, doc: &DocumentRecord) -> Result<bool> {
let conn = self.pool.get().map_err(pool_err)?;
let rows = conn.execute(
"UPDATE documents SET title = ?1, raw_content = ?2, theme = ?3, password = ?4,
description = ?5, expires_at = ?6, updated_at = ?7
WHERE slug = ?8",
params![
doc.title,
doc.raw_content,
doc.theme,
doc.password,
doc.description,
doc.expires_at,
doc.updated_at,
slug,
],
)?;
Ok(rows > 0)
}
pub fn delete_by_slug(&self, slug: &str) -> Result<bool> {
let conn = self.pool.get().map_err(pool_err)?;
let rows = conn.execute("DELETE FROM documents WHERE slug = ?1", params![slug])?;
Ok(rows > 0)
}
pub fn get_by_slug(&self, slug: &str) -> Result<Option<DocumentRecord>> {
let conn = self.pool.get().map_err(pool_err)?;
let mut stmt = conn.prepare(
"SELECT id, slug, title, raw_content, theme, password, description, created_at, expires_at, updated_at
FROM documents WHERE slug = ?1",
)?;
let mut rows = stmt.query(params![slug])?;
match rows.next()? {
None => Ok(None),
Some(row) => Ok(Some(DocumentRecord {
id: row.get(0)?,
slug: row.get(1)?,
title: row.get(2)?,
raw_content: row.get(3)?,
theme: row.get(4)?,
password: row.get(5)?,
description: row.get(6)?,
created_at: row.get(7)?,
expires_at: row.get(8)?,
updated_at: row.get(9)?,
})),
}
}
pub fn delete_expired_older_than(&self, now: &str, days: u32) -> Result<usize> {
let conn = self.pool.get().map_err(pool_err)?;
let rows = conn.execute(
"DELETE FROM documents \
WHERE expires_at IS NOT NULL \
AND expires_at < datetime(?1, printf('-%d days', ?2))",
params![now, days],
)?;
Ok(rows)
}
pub fn insert_token(&self, token: &TokenRecord) -> Result<()> {
let conn = self.pool.get().map_err(pool_err)?;
conn.execute(
"INSERT INTO tokens (id, name, hash, created_at, last_used, revoked, prefix)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
params![
token.id,
token.name,
token.hash,
token.created_at,
token.last_used,
token.revoked as i32,
token.prefix,
],
)?;
Ok(())
}
pub fn get_token_by_prefix(&self, prefix: &str) -> Result<Option<TokenRecord>> {
let conn = self.pool.get().map_err(pool_err)?;
let mut stmt = conn.prepare(
"SELECT id, name, hash, created_at, last_used, revoked, prefix
FROM tokens WHERE prefix = ?1 AND revoked = 0
LIMIT 1",
)?;
let mut rows = stmt.query(params![prefix])?;
match rows.next()? {
None => Ok(None),
Some(row) => Ok(Some(TokenRecord {
id: row.get(0)?,
name: row.get(1)?,
hash: row.get(2)?,
created_at: row.get(3)?,
last_used: row.get(4)?,
revoked: row.get::<_, i32>(5)? != 0,
prefix: row.get(6)?,
})),
}
}
pub fn get_legacy_active_tokens(&self) -> Result<Vec<TokenRecord>> {
let conn = self.pool.get().map_err(pool_err)?;
let mut stmt = conn.prepare(
"SELECT id, name, hash, created_at, last_used, revoked, prefix
FROM tokens WHERE revoked = 0 AND prefix IS NULL",
)?;
let tokens = stmt
.query_map([], |row| {
Ok(TokenRecord {
id: row.get(0)?,
name: row.get(1)?,
hash: row.get(2)?,
created_at: row.get(3)?,
last_used: row.get(4)?,
revoked: row.get::<_, i32>(5)? != 0,
prefix: row.get(6)?,
})
})?
.filter_map(|r| {
r.map_err(|e| tracing::warn!("Failed to deserialize token row: {}", e))
.ok()
})
.collect();
Ok(tokens)
}
pub fn list_tokens(&self) -> Result<Vec<TokenRecord>> {
let conn = self.pool.get().map_err(pool_err)?;
let mut stmt = conn.prepare(
"SELECT id, name, hash, created_at, last_used, revoked, prefix
FROM tokens ORDER BY created_at DESC",
)?;
let tokens = stmt
.query_map([], |row| {
Ok(TokenRecord {
id: row.get(0)?,
name: row.get(1)?,
hash: row.get(2)?,
created_at: row.get(3)?,
last_used: row.get(4)?,
revoked: row.get::<_, i32>(5)? != 0,
prefix: row.get(6)?,
})
})?
.filter_map(|r| {
r.map_err(|e| tracing::warn!("Failed to deserialize token row: {}", e))
.ok()
})
.collect();
Ok(tokens)
}
pub fn revoke_token(&self, name: &str) -> Result<bool> {
let conn = self.pool.get().map_err(pool_err)?;
let rows = conn.execute(
"UPDATE tokens SET revoked = 1 WHERE name = ?1 AND revoked = 0",
params![name],
)?;
Ok(rows > 0)
}
pub fn touch_token(&self, id: &str, now: &str) -> Result<()> {
let conn = self.pool.get().map_err(pool_err)?;
conn.execute(
"UPDATE tokens SET last_used = ?1 WHERE id = ?2",
params![now, id],
)?;
Ok(())
}
pub fn token_name_exists(&self, name: &str) -> Result<bool> {
let conn = self.pool.get().map_err(pool_err)?;
let mut stmt = conn.prepare("SELECT COUNT(*) FROM tokens WHERE name = ?1")?;
let count: i64 = stmt.query_row(params![name], |row| row.get(0))?;
Ok(count > 0)
}
pub fn list_documents(&self, limit: u32, offset: u32) -> Result<(Vec<DocumentSummary>, u64)> {
let capped_limit = limit.min(100);
let conn = self.pool.get().map_err(pool_err)?;
let total: u64 = {
let mut stmt = conn.prepare(
"SELECT COUNT(*) FROM documents \
WHERE expires_at IS NULL OR expires_at > strftime('%Y-%m-%dT%H:%M:%SZ', 'now')",
)?;
stmt.query_row([], |row| row.get::<_, i64>(0))
.map(|n| n as u64)?
};
let mut stmt = conn.prepare(
"SELECT slug, title, description, created_at, expires_at \
FROM documents \
WHERE expires_at IS NULL OR expires_at > strftime('%Y-%m-%dT%H:%M:%SZ', 'now') \
ORDER BY created_at DESC \
LIMIT ?1 OFFSET ?2",
)?;
let docs = stmt
.query_map(params![capped_limit, offset], |row| {
Ok(DocumentSummary {
slug: row.get(0)?,
title: row.get(1)?,
description: row.get(2)?,
created_at: row.get(3)?,
expires_at: row.get(4)?,
})
})?
.filter_map(|r| r.ok())
.collect();
Ok((docs, total))
}
pub fn insert_audit_entry(&self, entry: &AuditEntry) -> Result<()> {
let conn = self.pool.get().map_err(pool_err)?;
conn.execute(
"INSERT INTO audit_log (id, timestamp, action, slug, token_name, ip_address)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
params![
entry.id,
entry.timestamp,
entry.action,
entry.slug,
entry.token_name,
entry.ip_address,
],
)?;
Ok(())
}
pub fn list_audit_entries(&self, limit: u32, offset: u32) -> Result<(Vec<AuditEntry>, u64)> {
let capped_limit = limit.min(100);
let conn = self.pool.get().map_err(pool_err)?;
let total: u64 = {
let mut stmt = conn.prepare("SELECT COUNT(*) FROM audit_log")?;
stmt.query_row([], |row| row.get::<_, i64>(0))
.map(|n| n as u64)?
};
let mut stmt = conn.prepare(
"SELECT id, timestamp, action, slug, token_name, ip_address \
FROM audit_log \
ORDER BY timestamp DESC \
LIMIT ?1 OFFSET ?2",
)?;
let entries = stmt
.query_map(params![capped_limit, offset], |row| {
Ok(AuditEntry {
id: row.get(0)?,
timestamp: row.get(1)?,
action: row.get(2)?,
slug: row.get(3)?,
token_name: row.get(4)?,
ip_address: row.get(5)?,
})
})?
.filter_map(|r| match r {
Ok(e) => Some(e),
Err(e) => {
tracing::warn!("Audit row deserialization failed: {e}");
None
}
})
.collect();
Ok((entries, total))
}
}
#[derive(Debug, Clone)]
pub struct OAuthClientRow {
pub client_id: String,
pub client_name: String,
pub redirect_uris: String,
pub grant_types: String,
pub response_types: String,
pub token_endpoint_auth_method: String,
pub created_at: String,
pub provisioned: bool,
pub client_secret: Option<String>,
}
#[derive(Debug, Clone)]
pub struct AuthCodeRow {
pub code: String,
pub client_id: String,
pub redirect_uri: String,
pub expires_at: String,
pub code_challenge: String,
pub resource: Option<String>,
pub scope: Option<String>,
}
#[derive(Debug, Clone)]
pub struct AccessTokenRow {
pub token: String,
pub client_id: String,
pub scope: Option<String>,
pub expires_at: String,
}
#[derive(Debug, Clone)]
pub struct RefreshTokenRow {
pub token: String,
pub client_id: String,
pub access_token: String,
pub scope: Option<String>,
pub expires_at: String,
}
impl Db {
pub fn insert_oauth_client(&self, row: &OAuthClientRow) -> Result<()> {
let conn = self.pool.get().map_err(pool_err)?;
conn.execute(
"INSERT INTO oauth_clients
(client_id, client_name, redirect_uris, grant_types, response_types,
token_endpoint_auth_method, created_at, provisioned, client_secret)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
params![
row.client_id,
row.client_name,
row.redirect_uris,
row.grant_types,
row.response_types,
row.token_endpoint_auth_method,
row.created_at,
row.provisioned as i32,
row.client_secret,
],
)?;
Ok(())
}
pub fn get_oauth_client(&self, client_id: &str) -> Result<Option<OAuthClientRow>> {
let conn = self.pool.get().map_err(pool_err)?;
let mut stmt = conn.prepare(
"SELECT client_id, client_name, redirect_uris, grant_types, response_types,
token_endpoint_auth_method, created_at, provisioned, client_secret
FROM oauth_clients WHERE client_id = ?1",
)?;
let mut rows = stmt.query(params![client_id])?;
match rows.next()? {
None => Ok(None),
Some(row) => Ok(Some(OAuthClientRow {
client_id: row.get(0)?,
client_name: row.get(1)?,
redirect_uris: row.get(2)?,
grant_types: row.get(3)?,
response_types: row.get(4)?,
token_endpoint_auth_method: row.get(5)?,
created_at: row.get(6)?,
provisioned: row.get::<_, i32>(7)? != 0,
client_secret: row.get(8)?,
})),
}
}
pub fn count_active_oauth_clients(&self, cutoff: &str) -> Result<i64> {
let conn = self.pool.get().map_err(pool_err)?;
let mut stmt = conn.prepare("SELECT COUNT(*) FROM oauth_clients WHERE created_at >= ?1")?;
let count: i64 = stmt.query_row(params![cutoff], |row| row.get(0))?;
Ok(count)
}
pub fn delete_expired_oauth_clients(&self, cutoff: &str) -> Result<usize> {
let conn = self.pool.get().map_err(pool_err)?;
let rows = conn.execute(
"DELETE FROM oauth_clients WHERE created_at < ?1 AND provisioned = 0",
params![cutoff],
)?;
Ok(rows)
}
pub fn list_provisioned_clients(&self) -> Result<Vec<OAuthClientRow>> {
let conn = self.pool.get().map_err(pool_err)?;
let mut stmt = conn.prepare(
"SELECT client_id, client_name, redirect_uris, grant_types, response_types,
token_endpoint_auth_method, created_at, provisioned, client_secret
FROM oauth_clients WHERE provisioned = 1
ORDER BY created_at DESC",
)?;
let clients = stmt
.query_map([], |row| {
Ok(OAuthClientRow {
client_id: row.get(0)?,
client_name: row.get(1)?,
redirect_uris: row.get(2)?,
grant_types: row.get(3)?,
response_types: row.get(4)?,
token_endpoint_auth_method: row.get(5)?,
created_at: row.get(6)?,
provisioned: row.get::<_, i32>(7)? != 0,
client_secret: row.get(8)?,
})
})?
.filter_map(|r| {
r.map_err(|e| tracing::warn!("Failed to deserialize oauth_client row: {}", e))
.ok()
})
.collect();
Ok(clients)
}
pub fn revoke_provisioned_client(&self, client_id: &str) -> Result<bool> {
let conn = self.pool.get().map_err(pool_err)?;
let tx = conn.unchecked_transaction()?;
tx.execute(
"DELETE FROM oauth_access_tokens WHERE client_id = ?1",
params![client_id],
)?;
tx.execute(
"DELETE FROM oauth_refresh_tokens WHERE client_id = ?1",
params![client_id],
)?;
tx.execute(
"DELETE FROM oauth_auth_codes WHERE client_id = ?1",
params![client_id],
)?;
let rows = tx.execute(
"DELETE FROM oauth_clients WHERE client_id = ?1 AND provisioned = 1",
params![client_id],
)?;
tx.commit()?;
Ok(rows > 0)
}
pub fn insert_auth_code(&self, row: &AuthCodeRow) -> Result<()> {
let conn = self.pool.get().map_err(pool_err)?;
conn.execute(
"INSERT INTO oauth_auth_codes
(code, client_id, redirect_uri, expires_at, code_challenge, resource, scope)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
params![
row.code,
row.client_id,
row.redirect_uri,
row.expires_at,
row.code_challenge,
row.resource,
row.scope,
],
)?;
Ok(())
}
pub fn take_auth_code(&self, code: &str) -> Result<Option<AuthCodeRow>> {
let conn = self.pool.get().map_err(pool_err)?;
let tx = conn.unchecked_transaction()?;
let row = {
let mut stmt = tx.prepare(
"SELECT code, client_id, redirect_uri, expires_at, code_challenge, resource, scope
FROM oauth_auth_codes WHERE code = ?1",
)?;
let mut rows = stmt.query(params![code])?;
match rows.next()? {
None => None,
Some(r) => Some(AuthCodeRow {
code: r.get(0)?,
client_id: r.get(1)?,
redirect_uri: r.get(2)?,
expires_at: r.get(3)?,
code_challenge: r.get(4)?,
resource: r.get(5)?,
scope: r.get(6)?,
}),
}
};
if row.is_some() {
tx.execute(
"DELETE FROM oauth_auth_codes WHERE code = ?1",
params![code],
)?;
}
tx.commit()?;
Ok(row)
}
pub fn delete_expired_auth_codes(&self, now: &str) -> Result<usize> {
let conn = self.pool.get().map_err(pool_err)?;
let rows = conn.execute(
"DELETE FROM oauth_auth_codes WHERE expires_at < ?1",
params![now],
)?;
Ok(rows)
}
pub fn insert_access_token(&self, row: &AccessTokenRow) -> Result<()> {
let conn = self.pool.get().map_err(pool_err)?;
conn.execute(
"INSERT INTO oauth_access_tokens (token, client_id, scope, expires_at)
VALUES (?1, ?2, ?3, ?4)",
params![row.token, row.client_id, row.scope, row.expires_at],
)?;
Ok(())
}
pub fn get_access_token(&self, token: &str) -> Result<Option<AccessTokenRow>> {
let conn = self.pool.get().map_err(pool_err)?;
let mut stmt = conn.prepare(
"SELECT token, client_id, scope, expires_at
FROM oauth_access_tokens WHERE token = ?1",
)?;
let mut rows = stmt.query(params![token])?;
match rows.next()? {
None => Ok(None),
Some(row) => Ok(Some(AccessTokenRow {
token: row.get(0)?,
client_id: row.get(1)?,
scope: row.get(2)?,
expires_at: row.get(3)?,
})),
}
}
pub fn delete_expired_access_tokens(&self, now: &str) -> Result<usize> {
let conn = self.pool.get().map_err(pool_err)?;
let rows = conn.execute(
"DELETE FROM oauth_access_tokens WHERE expires_at < ?1",
params![now],
)?;
Ok(rows)
}
pub fn insert_refresh_token(&self, row: &RefreshTokenRow) -> Result<()> {
let conn = self.pool.get().map_err(pool_err)?;
conn.execute(
"INSERT INTO oauth_refresh_tokens (token, client_id, access_token, scope, expires_at)
VALUES (?1, ?2, ?3, ?4, ?5)",
params![
row.token,
row.client_id,
row.access_token,
row.scope,
row.expires_at,
],
)?;
Ok(())
}
pub fn take_refresh_token(&self, token: &str) -> Result<Option<RefreshTokenRow>> {
let conn = self.pool.get().map_err(pool_err)?;
let tx = conn.unchecked_transaction()?;
let row = {
let mut stmt = tx.prepare(
"SELECT token, client_id, access_token, scope, expires_at
FROM oauth_refresh_tokens WHERE token = ?1",
)?;
let mut rows = stmt.query(params![token])?;
match rows.next()? {
None => None,
Some(r) => Some(RefreshTokenRow {
token: r.get(0)?,
client_id: r.get(1)?,
access_token: r.get(2)?,
scope: r.get(3)?,
expires_at: r.get(4)?,
}),
}
};
if row.is_some() {
tx.execute(
"DELETE FROM oauth_refresh_tokens WHERE token = ?1",
params![token],
)?;
}
tx.commit()?;
Ok(row)
}
pub fn delete_expired_refresh_tokens(&self, now: &str) -> Result<usize> {
let conn = self.pool.get().map_err(pool_err)?;
let rows = conn.execute(
"DELETE FROM oauth_refresh_tokens WHERE expires_at < ?1",
params![now],
)?;
Ok(rows)
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct DocumentSummary {
pub slug: String,
pub title: String,
pub description: Option<String>,
pub created_at: String,
pub expires_at: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
fn open_test_db() -> Db {
Db::open(":memory:").expect("in-memory db should open")
}
fn make_doc(slug: &str, created_at: &str) -> DocumentRecord {
DocumentRecord {
id: slug.to_string(),
slug: slug.to_string(),
title: format!("Title for {slug}"),
raw_content: format!("# {slug}\n\nContent."),
theme: "clean".to_string(),
password: None,
description: Some(format!("Desc for {slug}")),
created_at: created_at.to_string(),
expires_at: None,
updated_at: created_at.to_string(),
}
}
#[test]
fn insert_and_get_document() {
let db = open_test_db();
let doc = make_doc("test-slug", "2024-01-01T00:00:00Z");
db.insert_document(&doc).expect("insert should succeed");
let fetched = db
.get_by_slug("test-slug")
.expect("query should succeed")
.expect("document should exist");
assert_eq!(fetched.id, doc.id);
assert_eq!(fetched.slug, doc.slug);
assert_eq!(fetched.title, doc.title);
assert_eq!(fetched.raw_content, doc.raw_content);
assert_eq!(fetched.theme, doc.theme);
assert_eq!(fetched.password, doc.password);
assert_eq!(fetched.description, doc.description);
assert_eq!(fetched.created_at, doc.created_at);
assert_eq!(fetched.expires_at, doc.expires_at);
assert_eq!(fetched.updated_at, doc.updated_at);
}
#[test]
fn list_documents_pagination() {
let db = open_test_db();
for i in 1..=5u32 {
let slug = format!("slug-{:02}", i);
let ts = format!("2024-01-{:02}T00:00:00Z", i);
let doc = make_doc(&slug, &ts);
db.insert_document(&doc).expect("insert");
}
let (page1, total) = db.list_documents(2, 0).expect("list page 1");
assert_eq!(total, 5, "total count should be 5");
assert_eq!(page1.len(), 2, "page 1 should have 2 docs");
assert_eq!(page1[0].slug, "slug-05", "newest first");
assert_eq!(page1[1].slug, "slug-04");
let (page2, _) = db.list_documents(2, 2).expect("list page 2");
assert_eq!(page2.len(), 2);
assert_eq!(page2[0].slug, "slug-03");
assert_eq!(page2[1].slug, "slug-02");
let (page3, _) = db.list_documents(2, 4).expect("list page 3");
assert_eq!(page3.len(), 1);
assert_eq!(page3[0].slug, "slug-01");
}
#[test]
fn reaper_auth_codes() {
let db = open_test_db();
let now = "2025-06-01T00:00:00Z";
db.insert_auth_code(&AuthCodeRow {
code: "expired-code".to_string(),
client_id: "c1".to_string(),
redirect_uri: "https://example.com/cb".to_string(),
expires_at: "2020-01-01T00:00:00Z".to_string(), code_challenge: "challenge".to_string(),
resource: None,
scope: None,
})
.expect("insert expired auth code");
db.insert_auth_code(&AuthCodeRow {
code: "future-code".to_string(),
client_id: "c1".to_string(),
redirect_uri: "https://example.com/cb".to_string(),
expires_at: "2099-01-01T00:00:00Z".to_string(), code_challenge: "challenge".to_string(),
resource: None,
scope: None,
})
.expect("insert future auth code");
let deleted = db
.delete_expired_auth_codes(now)
.expect("reaper should succeed");
assert_eq!(deleted, 1, "reaper should delete exactly 1 expired code");
let taken_expired = db
.take_auth_code("expired-code")
.expect("take should not error");
assert!(
taken_expired.is_none(),
"expired code must be gone after reaper"
);
let taken_future = db
.take_auth_code("future-code")
.expect("take should not error");
assert!(taken_future.is_some(), "future code must survive reaper");
}
#[test]
fn reaper_access_tokens() {
let db = open_test_db();
let now = "2025-06-01T00:00:00Z";
db.insert_access_token(&AccessTokenRow {
token: "expired-at".to_string(),
client_id: "c1".to_string(),
scope: None,
expires_at: "2020-01-01T00:00:00Z".to_string(),
})
.expect("insert expired access token");
db.insert_access_token(&AccessTokenRow {
token: "future-at".to_string(),
client_id: "c1".to_string(),
scope: None,
expires_at: "2099-01-01T00:00:00Z".to_string(),
})
.expect("insert future access token");
let deleted = db
.delete_expired_access_tokens(now)
.expect("reaper should succeed");
assert_eq!(
deleted, 1,
"reaper should delete exactly 1 expired access token"
);
let expired = db.get_access_token("expired-at").expect("lookup ok");
assert!(
expired.is_none(),
"expired access token must be gone after reaper"
);
let future = db.get_access_token("future-at").expect("lookup ok");
assert!(future.is_some(), "future access token must survive reaper");
}
#[test]
fn reaper_refresh_tokens() {
let db = open_test_db();
let now = "2025-06-01T00:00:00Z";
db.insert_refresh_token(&RefreshTokenRow {
token: "expired-rt".to_string(),
client_id: "c1".to_string(),
access_token: "at1".to_string(),
scope: None,
expires_at: "2020-01-01T00:00:00Z".to_string(),
})
.expect("insert expired refresh token");
db.insert_refresh_token(&RefreshTokenRow {
token: "future-rt".to_string(),
client_id: "c1".to_string(),
access_token: "at2".to_string(),
scope: None,
expires_at: "2099-01-01T00:00:00Z".to_string(),
})
.expect("insert future refresh token");
let deleted = db
.delete_expired_refresh_tokens(now)
.expect("reaper should succeed");
assert_eq!(
deleted, 1,
"reaper should delete exactly 1 expired refresh token"
);
let expired = db.take_refresh_token("expired-rt").expect("take ok");
assert!(
expired.is_none(),
"expired refresh token must be gone after reaper"
);
let future = db.take_refresh_token("future-rt").expect("take ok");
assert!(future.is_some(), "future refresh token must survive reaper");
}
#[test]
fn reaper_oauth_clients() {
let db = open_test_db();
let cutoff = "2025-06-01T00:00:00Z";
db.insert_oauth_client(&OAuthClientRow {
client_id: "old-client".to_string(),
client_name: "old".to_string(),
redirect_uris: "[]".to_string(),
grant_types: "[]".to_string(),
response_types: "[]".to_string(),
token_endpoint_auth_method: "none".to_string(),
created_at: "2020-01-01T00:00:00Z".to_string(), provisioned: false,
client_secret: None,
})
.expect("insert old client");
db.insert_oauth_client(&OAuthClientRow {
client_id: "new-client".to_string(),
client_name: "new".to_string(),
redirect_uris: "[]".to_string(),
grant_types: "[]".to_string(),
response_types: "[]".to_string(),
token_endpoint_auth_method: "none".to_string(),
created_at: "2099-01-01T00:00:00Z".to_string(), provisioned: false,
client_secret: None,
})
.expect("insert new client");
let deleted = db
.delete_expired_oauth_clients(cutoff)
.expect("reaper should succeed");
assert_eq!(deleted, 1, "reaper should delete exactly 1 expired client");
let old = db.get_oauth_client("old-client").expect("lookup ok");
assert!(old.is_none(), "old client must be gone after reaper");
let new = db.get_oauth_client("new-client").expect("lookup ok");
assert!(new.is_some(), "new client must survive reaper");
}
#[test]
fn reaper_spares_provisioned_clients() {
let db = open_test_db();
let cutoff = "2025-06-01T00:00:00Z";
db.insert_oauth_client(&OAuthClientRow {
client_id: "prov-client".to_string(),
client_name: "provisioned".to_string(),
redirect_uris: r#"["https://claude.ai/api/mcp/auth_callback"]"#.to_string(),
grant_types: r#"["authorization_code"]"#.to_string(),
response_types: r#"["code"]"#.to_string(),
token_endpoint_auth_method: "client_secret_post".to_string(),
created_at: "2020-01-01T00:00:00Z".to_string(), provisioned: true,
client_secret: Some("secret".to_string()),
})
.expect("insert provisioned client");
db.insert_oauth_client(&OAuthClientRow {
client_id: "dynamic-client".to_string(),
client_name: "dynamic".to_string(),
redirect_uris: "[]".to_string(),
grant_types: "[]".to_string(),
response_types: "[]".to_string(),
token_endpoint_auth_method: "none".to_string(),
created_at: "2020-01-01T00:00:00Z".to_string(), provisioned: false,
client_secret: None,
})
.expect("insert dynamic client");
let deleted = db
.delete_expired_oauth_clients(cutoff)
.expect("reaper should succeed");
assert_eq!(
deleted, 1,
"reaper should delete exactly 1 (the non-provisioned client)"
);
let prov = db.get_oauth_client("prov-client").expect("lookup ok");
assert!(
prov.is_some(),
"provisioned client must survive reaper regardless of age"
);
let dyn_client = db.get_oauth_client("dynamic-client").expect("lookup ok");
assert!(
dyn_client.is_none(),
"non-provisioned old client must be reaped"
);
}
#[test]
fn list_provisioned_clients_only_provisioned() {
let db = open_test_db();
db.insert_oauth_client(&OAuthClientRow {
client_id: "prov-1".to_string(),
client_name: "Provisioned One".to_string(),
redirect_uris: r#"["https://claude.ai/api/mcp/auth_callback"]"#.to_string(),
grant_types: r#"["authorization_code"]"#.to_string(),
response_types: r#"["code"]"#.to_string(),
token_endpoint_auth_method: "client_secret_post".to_string(),
created_at: "2025-01-01T00:00:00Z".to_string(),
provisioned: true,
client_secret: Some("secret-1".to_string()),
})
.expect("insert prov-1");
db.insert_oauth_client(&OAuthClientRow {
client_id: "dyn-1".to_string(),
client_name: "Dynamic One".to_string(),
redirect_uris: r#"["https://example.com/cb"]"#.to_string(),
grant_types: r#"["authorization_code"]"#.to_string(),
response_types: r#"["code"]"#.to_string(),
token_endpoint_auth_method: "none".to_string(),
created_at: "2025-01-02T00:00:00Z".to_string(),
provisioned: false,
client_secret: None,
})
.expect("insert dyn-1");
let provisioned = db.list_provisioned_clients().expect("list should succeed");
assert_eq!(provisioned.len(), 1);
assert_eq!(provisioned[0].client_id, "prov-1");
assert!(provisioned[0].provisioned);
}
#[test]
fn revoke_provisioned_client_cascades() {
let db = open_test_db();
db.insert_oauth_client(&OAuthClientRow {
client_id: "prov-revoke".to_string(),
client_name: "ToRevoke".to_string(),
redirect_uris: r#"["https://claude.ai/api/mcp/auth_callback"]"#.to_string(),
grant_types: r#"["authorization_code"]"#.to_string(),
response_types: r#"["code"]"#.to_string(),
token_endpoint_auth_method: "client_secret_post".to_string(),
created_at: "2025-01-01T00:00:00Z".to_string(),
provisioned: true,
client_secret: Some("supersecret".to_string()),
})
.expect("insert client");
db.insert_access_token(&AccessTokenRow {
token: "at-for-revoke".to_string(),
client_id: "prov-revoke".to_string(),
scope: Some("mcp:tools".to_string()),
expires_at: "2099-01-01T00:00:00Z".to_string(),
})
.expect("seed access token");
let revoked = db
.revoke_provisioned_client("prov-revoke")
.expect("revoke should succeed");
assert!(revoked, "revoke should return true");
let client = db.get_oauth_client("prov-revoke").expect("lookup ok");
assert!(client.is_none(), "client must be deleted after revoke");
let at = db.get_access_token("at-for-revoke").expect("lookup ok");
assert!(
at.is_none(),
"access token must be deleted after client revoke"
);
}
#[test]
fn revoke_provisioned_client_not_found() {
let db = open_test_db();
let result = db
.revoke_provisioned_client("does-not-exist")
.expect("should not error");
assert!(!result, "should return false for non-existent client");
}
#[test]
fn reaper_documents() {
let db = open_test_db();
let now = "2025-06-01T00:00:00Z";
let expired_doc = DocumentRecord {
id: "doc-expired".to_string(),
slug: "expired-doc".to_string(),
title: "Expired".to_string(),
raw_content: "# Expired".to_string(),
theme: "clean".to_string(),
password: None,
description: None,
created_at: "2025-01-01T00:00:00Z".to_string(),
expires_at: Some("2025-05-22T00:00:00Z".to_string()), updated_at: "2025-01-01T00:00:00Z".to_string(),
};
let future_doc = DocumentRecord {
id: "doc-future".to_string(),
slug: "future-doc".to_string(),
title: "Future".to_string(),
raw_content: "# Future".to_string(),
theme: "clean".to_string(),
password: None,
description: None,
created_at: "2025-01-01T00:00:00Z".to_string(),
expires_at: Some("2099-01-01T00:00:00Z".to_string()),
updated_at: "2025-01-01T00:00:00Z".to_string(),
};
db.insert_document(&expired_doc)
.expect("insert expired doc");
db.insert_document(&future_doc).expect("insert future doc");
let deleted = db
.delete_expired_older_than(now, 5)
.expect("reaper should succeed");
assert_eq!(
deleted, 1,
"reaper should delete exactly 1 expired document"
);
let expired = db.get_by_slug("expired-doc").expect("lookup ok");
assert!(expired.is_none(), "expired doc must be gone after reaper");
let future = db.get_by_slug("future-doc").expect("lookup ok");
assert!(future.is_some(), "future doc must survive reaper");
}
#[test]
fn token_crud() {
let db = open_test_db();
let token = TokenRecord {
id: "tok-id-1".to_string(),
name: "my-token".to_string(),
hash: "fakehash".to_string(),
created_at: "2024-01-01T00:00:00Z".to_string(),
last_used: None,
revoked: false,
prefix: Some("tok12345".to_string()),
};
db.insert_token(&token).expect("insert token");
let tokens = db.list_tokens().expect("list tokens");
assert_eq!(tokens.len(), 1);
assert_eq!(tokens[0].name, "my-token");
assert!(!tokens[0].revoked, "should not be revoked yet");
let revoked = db.revoke_token("my-token").expect("revoke");
assert!(revoked, "revoke should return true on first call");
let revoked_again = db.revoke_token("my-token").expect("revoke again");
assert!(
!revoked_again,
"revoking an already-revoked token returns false"
);
let tokens_after = db.list_tokens().expect("list after revoke");
assert_eq!(tokens_after.len(), 1);
assert!(tokens_after[0].revoked, "should be revoked now");
let found = db
.get_token_by_prefix("tok12345")
.expect("prefix lookup should not error");
assert!(
found.is_none(),
"revoked token should not be returned by prefix lookup"
);
}
}