Skip to main content

claw_branch/diff/
extractor.rs

1//! Entity diff extraction between two branch databases.
2
3use std::{
4    collections::{HashMap, HashSet},
5    sync::Arc,
6};
7
8use chrono::Utc;
9use sqlx::{
10    sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions},
11    Row, SqlitePool,
12};
13use uuid::Uuid;
14
15use crate::{
16    config::BranchConfig,
17    diff::scorer::score_divergence,
18    error::{BranchError, BranchResult},
19    types::{Branch, DiffKind, DiffResult, DiffStats, EntityDiff, EntityType, FieldDiff},
20};
21
22/// Opens branch databases and computes entity-level diffs across the ClawDB schema.
23pub struct DiffExtractor {
24    /// Configuration used for connection limits and timeouts.
25    pub config: Arc<BranchConfig>,
26}
27
28impl DiffExtractor {
29    /// Creates a new extractor with the given workspace config.
30    pub fn new(config: Arc<BranchConfig>) -> Self {
31        Self { config }
32    }
33
34    /// Computes the full diff between `branch_a` and `branch_b`.
35    ///
36    /// When `entity_types` is `None`, all three entity types are compared.
37    pub async fn diff(
38        &self,
39        branch_a: &Branch,
40        branch_b: &Branch,
41        entity_types: Option<&[EntityType]>,
42    ) -> BranchResult<DiffResult> {
43        let pool_a = open_pool(&branch_a.db_path).await?;
44        let pool_b = open_pool(&branch_b.db_path).await?;
45
46        let types: &[EntityType] = entity_types.unwrap_or(&[
47            EntityType::MemoryRecord,
48            EntityType::Session,
49            EntityType::ToolOutput,
50        ]);
51
52        let mut entity_diffs: Vec<EntityDiff> = Vec::new();
53        let mut stats = DiffStats::default();
54
55        for entity_type in types {
56            let map_a = fetch_all_entities(&pool_a, entity_type).await?;
57            let map_b = fetch_all_entities(&pool_b, entity_type).await?;
58
59            let all_ids: HashSet<&String> = map_a.keys().chain(map_b.keys()).collect();
60            stats.total_entities += all_ids.len() as u32;
61
62            for id in all_ids {
63                let ed = match (map_a.get(id), map_b.get(id)) {
64                    (Some(_), None) => {
65                        stats.removed += 1;
66                        EntityDiff {
67                            entity_id: id.clone(),
68                            entity_type: entity_type.clone(),
69                            diff_kind: DiffKind::Removed,
70                            field_diffs: Vec::new(),
71                        }
72                    }
73                    (None, Some(_)) => {
74                        stats.added += 1;
75                        EntityDiff {
76                            entity_id: id.clone(),
77                            entity_type: entity_type.clone(),
78                            diff_kind: DiffKind::Added,
79                            field_diffs: Vec::new(),
80                        }
81                    }
82                    (Some(va), Some(vb)) => {
83                        let ed = compare_entity_values(id, entity_type.clone(), va, vb);
84                        match ed.diff_kind {
85                            DiffKind::Modified => stats.modified += 1,
86                            DiffKind::Unchanged => stats.unchanged += 1,
87                            _ => {}
88                        }
89                        ed
90                    }
91                    (None, None) => unreachable!(),
92                };
93                entity_diffs.push(ed);
94            }
95        }
96
97        let divergence_score = score_divergence(&stats);
98
99        pool_a.close().await;
100        pool_b.close().await;
101
102        Ok(DiffResult {
103            branch_a_id: branch_a.id,
104            branch_b_id: branch_b.id,
105            compared_at: Utc::now(),
106            entity_diffs,
107            stats,
108            divergence_score,
109        })
110    }
111
112    /// Diffs a single named entity between two already-opened pools.
113    pub async fn diff_entity(
114        &self,
115        entity_id: &str,
116        entity_type: &EntityType,
117        pool_a: &SqlitePool,
118        pool_b: &SqlitePool,
119    ) -> BranchResult<EntityDiff> {
120        let map_a = fetch_all_entities(pool_a, entity_type).await?;
121        let map_b = fetch_all_entities(pool_b, entity_type).await?;
122        Ok(match (map_a.get(entity_id), map_b.get(entity_id)) {
123            (Some(_), None) => EntityDiff {
124                entity_id: entity_id.to_string(),
125                entity_type: entity_type.clone(),
126                diff_kind: DiffKind::Removed,
127                field_diffs: Vec::new(),
128            },
129            (None, Some(_)) => EntityDiff {
130                entity_id: entity_id.to_string(),
131                entity_type: entity_type.clone(),
132                diff_kind: DiffKind::Added,
133                field_diffs: Vec::new(),
134            },
135            (Some(va), Some(vb)) => compare_entity_values(entity_id, entity_type.clone(), va, vb),
136            (None, None) => EntityDiff {
137                entity_id: entity_id.to_string(),
138                entity_type: entity_type.clone(),
139                diff_kind: DiffKind::Unchanged,
140                field_diffs: Vec::new(),
141            },
142        })
143    }
144}
145
146/// Fetches all rows from an entity table as a `HashMap<entity_id, JSON object>`.
147///
148/// Uses `PRAGMA table_info` to discover column names, then builds each row into a
149/// `serde_json::Value::Object` via SQLite's `json_object()` function.
150pub async fn fetch_all_entities(
151    pool: &SqlitePool,
152    entity_type: &EntityType,
153) -> BranchResult<HashMap<String, serde_json::Value>> {
154    let table = entity_type.table_name();
155
156    // Discover column names dynamically.
157    let pragma_sql = format!("PRAGMA table_info({table})");
158    let col_rows = sqlx::query(&pragma_sql).fetch_all(pool).await?;
159    let columns: Vec<String> = col_rows
160        .iter()
161        .filter_map(|r| r.try_get::<String, _>("name").ok())
162        .collect();
163
164    if columns.is_empty() {
165        // Table does not exist in this snapshot — treat as empty.
166        return Ok(HashMap::new());
167    }
168
169    // Build SELECT id, json_object('col', col, …) AS __data FROM <table>
170    let json_args: String = columns
171        .iter()
172        .map(|c| format!("'{}', {}", c, c))
173        .collect::<Vec<_>>()
174        .join(", ");
175    let query_sql = format!("SELECT id, json_object({json_args}) AS __data FROM {table}");
176
177    let rows = sqlx::query(&query_sql).fetch_all(pool).await?;
178    let mut map = HashMap::with_capacity(rows.len());
179    for row in rows {
180        let id: String = row.try_get("id")?;
181        let data_str: String = row.try_get("__data")?;
182        let data: serde_json::Value = serde_json::from_str(&data_str)?;
183        map.insert(id, data);
184    }
185    Ok(map)
186}
187
188// ── Internal helpers ────────────────────────────────────────────────────────
189
190async fn open_pool(path: &std::path::Path) -> BranchResult<SqlitePool> {
191    SqlitePoolOptions::new()
192        .max_connections(1)
193        .connect_with(
194            SqliteConnectOptions::new()
195                .filename(path)
196                .create_if_missing(false)
197                .read_only(true)
198                .journal_mode(SqliteJournalMode::Wal),
199        )
200        .await
201        .map_err(BranchError::Database)
202}
203
204fn compare_entity_values(
205    entity_id: &str,
206    entity_type: EntityType,
207    a: &serde_json::Value,
208    b: &serde_json::Value,
209) -> EntityDiff {
210    let a_obj = a.as_object().cloned().unwrap_or_default();
211    let b_obj = b.as_object().cloned().unwrap_or_default();
212
213    let all_fields: HashSet<String> = a_obj.keys().chain(b_obj.keys()).cloned().collect();
214
215    let mut field_diffs: Vec<FieldDiff> = Vec::new();
216    for field in &all_fields {
217        let av = a_obj.get(field).cloned().unwrap_or(serde_json::Value::Null);
218        let bv = b_obj.get(field).cloned().unwrap_or(serde_json::Value::Null);
219        if av != bv {
220            field_diffs.push(FieldDiff {
221                field: field.clone(),
222                before: av,
223                after: bv,
224            });
225        }
226    }
227
228    let diff_kind = if field_diffs.is_empty() {
229        DiffKind::Unchanged
230    } else {
231        DiffKind::Modified
232    };
233
234    EntityDiff {
235        entity_id: entity_id.to_string(),
236        entity_type,
237        diff_kind,
238        field_diffs,
239    }
240}
241
242// Keep the old free-function for backward compat.
243/// Returns an empty diff placeholder.  Use [`DiffExtractor`] for real extraction.
244pub async fn extract_diff(branch_a_id: Uuid, branch_b_id: Uuid) -> BranchResult<DiffResult> {
245    let stats = DiffStats::default();
246    Ok(DiffResult {
247        branch_a_id,
248        branch_b_id,
249        compared_at: Utc::now(),
250        entity_diffs: Vec::new(),
251        divergence_score: score_divergence(&stats),
252        stats,
253    })
254}