Skip to main content

claw_branch/commit/
selective.rs

1//! Selective entity commit operations between branches.
2
3use std::{sync::Arc, time::Instant};
4
5use chrono::Utc;
6use sqlx::{
7    sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions},
8    Arguments, SqlitePool,
9};
10use uuid::Uuid;
11
12use crate::{
13    branch::store::BranchStore,
14    commit::{
15        cherry::{CherryPick, EntitySelection},
16        validator::CommitValidator,
17    },
18    diff::extractor::fetch_all_entities,
19    error::{BranchError, BranchResult},
20    types::{CommitLogEntry, CommitResult, EntityType},
21};
22
23/// Applies cherry-picked entity changes from a source branch into a target branch.
24///
25/// # Example
26/// ```rust,ignore
27/// let committer = SelectiveCommit::new(source_pool, target_pool, store);
28/// let result = committer.commit(&cherry).await?;
29/// ```
30pub struct SelectiveCommit {
31    /// The pool connected to the source branch database.
32    pub source_pool: SqlitePool,
33    /// The pool connected to the target branch database.
34    pub target_pool: SqlitePool,
35    /// Shared branch registry for commit log recording.
36    pub store: Arc<BranchStore>,
37    /// Workspace identifier used to scope branch registry lookups.
38    pub workspace_id: Uuid,
39}
40
41impl SelectiveCommit {
42    /// Creates a new selective commit executor with pre-opened pools.
43    pub fn new(
44        source_pool: SqlitePool,
45        target_pool: SqlitePool,
46        store: Arc<BranchStore>,
47        workspace_id: Uuid,
48    ) -> Self {
49        Self {
50            source_pool,
51            target_pool,
52            store,
53            workspace_id,
54        }
55    }
56
57    /// Opens a `SelectiveCommit` from branch metadata (opens pools internally).
58    pub async fn from_store(
59        store: Arc<BranchStore>,
60        source_id: Uuid,
61        target_id: Uuid,
62        workspace_id: Uuid,
63    ) -> BranchResult<Self> {
64        let source = store.get(workspace_id, source_id).await?;
65        let target = store.get(workspace_id, target_id).await?;
66        let source_pool = open_pool(&source.db_path, true).await?;
67        let target_pool = open_pool(&target.db_path, false).await?;
68        Ok(Self {
69            source_pool,
70            target_pool,
71            store,
72            workspace_id,
73        })
74    }
75
76    /// Applies a cherry-pick to the target branch within a single transaction.
77    ///
78    /// Steps:
79    /// 1. Validate the selection via [`CommitValidator`]
80    /// 2. Open a write transaction on target
81    /// 3. Fetch and upsert selected entities (field-filtered if requested)
82    /// 4. Record a commit log entry
83    /// 5. Return [`CommitResult`]
84    pub async fn commit(&self, cherry: &CherryPick) -> BranchResult<CommitResult> {
85        let started = Instant::now();
86
87        // Load source and target branches for validation.
88        let source = self
89            .store
90            .get(self.workspace_id, cherry.source_branch_id)
91            .await?;
92        let target = self
93            .store
94            .get(self.workspace_id, cherry.target_branch_id)
95            .await?;
96
97        let validator = CommitValidator::new(self.source_pool.clone());
98        let report = validator.validate(cherry, &source, &target).await?;
99        if !report.ok {
100            return Err(BranchError::CommitValidationFailed {
101                branch_id: cherry.target_branch_id,
102                violations: report.violations,
103            });
104        }
105
106        let mut tx = self.target_pool.begin().await?;
107        let mut committed_entity_count = 0u32;
108        let mut fields_updated = 0u32;
109        let mut all_entity_ids: Vec<String> = Vec::new();
110
111        for sel in &cherry.entity_selections {
112            let source_map = fetch_all_entities(&self.source_pool, &sel.entity_type).await?;
113
114            let ids_to_process: Vec<&String> = if sel.entity_ids.is_empty() {
115                source_map.keys().collect()
116            } else {
117                sel.entity_ids.iter().collect()
118            };
119
120            for entity_id in ids_to_process {
121                let source_val = match source_map.get(entity_id) {
122                    Some(v) => v.clone(),
123                    None => continue,
124                };
125
126                let final_val = if let Some(fields) = &sel.fields {
127                    // Merge only specified fields into existing target entity.
128                    let target_map =
129                        fetch_all_entities(&self.target_pool, &sel.entity_type).await?;
130                    let mut merged = target_map
131                        .get(entity_id)
132                        .cloned()
133                        .unwrap_or_else(|| serde_json::Value::Object(serde_json::Map::new()));
134                    if let (Some(merged_obj), Some(source_obj)) =
135                        (merged.as_object_mut(), source_val.as_object())
136                    {
137                        for f in fields {
138                            if let Some(v) = source_obj.get(f) {
139                                merged_obj.insert(f.clone(), v.clone());
140                                fields_updated += 1;
141                            }
142                        }
143                    }
144                    merged
145                } else {
146                    let field_count = source_val.as_object().map(|o| o.len()).unwrap_or(0);
147                    fields_updated += field_count as u32;
148                    source_val
149                };
150
151                upsert_entity_tx(&mut tx, sel.entity_type.table_name(), entity_id, &final_val)
152                    .await?;
153                committed_entity_count += 1;
154                all_entity_ids.push(entity_id.clone());
155            }
156        }
157
158        tx.commit().await?;
159
160        // Record commit log entry.
161        let entry = CommitLogEntry {
162            id: Uuid::new_v4(),
163            branch_id: cherry.target_branch_id,
164            entity_type: cherry
165                .entity_selections
166                .first()
167                .map(|s| s.entity_type.clone()),
168            entity_ids: all_entity_ids,
169            op_kind: "cherry_pick".to_string(),
170            committed_at: Utc::now(),
171            message: cherry.message.clone(),
172        };
173        self.store.insert_commit_log(&entry).await?;
174
175        Ok(CommitResult {
176            committed_entity_count,
177            fields_updated,
178            duration_ms: started.elapsed().as_millis() as u64,
179            target_branch_id: cherry.target_branch_id,
180            committed_at: entry.committed_at,
181        })
182    }
183
184    /// Promotes all divergent entities from `source_branch_id` into `target_branch_id`.
185    ///
186    /// This is a full merge-commit that uses all entity types.
187    pub async fn commit_all(
188        &self,
189        source_branch_id: Uuid,
190        target_branch_id: Uuid,
191    ) -> BranchResult<CommitResult> {
192        let cherry = CherryPick {
193            source_branch_id,
194            target_branch_id,
195            entity_selections: vec![
196                EntitySelection {
197                    entity_type: EntityType::MemoryRecord,
198                    entity_ids: Vec::new(),
199                    fields: None,
200                },
201                EntitySelection {
202                    entity_type: EntityType::Session,
203                    entity_ids: Vec::new(),
204                    fields: None,
205                },
206                EntitySelection {
207                    entity_type: EntityType::ToolOutput,
208                    entity_ids: Vec::new(),
209                    fields: None,
210                },
211            ],
212            message: Some("commit_all".to_string()),
213        };
214        self.commit(&cherry).await
215    }
216}
217
218// ── SQL helpers ──────────────────────────────────────────────────────────────
219
220async fn upsert_entity_tx(
221    tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>,
222    table: &str,
223    entity_id: &str,
224    value: &serde_json::Value,
225) -> BranchResult<()> {
226    let obj = match value.as_object() {
227        Some(o) => o,
228        None => return Ok(()),
229    };
230
231    let mut columns: Vec<String> = vec!["id".to_string()];
232    let mut values: Vec<Option<String>> = vec![Some(entity_id.to_string())];
233
234    for (k, v) in obj {
235        if k != "id" {
236            columns.push(k.clone());
237            values.push(json_to_str(v));
238        }
239    }
240
241    let col_list = columns.join(", ");
242    let placeholders = columns.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
243    let sql = format!("INSERT OR REPLACE INTO {table} ({col_list}) VALUES ({placeholders})");
244
245    let mut args = sqlx::sqlite::SqliteArguments::default();
246    for v in &values {
247        args.add(v.clone())
248            .map_err(|error| BranchError::InvalidConfig(format!("invalid sqlite arg: {error}")))?;
249    }
250    sqlx::query_with(&sql, args).execute(&mut **tx).await?;
251    Ok(())
252}
253
254fn json_to_str(v: &serde_json::Value) -> Option<String> {
255    match v {
256        serde_json::Value::Null => None,
257        serde_json::Value::Bool(b) => Some(if *b { "1" } else { "0" }.to_string()),
258        serde_json::Value::Number(n) => Some(n.to_string()),
259        serde_json::Value::String(s) => Some(s.clone()),
260        serde_json::Value::Array(_) | serde_json::Value::Object(_) => Some(v.to_string()),
261    }
262}
263
264async fn open_pool(path: &std::path::Path, read_only: bool) -> BranchResult<SqlitePool> {
265    SqlitePoolOptions::new()
266        .max_connections(2)
267        .connect_with(
268            SqliteConnectOptions::new()
269                .filename(path)
270                .create_if_missing(false)
271                .read_only(read_only)
272                .journal_mode(SqliteJournalMode::Wal),
273        )
274        .await
275        .map_err(BranchError::Database)
276}