use std::path::Path;
use anyhow::Context as _;
use surrealdb::{
Surreal,
engine::local::{Db, Mem, SurrealKv},
types::{SurrealValue, Value},
};
use crate::base::{Res, Visibility, Void};
const NAMESPACE: &str = "conclave";
const DATABASE: &str = "conclave";
const SCHEMA: &str = "\
DEFINE INDEX IF NOT EXISTS user_username ON user FIELDS username UNIQUE;
DEFINE INDEX IF NOT EXISTS machine_pubkey ON machine FIELDS pubkey UNIQUE;
DEFINE INDEX IF NOT EXISTS machine_user_name ON machine FIELDS user, name UNIQUE;
DEFINE INDEX IF NOT EXISTS channel_name ON channel FIELDS name UNIQUE;
DEFINE INDEX IF NOT EXISTS invite_token ON invite FIELDS token UNIQUE;
";
#[derive(Debug, Clone, PartialEq, Eq, SurrealValue)]
pub struct UserRecord {
pub username: String,
pub created_at: String,
}
#[derive(Debug, Clone, PartialEq, Eq, SurrealValue)]
pub struct MachineRecord {
pub user: String,
pub name: String,
pub pubkey: String,
pub added_at: String,
}
#[derive(Debug, Clone, PartialEq, Eq, SurrealValue)]
pub struct ChannelRecord {
pub name: String,
pub visibility: String,
pub acl: Vec<String>,
pub created_by: String,
pub created_at: String,
}
#[derive(Debug, Clone, PartialEq, Eq, SurrealValue)]
pub struct InviteRecord {
pub channel: String,
pub token: String,
pub uses_remaining: Option<i64>,
pub expires_at: Option<String>,
pub created_by: String,
}
#[derive(SurrealValue)]
struct ByUsername {
username: String,
}
#[derive(SurrealValue)]
struct ByPubkey {
pubkey: String,
}
#[derive(SurrealValue)]
struct ByUser {
user: String,
}
#[derive(SurrealValue)]
struct ByName {
name: String,
}
#[derive(SurrealValue)]
struct ByToken {
tok: String,
}
#[derive(SurrealValue)]
struct ByUserAndName {
user: String,
name: String,
}
#[derive(SurrealValue)]
struct SetAcl {
name: String,
acl: Vec<String>,
}
#[derive(SurrealValue)]
struct SetVisibility {
name: String,
visibility: String,
}
#[derive(SurrealValue)]
struct Rename {
old: String,
new: String,
}
#[derive(SurrealValue)]
struct SetUses {
tok: String,
uses: i64,
}
pub struct Store {
db: Surreal<Db>,
}
impl Store {
pub async fn open(path: &Path) -> Res<Self> {
let db = Surreal::new::<SurrealKv>(path.to_string_lossy().as_ref()).await.context("failed to open the embedded store")?;
Self::init(db).await
}
pub async fn open_in_memory() -> Res<Self> {
let db = Surreal::new::<Mem>(()).await.context("failed to open the in-memory store")?;
Self::init(db).await
}
async fn init(db: Surreal<Db>) -> Res<Self> {
db.use_ns(NAMESPACE).use_db(DATABASE).await.context("failed to select namespace/database")?;
db.query(SCHEMA).await.context("failed to apply schema")?.check().context("schema application reported an error")?;
Ok(Self { db })
}
async fn insert<T: SurrealValue>(&self, table: &str, record: T) -> Void {
let _created: Option<Value> = self.db.create(table.to_owned()).content(record).await.with_context(|| format!("failed to insert into `{table}`"))?;
Ok(())
}
pub async fn create_user(&self, username: &str) -> Res<UserRecord> {
let record = UserRecord {
username: username.to_owned(),
created_at: now_rfc3339(),
};
self.insert("user", record.clone()).await?;
Ok(record)
}
pub async fn get_user(&self, username: &str) -> Res<Option<UserRecord>> {
let mut response = self
.db
.query("SELECT * OMIT id FROM user WHERE username = $username")
.bind(ByUsername { username: username.to_owned() })
.await
.context("failed to query user")?;
let rows: Vec<UserRecord> = response.take(0).context("failed to decode user rows")?;
Ok(rows.into_iter().next())
}
pub async fn create_machine(&self, user: &str, name: &str, pubkey_base64: &str) -> Res<MachineRecord> {
let record = MachineRecord {
user: user.to_owned(),
name: name.to_owned(),
pubkey: pubkey_base64.to_owned(),
added_at: now_rfc3339(),
};
self.insert("machine", record.clone()).await?;
Ok(record)
}
pub async fn get_machine_by_pubkey(&self, pubkey_base64: &str) -> Res<Option<MachineRecord>> {
let mut response = self
.db
.query("SELECT * OMIT id FROM machine WHERE pubkey = $pubkey")
.bind(ByPubkey { pubkey: pubkey_base64.to_owned() })
.await
.context("failed to query machine")?;
let rows: Vec<MachineRecord> = response.take(0).context("failed to decode machine rows")?;
Ok(rows.into_iter().next())
}
pub async fn list_machines(&self, user: &str) -> Res<Vec<MachineRecord>> {
let mut response = self
.db
.query("SELECT * OMIT id FROM machine WHERE user = $user")
.bind(ByUser { user: user.to_owned() })
.await
.context("failed to list machines")?;
response.take(0).context("failed to decode machine rows")
}
pub async fn delete_machine(&self, user: &str, name: &str) -> Void {
self.db
.query("DELETE machine WHERE user = $user AND name = $name")
.bind(ByUserAndName {
user: user.to_owned(),
name: name.to_owned(),
})
.await
.context("failed to delete machine")?
.check()
.context("machine delete reported an error")?;
Ok(())
}
pub async fn create_channel(&self, name: &str, visibility: Visibility, created_by: &str) -> Res<ChannelRecord> {
let record = ChannelRecord {
name: name.to_owned(),
visibility: visibility.as_str().to_owned(),
acl: vec![created_by.to_owned()],
created_by: created_by.to_owned(),
created_at: now_rfc3339(),
};
self.insert("channel", record.clone()).await?;
Ok(record)
}
pub async fn get_channel(&self, name: &str) -> Res<Option<ChannelRecord>> {
let mut response = self
.db
.query("SELECT * OMIT id FROM channel WHERE name = $name")
.bind(ByName { name: name.to_owned() })
.await
.context("failed to query channel")?;
let rows: Vec<ChannelRecord> = response.take(0).context("failed to decode channel rows")?;
Ok(rows.into_iter().next())
}
pub async fn create_invite(&self, channel: &str, token: &str, uses_remaining: Option<i64>, expires_at: Option<String>, created_by: &str) -> Res<InviteRecord> {
let record = InviteRecord {
channel: channel.to_owned(),
token: token.to_owned(),
uses_remaining,
expires_at,
created_by: created_by.to_owned(),
};
self.insert("invite", record.clone()).await?;
Ok(record)
}
pub async fn get_invite(&self, token: &str) -> Res<Option<InviteRecord>> {
let mut response = self
.db
.query("SELECT * OMIT id FROM invite WHERE token = $tok")
.bind(ByToken { tok: token.to_owned() })
.await
.context("failed to query invite")?;
let rows: Vec<InviteRecord> = response.take(0).context("failed to decode invite rows")?;
Ok(rows.into_iter().next())
}
pub async fn list_channels(&self) -> Res<Vec<ChannelRecord>> {
let mut response = self.db.query("SELECT * OMIT id FROM channel").await.context("failed to list channels")?;
response.take(0).context("failed to decode channel rows")
}
pub async fn set_channel_acl(&self, name: &str, acl: &[String]) -> Void {
self.db
.query("UPDATE channel SET acl = $acl WHERE name = $name")
.bind(SetAcl { name: name.to_owned(), acl: acl.to_vec() })
.await
.context("failed to update channel acl")?
.check()
.context("channel acl update reported an error")?;
Ok(())
}
pub async fn set_channel_visibility(&self, name: &str, visibility: Visibility) -> Void {
self.db
.query("UPDATE channel SET visibility = $visibility WHERE name = $name")
.bind(SetVisibility {
name: name.to_owned(),
visibility: visibility.as_str().to_owned(),
})
.await
.context("failed to update channel visibility")?
.check()
.context("channel visibility update reported an error")?;
Ok(())
}
pub async fn rename_channel(&self, old: &str, new: &str) -> Void {
self.db
.query("UPDATE channel SET name = $new WHERE name = $old")
.bind(Rename { old: old.to_owned(), new: new.to_owned() })
.await
.context("failed to rename channel")?
.check()
.context("channel rename reported an error")?;
Ok(())
}
pub async fn delete_channel(&self, name: &str) -> Void {
self.db
.query("DELETE channel WHERE name = $name")
.bind(ByName { name: name.to_owned() })
.await
.context("failed to delete channel")?
.check()
.context("channel delete reported an error")?;
Ok(())
}
pub async fn set_invite_uses(&self, token: &str, uses_remaining: i64) -> Void {
self.db
.query("UPDATE invite SET uses_remaining = $uses WHERE token = $tok")
.bind(SetUses {
tok: token.to_owned(),
uses: uses_remaining,
})
.await
.context("failed to update invite uses")?
.check()
.context("invite uses update reported an error")?;
Ok(())
}
pub async fn delete_invite(&self, token: &str) -> Void {
self.db
.query("DELETE invite WHERE token = $tok")
.bind(ByToken { tok: token.to_owned() })
.await
.context("failed to delete invite")?
.check()
.context("invite delete reported an error")?;
Ok(())
}
pub async fn list_users(&self) -> Res<Vec<UserRecord>> {
let mut response = self.db.query("SELECT * OMIT id FROM user").await.context("failed to list users")?;
response.take(0).context("failed to decode user rows")
}
pub async fn delete_user(&self, username: &str) -> Void {
self.db
.query("DELETE user WHERE username = $username")
.bind(ByUsername { username: username.to_owned() })
.await
.context("failed to delete user")?
.check()
.context("user delete reported an error")?;
Ok(())
}
}
fn now_rfc3339() -> String {
chrono::Utc::now().to_rfc3339()
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
use pretty_assertions::assert_eq;
async fn store() -> Store {
Store::open_in_memory().await.unwrap()
}
#[tokio::test]
async fn user_create_and_fetch_round_trip() {
let store = store().await;
let created = store.create_user("aaron").await.unwrap();
assert_eq!(store.get_user("aaron").await.unwrap(), Some(created));
assert_eq!(store.get_user("nobody").await.unwrap(), None);
}
#[tokio::test]
async fn duplicate_username_is_rejected() {
let store = store().await;
store.create_user("aaron").await.unwrap();
assert!(store.create_user("aaron").await.is_err(), "the unique-username constraint must reject a duplicate");
}
#[tokio::test]
async fn machine_pubkey_is_globally_unique() {
let store = store().await;
store.create_machine("aaron", "workstation", "PUBKEY-A").await.unwrap();
assert!(store.create_machine("david", "desktop", "PUBKEY-A").await.is_err());
}
#[tokio::test]
async fn machine_name_is_unique_within_a_user_but_not_across_users() {
let store = store().await;
store.create_machine("aaron", "workstation", "PUBKEY-A").await.unwrap();
assert!(store.create_machine("aaron", "workstation", "PUBKEY-B").await.is_err());
store.create_machine("david", "workstation", "PUBKEY-C").await.unwrap();
}
#[tokio::test]
async fn machines_list_and_delete_for_a_user() {
let store = store().await;
store.create_machine("aaron", "workstation", "PUBKEY-A").await.unwrap();
store.create_machine("aaron", "sno-box", "PUBKEY-B").await.unwrap();
assert_eq!(store.list_machines("aaron").await.unwrap().len(), 2);
store.delete_machine("aaron", "sno-box").await.unwrap();
let remaining = store.list_machines("aaron").await.unwrap();
assert_eq!(remaining.len(), 1);
assert_eq!(remaining[0].name, "workstation");
}
#[tokio::test]
async fn channel_create_fetch_and_unique_name() {
let store = store().await;
let created = store.create_channel("ops", Visibility::Private, "aaron").await.unwrap();
assert_eq!(created.visibility, "private");
assert_eq!(store.get_channel("ops").await.unwrap(), Some(created));
assert!(store.create_channel("ops", Visibility::Public, "david").await.is_err());
}
#[tokio::test]
async fn invite_create_fetch_and_unique_token() {
let store = store().await;
let created = store.create_invite("ops", "tok-123", Some(5), None, "aaron").await.unwrap();
assert_eq!(store.get_invite("tok-123").await.unwrap(), Some(created));
assert!(store.create_invite("ops", "tok-123", None, None, "aaron").await.is_err());
}
#[tokio::test]
async fn channel_acl_can_be_replaced() {
let store = store().await;
store.create_channel("ops", Visibility::Private, "aaron").await.unwrap();
store.set_channel_acl("ops", &["aaron".to_owned(), "david".to_owned()]).await.unwrap();
assert_eq!(store.get_channel("ops").await.unwrap().unwrap().acl, vec!["aaron".to_owned(), "david".to_owned()]);
}
#[tokio::test]
async fn channel_visibility_can_be_changed() {
let store = store().await;
store.create_channel("ops", Visibility::Private, "aaron").await.unwrap();
store.set_channel_visibility("ops", Visibility::Public).await.unwrap();
assert_eq!(store.get_channel("ops").await.unwrap().unwrap().visibility, "public");
}
#[tokio::test]
async fn channel_rename_moves_the_record_and_respects_uniqueness() {
let store = store().await;
store.create_channel("ops", Visibility::Private, "aaron").await.unwrap();
store.create_channel("taken", Visibility::Public, "aaron").await.unwrap();
store.rename_channel("ops", "operations").await.unwrap();
assert!(store.get_channel("ops").await.unwrap().is_none());
assert!(store.get_channel("operations").await.unwrap().is_some());
assert!(store.rename_channel("operations", "taken").await.is_err());
}
#[tokio::test]
async fn channel_can_be_deleted_and_listed() {
let store = store().await;
store.create_channel("ops", Visibility::Private, "aaron").await.unwrap();
store.create_channel("lobby", Visibility::Public, "aaron").await.unwrap();
assert_eq!(store.list_channels().await.unwrap().len(), 2);
store.delete_channel("ops").await.unwrap();
let remaining = store.list_channels().await.unwrap();
assert_eq!(remaining.len(), 1);
assert_eq!(remaining[0].name, "lobby");
}
#[tokio::test]
async fn invite_uses_can_be_decremented_and_revoked() {
let store = store().await;
store.create_invite("ops", "tok-123", Some(5), None, "aaron").await.unwrap();
store.set_invite_uses("tok-123", 4).await.unwrap();
assert_eq!(store.get_invite("tok-123").await.unwrap().unwrap().uses_remaining, Some(4));
store.delete_invite("tok-123").await.unwrap();
assert!(store.get_invite("tok-123").await.unwrap().is_none());
}
#[tokio::test]
async fn users_can_be_listed_and_deleted() {
let store = store().await;
store.create_user("aaron").await.unwrap();
store.create_user("david").await.unwrap();
assert_eq!(store.list_users().await.unwrap().len(), 2);
store.delete_user("david").await.unwrap();
let remaining = store.list_users().await.unwrap();
assert_eq!(remaining.len(), 1);
assert_eq!(remaining[0].username, "aaron");
}
}