use std::{fmt, str::FromStr};
use sqlx::SqlitePool;
use super::Error;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RecordType {
A,
Aaaa,
}
impl RecordType {
pub fn as_str(&self) -> &'static str {
match self {
Self::A => "A",
Self::Aaaa => "AAAA",
}
}
}
impl fmt::Display for RecordType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl FromStr for RecordType {
type Err = Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s {
"A" => Ok(Self::A),
"AAAA" => Ok(Self::Aaaa),
other => Err(Error::Decode(format!(
"unknown record_type value: {other:?}"
))),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct LocalRecord {
pub id: i64,
pub name: String,
pub record_type: RecordType,
pub value: String,
pub ttl: u32,
}
#[derive(Debug, Clone)]
pub struct NewLocalRecord {
pub name: String,
pub record_type: RecordType,
pub value: String,
pub ttl: u32,
}
struct LocalRecordRow {
id: i64,
name: String,
record_type: String,
value: String,
ttl: i64,
}
impl TryFrom<LocalRecordRow> for LocalRecord {
type Error = Error;
fn try_from(row: LocalRecordRow) -> Result<Self> {
let record_type: RecordType = row.record_type.parse()?;
let ttl = u32::try_from(row.ttl).map_err(|_| {
Error::Decode(format!(
"ttl value {} is out of u32 range for record {:?}",
row.ttl, row.name
))
})?;
Ok(LocalRecord {
id: row.id,
name: row.name,
record_type,
value: row.value,
ttl,
})
}
}
fn normalize_local_name(name: &str) -> String {
let s = name.strip_suffix('.').unwrap_or(name);
let lowered = s.to_ascii_lowercase();
format!("{lowered}.")
}
#[allow(async_fn_in_trait)]
pub trait LocalRecordRepository {
async fn add(&self, record: NewLocalRecord) -> Result<LocalRecord>;
async fn remove(&self, id: i64) -> Result<()>;
async fn list(&self) -> Result<Vec<LocalRecord>>;
async fn load_all(&self) -> Result<Vec<LocalRecord>>;
}
pub struct SqliteLocalRecordRepo {
pool: SqlitePool,
}
impl SqliteLocalRecordRepo {
pub fn new(pool: SqlitePool) -> Self {
Self { pool }
}
}
impl LocalRecordRepository for SqliteLocalRecordRepo {
async fn add(&self, record: NewLocalRecord) -> Result<LocalRecord> {
let name = normalize_local_name(&record.name);
let record_type = record.record_type.as_str();
let ttl = record.ttl as i64;
let id = sqlx::query!(
r#"INSERT INTO local_records (name, record_type, value, ttl)
VALUES (?, ?, ?, ?)
RETURNING id"#,
name,
record_type,
record.value,
ttl,
)
.fetch_one(&self.pool)
.await?
.id;
Ok(LocalRecord {
id,
name,
record_type: record.record_type,
value: record.value,
ttl: record.ttl,
})
}
async fn remove(&self, id: i64) -> Result<()> {
sqlx::query!("DELETE FROM local_records WHERE id = ?", id)
.execute(&self.pool)
.await?;
Ok(())
}
async fn list(&self) -> Result<Vec<LocalRecord>> {
let rows = sqlx::query_as!(
LocalRecordRow,
r#"SELECT
id AS "id!",
name,
record_type,
value,
ttl AS "ttl!"
FROM local_records
ORDER BY name, record_type"#
)
.fetch_all(&self.pool)
.await?;
rows.into_iter().map(LocalRecord::try_from).collect()
}
async fn load_all(&self) -> Result<Vec<LocalRecord>> {
self.list().await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::Db;
use tempfile::TempDir;
async fn open_repo() -> (TempDir, SqliteLocalRecordRepo) {
let dir = TempDir::new().expect("temp dir");
let path = dir.path().join("test.db");
let db = Db::connect(&path).await.expect("connect");
let repo = SqliteLocalRecordRepo::new(db.pool().clone());
(dir, repo)
}
fn new_record(name: &str, record_type: RecordType, value: &str, ttl: u32) -> NewLocalRecord {
NewLocalRecord {
name: name.to_owned(),
record_type,
value: value.to_owned(),
ttl,
}
}
#[test]
fn record_type_display() {
assert_eq!(RecordType::A.to_string(), "A");
assert_eq!(RecordType::Aaaa.to_string(), "AAAA");
}
#[test]
fn record_type_from_str_valid() {
assert_eq!("A".parse::<RecordType>().unwrap(), RecordType::A);
assert_eq!("AAAA".parse::<RecordType>().unwrap(), RecordType::Aaaa);
}
#[test]
fn record_type_from_str_invalid() {
let err = "CNAME".parse::<RecordType>();
assert!(err.is_err(), "invalid record type must fail");
let msg = err.unwrap_err().to_string();
assert!(
msg.contains("CNAME"),
"error must mention the bad value: {msg}"
);
}
#[test]
fn record_type_from_str_case_sensitive() {
assert!("a".parse::<RecordType>().is_err());
assert!("aaaa".parse::<RecordType>().is_err());
assert!("Aaaa".parse::<RecordType>().is_err());
}
#[test]
fn normalize_plain_domain() {
assert_eq!(normalize_local_name("Router.Home.LAN"), "router.home.lan.");
}
#[test]
fn normalize_already_lowercase_with_dot() {
assert_eq!(normalize_local_name("router.home.lan."), "router.home.lan.");
}
#[test]
fn normalize_wildcard_name() {
assert_eq!(normalize_local_name("*.Home.LAN"), "*.home.lan.");
}
#[test]
fn normalize_wildcard_already_normalized() {
assert_eq!(normalize_local_name("*.home.lan."), "*.home.lan.");
}
#[tokio::test]
async fn add_then_list_round_trips() {
let (_dir, repo) = open_repo().await;
let inserted = repo
.add(new_record(
"router.home.lan",
RecordType::A,
"192.168.1.1",
300,
))
.await
.expect("add A record");
assert!(inserted.id > 0);
assert_eq!(inserted.name, "router.home.lan.");
assert_eq!(inserted.record_type, RecordType::A);
assert_eq!(inserted.value, "192.168.1.1");
assert_eq!(inserted.ttl, 300);
let records = repo.list().await.expect("list");
assert_eq!(records.len(), 1);
assert_eq!(records[0], inserted);
}
#[tokio::test]
async fn name_is_normalized_on_insert() {
let (_dir, repo) = open_repo().await;
let inserted = repo
.add(new_record(
"Router.Home.LAN",
RecordType::A,
"192.168.1.1",
300,
))
.await
.expect("add");
assert_eq!(inserted.name, "router.home.lan.", "name must be normalized");
let records = repo.list().await.expect("list");
assert_eq!(records[0].name, "router.home.lan.");
}
#[tokio::test]
async fn same_wildcard_name_a_and_aaaa_both_succeed() {
let (_dir, repo) = open_repo().await;
let a_rec = repo
.add(new_record(
"*.home.lan",
RecordType::A,
"192.168.1.100",
300,
))
.await
.expect("add A record");
let aaaa_rec = repo
.add(new_record("*.home.lan", RecordType::Aaaa, "fd00::1", 300))
.await
.expect("add AAAA record for same wildcard name");
assert_eq!(a_rec.name, "*.home.lan.");
assert_eq!(aaaa_rec.name, "*.home.lan.");
assert_ne!(a_rec.id, aaaa_rec.id);
assert_eq!(a_rec.record_type, RecordType::A);
assert_eq!(aaaa_rec.record_type, RecordType::Aaaa);
let records = repo.list().await.expect("list");
assert_eq!(records.len(), 2);
assert_eq!(records[0].record_type, RecordType::A);
assert_eq!(records[1].record_type, RecordType::Aaaa);
}
#[tokio::test]
async fn duplicate_name_type_insert_returns_error() {
let (_dir, repo) = open_repo().await;
repo.add(new_record("*.home.lan", RecordType::A, "192.168.1.1", 300))
.await
.expect("first add");
let err = repo
.add(new_record("*.home.lan", RecordType::A, "10.0.0.1", 300))
.await
.expect_err("duplicate (name, record_type) must error");
assert!(
matches!(err, Error::Sqlx(_)),
"expected Sqlx error for UNIQUE violation, got {err:?}"
);
}
#[tokio::test]
async fn remove_deletes_record() {
let (_dir, repo) = open_repo().await;
let inserted = repo
.add(new_record(
"router.home.lan",
RecordType::A,
"192.168.1.1",
300,
))
.await
.expect("add");
repo.remove(inserted.id).await.expect("remove");
let records = repo.list().await.expect("list after remove");
assert!(records.is_empty(), "record must be gone after remove");
}
#[tokio::test]
async fn remove_nonexistent_is_noop() {
let (_dir, repo) = open_repo().await;
repo.remove(9999)
.await
.expect("remove non-existent must not error");
}
#[tokio::test]
async fn load_all_returns_all_records() {
let (_dir, repo) = open_repo().await;
repo.add(new_record(
"router.home.lan",
RecordType::A,
"192.168.1.1",
300,
))
.await
.expect("add A");
repo.add(new_record(
"router.home.lan",
RecordType::Aaaa,
"fd00::1",
300,
))
.await
.expect("add AAAA");
repo.add(new_record(
"nas.home.lan",
RecordType::A,
"192.168.1.2",
600,
))
.await
.expect("add nas");
let all = repo.load_all().await.expect("load_all");
assert_eq!(all.len(), 3);
}
#[tokio::test]
async fn wildcard_name_round_trips_faithfully() {
let (_dir, repo) = open_repo().await;
let inserted = repo
.add(new_record(
"*.WILD.example.COM",
RecordType::A,
"1.2.3.4",
120,
))
.await
.expect("add wildcard");
assert_eq!(
inserted.name, "*.wild.example.com.",
"wildcard name must normalize to lowercase with trailing dot"
);
assert_eq!(inserted.value, "1.2.3.4");
assert_eq!(inserted.ttl, 120);
let records = repo.list().await.expect("list");
assert_eq!(records[0].name, "*.wild.example.com.");
}
#[tokio::test]
async fn type_value_ttl_are_stored_faithfully() {
let (_dir, repo) = open_repo().await;
let inserted = repo
.add(new_record(
"host.local",
RecordType::Aaaa,
"2001:db8::1",
7200,
))
.await
.expect("add AAAA");
assert_eq!(inserted.record_type, RecordType::Aaaa);
assert_eq!(inserted.value, "2001:db8::1");
assert_eq!(inserted.ttl, 7200);
}
}