1use crate::error::AqlError;
7use crate::types::{ProjectRoot, RelativePath};
8use amql_mutate::{insert_source, remove_node, replace_node, InsertPosition, NodeRef};
9use rustc_hash::FxHashMap;
10use serde::{Deserialize, Serialize};
11
12#[non_exhaustive]
14#[derive(Debug, Clone, Serialize, Deserialize)]
15#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
16#[serde(tag = "type", rename_all = "snake_case")]
17pub enum TransactionCondition {
18 ReadContains { node: NodeRef, text: String },
20}
21
22#[non_exhaustive]
28#[derive(Debug, Clone, Serialize, Deserialize)]
29#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
30#[serde(tag = "type", rename_all = "snake_case")]
31pub enum TransactionOp {
32 Replace { node: NodeRef, source: String },
34 Insert {
36 target: NodeRef,
38 position: InsertPosition,
40 source: String,
42 },
43 Remove { node: NodeRef },
45 If {
47 condition: TransactionCondition,
49 then: Vec<TransactionOp>,
51 #[serde(default, rename = "else")]
53 else_: Vec<TransactionOp>,
54 },
55}
56
57#[derive(Debug, Clone, Serialize)]
59#[serde(rename_all = "camelCase")]
60pub struct TransactionResult {
61 pub files_modified: Vec<RelativePath>,
63 pub ops_applied: usize,
65}
66
67fn validate_path(file: &RelativePath) -> Result<(), AqlError> {
69 let s: &str = file;
70 if s.contains("..") || std::path::Path::new(s).is_absolute() {
71 return Err(AqlError::new(format!(
72 "Path '{s}' is not allowed: must be relative and must not contain '..'"
73 )));
74 }
75 Ok(())
76}
77
78pub fn execute_transaction(
85 root: &ProjectRoot,
86 ops: Vec<TransactionOp>,
87) -> Result<TransactionResult, AqlError> {
88 validate_ops(&ops)?;
89 let mut buffer: FxHashMap<RelativePath, String> = FxHashMap::default();
90 let mut ops_applied: usize = 0;
91 run_ops(root, &ops, &mut buffer, &mut ops_applied)?;
92
93 let mut files_modified: Vec<RelativePath> = Vec::new();
94 for (file, content) in &buffer {
95 let path = root.join(AsRef::<std::path::Path>::as_ref(file));
96 std::fs::write(&path, content)
97 .map_err(|e| AqlError::new(format!("Failed to write {}: {e}", path.display())))?;
98 files_modified.push(file.clone());
99 }
100 files_modified.sort_by(|a, b| {
101 let a_str: &str = a;
102 let b_str: &str = b;
103 a_str.cmp(b_str)
104 });
105
106 Ok(TransactionResult {
107 files_modified,
108 ops_applied,
109 })
110}
111
112fn validate_ops(ops: &[TransactionOp]) -> Result<(), AqlError> {
114 for op in ops {
115 match op {
116 TransactionOp::Replace { node, .. } => validate_path(&node.file)?,
117 TransactionOp::Insert { target, .. } => validate_path(&target.file)?,
118 TransactionOp::Remove { node } => validate_path(&node.file)?,
119 TransactionOp::If {
120 condition,
121 then,
122 else_,
123 } => {
124 match condition {
125 TransactionCondition::ReadContains { node, .. } => validate_path(&node.file)?,
126 }
127 validate_ops(then)?;
128 validate_ops(else_)?;
129 }
130 }
131 }
132 Ok(())
133}
134
135fn read_from_buffer(
136 root: &ProjectRoot,
137 buffer: &FxHashMap<RelativePath, String>,
138 file: &RelativePath,
139) -> Result<String, AqlError> {
140 if let Some(s) = buffer.get(file) {
141 return Ok(s.clone());
142 }
143 let path = root.join(AsRef::<std::path::Path>::as_ref(file));
144 std::fs::read_to_string(&path)
145 .map_err(|e| AqlError::new(format!("Failed to read {}: {e}", path.display())))
146}
147
148fn run_ops(
149 root: &ProjectRoot,
150 ops: &[TransactionOp],
151 buffer: &mut FxHashMap<RelativePath, String>,
152 ops_applied: &mut usize,
153) -> Result<(), AqlError> {
154 for op in ops {
155 match op {
156 TransactionOp::Replace { node, source } => {
157 let current = read_from_buffer(root, buffer, &node.file)?;
158 let result =
159 replace_node(¤t, &node.file, node, source).map_err(AqlError::from)?;
160 buffer.insert(node.file.clone(), result.source);
161 *ops_applied += 1;
162 }
163 TransactionOp::Insert {
164 target,
165 position,
166 source,
167 } => {
168 let current = read_from_buffer(root, buffer, &target.file)?;
169 let result = insert_source(¤t, &target.file, target, *position, source)
170 .map_err(AqlError::from)?;
171 buffer.insert(target.file.clone(), result.source);
172 *ops_applied += 1;
173 }
174 TransactionOp::Remove { node } => {
175 let current = read_from_buffer(root, buffer, &node.file)?;
176 let removal = remove_node(¤t, node).map_err(AqlError::from)?;
177 buffer.insert(node.file.clone(), removal.result.source);
178 *ops_applied += 1;
179 }
180 TransactionOp::If {
181 condition,
182 then,
183 else_,
184 } => {
185 let branch_taken = eval_condition(root, buffer, condition)?;
186 let branch = if branch_taken { then } else { else_ };
187 run_ops(root, branch, buffer, ops_applied)?;
188 }
189 }
190 }
191 Ok(())
192}
193
194fn eval_condition(
195 root: &ProjectRoot,
196 buffer: &FxHashMap<RelativePath, String>,
197 condition: &TransactionCondition,
198) -> Result<bool, AqlError> {
199 match condition {
200 TransactionCondition::ReadContains { node, text } => {
201 let source = read_from_buffer(root, buffer, &node.file)?;
202 let node_text = source.get(node.start_byte..node.end_byte).ok_or_else(|| {
203 AqlError::new(format!(
204 "Byte range {}..{} out of bounds for {}",
205 node.start_byte, node.end_byte, node.file
206 ))
207 })?;
208 Ok(node_text.contains(text.as_str()))
209 }
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use std::io::Write as _;
217 use tempfile::TempDir;
218
219 fn setup(content: &str) -> (TempDir, ProjectRoot, RelativePath) {
220 let dir = TempDir::new().unwrap();
221 let file_path = dir.path().join("test.ts");
222 let mut f = std::fs::File::create(&file_path).unwrap();
223 f.write_all(content.as_bytes()).unwrap();
224 let root = ProjectRoot::from(dir.path());
225 let rel = RelativePath::from("test.ts");
226 (dir, root, rel)
227 }
228
229 fn node_for_range(file: &RelativePath, start: usize, end: usize) -> NodeRef {
230 NodeRef {
232 file: file.clone(),
233 start_byte: start,
234 end_byte: end,
235 kind: "identifier".into(),
236 line: 1,
237 column: 0,
238 end_line: 1,
239 end_column: 0,
240 }
241 }
242
243 #[test]
244 fn commits_replace_to_disk() {
245 let (_dir, root, rel) = setup("hello world");
247 let node = node_for_range(&rel, 6, 11);
248 let ops = vec![TransactionOp::Replace {
249 node,
250 source: "AQL".to_string(),
251 }];
252
253 let result = execute_transaction(&root, ops).unwrap();
255
256 let on_disk = std::fs::read_to_string(root.join(std::path::Path::new("test.ts"))).unwrap();
258 assert_eq!(on_disk, "hello AQL", "file should be rewritten");
259 assert_eq!(result.ops_applied, 1, "one op applied");
260 assert_eq!(result.files_modified, vec![rel], "one file modified");
261 }
262
263 #[test]
264 fn rolls_back_on_bad_range() {
265 let (_dir, root, rel) = setup("hello");
267 let node = node_for_range(&rel, 100, 200); let ops = vec![TransactionOp::Replace {
269 node,
270 source: "x".to_string(),
271 }];
272
273 let err = execute_transaction(&root, ops);
275
276 assert!(err.is_err(), "should fail on bad range");
278 let on_disk = std::fs::read_to_string(root.join(std::path::Path::new("test.ts"))).unwrap();
279 assert_eq!(on_disk, "hello", "file must be unchanged after rollback");
280 }
281
282 #[test]
283 fn if_condition_branches_correctly() {
284 let (_dir, root, rel) = setup("fn foo() {}");
286 let node = node_for_range(&rel, 0, 11);
287 let ops = vec![TransactionOp::If {
288 condition: TransactionCondition::ReadContains {
289 node: node.clone(),
290 text: "foo".to_string(),
291 },
292 then: vec![TransactionOp::Replace {
293 node,
294 source: "fn bar() {}".to_string(),
295 }],
296 else_: vec![],
297 }];
298
299 let result = execute_transaction(&root, ops).unwrap();
301
302 let on_disk = std::fs::read_to_string(root.join(std::path::Path::new("test.ts"))).unwrap();
304 assert_eq!(on_disk, "fn bar() {}", "then branch should have been taken");
305 assert_eq!(result.ops_applied, 1, "one mutation applied");
306 }
307}