1use std::{sync::Arc, time::Instant};
4
5use chrono::Utc;
6use sqlx::{
7 sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions},
8 Arguments, SqlitePool,
9};
10use uuid::Uuid;
11
12use crate::{
13 branch::store::BranchStore,
14 commit::{
15 cherry::{CherryPick, EntitySelection},
16 validator::CommitValidator,
17 },
18 diff::extractor::fetch_all_entities,
19 error::{BranchError, BranchResult},
20 types::{CommitLogEntry, CommitResult, EntityType},
21};
22
23pub struct SelectiveCommit {
31 pub source_pool: SqlitePool,
33 pub target_pool: SqlitePool,
35 pub store: Arc<BranchStore>,
37 pub workspace_id: Uuid,
39}
40
41impl SelectiveCommit {
42 pub fn new(
44 source_pool: SqlitePool,
45 target_pool: SqlitePool,
46 store: Arc<BranchStore>,
47 workspace_id: Uuid,
48 ) -> Self {
49 Self {
50 source_pool,
51 target_pool,
52 store,
53 workspace_id,
54 }
55 }
56
57 pub async fn from_store(
59 store: Arc<BranchStore>,
60 source_id: Uuid,
61 target_id: Uuid,
62 workspace_id: Uuid,
63 ) -> BranchResult<Self> {
64 let source = store.get(workspace_id, source_id).await?;
65 let target = store.get(workspace_id, target_id).await?;
66 let source_pool = open_pool(&source.db_path, true).await?;
67 let target_pool = open_pool(&target.db_path, false).await?;
68 Ok(Self {
69 source_pool,
70 target_pool,
71 store,
72 workspace_id,
73 })
74 }
75
76 pub async fn commit(&self, cherry: &CherryPick) -> BranchResult<CommitResult> {
85 let started = Instant::now();
86
87 let source = self
89 .store
90 .get(self.workspace_id, cherry.source_branch_id)
91 .await?;
92 let target = self
93 .store
94 .get(self.workspace_id, cherry.target_branch_id)
95 .await?;
96
97 let validator = CommitValidator::new(self.source_pool.clone());
98 let report = validator.validate(cherry, &source, &target).await?;
99 if !report.ok {
100 return Err(BranchError::CommitValidationFailed {
101 branch_id: cherry.target_branch_id,
102 violations: report.violations,
103 });
104 }
105
106 let mut tx = self.target_pool.begin().await?;
107 let mut committed_entity_count = 0u32;
108 let mut fields_updated = 0u32;
109 let mut all_entity_ids: Vec<String> = Vec::new();
110
111 for sel in &cherry.entity_selections {
112 let source_map = fetch_all_entities(&self.source_pool, &sel.entity_type).await?;
113
114 let ids_to_process: Vec<&String> = if sel.entity_ids.is_empty() {
115 source_map.keys().collect()
116 } else {
117 sel.entity_ids.iter().collect()
118 };
119
120 for entity_id in ids_to_process {
121 let source_val = match source_map.get(entity_id) {
122 Some(v) => v.clone(),
123 None => continue,
124 };
125
126 let final_val = if let Some(fields) = &sel.fields {
127 let target_map =
129 fetch_all_entities(&self.target_pool, &sel.entity_type).await?;
130 let mut merged = target_map
131 .get(entity_id)
132 .cloned()
133 .unwrap_or_else(|| serde_json::Value::Object(serde_json::Map::new()));
134 if let (Some(merged_obj), Some(source_obj)) =
135 (merged.as_object_mut(), source_val.as_object())
136 {
137 for f in fields {
138 if let Some(v) = source_obj.get(f) {
139 merged_obj.insert(f.clone(), v.clone());
140 fields_updated += 1;
141 }
142 }
143 }
144 merged
145 } else {
146 let field_count = source_val.as_object().map(|o| o.len()).unwrap_or(0);
147 fields_updated += field_count as u32;
148 source_val
149 };
150
151 upsert_entity_tx(&mut tx, sel.entity_type.table_name(), entity_id, &final_val)
152 .await?;
153 committed_entity_count += 1;
154 all_entity_ids.push(entity_id.clone());
155 }
156 }
157
158 tx.commit().await?;
159
160 let entry = CommitLogEntry {
162 id: Uuid::new_v4(),
163 branch_id: cherry.target_branch_id,
164 entity_type: cherry
165 .entity_selections
166 .first()
167 .map(|s| s.entity_type.clone()),
168 entity_ids: all_entity_ids,
169 op_kind: "cherry_pick".to_string(),
170 committed_at: Utc::now(),
171 message: cherry.message.clone(),
172 };
173 self.store.insert_commit_log(&entry).await?;
174
175 Ok(CommitResult {
176 committed_entity_count,
177 fields_updated,
178 duration_ms: started.elapsed().as_millis() as u64,
179 target_branch_id: cherry.target_branch_id,
180 committed_at: entry.committed_at,
181 })
182 }
183
184 pub async fn commit_all(
188 &self,
189 source_branch_id: Uuid,
190 target_branch_id: Uuid,
191 ) -> BranchResult<CommitResult> {
192 let cherry = CherryPick {
193 source_branch_id,
194 target_branch_id,
195 entity_selections: vec![
196 EntitySelection {
197 entity_type: EntityType::MemoryRecord,
198 entity_ids: Vec::new(),
199 fields: None,
200 },
201 EntitySelection {
202 entity_type: EntityType::Session,
203 entity_ids: Vec::new(),
204 fields: None,
205 },
206 EntitySelection {
207 entity_type: EntityType::ToolOutput,
208 entity_ids: Vec::new(),
209 fields: None,
210 },
211 ],
212 message: Some("commit_all".to_string()),
213 };
214 self.commit(&cherry).await
215 }
216}
217
218async fn upsert_entity_tx(
221 tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>,
222 table: &str,
223 entity_id: &str,
224 value: &serde_json::Value,
225) -> BranchResult<()> {
226 let obj = match value.as_object() {
227 Some(o) => o,
228 None => return Ok(()),
229 };
230
231 let mut columns: Vec<String> = vec!["id".to_string()];
232 let mut values: Vec<Option<String>> = vec![Some(entity_id.to_string())];
233
234 for (k, v) in obj {
235 if k != "id" {
236 columns.push(k.clone());
237 values.push(json_to_str(v));
238 }
239 }
240
241 let col_list = columns.join(", ");
242 let placeholders = columns.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
243 let sql = format!("INSERT OR REPLACE INTO {table} ({col_list}) VALUES ({placeholders})");
244
245 let mut args = sqlx::sqlite::SqliteArguments::default();
246 for v in &values {
247 args.add(v.clone())
248 .map_err(|error| BranchError::InvalidConfig(format!("invalid sqlite arg: {error}")))?;
249 }
250 sqlx::query_with(&sql, args).execute(&mut **tx).await?;
251 Ok(())
252}
253
254fn json_to_str(v: &serde_json::Value) -> Option<String> {
255 match v {
256 serde_json::Value::Null => None,
257 serde_json::Value::Bool(b) => Some(if *b { "1" } else { "0" }.to_string()),
258 serde_json::Value::Number(n) => Some(n.to_string()),
259 serde_json::Value::String(s) => Some(s.clone()),
260 serde_json::Value::Array(_) | serde_json::Value::Object(_) => Some(v.to_string()),
261 }
262}
263
264async fn open_pool(path: &std::path::Path, read_only: bool) -> BranchResult<SqlitePool> {
265 SqlitePoolOptions::new()
266 .max_connections(2)
267 .connect_with(
268 SqliteConnectOptions::new()
269 .filename(path)
270 .create_if_missing(false)
271 .read_only(read_only)
272 .journal_mode(SqliteJournalMode::Wal),
273 )
274 .await
275 .map_err(BranchError::Database)
276}