use std::path::{Path, PathBuf};
use super::DbPool;
use crate::error::StorageError;
pub const DEFAULT_ACCOUNT_ID: &str = "00000000-0000-0000-0000-000000000000";
#[derive(Debug, Clone, sqlx::FromRow, serde::Serialize, serde::Deserialize)]
pub struct Account {
pub id: String,
pub label: String,
pub x_user_id: Option<String>,
pub x_username: Option<String>,
pub x_display_name: Option<String>,
pub x_avatar_url: Option<String>,
pub config_overrides: String,
pub token_path: Option<String>,
pub status: String,
pub created_at: String,
pub updated_at: String,
}
#[derive(Debug, Clone, sqlx::FromRow, serde::Serialize, serde::Deserialize)]
pub struct AccountRole {
pub account_id: String,
pub actor: String,
pub role: String,
pub created_at: String,
}
pub async fn ensure_default_account(pool: &DbPool) -> Result<(), StorageError> {
sqlx::query(
"INSERT OR IGNORE INTO accounts (id, label, status) \
VALUES (?, 'Default', 'active')",
)
.bind(DEFAULT_ACCOUNT_ID)
.execute(pool)
.await
.map_err(|e| StorageError::Query { source: e })?;
sqlx::query(
"INSERT OR IGNORE INTO account_roles (account_id, actor, role) \
VALUES (?, 'dashboard', 'admin')",
)
.bind(DEFAULT_ACCOUNT_ID)
.execute(pool)
.await
.map_err(|e| StorageError::Query { source: e })?;
sqlx::query(
"INSERT OR IGNORE INTO account_roles (account_id, actor, role) \
VALUES (?, 'mcp', 'admin')",
)
.bind(DEFAULT_ACCOUNT_ID)
.execute(pool)
.await
.map_err(|e| StorageError::Query { source: e })?;
Ok(())
}
pub async fn list_accounts(pool: &DbPool) -> Result<Vec<Account>, StorageError> {
sqlx::query_as::<_, Account>(
"SELECT * FROM accounts WHERE status = 'active' ORDER BY created_at",
)
.fetch_all(pool)
.await
.map_err(|e| StorageError::Query { source: e })
}
pub async fn get_account(pool: &DbPool, id: &str) -> Result<Option<Account>, StorageError> {
sqlx::query_as::<_, Account>("SELECT * FROM accounts WHERE id = ?")
.bind(id)
.fetch_optional(pool)
.await
.map_err(|e| StorageError::Query { source: e })
}
const NEW_ACCOUNT_OVERRIDES: &str = r#"{
"business": {
"product_name": "",
"product_keywords": [],
"product_description": "",
"product_url": null,
"target_audience": "",
"competitor_keywords": [],
"industry_topics": [],
"brand_voice": null,
"reply_style": null,
"content_style": null,
"persona_opinions": [],
"persona_experiences": [],
"content_pillars": []
},
"targets": []
}"#;
pub async fn create_account(pool: &DbPool, id: &str, label: &str) -> Result<String, StorageError> {
sqlx::query(
"INSERT INTO accounts (id, label, config_overrides, updated_at) \
VALUES (?, ?, ?, strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))",
)
.bind(id)
.bind(label)
.bind(NEW_ACCOUNT_OVERRIDES)
.execute(pool)
.await
.map_err(|e| StorageError::Query { source: e })?;
set_role(pool, id, "dashboard", "admin").await?;
Ok(id.to_string())
}
#[derive(Debug, Default)]
pub struct UpdateAccountParams<'a> {
pub label: Option<&'a str>,
pub x_user_id: Option<&'a str>,
pub x_username: Option<&'a str>,
pub x_display_name: Option<&'a str>,
pub x_avatar_url: Option<&'a str>,
pub config_overrides: Option<&'a str>,
pub token_path: Option<&'a str>,
pub status: Option<&'a str>,
}
pub async fn update_account(
pool: &DbPool,
id: &str,
params: UpdateAccountParams<'_>,
) -> Result<(), StorageError> {
let mut sets = vec!["updated_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now')".to_string()];
let mut binds: Vec<String> = Vec::new();
if let Some(v) = params.label {
sets.push(format!("label = ?{}", binds.len() + 1));
binds.push(v.to_string());
}
if let Some(v) = params.x_user_id {
sets.push(format!("x_user_id = ?{}", binds.len() + 1));
binds.push(v.to_string());
}
if let Some(v) = params.x_username {
sets.push(format!("x_username = ?{}", binds.len() + 1));
binds.push(v.to_string());
}
if let Some(v) = params.x_display_name {
sets.push(format!("x_display_name = ?{}", binds.len() + 1));
binds.push(v.to_string());
}
if let Some(v) = params.x_avatar_url {
sets.push(format!("x_avatar_url = ?{}", binds.len() + 1));
binds.push(v.to_string());
}
if let Some(v) = params.config_overrides {
sets.push(format!("config_overrides = ?{}", binds.len() + 1));
binds.push(v.to_string());
}
if let Some(v) = params.token_path {
sets.push(format!("token_path = ?{}", binds.len() + 1));
binds.push(v.to_string());
}
if let Some(v) = params.status {
sets.push(format!("status = ?{}", binds.len() + 1));
binds.push(v.to_string());
}
let id_param = binds.len() + 1;
let sql = format!(
"UPDATE accounts SET {} WHERE id = ?{}",
sets.join(", "),
id_param
);
let mut query = sqlx::query(&sql);
for b in &binds {
query = query.bind(b);
}
query = query.bind(id);
query
.execute(pool)
.await
.map_err(|e| StorageError::Query { source: e })?;
Ok(())
}
pub async fn delete_account(pool: &DbPool, id: &str) -> Result<(), StorageError> {
if id == DEFAULT_ACCOUNT_ID {
return Err(StorageError::Query {
source: sqlx::Error::Protocol("cannot delete the default account".into()),
});
}
sqlx::query(
"UPDATE accounts SET status = 'archived', \
updated_at = strftime('%Y-%m-%dT%H:%M:%SZ', 'now') WHERE id = ?",
)
.bind(id)
.execute(pool)
.await
.map_err(|e| StorageError::Query { source: e })?;
Ok(())
}
pub async fn account_exists(pool: &DbPool, id: &str) -> Result<bool, StorageError> {
let row: Option<(i64,)> =
sqlx::query_as("SELECT COUNT(*) FROM accounts WHERE id = ? AND status = 'active'")
.bind(id)
.fetch_optional(pool)
.await
.map_err(|e| StorageError::Query { source: e })?;
Ok(row.map(|(c,)| c > 0).unwrap_or(false))
}
pub async fn get_role(
pool: &DbPool,
account_id: &str,
actor: &str,
) -> Result<Option<String>, StorageError> {
if account_id == DEFAULT_ACCOUNT_ID {
return Ok(Some("admin".to_string()));
}
let row: Option<(String,)> =
sqlx::query_as("SELECT role FROM account_roles WHERE account_id = ? AND actor = ?")
.bind(account_id)
.bind(actor)
.fetch_optional(pool)
.await
.map_err(|e| StorageError::Query { source: e })?;
Ok(row.map(|(r,)| r))
}
pub async fn set_role(
pool: &DbPool,
account_id: &str,
actor: &str,
role: &str,
) -> Result<(), StorageError> {
sqlx::query(
"INSERT INTO account_roles (account_id, actor, role) VALUES (?, ?, ?) \
ON CONFLICT(account_id, actor) DO UPDATE SET role = excluded.role",
)
.bind(account_id)
.bind(actor)
.bind(role)
.execute(pool)
.await
.map_err(|e| StorageError::Query { source: e })?;
Ok(())
}
pub async fn remove_role(pool: &DbPool, account_id: &str, actor: &str) -> Result<(), StorageError> {
sqlx::query("DELETE FROM account_roles WHERE account_id = ? AND actor = ?")
.bind(account_id)
.bind(actor)
.execute(pool)
.await
.map_err(|e| StorageError::Query { source: e })?;
Ok(())
}
pub async fn list_roles(pool: &DbPool, account_id: &str) -> Result<Vec<AccountRole>, StorageError> {
sqlx::query_as::<_, AccountRole>(
"SELECT * FROM account_roles WHERE account_id = ? ORDER BY actor",
)
.bind(account_id)
.fetch_all(pool)
.await
.map_err(|e| StorageError::Query { source: e })
}
pub fn account_data_dir(data_dir: &Path, account_id: &str) -> PathBuf {
if account_id == DEFAULT_ACCOUNT_ID {
data_dir.to_path_buf()
} else {
data_dir.join("accounts").join(account_id)
}
}
pub fn account_scraper_session_path(data_dir: &Path, account_id: &str) -> PathBuf {
account_data_dir(data_dir, account_id).join("scraper_session.json")
}
pub fn account_token_path(data_dir: &Path, account_id: &str) -> PathBuf {
account_data_dir(data_dir, account_id).join("tokens.json")
}
pub fn get_active_account_id() -> String {
use crate::startup::data_dir;
read_active_account_id(&data_dir())
}
pub fn read_active_account_id(dir: &Path) -> String {
let sentinel = dir.join("active_account");
match std::fs::read_to_string(&sentinel) {
Ok(content) => content.trim().to_string(),
Err(_) => DEFAULT_ACCOUNT_ID.to_string(),
}
}
pub fn set_active_account_id(account_id: &str) -> Result<(), std::io::Error> {
use crate::startup::data_dir;
write_active_account_id(&data_dir(), account_id)
}
pub fn write_active_account_id(dir: &Path, account_id: &str) -> Result<(), std::io::Error> {
std::fs::create_dir_all(dir)?;
let sentinel = dir.join("active_account");
std::fs::write(&sentinel, account_id)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::init_test_db;
#[tokio::test]
async fn default_account_seeded() {
let pool = init_test_db().await.expect("init db");
let account = get_account(&pool, DEFAULT_ACCOUNT_ID)
.await
.expect("get")
.expect("should exist");
assert_eq!(account.label, "Default");
assert_eq!(account.status, "active");
}
#[tokio::test]
async fn create_and_list_accounts() {
let pool = init_test_db().await.expect("init db");
let id = uuid::Uuid::new_v4().to_string();
create_account(&pool, &id, "Test Account")
.await
.expect("create");
let accounts = list_accounts(&pool).await.expect("list");
assert!(accounts.iter().any(|a| a.id == id));
}
#[tokio::test]
async fn update_account_fields() {
let pool = init_test_db().await.expect("init db");
let id = uuid::Uuid::new_v4().to_string();
create_account(&pool, &id, "Original")
.await
.expect("create");
update_account(
&pool,
&id,
UpdateAccountParams {
label: Some("Updated"),
x_user_id: Some("12345"),
x_username: Some("testuser"),
x_display_name: Some("Test User"),
x_avatar_url: Some("https://pbs.twimg.com/profile_images/test.jpg"),
..Default::default()
},
)
.await
.expect("update");
let account = get_account(&pool, &id).await.expect("get").expect("found");
assert_eq!(account.label, "Updated");
assert_eq!(account.x_user_id.as_deref(), Some("12345"));
assert_eq!(account.x_username.as_deref(), Some("testuser"));
assert_eq!(account.x_display_name.as_deref(), Some("Test User"));
assert_eq!(
account.x_avatar_url.as_deref(),
Some("https://pbs.twimg.com/profile_images/test.jpg")
);
}
#[tokio::test]
async fn delete_archives_account() {
let pool = init_test_db().await.expect("init db");
let id = uuid::Uuid::new_v4().to_string();
create_account(&pool, &id, "ToDelete")
.await
.expect("create");
delete_account(&pool, &id).await.expect("delete");
let accounts = list_accounts(&pool).await.expect("list");
assert!(!accounts.iter().any(|a| a.id == id));
let account = get_account(&pool, &id).await.expect("get").expect("found");
assert_eq!(account.status, "archived");
}
#[tokio::test]
async fn cannot_delete_default_account() {
let pool = init_test_db().await.expect("init db");
let result = delete_account(&pool, DEFAULT_ACCOUNT_ID).await;
assert!(result.is_err());
}
#[tokio::test]
async fn default_account_grants_admin_to_all() {
let pool = init_test_db().await.expect("init db");
let role = get_role(&pool, DEFAULT_ACCOUNT_ID, "anyone")
.await
.expect("get role");
assert_eq!(role.as_deref(), Some("admin"));
}
#[tokio::test]
async fn role_crud() {
let pool = init_test_db().await.expect("init db");
let id = uuid::Uuid::new_v4().to_string();
create_account(&pool, &id, "RoleTest")
.await
.expect("create");
let role = get_role(&pool, &id, "dashboard").await.expect("get");
assert_eq!(role.as_deref(), Some("admin"));
set_role(&pool, &id, "mcp", "viewer").await.expect("set");
let role = get_role(&pool, &id, "mcp").await.expect("get");
assert_eq!(role.as_deref(), Some("viewer"));
set_role(&pool, &id, "mcp", "approver").await.expect("set");
let role = get_role(&pool, &id, "mcp").await.expect("get");
assert_eq!(role.as_deref(), Some("approver"));
let roles = list_roles(&pool, &id).await.expect("list");
assert_eq!(roles.len(), 2);
remove_role(&pool, &id, "mcp").await.expect("remove");
let role = get_role(&pool, &id, "mcp").await.expect("get");
assert!(role.is_none());
}
#[tokio::test]
async fn account_exists_check() {
let pool = init_test_db().await.expect("init db");
assert!(account_exists(&pool, DEFAULT_ACCOUNT_ID)
.await
.expect("check"));
assert!(!account_exists(&pool, "nonexistent").await.expect("check"));
}
#[test]
fn account_data_dir_default() {
let base = std::env::temp_dir().join(".tuitbot");
let result = account_data_dir(&base, DEFAULT_ACCOUNT_ID);
assert_eq!(result, base);
}
#[test]
fn account_data_dir_other() {
let base = std::env::temp_dir().join(".tuitbot");
let result = account_data_dir(&base, "abc-123");
assert_eq!(result, base.join("accounts").join("abc-123"));
}
#[test]
fn scraper_session_path_default() {
let base = std::env::temp_dir().join(".tuitbot");
let result = account_scraper_session_path(&base, DEFAULT_ACCOUNT_ID);
assert_eq!(result, base.join("scraper_session.json"));
}
#[test]
fn scraper_session_path_other() {
let base = std::env::temp_dir().join(".tuitbot");
let result = account_scraper_session_path(&base, "abc-123");
assert_eq!(
result,
base.join("accounts")
.join("abc-123")
.join("scraper_session.json")
);
}
#[test]
fn token_path_default() {
let base = std::env::temp_dir().join(".tuitbot");
let result = account_token_path(&base, DEFAULT_ACCOUNT_ID);
assert_eq!(result, base.join("tokens.json"));
}
#[test]
fn token_path_other() {
let base = std::env::temp_dir().join(".tuitbot");
let result = account_token_path(&base, "abc-123");
assert_eq!(
result,
base.join("accounts").join("abc-123").join("tokens.json")
);
}
#[tokio::test]
async fn ensure_default_account_idempotent() {
let pool = init_test_db().await.expect("init db");
ensure_default_account(&pool).await.expect("first call");
ensure_default_account(&pool).await.expect("second call");
let account = get_account(&pool, DEFAULT_ACCOUNT_ID)
.await
.expect("get")
.expect("exists");
assert_eq!(account.label, "Default");
}
#[tokio::test]
async fn update_account_config_overrides() {
let pool = init_test_db().await.expect("init db");
let id = uuid::Uuid::new_v4().to_string();
create_account(&pool, &id, "ConfigTest")
.await
.expect("create");
let overrides_json = r#"{"business":{"product_name":"TestProd"}}"#;
update_account(
&pool,
&id,
UpdateAccountParams {
config_overrides: Some(overrides_json),
..Default::default()
},
)
.await
.expect("update");
let account = get_account(&pool, &id).await.expect("get").expect("found");
assert_eq!(account.config_overrides, overrides_json);
}
#[tokio::test]
async fn update_account_token_path_and_status() {
let pool = init_test_db().await.expect("init db");
let id = uuid::Uuid::new_v4().to_string();
create_account(&pool, &id, "TokenTest")
.await
.expect("create");
update_account(
&pool,
&id,
UpdateAccountParams {
token_path: Some("/data/tokens.json"),
status: Some("paused"),
..Default::default()
},
)
.await
.expect("update");
let account = get_account(&pool, &id).await.expect("get").expect("found");
assert_eq!(account.token_path.as_deref(), Some("/data/tokens.json"));
assert_eq!(account.status, "paused");
}
#[tokio::test]
async fn account_exists_after_archive() {
let pool = init_test_db().await.expect("init db");
let id = uuid::Uuid::new_v4().to_string();
create_account(&pool, &id, "ArchiveCheck")
.await
.expect("create");
assert!(account_exists(&pool, &id).await.expect("check"));
delete_account(&pool, &id).await.expect("archive");
assert!(!account_exists(&pool, &id)
.await
.expect("check after archive"));
}
#[tokio::test]
async fn role_for_non_default_account_unknown_actor() {
let pool = init_test_db().await.expect("init db");
let id = uuid::Uuid::new_v4().to_string();
create_account(&pool, &id, "RoleCheck")
.await
.expect("create");
let role = get_role(&pool, &id, "unknown_actor")
.await
.expect("get role");
assert!(role.is_none());
}
#[tokio::test]
async fn create_account_auto_grants_dashboard_admin() {
let pool = init_test_db().await.expect("init db");
let id = uuid::Uuid::new_v4().to_string();
create_account(&pool, &id, "AutoGrant")
.await
.expect("create");
let roles = list_roles(&pool, &id).await.expect("list");
assert!(roles
.iter()
.any(|r| r.actor == "dashboard" && r.role == "admin"));
}
#[tokio::test]
async fn new_account_has_blank_overrides() {
let pool = init_test_db().await.expect("init db");
let id = uuid::Uuid::new_v4().to_string();
create_account(&pool, &id, "OverrideCheck")
.await
.expect("create");
let account = get_account(&pool, &id).await.expect("get").expect("found");
let overrides: serde_json::Value =
serde_json::from_str(&account.config_overrides).expect("parse");
assert_eq!(overrides["business"]["product_name"].as_str(), Some(""));
}
#[tokio::test]
async fn update_account_only_provided_fields() {
let pool = init_test_db().await.expect("init db");
let id = uuid::Uuid::new_v4().to_string();
create_account(&pool, &id, "PartialUpdate")
.await
.expect("create");
update_account(
&pool,
&id,
UpdateAccountParams {
label: Some("NewLabel"),
..Default::default()
},
)
.await
.expect("update");
let account = get_account(&pool, &id).await.expect("get").expect("found");
assert_eq!(account.label, "NewLabel");
assert!(account.x_user_id.is_none()); assert!(account.x_username.is_none()); }
#[test]
fn read_active_account_missing_sentinel_returns_default() {
let tmpdir = std::env::temp_dir().join(format!("tuitbot_test_{}", uuid::Uuid::new_v4()));
std::fs::create_dir_all(&tmpdir).expect("create tmpdir");
let active = read_active_account_id(&tmpdir);
assert_eq!(active, DEFAULT_ACCOUNT_ID);
let _ = std::fs::remove_dir_all(&tmpdir);
}
#[test]
fn write_and_read_active_account_roundtrip() {
let tmpdir = std::env::temp_dir().join(format!("tuitbot_test_{}", uuid::Uuid::new_v4()));
let test_id = "abc-123-def";
write_active_account_id(&tmpdir, test_id).expect("write");
let active = read_active_account_id(&tmpdir);
assert_eq!(active, test_id);
let _ = std::fs::remove_dir_all(&tmpdir);
}
#[test]
fn read_active_account_trims_whitespace() {
let tmpdir = std::env::temp_dir().join(format!("tuitbot_test_{}", uuid::Uuid::new_v4()));
std::fs::create_dir_all(&tmpdir).expect("create tmpdir");
let sentinel = tmpdir.join("active_account");
std::fs::write(&sentinel, " abc-123 \n").expect("write");
let active = read_active_account_id(&tmpdir);
assert_eq!(active, "abc-123");
let _ = std::fs::remove_dir_all(&tmpdir);
}
#[test]
fn read_active_account_nonexistent_dir_returns_default() {
let tmpdir = std::env::temp_dir().join(format!("tuitbot_noexist_{}", uuid::Uuid::new_v4()));
let active = read_active_account_id(&tmpdir);
assert_eq!(active, DEFAULT_ACCOUNT_ID);
}
}