Skip to main content

flake_edit/
walk.rs

1//! AST walking and manipulation for flake.nix files.
2
3mod context;
4mod error;
5mod inputs;
6mod node;
7mod outputs;
8
9use std::collections::HashMap;
10
11use rnix::{Root, SyntaxKind, SyntaxNode};
12
13use crate::change::Change;
14use crate::edit::{OutputChange, Outputs};
15use crate::input::Input;
16
17pub use context::Context;
18pub use error::WalkerError;
19
20use inputs::walk_inputs;
21use node::{
22    adjacent_whitespace_index, get_sibling_whitespace, insertion_index_after, make_quoted_string,
23    make_toplevel_flake_false_attr, make_toplevel_nested_follows_attr, make_toplevel_url_attr,
24    parse_node, substitute_child,
25};
26
27#[derive(Debug, Clone)]
28pub struct Walker {
29    pub root: SyntaxNode,
30    pub inputs: HashMap<String, Input>,
31    pub add_toplevel: bool,
32}
33
34impl<'a> Walker {
35    pub fn new(stream: &'a str) -> Self {
36        let root = Root::parse(stream).syntax();
37        Self {
38            root,
39            inputs: HashMap::new(),
40            add_toplevel: false,
41        }
42    }
43
44    /// Traverse the toplevel `flake.nix` file.
45    /// It should consist of three attribute keys:
46    /// - description
47    /// - inputs
48    /// - outputs
49    pub fn walk(&mut self, change: &Change) -> Result<Option<SyntaxNode>, WalkerError> {
50        let cst = self.root.clone();
51        if cst.kind() != SyntaxKind::NODE_ROOT {
52            return Err(WalkerError::NotARoot(cst.kind()));
53        }
54        self.walk_toplevel(cst, None, change)
55    }
56
57    /// Only walk the outputs attribute
58    pub(crate) fn list_outputs(&mut self) -> Result<Outputs, WalkerError> {
59        outputs::list_outputs(&self.root)
60    }
61
62    /// Only change the outputs attribute
63    pub(crate) fn change_outputs(
64        &mut self,
65        change: OutputChange,
66    ) -> Result<Option<SyntaxNode>, WalkerError> {
67        outputs::change_outputs(&self.root, change)
68    }
69
70    /// Traverse the toplevel `flake.nix` file.
71    fn walk_toplevel(
72        &mut self,
73        node: SyntaxNode,
74        ctx: Option<Context>,
75        change: &Change,
76    ) -> Result<Option<SyntaxNode>, WalkerError> {
77        let Some(attr_set) = node.first_child() else {
78            return Ok(None);
79        };
80
81        for toplevel in attr_set.children() {
82            if toplevel.kind() != SyntaxKind::NODE_ATTRPATH_VALUE {
83                return Err(WalkerError::UnexpectedNodeKind {
84                    expected: SyntaxKind::NODE_ATTRPATH_VALUE,
85                    found: toplevel.kind(),
86                });
87            }
88
89            for child in toplevel.children() {
90                let child_str = child.to_string();
91
92                if child_str == "description" {
93                    break;
94                }
95
96                if child_str == "inputs" {
97                    if let Some(result) = self.handle_inputs_attr(&toplevel, &child, &ctx, change) {
98                        return Ok(Some(result));
99                    }
100                    continue;
101                }
102
103                if child_str.starts_with("inputs") {
104                    if let Some(result) =
105                        self.handle_inputs_flat(&attr_set, &toplevel, &child, &ctx, change)
106                    {
107                        return Ok(Some(result));
108                    }
109                    continue;
110                }
111
112                if child_str == "outputs"
113                    && let Some(result) = self.handle_add_at_outputs(&attr_set, &toplevel, change)
114                {
115                    return Ok(Some(result));
116                }
117            }
118        }
119
120        // Handle follows for toplevel flat-style inputs (inputs.X.url = "...")
121        if let Change::Follows { input, target } = change
122            && let Some(nested_id) = input.follows()
123        {
124            let parent_id = input.input();
125            if self.inputs.contains_key(parent_id) {
126                return self.handle_follows_flat_toplevel(&attr_set, parent_id, nested_id, target);
127            }
128        }
129
130        Ok(None)
131    }
132
133    /// Handle adding follows to a toplevel flat-style input.
134    ///
135    /// Converts `inputs.crane.url = "github:...";` into:
136    /// ```nix
137    /// inputs.crane.url = "github:...";
138    /// inputs.crane.inputs.nixpkgs.follows = "nixpkgs";
139    /// ```
140    fn handle_follows_flat_toplevel(
141        &self,
142        attr_set: &SyntaxNode,
143        parent_id: &str,
144        nested_id: &str,
145        target: &str,
146    ) -> Result<Option<SyntaxNode>, WalkerError> {
147        let mut last_parent_attr: Option<SyntaxNode> = None;
148        let mut has_nested_block = false;
149
150        for toplevel in attr_set.children() {
151            if toplevel.kind() != SyntaxKind::NODE_ATTRPATH_VALUE {
152                continue;
153            }
154            let Some(attrpath) = toplevel
155                .children()
156                .find(|c| c.kind() == SyntaxKind::NODE_ATTRPATH)
157            else {
158                continue;
159            };
160            let idents: Vec<String> = attrpath.children().map(|c| c.to_string()).collect();
161
162            // Detect `inputs.{parent_id} = { ... }` block style
163            if idents.len() == 2
164                && idents[0] == "inputs"
165                && idents[1] == parent_id
166                && toplevel
167                    .children()
168                    .any(|c| c.kind() == SyntaxKind::NODE_ATTR_SET)
169            {
170                has_nested_block = true;
171            }
172
173            // Check for existing follows: inputs.{parent_id}.inputs.{nested_id}.follows
174            if idents.len() == 5
175                && idents[0] == "inputs"
176                && idents[1] == parent_id
177                && idents[2] == "inputs"
178                && idents[3] == nested_id
179                && idents[4] == "follows"
180            {
181                let value_node = attrpath.next_sibling();
182                let current_target = value_node
183                    .as_ref()
184                    .map(|v| v.to_string().trim_matches('"').to_string())
185                    .unwrap_or_default();
186
187                if current_target == target {
188                    // Same target, no-op
189                    return Ok(Some(parse_node(&attr_set.parent().unwrap().to_string())));
190                }
191                // Different target, retarget
192                if let Some(value) = value_node {
193                    let new_value = make_quoted_string(target);
194                    let new_toplevel = substitute_child(&toplevel, value.index(), &new_value);
195                    let green = attr_set
196                        .green()
197                        .replace_child(toplevel.index(), new_toplevel.green().into());
198                    return Ok(Some(parse_node(&attr_set.replace_with(green).to_string())));
199                }
200            }
201
202            // Track last inputs.{parent_id}.* attribute
203            if idents.len() >= 2 && idents[0] == "inputs" && idents[1] == parent_id {
204                last_parent_attr = Some(toplevel.clone());
205            }
206        }
207
208        // No existing follows, insert after the last parent attribute
209        // Only for flat-style inputs (not block-style which should be handled elsewhere)
210        if !has_nested_block && let Some(ref_child) = last_parent_attr {
211            let follows_node = make_toplevel_nested_follows_attr(parent_id, nested_id, target);
212            let insert_index = insertion_index_after(&ref_child);
213
214            let mut green = attr_set
215                .green()
216                .insert_child(insert_index, follows_node.green().into());
217
218            if let Some(whitespace) = get_sibling_whitespace(&ref_child) {
219                let ws_str = whitespace.to_string();
220                let normalized = if let Some(last_nl) = ws_str.rfind('\n') {
221                    &ws_str[last_nl..]
222                } else {
223                    &ws_str
224                };
225                let ws_node = parse_node(normalized);
226                green = green.insert_child(insert_index, ws_node.green().into());
227            }
228
229            return Ok(Some(parse_node(&attr_set.replace_with(green).to_string())));
230        }
231
232        Ok(None)
233    }
234
235    /// Handle `inputs = { ... }` attribute.
236    ///
237    /// `toplevel.replace_with()` propagates through NODE_ATTR_SET up to
238    /// NODE_ROOT, preserving any leading comments/trivia.
239    fn handle_inputs_attr(
240        &mut self,
241        toplevel: &SyntaxNode,
242        child: &SyntaxNode,
243        ctx: &Option<Context>,
244        change: &Change,
245    ) -> Option<SyntaxNode> {
246        let sibling = child.next_sibling()?;
247        let replacement = walk_inputs(&mut self.inputs, sibling.clone(), ctx, change)?;
248
249        let green = toplevel
250            .green()
251            .replace_child(sibling.index(), replacement.green().into());
252        let green = toplevel.replace_with(green);
253        Some(parse_node(&green.to_string()))
254    }
255
256    /// Handle flat-style `inputs.foo.url = "..."` attributes.
257    ///
258    /// For removals, builds the modified attr_set green and uses
259    /// `replace_with()` to propagate to NODE_ROOT.
260    /// For replacements, `toplevel.replace_with()` propagates naturally.
261    fn handle_inputs_flat(
262        &mut self,
263        attr_set: &SyntaxNode,
264        toplevel: &SyntaxNode,
265        child: &SyntaxNode,
266        ctx: &Option<Context>,
267        change: &Change,
268    ) -> Option<SyntaxNode> {
269        let replacement = walk_inputs(&mut self.inputs, child.clone(), ctx, change)?;
270
271        // If replacement is empty, remove the entire toplevel node and
272        // propagate through attr_set to NODE_ROOT.
273        if replacement.to_string().is_empty() {
274            let element: rnix::SyntaxElement = toplevel.clone().into();
275            let mut green = attr_set.green().remove_child(toplevel.index());
276            if let Some(ws_index) = adjacent_whitespace_index(&element) {
277                green = green.remove_child(ws_index);
278            }
279            return Some(parse_node(&attr_set.replace_with(green).to_string()));
280        }
281
282        let sibling = child.next_sibling()?;
283        let green = toplevel
284            .green()
285            .replace_child(sibling.index(), replacement.green().into());
286        let green = toplevel.replace_with(green);
287        Some(parse_node(&green.to_string()))
288    }
289
290    /// Handle adding inputs when we've reached `outputs` but have no inputs yet.
291    ///
292    /// Builds the modified attr_set green and uses `replace_with()` to
293    /// propagate to NODE_ROOT, preserving leading comments.
294    fn handle_add_at_outputs(
295        &mut self,
296        attr_set: &SyntaxNode,
297        toplevel: &SyntaxNode,
298        change: &Change,
299    ) -> Option<SyntaxNode> {
300        if !self.add_toplevel {
301            return None;
302        }
303
304        let Change::Add {
305            id: Some(id),
306            uri: Some(uri),
307            flake,
308        } = change
309        else {
310            return None;
311        };
312
313        if toplevel.index() == 0 {
314            return None;
315        }
316
317        // Find normalized whitespace (single newline + indent) by walking back
318        // from `outputs` through tokens. This handles comments between the last
319        // input and outputs correctly.
320        let ws_node = {
321            let mut ws: Option<SyntaxNode> = None;
322            let mut cursor = toplevel.prev_sibling_or_token();
323            while let Some(ref tok) = cursor {
324                if tok.kind() == SyntaxKind::TOKEN_WHITESPACE {
325                    let ws_str = tok.to_string();
326                    let normalized = if let Some(last_nl) = ws_str.rfind('\n') {
327                        &ws_str[last_nl..]
328                    } else {
329                        &ws_str
330                    };
331                    ws = Some(parse_node(normalized));
332                    break;
333                }
334                cursor = tok.prev_sibling_or_token();
335            }
336            ws
337        };
338
339        let addition = make_toplevel_url_attr(id, uri);
340        let insert_pos = toplevel.index() - 1;
341
342        let mut green = attr_set
343            .green()
344            .insert_child(insert_pos, addition.green().into());
345
346        if let Some(ref ws) = ws_node {
347            green = green.insert_child(insert_pos, ws.green().into());
348        }
349
350        // Add flake=false if needed
351        if !flake {
352            let no_flake = make_toplevel_flake_false_attr(id);
353            green = green.insert_child(toplevel.index() + 1, no_flake.green().into());
354
355            if let Some(ref ws) = ws_node {
356                green = green.insert_child(toplevel.index() + 1, ws.green().into());
357            }
358        }
359
360        Some(parse_node(&attr_set.replace_with(green).to_string()))
361    }
362}