1use std::{
4 collections::{HashMap, HashSet},
5 sync::Arc,
6};
7
8use chrono::Utc;
9use sqlx::{
10 sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions},
11 Row, SqlitePool,
12};
13use uuid::Uuid;
14
15use crate::{
16 config::BranchConfig,
17 diff::scorer::score_divergence,
18 error::{BranchError, BranchResult},
19 types::{Branch, DiffKind, DiffResult, DiffStats, EntityDiff, EntityType, FieldDiff},
20};
21
22pub struct DiffExtractor {
24 pub config: Arc<BranchConfig>,
26}
27
28impl DiffExtractor {
29 pub fn new(config: Arc<BranchConfig>) -> Self {
31 Self { config }
32 }
33
34 pub async fn diff(
38 &self,
39 branch_a: &Branch,
40 branch_b: &Branch,
41 entity_types: Option<&[EntityType]>,
42 ) -> BranchResult<DiffResult> {
43 let pool_a = open_pool(&branch_a.db_path).await?;
44 let pool_b = open_pool(&branch_b.db_path).await?;
45
46 let types: &[EntityType] = entity_types.unwrap_or(&[
47 EntityType::MemoryRecord,
48 EntityType::Session,
49 EntityType::ToolOutput,
50 ]);
51
52 let mut entity_diffs: Vec<EntityDiff> = Vec::new();
53 let mut stats = DiffStats::default();
54
55 for entity_type in types {
56 let map_a = fetch_all_entities(&pool_a, entity_type).await?;
57 let map_b = fetch_all_entities(&pool_b, entity_type).await?;
58
59 let all_ids: HashSet<&String> = map_a.keys().chain(map_b.keys()).collect();
60 stats.total_entities += all_ids.len() as u32;
61
62 for id in all_ids {
63 let ed = match (map_a.get(id), map_b.get(id)) {
64 (Some(_), None) => {
65 stats.removed += 1;
66 EntityDiff {
67 entity_id: id.clone(),
68 entity_type: entity_type.clone(),
69 diff_kind: DiffKind::Removed,
70 field_diffs: Vec::new(),
71 }
72 }
73 (None, Some(_)) => {
74 stats.added += 1;
75 EntityDiff {
76 entity_id: id.clone(),
77 entity_type: entity_type.clone(),
78 diff_kind: DiffKind::Added,
79 field_diffs: Vec::new(),
80 }
81 }
82 (Some(va), Some(vb)) => {
83 let ed = compare_entity_values(id, entity_type.clone(), va, vb);
84 match ed.diff_kind {
85 DiffKind::Modified => stats.modified += 1,
86 DiffKind::Unchanged => stats.unchanged += 1,
87 _ => {}
88 }
89 ed
90 }
91 (None, None) => unreachable!(),
92 };
93 entity_diffs.push(ed);
94 }
95 }
96
97 let divergence_score = score_divergence(&stats);
98
99 pool_a.close().await;
100 pool_b.close().await;
101
102 Ok(DiffResult {
103 branch_a_id: branch_a.id,
104 branch_b_id: branch_b.id,
105 compared_at: Utc::now(),
106 entity_diffs,
107 stats,
108 divergence_score,
109 })
110 }
111
112 pub async fn diff_entity(
114 &self,
115 entity_id: &str,
116 entity_type: &EntityType,
117 pool_a: &SqlitePool,
118 pool_b: &SqlitePool,
119 ) -> BranchResult<EntityDiff> {
120 let map_a = fetch_all_entities(pool_a, entity_type).await?;
121 let map_b = fetch_all_entities(pool_b, entity_type).await?;
122 Ok(match (map_a.get(entity_id), map_b.get(entity_id)) {
123 (Some(_), None) => EntityDiff {
124 entity_id: entity_id.to_string(),
125 entity_type: entity_type.clone(),
126 diff_kind: DiffKind::Removed,
127 field_diffs: Vec::new(),
128 },
129 (None, Some(_)) => EntityDiff {
130 entity_id: entity_id.to_string(),
131 entity_type: entity_type.clone(),
132 diff_kind: DiffKind::Added,
133 field_diffs: Vec::new(),
134 },
135 (Some(va), Some(vb)) => compare_entity_values(entity_id, entity_type.clone(), va, vb),
136 (None, None) => EntityDiff {
137 entity_id: entity_id.to_string(),
138 entity_type: entity_type.clone(),
139 diff_kind: DiffKind::Unchanged,
140 field_diffs: Vec::new(),
141 },
142 })
143 }
144}
145
146pub async fn fetch_all_entities(
151 pool: &SqlitePool,
152 entity_type: &EntityType,
153) -> BranchResult<HashMap<String, serde_json::Value>> {
154 let table = entity_type.table_name();
155
156 let pragma_sql = format!("PRAGMA table_info({table})");
158 let col_rows = sqlx::query(&pragma_sql).fetch_all(pool).await?;
159 let columns: Vec<String> = col_rows
160 .iter()
161 .filter_map(|r| r.try_get::<String, _>("name").ok())
162 .collect();
163
164 if columns.is_empty() {
165 return Ok(HashMap::new());
167 }
168
169 let json_args: String = columns
171 .iter()
172 .map(|c| format!("'{}', {}", c, c))
173 .collect::<Vec<_>>()
174 .join(", ");
175 let query_sql = format!("SELECT id, json_object({json_args}) AS __data FROM {table}");
176
177 let rows = sqlx::query(&query_sql).fetch_all(pool).await?;
178 let mut map = HashMap::with_capacity(rows.len());
179 for row in rows {
180 let id: String = row.try_get("id")?;
181 let data_str: String = row.try_get("__data")?;
182 let data: serde_json::Value = serde_json::from_str(&data_str)?;
183 map.insert(id, data);
184 }
185 Ok(map)
186}
187
188async fn open_pool(path: &std::path::Path) -> BranchResult<SqlitePool> {
191 SqlitePoolOptions::new()
192 .max_connections(1)
193 .connect_with(
194 SqliteConnectOptions::new()
195 .filename(path)
196 .create_if_missing(false)
197 .read_only(true)
198 .journal_mode(SqliteJournalMode::Wal),
199 )
200 .await
201 .map_err(BranchError::Database)
202}
203
204fn compare_entity_values(
205 entity_id: &str,
206 entity_type: EntityType,
207 a: &serde_json::Value,
208 b: &serde_json::Value,
209) -> EntityDiff {
210 let a_obj = a.as_object().cloned().unwrap_or_default();
211 let b_obj = b.as_object().cloned().unwrap_or_default();
212
213 let all_fields: HashSet<String> = a_obj.keys().chain(b_obj.keys()).cloned().collect();
214
215 let mut field_diffs: Vec<FieldDiff> = Vec::new();
216 for field in &all_fields {
217 let av = a_obj.get(field).cloned().unwrap_or(serde_json::Value::Null);
218 let bv = b_obj.get(field).cloned().unwrap_or(serde_json::Value::Null);
219 if av != bv {
220 field_diffs.push(FieldDiff {
221 field: field.clone(),
222 before: av,
223 after: bv,
224 });
225 }
226 }
227
228 let diff_kind = if field_diffs.is_empty() {
229 DiffKind::Unchanged
230 } else {
231 DiffKind::Modified
232 };
233
234 EntityDiff {
235 entity_id: entity_id.to_string(),
236 entity_type,
237 diff_kind,
238 field_diffs,
239 }
240}
241
242pub async fn extract_diff(branch_a_id: Uuid, branch_b_id: Uuid) -> BranchResult<DiffResult> {
245 let stats = DiffStats::default();
246 Ok(DiffResult {
247 branch_a_id,
248 branch_b_id,
249 compared_at: Utc::now(),
250 entity_diffs: Vec::new(),
251 divergence_score: score_divergence(&stats),
252 stats,
253 })
254}