Skip to main content

vyre_foundation/visit/
traits.rs

1use crate::error::Result;
2use crate::ir_inner::model::expr::{Expr, GeneratorRef, Ident};
3use crate::ir_inner::model::generated::Node;
4use crate::ir_inner::model::node::NodeExtension;
5use crate::visit::VisitOrder;
6use smallvec::SmallVec;
7use std::ops::ControlFlow;
8
9/// Anything that can be lowered to a target representation.
10///
11/// Backends implement this trait for their target. The IR does not know
12/// what targets exist  -  it only knows that calling `.lower(&mut ctx)`
13/// walks the structure through the visitor contract.
14///
15/// # Errors
16///
17/// Backends report structured errors through their own context type.
18pub trait Lowerable<Ctx: ?Sized> {
19    /// Visit this IR structure and emit into the backend-specific context.
20    ///
21    /// # Errors
22    ///
23    /// Returns the backend context's structured error when lowering cannot
24    /// represent this IR structure.
25    fn lower(&self, ctx: &mut Ctx) -> Result<()>;
26}
27
28/// Anything that can be executed against a runtime environment.
29///
30/// The reference interpreter and each backend implement this trait. Two
31/// `Evaluatable` implementations that produce the same output for the
32/// same input + environment are certifiably equivalent under the
33/// conform contract.
34pub trait Evaluatable<Env: ?Sized> {
35    /// The value type the evaluator produces (typically `Value` for the
36    /// reference interpreter, a typed handle for GPU backends).
37    type Value;
38
39    /// Evaluate this IR structure against the environment.
40    ///
41    /// # Errors
42    ///
43    /// Returns the evaluator's structured error when the environment cannot
44    /// execute this IR structure.
45    fn evaluate(&self, env: &mut Env) -> Result<Self::Value>;
46}
47
48/// Visitor over [`Node`] trees.
49///
50/// Implementors must handle every core node variant explicitly. Like
51/// [`crate::visit::ExprVisitor`], this trait is abstract-by-default so
52/// adding a new node variant forces downstream code to make a conscious
53/// decision.
54///
55/// Traversal order is explicit:
56/// - [`visit_node_preorder`] visits the current node before nested nodes.
57/// - [`visit_node_postorder`] visits nested nodes before the current node.
58///
59/// `NodeVisitor` traverses node structure only. If a visitor also needs
60/// to recurse into node-owned expressions, it should pair this trait
61/// with [`crate::visit::ExprVisitor`] and call the expression entry
62/// points from the relevant node hooks.
63pub trait NodeVisitor {
64    /// Break payload returned when traversal short-circuits.
65    type Break;
66
67    /// Variable declaration.
68    fn visit_let(&mut self, node: &Node, name: &Ident, value: &Expr) -> ControlFlow<Self::Break>;
69    /// Variable assignment.
70    fn visit_assign(&mut self, node: &Node, name: &Ident, value: &Expr)
71        -> ControlFlow<Self::Break>;
72    /// Buffer store.
73    fn visit_store(
74        &mut self,
75        node: &Node,
76        buffer: &Ident,
77        index: &Expr,
78        value: &Expr,
79    ) -> ControlFlow<Self::Break>;
80    /// Conditional branch.
81    fn visit_if(
82        &mut self,
83        node: &Node,
84        cond: &Expr,
85        then_nodes: &[Node],
86        otherwise: &[Node],
87    ) -> ControlFlow<Self::Break>;
88    /// Counted loop.
89    fn visit_loop(
90        &mut self,
91        node: &Node,
92        var: &Ident,
93        from: &Expr,
94        to: &Expr,
95        body: &[Node],
96    ) -> ControlFlow<Self::Break>;
97    /// Indirect dispatch source.
98    fn visit_indirect_dispatch(
99        &mut self,
100        node: &Node,
101        count_buffer: &Ident,
102        count_offset: u64,
103    ) -> ControlFlow<Self::Break>;
104    /// Async load node.
105    fn visit_async_load(
106        &mut self,
107        node: &Node,
108        source: &Ident,
109        destination: &Ident,
110        offset: &Expr,
111        size: &Expr,
112        tag: &Ident,
113    ) -> ControlFlow<Self::Break>;
114    /// Async store node.
115    fn visit_async_store(
116        &mut self,
117        node: &Node,
118        source: &Ident,
119        destination: &Ident,
120        offset: &Expr,
121        size: &Expr,
122        tag: &Ident,
123    ) -> ControlFlow<Self::Break>;
124    /// Async wait node.
125    fn visit_async_wait(&mut self, node: &Node, tag: &Ident) -> ControlFlow<Self::Break>;
126    /// Trap node.
127    fn visit_trap(&mut self, node: &Node, address: &Expr, tag: &Ident) -> ControlFlow<Self::Break>;
128    /// Resume node.
129    fn visit_resume(&mut self, node: &Node, tag: &Ident) -> ControlFlow<Self::Break>;
130    /// Return node.
131    fn visit_return(&mut self, node: &Node) -> ControlFlow<Self::Break>;
132    /// Barrier node.
133    fn visit_barrier(&mut self, node: &Node) -> ControlFlow<Self::Break>;
134    /// Distributed collective node.
135    fn visit_collective(&mut self, node: &Node) -> ControlFlow<Self::Break> {
136        let _ = node;
137        ControlFlow::Continue(())
138    }
139    /// Block node.
140    fn visit_block(&mut self, node: &Node, body: &[Node]) -> ControlFlow<Self::Break>;
141    /// Region wrapper node.
142    fn visit_region(
143        &mut self,
144        node: &Node,
145        generator: &Ident,
146        source_region: &Option<GeneratorRef>,
147        body: &[Node],
148    ) -> ControlFlow<Self::Break>;
149    /// Downstream opaque node extension.
150    fn visit_opaque_node(
151        &mut self,
152        node: &Node,
153        extension: &dyn NodeExtension,
154    ) -> ControlFlow<Self::Break>;
155
156    /// Recursively walk this node's nested node children using the requested order.
157    fn walk_children_default(&mut self, node: &Node, order: VisitOrder) -> ControlFlow<Self::Break>
158    where
159        Self: Sized,
160    {
161        walk_node_children_default(self, node, order)
162    }
163}
164
165/// Visit a node tree in pre-order.
166pub fn visit_node<V: NodeVisitor>(visitor: &mut V, node: &Node) -> ControlFlow<V::Break> {
167    visit_node_preorder(visitor, node)
168}
169
170/// Visit a node tree in pre-order without recursive stack growth.
171pub fn visit_node_preorder<V: NodeVisitor>(visitor: &mut V, node: &Node) -> ControlFlow<V::Break> {
172    let mut stack = SmallVec::<[&Node; 32]>::new();
173    stack.push(node);
174    while let Some(current) = stack.pop() {
175        dispatch_node(visitor, current)?;
176        match current {
177            Node::If {
178                then, otherwise, ..
179            } => {
180                for n in otherwise.iter().rev() {
181                    stack.push(n);
182                }
183                for n in then.iter().rev() {
184                    stack.push(n);
185                }
186            }
187            Node::Loop { body, .. } | Node::Block(body) => {
188                for n in body.iter().rev() {
189                    stack.push(n);
190                }
191            }
192            Node::Region { body, .. } => {
193                for n in body.iter().rev() {
194                    stack.push(n);
195                }
196            }
197            _ => {}
198        }
199    }
200    ControlFlow::Continue(())
201}
202
203/// Visit a node tree in post-order without recursive stack growth.
204pub fn visit_node_postorder<V: NodeVisitor>(visitor: &mut V, node: &Node) -> ControlFlow<V::Break> {
205    enum Task<'a> {
206        Visit(&'a Node),
207        Dispatch(&'a Node),
208    }
209    let mut stack = SmallVec::<[Task<'_>; 32]>::new();
210    stack.push(Task::Visit(node));
211    while let Some(task) = stack.pop() {
212        match task {
213            Task::Visit(n) => {
214                stack.push(Task::Dispatch(n));
215                match n {
216                    Node::If {
217                        then, otherwise, ..
218                    } => {
219                        for child in otherwise.iter().rev() {
220                            stack.push(Task::Visit(child));
221                        }
222                        for child in then.iter().rev() {
223                            stack.push(Task::Visit(child));
224                        }
225                    }
226                    Node::Loop { body, .. } | Node::Block(body) => {
227                        for child in body.iter().rev() {
228                            stack.push(Task::Visit(child));
229                        }
230                    }
231                    Node::Region { body, .. } => {
232                        for child in body.iter().rev() {
233                            stack.push(Task::Visit(child));
234                        }
235                    }
236                    _ => {}
237                }
238            }
239            Task::Dispatch(n) => {
240                dispatch_node(visitor, n)?;
241            }
242        }
243    }
244    ControlFlow::Continue(())
245}
246
247/// Walk only the nested node children of `node`, leaving the current node to the caller.
248pub fn walk_node_children_default<V: NodeVisitor>(
249    visitor: &mut V,
250    node: &Node,
251    order: VisitOrder,
252) -> ControlFlow<V::Break> {
253    match node {
254        Node::If {
255            then, otherwise, ..
256        } => {
257            for child in then {
258                visit_node_with_order(visitor, child, order)?;
259            }
260            for child in otherwise {
261                visit_node_with_order(visitor, child, order)?;
262            }
263        }
264        Node::Loop { body, .. } | Node::Block(body) => {
265            for child in body {
266                visit_node_with_order(visitor, child, order)?;
267            }
268        }
269        Node::Region { body, .. } => {
270            for child in body.iter() {
271                visit_node_with_order(visitor, child, order)?;
272            }
273        }
274        _ => {}
275    }
276    ControlFlow::Continue(())
277}
278
279fn visit_node_with_order<V: NodeVisitor>(
280    visitor: &mut V,
281    node: &Node,
282    order: VisitOrder,
283) -> ControlFlow<V::Break> {
284    match order {
285        VisitOrder::Preorder => visit_node_preorder(visitor, node),
286        VisitOrder::Postorder => visit_node_postorder(visitor, node),
287    }
288}
289
290pub(crate) fn dispatch_node<V: NodeVisitor>(visitor: &mut V, node: &Node) -> ControlFlow<V::Break> {
291    match node {
292        Node::Let { name, value } => visitor.visit_let(node, name, value),
293        Node::Assign { name, value } => visitor.visit_assign(node, name, value),
294        Node::Store {
295            buffer,
296            index,
297            value,
298        } => visitor.visit_store(node, buffer, index, value),
299        Node::If {
300            cond,
301            then,
302            otherwise,
303        } => visitor.visit_if(node, cond, then, otherwise),
304        Node::Loop {
305            var,
306            from,
307            to,
308            body,
309        } => visitor.visit_loop(node, var, from, to, body),
310        Node::IndirectDispatch {
311            count_buffer,
312            count_offset,
313        } => visitor.visit_indirect_dispatch(node, count_buffer, *count_offset),
314        Node::AsyncLoad {
315            source,
316            destination,
317            offset,
318            size,
319            tag,
320        } => visitor.visit_async_load(node, source, destination, offset, size, tag),
321        Node::AsyncStore {
322            source,
323            destination,
324            offset,
325            size,
326            tag,
327        } => visitor.visit_async_store(node, source, destination, offset, size, tag),
328        Node::AsyncWait { tag } => visitor.visit_async_wait(node, tag),
329        Node::Trap { address, tag } => visitor.visit_trap(node, address, tag),
330        Node::Resume { tag } => visitor.visit_resume(node, tag),
331        Node::AllReduce { .. }
332        | Node::AllGather { .. }
333        | Node::ReduceScatter { .. }
334        | Node::Broadcast { .. } => visitor.visit_collective(node),
335        Node::Return => visitor.visit_return(node),
336        Node::Barrier { .. } => visitor.visit_barrier(node),
337        Node::Block(body) => visitor.visit_block(node, body),
338        Node::Region {
339            generator,
340            source_region,
341            body,
342        } => visitor.visit_region(node, generator, source_region, body),
343        Node::Opaque(extension) => visitor.visit_opaque_node(node, extension.as_ref()),
344    }
345}