Skip to main content

amql_engine/
transact.rs

1//! Declarative multi-file transactional mutations.
2//!
3//! All operations are applied to an in-memory buffer. On success the buffer is
4//! flushed to disk. On any error nothing is written — the transaction rolls back.
5
6use crate::error::AqlError;
7use crate::types::{ProjectRoot, RelativePath};
8use amql_mutate::{insert_source, remove_node, replace_node, InsertPosition, NodeRef};
9use rustc_hash::FxHashMap;
10use serde::{Deserialize, Serialize};
11
12/// A condition evaluated against the in-memory buffer during a transaction.
13#[non_exhaustive]
14#[derive(Debug, Clone, Serialize, Deserialize)]
15#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
16#[serde(tag = "type", rename_all = "snake_case")]
17pub enum TransactionCondition {
18    /// True if the node's current source text contains `text` as a substring.
19    ReadContains { node: NodeRef, text: String },
20}
21
22/// A single operation within a transaction.
23///
24/// Ops are applied sequentially to an in-memory buffer. When mutating the
25/// same file multiple times, order ops by descending `start_byte` to avoid
26/// byte-offset drift.
27#[non_exhaustive]
28#[derive(Debug, Clone, Serialize, Deserialize)]
29#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
30#[serde(tag = "type", rename_all = "snake_case")]
31pub enum TransactionOp {
32    /// Replace a node's source text.
33    Replace { node: NodeRef, source: String },
34    /// Insert source text relative to a target node.
35    Insert {
36        /// The node to insert relative to.
37        target: NodeRef,
38        /// Where to insert: `before`, `after`, or `into`.
39        position: InsertPosition,
40        /// Source text to insert.
41        source: String,
42    },
43    /// Remove a node.
44    Remove { node: NodeRef },
45    /// Execute `then` if `condition` holds against the current buffer, else `else`.
46    If {
47        /// Condition to evaluate against the current in-memory buffer.
48        condition: TransactionCondition,
49        /// Ops to run if the condition is true.
50        then: Vec<TransactionOp>,
51        /// Ops to run if the condition is false. Empty means no-op.
52        #[serde(default, rename = "else")]
53        else_: Vec<TransactionOp>,
54    },
55}
56
57/// Result of a successfully committed transaction.
58#[derive(Debug, Clone, Serialize)]
59#[serde(rename_all = "camelCase")]
60pub struct TransactionResult {
61    /// Files written to disk.
62    pub files_modified: Vec<RelativePath>,
63    /// Total mutation operations applied (excluding `If` branch nodes themselves).
64    pub ops_applied: usize,
65}
66
67/// Reject any `RelativePath` that could escape the project root.
68fn validate_path(file: &RelativePath) -> Result<(), AqlError> {
69    let s: &str = file;
70    if s.contains("..") || std::path::Path::new(s).is_absolute() {
71        return Err(AqlError::new(format!(
72            "Path '{s}' is not allowed: must be relative and must not contain '..'"
73        )));
74    }
75    Ok(())
76}
77
78/// Execute a declarative transaction: apply `ops` to an in-memory buffer, then
79/// flush to disk only if all ops succeed. On error nothing is written.
80///
81/// Note: the atomicity guarantee covers the mutation phase (buffer). The write
82/// phase flushes files one at a time — a disk failure mid-flush leaves some
83/// files written and some not. Use the WAL-backed variant for full crash safety.
84pub fn execute_transaction(
85    root: &ProjectRoot,
86    ops: Vec<TransactionOp>,
87) -> Result<TransactionResult, AqlError> {
88    validate_ops(&ops)?;
89    let mut buffer: FxHashMap<RelativePath, String> = FxHashMap::default();
90    let mut ops_applied: usize = 0;
91    run_ops(root, &ops, &mut buffer, &mut ops_applied)?;
92
93    let mut files_modified: Vec<RelativePath> = Vec::new();
94    for (file, content) in &buffer {
95        let path = root.join(AsRef::<std::path::Path>::as_ref(file));
96        std::fs::write(&path, content)
97            .map_err(|e| AqlError::new(format!("Failed to write {}: {e}", path.display())))?;
98        files_modified.push(file.clone());
99    }
100    files_modified.sort_by(|a, b| {
101        let a_str: &str = a;
102        let b_str: &str = b;
103        a_str.cmp(b_str)
104    });
105
106    Ok(TransactionResult {
107        files_modified,
108        ops_applied,
109    })
110}
111
112/// Validate all file paths in the op tree before execution begins.
113fn validate_ops(ops: &[TransactionOp]) -> Result<(), AqlError> {
114    for op in ops {
115        match op {
116            TransactionOp::Replace { node, .. } => validate_path(&node.file)?,
117            TransactionOp::Insert { target, .. } => validate_path(&target.file)?,
118            TransactionOp::Remove { node } => validate_path(&node.file)?,
119            TransactionOp::If {
120                condition,
121                then,
122                else_,
123            } => {
124                match condition {
125                    TransactionCondition::ReadContains { node, .. } => validate_path(&node.file)?,
126                }
127                validate_ops(then)?;
128                validate_ops(else_)?;
129            }
130        }
131    }
132    Ok(())
133}
134
135fn read_from_buffer(
136    root: &ProjectRoot,
137    buffer: &FxHashMap<RelativePath, String>,
138    file: &RelativePath,
139) -> Result<String, AqlError> {
140    if let Some(s) = buffer.get(file) {
141        return Ok(s.clone());
142    }
143    let path = root.join(AsRef::<std::path::Path>::as_ref(file));
144    std::fs::read_to_string(&path)
145        .map_err(|e| AqlError::new(format!("Failed to read {}: {e}", path.display())))
146}
147
148fn run_ops(
149    root: &ProjectRoot,
150    ops: &[TransactionOp],
151    buffer: &mut FxHashMap<RelativePath, String>,
152    ops_applied: &mut usize,
153) -> Result<(), AqlError> {
154    for op in ops {
155        match op {
156            TransactionOp::Replace { node, source } => {
157                let current = read_from_buffer(root, buffer, &node.file)?;
158                let result =
159                    replace_node(&current, &node.file, node, source).map_err(AqlError::from)?;
160                buffer.insert(node.file.clone(), result.source);
161                *ops_applied += 1;
162            }
163            TransactionOp::Insert {
164                target,
165                position,
166                source,
167            } => {
168                let current = read_from_buffer(root, buffer, &target.file)?;
169                let result = insert_source(&current, &target.file, target, *position, source)
170                    .map_err(AqlError::from)?;
171                buffer.insert(target.file.clone(), result.source);
172                *ops_applied += 1;
173            }
174            TransactionOp::Remove { node } => {
175                let current = read_from_buffer(root, buffer, &node.file)?;
176                let removal = remove_node(&current, node).map_err(AqlError::from)?;
177                buffer.insert(node.file.clone(), removal.result.source);
178                *ops_applied += 1;
179            }
180            TransactionOp::If {
181                condition,
182                then,
183                else_,
184            } => {
185                let branch_taken = eval_condition(root, buffer, condition)?;
186                let branch = if branch_taken { then } else { else_ };
187                run_ops(root, branch, buffer, ops_applied)?;
188            }
189        }
190    }
191    Ok(())
192}
193
194fn eval_condition(
195    root: &ProjectRoot,
196    buffer: &FxHashMap<RelativePath, String>,
197    condition: &TransactionCondition,
198) -> Result<bool, AqlError> {
199    match condition {
200        TransactionCondition::ReadContains { node, text } => {
201            let source = read_from_buffer(root, buffer, &node.file)?;
202            let node_text = source.get(node.start_byte..node.end_byte).ok_or_else(|| {
203                AqlError::new(format!(
204                    "Byte range {}..{} out of bounds for {}",
205                    node.start_byte, node.end_byte, node.file
206                ))
207            })?;
208            Ok(node_text.contains(text.as_str()))
209        }
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use std::io::Write as _;
217    use tempfile::TempDir;
218
219    fn setup(content: &str) -> (TempDir, ProjectRoot, RelativePath) {
220        let dir = TempDir::new().unwrap();
221        let file_path = dir.path().join("test.ts");
222        let mut f = std::fs::File::create(&file_path).unwrap();
223        f.write_all(content.as_bytes()).unwrap();
224        let root = ProjectRoot::from(dir.path());
225        let rel = RelativePath::from("test.ts");
226        (dir, root, rel)
227    }
228
229    fn node_for_range(file: &RelativePath, start: usize, end: usize) -> NodeRef {
230        // line/column are metadata only; mutation ops use start_byte/end_byte exclusively.
231        NodeRef {
232            file: file.clone(),
233            start_byte: start,
234            end_byte: end,
235            kind: "identifier".into(),
236            line: 1,
237            column: 0,
238            end_line: 1,
239            end_column: 0,
240        }
241    }
242
243    #[test]
244    fn commits_replace_to_disk() {
245        // Arrange
246        let (_dir, root, rel) = setup("hello world");
247        let node = node_for_range(&rel, 6, 11);
248        let ops = vec![TransactionOp::Replace {
249            node,
250            source: "AQL".to_string(),
251        }];
252
253        // Act
254        let result = execute_transaction(&root, ops).unwrap();
255
256        // Assert
257        let on_disk = std::fs::read_to_string(root.join(std::path::Path::new("test.ts"))).unwrap();
258        assert_eq!(on_disk, "hello AQL", "file should be rewritten");
259        assert_eq!(result.ops_applied, 1, "one op applied");
260        assert_eq!(result.files_modified, vec![rel], "one file modified");
261    }
262
263    #[test]
264    fn rolls_back_on_bad_range() {
265        // Arrange
266        let (_dir, root, rel) = setup("hello");
267        let node = node_for_range(&rel, 100, 200); // out of bounds
268        let ops = vec![TransactionOp::Replace {
269            node,
270            source: "x".to_string(),
271        }];
272
273        // Act
274        let err = execute_transaction(&root, ops);
275
276        // Assert
277        assert!(err.is_err(), "should fail on bad range");
278        let on_disk = std::fs::read_to_string(root.join(std::path::Path::new("test.ts"))).unwrap();
279        assert_eq!(on_disk, "hello", "file must be unchanged after rollback");
280    }
281
282    #[test]
283    fn if_condition_branches_correctly() {
284        // Arrange
285        let (_dir, root, rel) = setup("fn foo() {}");
286        let node = node_for_range(&rel, 0, 11);
287        let ops = vec![TransactionOp::If {
288            condition: TransactionCondition::ReadContains {
289                node: node.clone(),
290                text: "foo".to_string(),
291            },
292            then: vec![TransactionOp::Replace {
293                node,
294                source: "fn bar() {}".to_string(),
295            }],
296            else_: vec![],
297        }];
298
299        // Act
300        let result = execute_transaction(&root, ops).unwrap();
301
302        // Assert
303        let on_disk = std::fs::read_to_string(root.join(std::path::Path::new("test.ts"))).unwrap();
304        assert_eq!(on_disk, "fn bar() {}", "then branch should have been taken");
305        assert_eq!(result.ops_applied, 1, "one mutation applied");
306    }
307}