use std::{collections::HashMap, sync::Arc};
use chrono::Utc;
use sqlx::{
sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions},
SqlitePool,
};
use uuid::Uuid;
use crate::{
branch::store::BranchStore,
config::BranchConfig,
diff::extractor::DiffExtractor,
error::BranchResult,
types::{Branch, BranchMetrics, EntityType},
};
#[derive(Debug, Clone, Copy)]
pub enum OpKind {
Insert,
Update,
Delete,
}
pub struct MetricsTracker {
store: Arc<BranchStore>,
config: Arc<BranchConfig>,
}
impl MetricsTracker {
pub fn new(store: Arc<BranchStore>, config: Arc<BranchConfig>) -> Self {
Self { store, config }
}
pub async fn refresh(&self, branch: &Branch) -> BranchResult<BranchMetrics> {
let pool = open_pool(&branch.db_path).await?;
let memory_record_count = count_table(&pool, EntityType::MemoryRecord.table_name()).await?;
let session_count = count_table(&pool, EntityType::Session.table_name()).await?;
let tool_output_count = count_table(&pool, EntityType::ToolOutput.table_name()).await?;
pool.close().await;
let bytes_on_disk = tokio::fs::metadata(&branch.db_path)
.await
.map(|m| m.len())
.unwrap_or(0);
let op_count = self.store.count_commits(branch.id).await?;
let divergence_score = if let Some(parent_id) = branch.parent_id {
match self.store.get(self.config.workspace_id, parent_id).await {
Ok(parent) => {
let extractor = DiffExtractor::new(Arc::clone(&self.config));
match extractor.diff(branch, &parent, None).await {
Ok(diff) => {
let base_count = (parent.metrics.memory_record_count
+ parent.metrics.session_count
+ parent.metrics.tool_output_count)
as u64;
crate::metrics::divergence::compute_score(&diff, base_count)
}
Err(_) => branch.metrics.divergence_score,
}
}
Err(_) => branch.metrics.divergence_score,
}
} else {
0.0
};
let metrics = BranchMetrics {
op_count,
memory_record_count,
session_count,
tool_output_count,
bytes_on_disk,
divergence_score,
created_entity_count: branch.metrics.created_entity_count,
updated_entity_count: branch.metrics.updated_entity_count,
deleted_entity_count: branch.metrics.deleted_entity_count,
last_activity_at: branch.metrics.last_activity_at,
};
self.store.update_metrics(branch.id, &metrics).await?;
Ok(metrics)
}
pub async fn refresh_all(
&self,
workspace_id: Uuid,
) -> BranchResult<HashMap<Uuid, BranchMetrics>> {
let branches = self.store.list(workspace_id, None).await?;
let mut result = HashMap::with_capacity(branches.len());
for branch in branches {
let metrics = self.refresh(&branch).await?;
result.insert(branch.id, metrics);
}
Ok(result)
}
pub async fn track_op(&self, branch_id: Uuid, op_kind: OpKind) -> BranchResult<()> {
let mut branch = self.store.get(self.config.workspace_id, branch_id).await?;
branch.metrics.op_count += 1;
branch.metrics.last_activity_at = Some(Utc::now());
match op_kind {
OpKind::Insert => branch.metrics.created_entity_count += 1,
OpKind::Update => branch.metrics.updated_entity_count += 1,
OpKind::Delete => branch.metrics.deleted_entity_count += 1,
}
self.store.update_metrics(branch_id, &branch.metrics).await
}
}
async fn open_pool(path: &std::path::Path) -> BranchResult<SqlitePool> {
SqlitePoolOptions::new()
.max_connections(1)
.connect_with(
SqliteConnectOptions::new()
.filename(path)
.create_if_missing(false)
.read_only(true)
.journal_mode(SqliteJournalMode::Wal),
)
.await
.map_err(crate::error::BranchError::Database)
}
async fn count_table(pool: &SqlitePool, table: &str) -> BranchResult<i64> {
let exists: i64 =
sqlx::query_scalar("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?")
.bind(table)
.fetch_one(pool)
.await?;
if exists == 0 {
return Ok(0);
}
let count: i64 = sqlx::query_scalar(&format!("SELECT COUNT(*) FROM {table}"))
.fetch_one(pool)
.await?;
Ok(count)
}