use std::sync::{Arc, Mutex};
use rusqlite::{Connection, params};
use seshat_core::BranchId;
use super::{BranchRepository, lock_conn};
use crate::StorageError;
const CURRENT_BRANCH_KEY: &str = "current_branch";
const DEFAULT_BRANCH: &str = "main";
#[derive(Debug, Clone)]
pub struct SqliteBranchRepository {
conn: Arc<Mutex<Connection>>,
}
impl SqliteBranchRepository {
pub fn new(conn: Arc<Mutex<Connection>>) -> Self {
Self { conn }
}
}
impl BranchRepository for SqliteBranchRepository {
fn create_snapshot(
&self,
source_branch: &BranchId,
new_branch: &BranchId,
) -> Result<(), StorageError> {
let conn = lock_conn(&self.conn)?;
let tx = conn.unchecked_transaction()?;
tx.execute(
"INSERT OR IGNORE INTO branches (branch_id) VALUES (?1)",
params![source_branch.0],
)?;
tx.execute(
"INSERT INTO branches (branch_id, snapshot_source) VALUES (?1, ?2)
ON CONFLICT(branch_id) DO UPDATE SET snapshot_source = excluded.snapshot_source",
params![new_branch.0, source_branch.0],
)?;
tx.execute(
"INSERT INTO nodes (branch_id, nature, weight, confidence, adoption_count, total_count, description, ext_data)
SELECT ?1, nature, weight, confidence, adoption_count, total_count, description, ext_data
FROM nodes WHERE branch_id = ?2",
params![new_branch.0, source_branch.0],
)?;
tx.execute(
"INSERT INTO edges (source_id, target_id, edge_type, branch_id, weight, metadata)
SELECT source_id, target_id, edge_type, ?1, weight, metadata
FROM edges WHERE branch_id = ?2",
params![new_branch.0, source_branch.0],
)?;
tx.execute(
"INSERT INTO files_ir (branch_id, file_path, language, content_hash, ir_data, updated_at)
SELECT ?1, file_path, language, content_hash, ir_data, updated_at
FROM files_ir WHERE branch_id = ?2",
params![new_branch.0, source_branch.0],
)?;
tx.execute(
"INSERT INTO symbol_definitions (branch_id, symbol_name, file_path, line, end_line, kind, is_public, snippet)
SELECT ?1, symbol_name, file_path, line, end_line, kind, is_public, snippet
FROM symbol_definitions WHERE branch_id = ?2",
params![new_branch.0, source_branch.0],
)?;
tx.execute(
"INSERT INTO symbol_imports (branch_id, imported_name, importer_file)
SELECT ?1, imported_name, importer_file
FROM symbol_imports WHERE branch_id = ?2",
params![new_branch.0, source_branch.0],
)?;
tx.execute(
"INSERT INTO branch_metadata (branch_id, key, value, updated_at)
SELECT ?1, key, value, updated_at
FROM branch_metadata WHERE branch_id = ?2",
params![new_branch.0, source_branch.0],
)?;
tx.commit()?;
Ok(())
}
fn switch_branch(&self, branch_id: &BranchId) -> Result<(), StorageError> {
let conn = lock_conn(&self.conn)?;
let tx = conn.unchecked_transaction()?;
tx.execute(
"INSERT OR IGNORE INTO branches (branch_id) VALUES (?1)",
params![branch_id.0],
)?;
tx.execute(
"INSERT INTO metadata (key, value) VALUES (?1, ?2)
ON CONFLICT(key) DO UPDATE SET value = excluded.value",
params![CURRENT_BRANCH_KEY, branch_id.0],
)?;
tx.commit()?;
Ok(())
}
fn delete_branch(&self, branch_id: &BranchId) -> Result<(), StorageError> {
let conn = lock_conn(&self.conn)?;
let tx = conn.unchecked_transaction()?;
tx.execute(
"DELETE FROM edges WHERE branch_id = ?1",
params![branch_id.0],
)?;
tx.execute(
"DELETE FROM nodes WHERE branch_id = ?1",
params![branch_id.0],
)?;
tx.execute(
"DELETE FROM files_ir WHERE branch_id = ?1",
params![branch_id.0],
)?;
tx.execute(
"DELETE FROM symbol_definitions WHERE branch_id = ?1",
params![branch_id.0],
)?;
tx.execute(
"DELETE FROM symbol_imports WHERE branch_id = ?1",
params![branch_id.0],
)?;
tx.execute(
"DELETE FROM branches WHERE branch_id = ?1",
params![branch_id.0],
)?;
tx.commit()?;
Ok(())
}
fn list_branches(&self) -> Result<Vec<BranchId>, StorageError> {
let conn = lock_conn(&self.conn)?;
let mut stmt = conn.prepare("SELECT branch_id FROM branches ORDER BY branch_id")?;
let rows = stmt.query_map([], |row| {
let id: String = row.get(0)?;
Ok(BranchId(id))
})?;
rows.collect::<Result<Vec<_>, _>>().map_err(Into::into)
}
fn get_current_branch(&self) -> Result<BranchId, StorageError> {
let conn = lock_conn(&self.conn)?;
let result: Result<String, _> = conn.query_row(
"SELECT value FROM metadata WHERE key = ?1",
params![CURRENT_BRANCH_KEY],
|row| row.get(0),
);
match result {
Ok(branch) => Ok(BranchId(branch)),
Err(rusqlite::Error::QueryReturnedNoRows) => {
tracing::debug!("No current_branch in metadata, defaulting to 'main'");
Ok(BranchId(DEFAULT_BRANCH.to_string()))
}
Err(e) => Err(e.into()),
}
}
fn get_last_scanned_commit(
&self,
branch_id: &BranchId,
) -> Result<Option<String>, StorageError> {
let conn = lock_conn(&self.conn)?;
let result: Result<Option<String>, _> = conn.query_row(
"SELECT last_scanned_commit FROM branches WHERE branch_id = ?1",
params![branch_id.0],
|row| row.get(0),
);
match result {
Ok(commit) => Ok(commit),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(e.into()),
}
}
fn set_last_scanned_commit(
&self,
branch_id: &BranchId,
commit: &str,
) -> Result<(), StorageError> {
let conn = lock_conn(&self.conn)?;
conn.execute(
"INSERT INTO branches (branch_id, last_scanned_commit, last_scanned_at)
VALUES (?1, ?2, unixepoch())
ON CONFLICT(branch_id) DO UPDATE SET
last_scanned_commit = excluded.last_scanned_commit,
last_scanned_at = excluded.last_scanned_at",
params![branch_id.0, commit],
)?;
Ok(())
}
fn ensure_branch_exists(&self, branch_id: &BranchId) -> Result<(), StorageError> {
let conn = lock_conn(&self.conn)?;
conn.execute(
"INSERT OR IGNORE INTO branches (branch_id) VALUES (?1)",
params![branch_id.0],
)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Database;
use crate::repository::file_ir_repository::SqliteFileIRRepository;
use crate::repository::node_repository::SqliteNodeRepository;
use crate::repository::{FileIRRepository, NodeRepository};
use seshat_core::test_helpers::{make_knowledge_node, make_project_file};
use seshat_core::{KnowledgeNature, Language};
fn test_repos() -> (
SqliteBranchRepository,
SqliteNodeRepository,
SqliteFileIRRepository,
) {
let db = Database::open(":memory:").expect("in-memory DB");
let conn = db.connection().clone();
(
SqliteBranchRepository::new(conn.clone()),
SqliteNodeRepository::new(conn.clone()),
SqliteFileIRRepository::new(conn),
)
}
#[test]
fn get_current_branch_default() {
let (branch_repo, _, _) = test_repos();
let current = branch_repo.get_current_branch().unwrap();
assert_eq!(current, BranchId::from("main"));
}
#[test]
fn switch_and_get_current_branch() {
let (branch_repo, _, _) = test_repos();
let feature = BranchId::from("feature-x");
branch_repo.switch_branch(&feature).unwrap();
let current = branch_repo.get_current_branch().unwrap();
assert_eq!(current, feature);
}
#[test]
fn switch_branch_overwrites() {
let (branch_repo, _, _) = test_repos();
branch_repo
.switch_branch(&BranchId::from("branch-a"))
.unwrap();
branch_repo
.switch_branch(&BranchId::from("branch-b"))
.unwrap();
let current = branch_repo.get_current_branch().unwrap();
assert_eq!(current, BranchId::from("branch-b"));
}
#[test]
fn create_snapshot_copies_nodes_and_files() {
let (branch_repo, node_repo, file_repo) = test_repos();
let main_branch = BranchId::from("main");
let mut n1 = make_knowledge_node(KnowledgeNature::Convention, 0.9);
n1.branch_id = main_branch.clone();
node_repo.insert(&n1).unwrap();
let mut n2 = make_knowledge_node(KnowledgeNature::Fact, 0.7);
n2.branch_id = main_branch.clone();
node_repo.insert(&n2).unwrap();
let mut file = make_project_file(Language::Rust);
file.path = "src/lib.rs".into();
file.content_hash = "snap_hash".to_string();
file_repo.upsert(&main_branch, &file, None).unwrap();
let feature = BranchId::from("feature-snap");
branch_repo.create_snapshot(&main_branch, &feature).unwrap();
let main_nodes = node_repo.find_by_branch(&main_branch).unwrap();
let feature_nodes = node_repo.find_by_branch(&feature).unwrap();
assert_eq!(main_nodes.len(), 2);
assert_eq!(feature_nodes.len(), 2);
let feature_files = file_repo.get_by_branch(&feature).unwrap();
assert_eq!(feature_files.len(), 1);
assert_eq!(feature_files[0].content_hash, "snap_hash");
}
#[test]
fn create_snapshot_empty_branch() {
let (branch_repo, node_repo, _) = test_repos();
let empty = BranchId::from("empty");
let target = BranchId::from("copy-of-empty");
branch_repo.create_snapshot(&empty, &target).unwrap();
let nodes = node_repo.find_by_branch(&target).unwrap();
assert!(nodes.is_empty());
}
#[test]
fn create_snapshot_copies_branch_metadata() {
use crate::repository::{BranchMetadataRepository, SqliteBranchMetadataRepository};
let (branch_repo, _, _) = test_repos();
let branch_meta = SqliteBranchMetadataRepository::new(branch_repo.conn.clone());
let main_branch = BranchId::from("main");
let feature = BranchId::from("feature-meta-snap");
branch_repo.ensure_branch_exists(&main_branch).unwrap();
branch_meta
.set("main", "workspace_crates", r#"["crate_a","crate_b"]"#)
.unwrap();
branch_meta.set("main", "other_key", "other_value").unwrap();
let source_rows: Vec<(String, String, i64)> = {
let conn = branch_repo.conn.lock().unwrap();
let mut stmt = conn
.prepare(
"SELECT key, value, updated_at FROM branch_metadata \
WHERE branch_id = ?1 ORDER BY key",
)
.unwrap();
stmt.query_map(params!["main"], |row| {
Ok((row.get(0)?, row.get(1)?, row.get(2)?))
})
.unwrap()
.collect::<Result<Vec<_>, _>>()
.unwrap()
};
assert_eq!(source_rows.len(), 2, "test setup must seed two rows");
branch_repo.create_snapshot(&main_branch, &feature).unwrap();
let snapshot_kv = branch_meta.list(&feature.0).unwrap();
assert_eq!(
snapshot_kv,
vec![
("other_key".to_string(), "other_value".to_string()),
(
"workspace_crates".to_string(),
r#"["crate_a","crate_b"]"#.to_string()
),
]
);
let snapshot_rows: Vec<(String, String, i64)> = {
let conn = branch_repo.conn.lock().unwrap();
let mut stmt = conn
.prepare(
"SELECT key, value, updated_at FROM branch_metadata \
WHERE branch_id = ?1 ORDER BY key",
)
.unwrap();
stmt.query_map(params![feature.0], |row| {
Ok((row.get(0)?, row.get(1)?, row.get(2)?))
})
.unwrap()
.collect::<Result<Vec<_>, _>>()
.unwrap()
};
assert_eq!(
snapshot_rows, source_rows,
"snapshotted branch_metadata must match source row-for-row"
);
let source_kv = branch_meta.list("main").unwrap();
assert_eq!(source_kv.len(), 2);
}
#[test]
fn list_branches_empty() {
let (branch_repo, _, _) = test_repos();
let branches = branch_repo.list_branches().unwrap();
assert!(branches.is_empty());
}
#[test]
fn list_branches_with_data() {
let (branch_repo, node_repo, file_repo) = test_repos();
let main_branch = BranchId::from("main");
let feature = BranchId::from("feature");
branch_repo.ensure_branch_exists(&main_branch).unwrap();
branch_repo.ensure_branch_exists(&feature).unwrap();
let mut n = make_knowledge_node(KnowledgeNature::Fact, 0.5);
n.branch_id = main_branch.clone();
node_repo.insert(&n).unwrap();
let mut f = make_project_file(Language::Python);
f.path = "app.py".into();
f.content_hash = "h".to_string();
file_repo.upsert(&feature, &f, None).unwrap();
let branches = branch_repo.list_branches().unwrap();
assert_eq!(branches.len(), 2);
assert!(branches.contains(&main_branch));
assert!(branches.contains(&feature));
}
#[test]
fn list_branches_reads_from_branches_table_not_nodes() {
let (branch_repo, node_repo, file_repo) = test_repos();
let ghost = BranchId::from("ghost-branch");
let mut n = make_knowledge_node(KnowledgeNature::Fact, 0.4);
n.branch_id = ghost.clone();
node_repo.insert(&n).unwrap();
let mut f = make_project_file(Language::Rust);
f.path = "ghost.rs".into();
f.content_hash = "ghost_hash".to_string();
file_repo.upsert(&ghost, &f, None).unwrap();
let branches = branch_repo.list_branches().unwrap();
assert!(
branches.is_empty(),
"list_branches should ignore raw rows in nodes/files_ir, got {branches:?}"
);
}
#[test]
fn delete_branch() {
let (branch_repo, node_repo, file_repo) = test_repos();
let branch = BranchId::from("to-delete");
let mut n = make_knowledge_node(KnowledgeNature::Observation, 0.6);
n.branch_id = branch.clone();
node_repo.insert(&n).unwrap();
let mut f = make_project_file(Language::TypeScript);
f.path = "index.ts".into();
f.content_hash = "del_hash".to_string();
file_repo.upsert(&branch, &f, None).unwrap();
assert_eq!(node_repo.find_by_branch(&branch).unwrap().len(), 1);
assert_eq!(file_repo.get_by_branch(&branch).unwrap().len(), 1);
branch_repo.delete_branch(&branch).unwrap();
assert!(node_repo.find_by_branch(&branch).unwrap().is_empty());
assert!(file_repo.get_by_branch(&branch).unwrap().is_empty());
}
#[test]
fn delete_branch_no_data_succeeds() {
let (branch_repo, _, _) = test_repos();
branch_repo.delete_branch(&BranchId::from("ghost")).unwrap();
}
#[test]
fn snapshot_and_delete_isolation() {
let (branch_repo, node_repo, file_repo) = test_repos();
let main_branch = BranchId::from("main");
let mut n = make_knowledge_node(KnowledgeNature::Decision, 0.95);
n.branch_id = main_branch.clone();
node_repo.insert(&n).unwrap();
let mut f = make_project_file(Language::Rust);
f.path = "src/main.rs".into();
f.content_hash = "iso_hash".to_string();
file_repo.upsert(&main_branch, &f, None).unwrap();
let snapshot = BranchId::from("snapshot");
branch_repo
.create_snapshot(&main_branch, &snapshot)
.unwrap();
branch_repo.delete_branch(&snapshot).unwrap();
assert_eq!(node_repo.find_by_branch(&main_branch).unwrap().len(), 1);
assert_eq!(file_repo.get_by_branch(&main_branch).unwrap().len(), 1);
assert!(node_repo.find_by_branch(&snapshot).unwrap().is_empty());
}
#[test]
fn ensure_branch_exists_is_idempotent() {
let (branch_repo, _, _) = test_repos();
let b = BranchId::from("idem");
branch_repo.ensure_branch_exists(&b).unwrap();
branch_repo.ensure_branch_exists(&b).unwrap();
branch_repo.ensure_branch_exists(&b).unwrap();
let branches = branch_repo.list_branches().unwrap();
assert_eq!(branches, vec![b]);
}
#[test]
fn ensure_branch_exists_does_not_overwrite_existing_metadata() {
let (branch_repo, _, _) = test_repos();
let b = BranchId::from("preserve-me");
branch_repo.set_last_scanned_commit(&b, "abc1234").unwrap();
branch_repo.ensure_branch_exists(&b).unwrap();
let commit = branch_repo.get_last_scanned_commit(&b).unwrap();
assert_eq!(commit.as_deref(), Some("abc1234"));
}
#[test]
fn get_last_scanned_commit_returns_none_for_unknown_branch() {
let (branch_repo, _, _) = test_repos();
let result = branch_repo
.get_last_scanned_commit(&BranchId::from("never-scanned"))
.unwrap();
assert!(result.is_none());
}
#[test]
fn get_last_scanned_commit_returns_none_when_branch_exists_but_not_scanned() {
let (branch_repo, _, _) = test_repos();
let b = BranchId::from("registered-only");
branch_repo.ensure_branch_exists(&b).unwrap();
let result = branch_repo.get_last_scanned_commit(&b).unwrap();
assert!(result.is_none());
}
#[test]
fn set_last_scanned_commit_round_trip() {
let (branch_repo, _, _) = test_repos();
let b = BranchId::from("round-trip");
branch_repo.set_last_scanned_commit(&b, "deadbeef").unwrap();
let read = branch_repo.get_last_scanned_commit(&b).unwrap();
assert_eq!(read.as_deref(), Some("deadbeef"));
}
#[test]
fn set_last_scanned_commit_upsert_overwrites_previous_value() {
let (branch_repo, _, _) = test_repos();
let b = BranchId::from("overwrite-me");
branch_repo.set_last_scanned_commit(&b, "first00").unwrap();
branch_repo.set_last_scanned_commit(&b, "secondf0").unwrap();
let read = branch_repo.get_last_scanned_commit(&b).unwrap();
assert_eq!(read.as_deref(), Some("secondf0"));
let branches = branch_repo.list_branches().unwrap();
assert_eq!(
branches.iter().filter(|x| **x == b).count(),
1,
"UPSERT must not duplicate rows"
);
}
#[test]
fn set_last_scanned_commit_bumps_last_scanned_at() {
let (branch_repo, _, _) = test_repos();
let b = BranchId::from("bump");
branch_repo.set_last_scanned_commit(&b, "h1").unwrap();
const PAST_TS: i64 = 0;
{
let conn = branch_repo.conn.lock().unwrap();
conn.execute(
"UPDATE branches SET last_scanned_at = ?1 WHERE branch_id = ?2",
params![PAST_TS, b.0],
)
.unwrap();
}
branch_repo.set_last_scanned_commit(&b, "h2").unwrap();
let conn = branch_repo.conn.lock().unwrap();
let ts2: i64 = conn
.query_row(
"SELECT last_scanned_at FROM branches WHERE branch_id = ?1",
params![b.0],
|row| row.get(0),
)
.unwrap();
assert!(
ts2 > PAST_TS,
"last_scanned_at must advance forward of any stored prior value; \
got ts2={ts2}, PAST_TS={PAST_TS}"
);
}
#[test]
fn create_snapshot_registers_target_branch_with_snapshot_source() {
let (branch_repo, _, _) = test_repos();
let main_branch = BranchId::from("main");
let snap = BranchId::from("snap-1");
branch_repo.create_snapshot(&main_branch, &snap).unwrap();
let listed = branch_repo.list_branches().unwrap();
assert!(listed.contains(&main_branch), "source must be registered");
assert!(listed.contains(&snap), "target must be registered");
let conn = branch_repo.conn.lock().unwrap();
let source: Option<String> = conn
.query_row(
"SELECT snapshot_source FROM branches WHERE branch_id = ?1",
params![snap.0],
|row| row.get(0),
)
.unwrap();
assert_eq!(source.as_deref(), Some("main"));
}
#[test]
fn delete_branch_removes_branches_row() {
let (branch_repo, _, _) = test_repos();
let b = BranchId::from("doomed");
branch_repo.set_last_scanned_commit(&b, "abc").unwrap();
assert!(branch_repo.list_branches().unwrap().contains(&b));
branch_repo.delete_branch(&b).unwrap();
assert!(
!branch_repo.list_branches().unwrap().contains(&b),
"delete_branch must drop the registry row"
);
}
#[test]
fn switch_branch_registers_branch_implicitly() {
let (branch_repo, _, _) = test_repos();
let b = BranchId::from("switched-only");
branch_repo.switch_branch(&b).unwrap();
assert!(branch_repo.list_branches().unwrap().contains(&b));
}
}