claw-branch 0.1.2

Fork, simulate, and merge engine for ClawDB agents.
Documentation
//! SQLite-backed persistence for branch metadata, metrics, and commit history.

use std::path::Path;

use chrono::{DateTime, Utc};
use sqlx::{
    sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions},
    Row, SqlitePool,
};
use uuid::Uuid;

use crate::{
    error::{BranchError, BranchResult},
    types::{Branch, BranchMetrics, BranchStatus, CommitLogEntry, EntityType},
};

/// Stores branch registry state in a SQLite database.
#[derive(Debug, Clone)]
pub struct BranchStore {
    pool: SqlitePool,
}

impl BranchStore {
    /// Opens or creates a branch registry database and applies migrations.
    pub async fn new(db_path: &Path) -> BranchResult<Self> {
        if let Some(parent) = db_path.parent() {
            tokio::fs::create_dir_all(parent).await?;
        }
        let options = SqliteConnectOptions::new()
            .filename(db_path)
            .create_if_missing(true)
            .journal_mode(SqliteJournalMode::Wal);
        let pool = SqlitePoolOptions::new()
            .max_connections(5)
            .connect_with(options)
            .await?;
        sqlx::migrate!("./migrations").run(&pool).await?;
        Ok(Self { pool })
    }

    /// Returns the underlying SQLite pool.
    pub fn pool(&self) -> &SqlitePool {
        &self.pool
    }

    /// Inserts a branch and its initial metrics row.
    pub async fn insert(&self, branch: &Branch) -> BranchResult<()> {
        let metadata = serde_json::to_string(&branch.metadata)?;
        sqlx::query(
            "INSERT INTO branches (id, name, slug, workspace_id, parent_id, status, db_path, snapshot_path, forked_from_cursor, description, metadata, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
        )
        .bind(branch.id.to_string())
        .bind(&branch.name)
        .bind(&branch.slug)
        .bind(branch.workspace_id.to_string())
        .bind(branch.parent_id.map(|value| value.to_string()))
        .bind(branch.status.to_storage())
        .bind(branch.db_path.to_string_lossy().to_string())
        .bind(branch.snapshot_path.to_string_lossy().to_string())
        .bind(&branch.forked_from_cursor)
        .bind(&branch.description)
        .bind(metadata)
        .bind(branch.created_at.to_rfc3339())
        .bind(branch.updated_at.to_rfc3339())
        .execute(&self.pool)
        .await
        .map_err(map_unique_name_error(&branch.name))?;

        self.update_metrics(branch.id, &branch.metrics).await
    }

    /// Retrieves a branch by id within a workspace.
    pub async fn get(&self, workspace_id: Uuid, id: Uuid) -> BranchResult<Branch> {
        let row = sqlx::query(
            "SELECT b.id, b.name, b.slug, b.workspace_id, b.parent_id, b.status, b.db_path, b.snapshot_path, b.forked_from_cursor, b.description, b.metadata, b.created_at, b.updated_at, m.op_count, m.memory_record_count, m.session_count, m.tool_output_count, m.bytes_on_disk, m.divergence_score, m.created_entity_count, m.updated_entity_count, m.deleted_entity_count, m.last_activity_at FROM branches b LEFT JOIN branch_metrics m ON m.branch_id = b.id WHERE b.workspace_id = ? AND b.id = ?",
        )
        .bind(workspace_id.to_string())
        .bind(id.to_string())
        .fetch_optional(&self.pool)
        .await?;

        row.map(branch_from_row)
            .transpose()?
            .ok_or(BranchError::BranchNotFound(id))
    }

    /// Retrieves a branch by workspace and slug.
    pub async fn get_by_slug(&self, workspace_id: Uuid, slug: &str) -> BranchResult<Branch> {
        let row = sqlx::query(
            "SELECT b.id, b.name, b.slug, b.workspace_id, b.parent_id, b.status, b.db_path, b.snapshot_path, b.forked_from_cursor, b.description, b.metadata, b.created_at, b.updated_at, m.op_count, m.memory_record_count, m.session_count, m.tool_output_count, m.bytes_on_disk, m.divergence_score, m.created_entity_count, m.updated_entity_count, m.deleted_entity_count, m.last_activity_at FROM branches b LEFT JOIN branch_metrics m ON m.branch_id = b.id WHERE b.workspace_id = ? AND b.slug = ?",
        )
        .bind(workspace_id.to_string())
        .bind(slug)
        .fetch_optional(&self.pool)
        .await?;

        row.map(branch_from_row)
            .transpose()?
            .ok_or_else(|| BranchError::BranchAlreadyExists(slug.to_string()))
    }

    /// Lists branches for a workspace, optionally filtered by status kind.
    pub async fn list(
        &self,
        workspace_id: Uuid,
        status: Option<BranchStatus>,
    ) -> BranchResult<Vec<Branch>> {
        let rows = sqlx::query(
            "SELECT b.id, b.name, b.slug, b.workspace_id, b.parent_id, b.status, b.db_path, b.snapshot_path, b.forked_from_cursor, b.description, b.metadata, b.created_at, b.updated_at, m.op_count, m.memory_record_count, m.session_count, m.tool_output_count, m.bytes_on_disk, m.divergence_score, m.created_entity_count, m.updated_entity_count, m.deleted_entity_count, m.last_activity_at FROM branches b LEFT JOIN branch_metrics m ON m.branch_id = b.id WHERE b.workspace_id = ? ORDER BY b.created_at ASC",
        )
        .bind(workspace_id.to_string())
        .fetch_all(&self.pool)
        .await?;
        let mut branches = rows
            .into_iter()
            .map(branch_from_row)
            .collect::<BranchResult<Vec<_>>>()?;

        if let Some(expected_status) = status {
            let expected_kind = expected_status.kind();
            branches.retain(|branch| branch.status.kind() == expected_kind);
        }

        Ok(branches)
    }

    /// Updates the lifecycle status of a branch.
    pub async fn update_status(&self, id: Uuid, status: BranchStatus) -> BranchResult<()> {
        let result = sqlx::query("UPDATE branches SET status = ?, updated_at = ? WHERE id = ?")
            .bind(status.to_storage())
            .bind(Utc::now().to_rfc3339())
            .bind(id.to_string())
            .execute(&self.pool)
            .await?;
        if result.rows_affected() == 0 {
            return Err(BranchError::BranchNotFound(id));
        }
        Ok(())
    }

    /// Upserts branch metrics for the given branch identifier.
    pub async fn update_metrics(&self, id: Uuid, metrics: &BranchMetrics) -> BranchResult<()> {
        sqlx::query(
            "INSERT INTO branch_metrics (branch_id, op_count, memory_record_count, session_count, tool_output_count, bytes_on_disk, divergence_score, created_entity_count, updated_entity_count, deleted_entity_count, last_activity_at, refreshed_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(branch_id) DO UPDATE SET op_count = excluded.op_count, memory_record_count = excluded.memory_record_count, session_count = excluded.session_count, tool_output_count = excluded.tool_output_count, bytes_on_disk = excluded.bytes_on_disk, divergence_score = excluded.divergence_score, created_entity_count = excluded.created_entity_count, updated_entity_count = excluded.updated_entity_count, deleted_entity_count = excluded.deleted_entity_count, last_activity_at = excluded.last_activity_at, refreshed_at = excluded.refreshed_at",
        )
        .bind(id.to_string())
        .bind(metrics.op_count as i64)
        .bind(metrics.memory_record_count)
        .bind(metrics.session_count)
        .bind(metrics.tool_output_count)
        .bind(metrics.bytes_on_disk as i64)
        .bind(metrics.divergence_score)
        .bind(metrics.created_entity_count as i64)
        .bind(metrics.updated_entity_count as i64)
        .bind(metrics.deleted_entity_count as i64)
        .bind(metrics.last_activity_at.map(|value| value.to_rfc3339()))
        .bind(Utc::now().to_rfc3339())
        .execute(&self.pool)
        .await?;
        Ok(())
    }

    /// Deletes a branch registry row and its metrics entry.
    pub async fn delete(&self, id: Uuid) -> BranchResult<()> {
        sqlx::query("DELETE FROM branch_metrics WHERE branch_id = ?")
            .bind(id.to_string())
            .execute(&self.pool)
            .await?;
        let result = sqlx::query("DELETE FROM branches WHERE id = ?")
            .bind(id.to_string())
            .execute(&self.pool)
            .await?;
        if result.rows_affected() == 0 {
            return Err(BranchError::BranchNotFound(id));
        }
        Ok(())
    }

    /// Counts branches for a workspace.
    pub async fn count(&self, workspace_id: Uuid) -> BranchResult<u64> {
        let row = sqlx::query("SELECT COUNT(*) AS count FROM branches WHERE workspace_id = ?")
            .bind(workspace_id.to_string())
            .fetch_one(&self.pool)
            .await?;
        let count: i64 = row.try_get("count")?;
        Ok(count as u64)
    }

    /// Counts active and dormant branches for a workspace.
    pub async fn count_active(&self, workspace_id: Uuid) -> BranchResult<u64> {
        let row = sqlx::query(
            "SELECT COUNT(*) AS count FROM branches WHERE workspace_id = ? AND (status = 'active' OR status = 'dormant')",
        )
        .bind(workspace_id.to_string())
        .fetch_one(&self.pool)
        .await?;
        let count: i64 = row.try_get("count")?;
        Ok(count as u64)
    }

    /// Inserts a commit log entry.
    pub async fn insert_commit_log(&self, entry: &CommitLogEntry) -> BranchResult<()> {
        let entity_ids = serde_json::to_string(&entry.entity_ids)?;
        sqlx::query(
            "INSERT INTO branch_commits (id, branch_id, entity_type, entity_ids, op_kind, committed_at, message) VALUES (?, ?, ?, ?, ?, ?, ?)",
        )
        .bind(entry.id.to_string())
        .bind(entry.branch_id.to_string())
        .bind(entry.entity_type.as_ref().map(EntityType::as_str))
        .bind(entity_ids)
        .bind(&entry.op_kind)
        .bind(entry.committed_at.to_rfc3339())
        .bind(&entry.message)
        .execute(&self.pool)
        .await?;
        Ok(())
    }

    /// Lists recent commit log entries for a branch.
    pub async fn list_commits(
        &self,
        branch_id: Uuid,
        limit: u32,
    ) -> BranchResult<Vec<CommitLogEntry>> {
        let rows = sqlx::query(
            "SELECT id, branch_id, entity_type, entity_ids, op_kind, committed_at, message FROM branch_commits WHERE branch_id = ? ORDER BY committed_at DESC LIMIT ?",
        )
        .bind(branch_id.to_string())
        .bind(limit as i64)
        .fetch_all(&self.pool)
        .await?;

        rows.into_iter().map(commit_log_from_row).collect()
    }

    /// Counts total commits recorded for a branch.
    pub async fn count_commits(&self, branch_id: Uuid) -> BranchResult<u64> {
        let row = sqlx::query("SELECT COUNT(*) AS count FROM branch_commits WHERE branch_id = ?")
            .bind(branch_id.to_string())
            .fetch_one(&self.pool)
            .await?;
        let count: i64 = row.try_get("count")?;
        Ok(count as u64)
    }

    /// Finds a branch by its human-readable name within the workspace.
    pub async fn get_by_name(&self, workspace_id: Uuid, name: &str) -> BranchResult<Branch> {
        let row = sqlx::query(
            "SELECT b.id, b.name, b.slug, b.workspace_id, b.parent_id, b.status, b.db_path, b.snapshot_path, b.forked_from_cursor, b.description, b.metadata, b.created_at, b.updated_at, m.op_count, m.memory_record_count, m.session_count, m.tool_output_count, m.bytes_on_disk, m.divergence_score, m.created_entity_count, m.updated_entity_count, m.deleted_entity_count, m.last_activity_at FROM branches b LEFT JOIN branch_metrics m ON m.branch_id = b.id WHERE b.workspace_id = ? AND b.name = ?",
        )
        .bind(workspace_id.to_string())
        .bind(name)
        .fetch_optional(&self.pool)
        .await?;

        row.map(branch_from_row)
            .transpose()?
            .ok_or_else(|| BranchError::BranchAlreadyExists(name.to_string()))
    }
}

fn map_unique_name_error(name: &str) -> impl FnOnce(sqlx::Error) -> BranchError + '_ {
    move |error| match &error {
        sqlx::Error::Database(database_error)
            if database_error
                .message()
                .contains("UNIQUE constraint failed") =>
        {
            BranchError::BranchAlreadyExists(name.to_string())
        }
        _ => BranchError::Database(error),
    }
}

fn branch_from_row(row: sqlx::sqlite::SqliteRow) -> BranchResult<Branch> {
    let metadata = row.try_get::<String, _>("metadata")?;
    let status = row.try_get::<String, _>("status")?;
    Ok(Branch {
        id: parse_uuid(row.try_get("id")?)?,
        name: row.try_get("name")?,
        slug: row.try_get("slug")?,
        workspace_id: parse_uuid(row.try_get("workspace_id")?)?,
        parent_id: row
            .try_get::<Option<String>, _>("parent_id")?
            .map(parse_uuid)
            .transpose()?,
        status: BranchStatus::from_storage(&status).ok_or_else(|| {
            BranchError::Serialization(serde_json::Error::io(std::io::Error::new(
                std::io::ErrorKind::InvalidData,
                format!("invalid branch status: {status}"),
            )))
        })?,
        db_path: row.try_get::<String, _>("db_path")?.into(),
        snapshot_path: row.try_get::<String, _>("snapshot_path")?.into(),
        forked_from_cursor: row.try_get("forked_from_cursor")?,
        description: row.try_get("description")?,
        metadata: serde_json::from_str(&metadata)?,
        created_at: parse_datetime(row.try_get("created_at")?)?,
        updated_at: parse_datetime(row.try_get("updated_at")?)?,
        metrics: BranchMetrics {
            op_count: row
                .try_get::<Option<i64>, _>("op_count")?
                .unwrap_or_default() as u64,
            memory_record_count: row
                .try_get::<Option<i64>, _>("memory_record_count")?
                .unwrap_or_default(),
            session_count: row
                .try_get::<Option<i64>, _>("session_count")?
                .unwrap_or_default(),
            tool_output_count: row
                .try_get::<Option<i64>, _>("tool_output_count")?
                .unwrap_or_default(),
            bytes_on_disk: row
                .try_get::<Option<i64>, _>("bytes_on_disk")?
                .unwrap_or_default() as u64,
            divergence_score: row
                .try_get::<Option<f64>, _>("divergence_score")?
                .unwrap_or_default(),
            created_entity_count: row
                .try_get::<Option<i64>, _>("created_entity_count")?
                .unwrap_or_default() as u64,
            updated_entity_count: row
                .try_get::<Option<i64>, _>("updated_entity_count")?
                .unwrap_or_default() as u64,
            deleted_entity_count: row
                .try_get::<Option<i64>, _>("deleted_entity_count")?
                .unwrap_or_default() as u64,
            last_activity_at: row
                .try_get::<Option<String>, _>("last_activity_at")?
                .map(parse_datetime)
                .transpose()?,
        },
    })
}

fn commit_log_from_row(row: sqlx::sqlite::SqliteRow) -> BranchResult<CommitLogEntry> {
    let entity_ids = row.try_get::<String, _>("entity_ids")?;
    Ok(CommitLogEntry {
        id: parse_uuid(row.try_get("id")?)?,
        branch_id: parse_uuid(row.try_get("branch_id")?)?,
        entity_type: row
            .try_get::<Option<String>, _>("entity_type")?
            .as_deref()
            .and_then(EntityType::parse),
        entity_ids: serde_json::from_str(&entity_ids)?,
        op_kind: row.try_get("op_kind")?,
        committed_at: parse_datetime(row.try_get("committed_at")?)?,
        message: row.try_get("message")?,
    })
}

fn parse_uuid(value: String) -> BranchResult<Uuid> {
    Uuid::parse_str(&value).map_err(|error| {
        BranchError::Serialization(serde_json::Error::io(std::io::Error::new(
            std::io::ErrorKind::InvalidData,
            error,
        )))
    })
}

fn parse_datetime(value: String) -> BranchResult<DateTime<Utc>> {
    value.parse::<DateTime<Utc>>().map_err(|error| {
        BranchError::Serialization(serde_json::Error::io(std::io::Error::new(
            std::io::ErrorKind::InvalidData,
            error,
        )))
    })
}