Skip to main content

mlx_core/
backend.rs

1//! Backend trait and Stream — pluggable compute engine for tensor evaluation.
2//!
3//! A `Backend` knows how to execute a single graph node (op + inputs → output).
4//! A `Stream` binds a `Backend` to a lazy computation `Graph`, managing
5//! materialized buffers and evaluation scheduling.
6
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::{Arc, LazyLock, Mutex};
10
11use crate::Result;
12use crate::graph::{Graph, Node, NodeId, OpKind, TensorMeta};
13use crate::types::{DType, Shape};
14
15/// Materialized input data passed to a backend for evaluation.
16pub struct NodeInput<'a> {
17    pub data: &'a [f32],
18    pub shape: &'a Shape,
19    pub dtype: DType,
20}
21
22/// Pluggable compute backend.
23///
24/// Backends evaluate individual graph nodes. The `Stream` handles scheduling
25/// (topological sort) and buffer management; the backend only needs to
26/// implement the actual kernel dispatch.
27pub trait Backend: Send + Sync {
28    /// Evaluate a single op node given its materialized inputs.
29    fn eval_node(
30        &self,
31        op: &OpKind,
32        inputs: &[NodeInput<'_>],
33        output_meta: &TensorMeta,
34    ) -> Result<Vec<f32>>;
35}
36
37/// A computation stream binding a graph to a backend.
38///
39/// Operations on tensors add nodes to the stream's graph lazily.
40/// Calling `eval()` topologically sorts and evaluates pending nodes.
41pub struct Stream {
42    graph: Mutex<Graph>,
43    backend: Box<dyn Backend>,
44    buffers: Mutex<HashMap<NodeId, Vec<f32>>>,
45    eval_calls: AtomicU64,
46}
47
48impl Stream {
49    /// Create a new stream with the given backend.
50    pub fn new(backend: Box<dyn Backend>) -> Self {
51        Self {
52            graph: Mutex::new(Graph::new()),
53            backend,
54            buffers: Mutex::new(HashMap::new()),
55            eval_calls: AtomicU64::new(0),
56        }
57    }
58
59    /// Add a constant node (data already known).
60    pub fn add_constant(&self, data: Vec<f32>, meta: TensorMeta) -> NodeId {
61        let mut graph = self.graph.lock().unwrap();
62        let id = graph.intern_node(
63            OpKind::Constant,
64            smallvec::SmallVec::new(),
65            meta,
66            Some(&data),
67        );
68        let mut buffers = self.buffers.lock().unwrap();
69        buffers.entry(id).or_insert(data);
70        id
71    }
72
73    /// Add an operation node to the graph.
74    pub fn add_op(
75        &self,
76        op: OpKind,
77        inputs: smallvec::SmallVec<[NodeId; 2]>,
78        meta: TensorMeta,
79    ) -> NodeId {
80        let mut graph = self.graph.lock().unwrap();
81        graph.intern_node(op, inputs, meta, None)
82    }
83
84    /// Evaluate all nodes needed to materialize the given output.
85    pub fn eval(&self, output: NodeId) -> Result<()> {
86        // Already materialized?
87        if self.buffers.lock().unwrap().contains_key(&output) {
88            return Ok(());
89        }
90
91        // Topo-sort the subgraph rooted at `output`.
92        let order = {
93            let graph = self.graph.lock().unwrap();
94            graph.topo_sort(&[output])
95        };
96
97        // Evaluate each node in order. Never hold both locks simultaneously.
98        for &node_id in &order {
99            if self.buffers.lock().unwrap().contains_key(&node_id) {
100                continue;
101            }
102
103            // Step 1: get node info (graph lock only).
104            let node: Node = {
105                let graph = self.graph.lock().unwrap();
106                graph
107                    .get(node_id)
108                    .cloned()
109                    .ok_or_else(|| crate::MlxError::InvalidArgument("missing graph node".into()))?
110            };
111
112            // Step 2: get input metadata from graph (graph lock only).
113            let input_metas: Vec<TensorMeta> = {
114                let graph = self.graph.lock().unwrap();
115                node.inputs
116                    .iter()
117                    .map(|&id| {
118                        graph
119                            .get(id)
120                            .expect("input node missing from graph")
121                            .meta
122                            .clone()
123                    })
124                    .collect()
125            };
126
127            // Step 3: gather input data + run backend (buffers lock only for reads).
128            let input_buffers: Vec<Vec<f32>> = {
129                let buffers = self.buffers.lock().unwrap();
130                node.inputs
131                    .iter()
132                    .map(|&id| {
133                        buffers
134                            .get(&id)
135                            .expect("input node should be evaluated before dependents")
136                            .clone()
137                    })
138                    .collect()
139            };
140
141            let inputs: Vec<NodeInput<'_>> = input_buffers
142                .iter()
143                .zip(input_metas.iter())
144                .map(|(data, meta)| NodeInput {
145                    data: data.as_slice(),
146                    shape: &meta.shape,
147                    dtype: meta.dtype,
148                })
149                .collect();
150
151            self.eval_calls.fetch_add(1, Ordering::Relaxed);
152            let result = self.backend.eval_node(&node.op, &inputs, &node.meta)?;
153
154            // Step 4: store result (buffers lock only for write).
155            self.buffers.lock().unwrap().insert(node_id, result);
156        }
157
158        Ok(())
159    }
160
161    /// Get materialized buffer data for a node (must call eval first).
162    pub fn get_buffer(&self, id: NodeId) -> Option<Vec<f32>> {
163        self.buffers.lock().unwrap().get(&id).cloned()
164    }
165
166    /// Get a clone of a graph node by ID.
167    pub fn get_node(&self, id: NodeId) -> Option<Node> {
168        self.graph.lock().unwrap().get(id).cloned()
169    }
170
171    /// Topological sort of the subgraph rooted at the given outputs.
172    pub fn topo_sort(&self, outputs: &[NodeId]) -> Vec<NodeId> {
173        self.graph.lock().unwrap().topo_sort(outputs)
174    }
175
176    /// Number of times the backend's `eval_node` has been called.
177    pub fn eval_count(&self) -> u64 {
178        self.eval_calls.load(Ordering::Relaxed)
179    }
180
181    /// Number of materialized buffers currently cached.
182    pub fn cache_len(&self) -> usize {
183        self.buffers.lock().unwrap().len()
184    }
185}
186
187impl std::fmt::Debug for Stream {
188    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189        f.debug_struct("Stream").finish_non_exhaustive()
190    }
191}
192
193/// The default stream using the built-in CPU reference backend.
194static DEFAULT_STREAM: LazyLock<Arc<Stream>> =
195    LazyLock::new(|| Arc::new(Stream::new(Box::new(crate::cpu_kernels::CpuRefBackend))));
196
197/// Get the default computation stream.
198pub fn default_stream() -> Arc<Stream> {
199    Arc::clone(&DEFAULT_STREAM)
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use crate::graph::TensorMeta;
206    use std::sync::atomic::{AtomicUsize, Ordering};
207
208    #[test]
209    fn test_stream_constant() {
210        let stream = default_stream();
211        let id = stream.add_constant(
212            vec![1.0, 2.0, 3.0],
213            TensorMeta {
214                shape: Shape::new(vec![3]),
215                dtype: DType::F32,
216            },
217        );
218        stream.eval(id).unwrap();
219        assert_eq!(stream.get_buffer(id).unwrap(), vec![1.0, 2.0, 3.0]);
220    }
221
222    #[test]
223    fn test_stream_add_op() {
224        let stream = default_stream();
225        let a = stream.add_constant(
226            vec![1.0, 2.0],
227            TensorMeta {
228                shape: Shape::new(vec![2]),
229                dtype: DType::F32,
230            },
231        );
232        let b = stream.add_constant(
233            vec![3.0, 4.0],
234            TensorMeta {
235                shape: Shape::new(vec![2]),
236                dtype: DType::F32,
237            },
238        );
239        let c = stream.add_op(
240            OpKind::Add,
241            smallvec::SmallVec::from_slice(&[a, b]),
242            TensorMeta {
243                shape: Shape::new(vec![2]),
244                dtype: DType::F32,
245            },
246        );
247        stream.eval(c).unwrap();
248        assert_eq!(stream.get_buffer(c).unwrap(), vec![4.0, 6.0]);
249    }
250
251    fn fresh_stream() -> Stream {
252        Stream::new(Box::new(crate::cpu_kernels::CpuRefBackend))
253    }
254
255    fn meta2() -> TensorMeta {
256        TensorMeta {
257            shape: Shape::new(vec![2]),
258            dtype: DType::F32,
259        }
260    }
261
262    #[test]
263    fn test_memoization_repeated_eval() {
264        let s = fresh_stream();
265        let a = s.add_constant(vec![1.0, 2.0], meta2());
266        let b = s.add_constant(vec![3.0, 4.0], meta2());
267        let c = s.add_op(
268            OpKind::Add,
269            smallvec::SmallVec::from_slice(&[a, b]),
270            meta2(),
271        );
272
273        s.eval(c).unwrap();
274        let count_after_first = s.eval_count();
275        assert!(count_after_first > 0);
276
277        s.eval(c).unwrap();
278        assert_eq!(
279            s.eval_count(),
280            count_after_first,
281            "repeated eval should not recompute"
282        );
283    }
284
285    #[test]
286    fn test_memoization_shared_subgraph() {
287        let s = fresh_stream();
288        let a = s.add_constant(vec![1.0, 2.0], meta2());
289        let b = s.add_constant(vec![3.0, 4.0], meta2());
290        let c = s.add_op(
291            OpKind::Add,
292            smallvec::SmallVec::from_slice(&[a, b]),
293            meta2(),
294        );
295        let d = s.add_op(
296            OpKind::Mul,
297            smallvec::SmallVec::from_slice(&[c, c]),
298            meta2(),
299        );
300
301        s.eval(d).unwrap();
302        let count_after_d = s.eval_count();
303
304        // Re-eval d — fully cached.
305        s.eval(d).unwrap();
306        assert_eq!(
307            s.eval_count(),
308            count_after_d,
309            "repeated eval of d should not recompute"
310        );
311
312        // Eval c separately — already cached from d's eval.
313        s.eval(c).unwrap();
314        assert_eq!(
315            s.eval_count(),
316            count_after_d,
317            "c should already be cached from d's eval"
318        );
319    }
320
321    #[test]
322    fn test_memoization_diamond() {
323        let s = fresh_stream();
324        let a = s.add_constant(vec![1.0, 2.0], meta2());
325        // diamond: b = a+a, c = a*a, d = b+c
326        let b = s.add_op(
327            OpKind::Add,
328            smallvec::SmallVec::from_slice(&[a, a]),
329            meta2(),
330        );
331        let c = s.add_op(
332            OpKind::Mul,
333            smallvec::SmallVec::from_slice(&[a, a]),
334            meta2(),
335        );
336        let d = s.add_op(
337            OpKind::Add,
338            smallvec::SmallVec::from_slice(&[b, c]),
339            meta2(),
340        );
341
342        s.eval(d).unwrap();
343        // a is a constant (pre-cached), so only b, c, d need backend calls = 3
344        assert_eq!(
345            s.eval_count(),
346            3,
347            "diamond should require exactly 3 backend calls"
348        );
349
350        s.eval(d).unwrap();
351        assert_eq!(
352            s.eval_count(),
353            3,
354            "repeated eval should not increase call count"
355        );
356    }
357
358    #[test]
359    fn test_cache_len_updates() {
360        let s = fresh_stream();
361        assert_eq!(s.cache_len(), 0);
362
363        let a = s.add_constant(vec![1.0, 2.0], meta2());
364        let b = s.add_constant(vec![3.0, 4.0], meta2());
365        assert_eq!(s.cache_len(), 2);
366
367        let c = s.add_op(
368            OpKind::Add,
369            smallvec::SmallVec::from_slice(&[a, b]),
370            meta2(),
371        );
372        s.eval(c).unwrap();
373        assert_eq!(s.cache_len(), 3);
374    }
375
376    #[test]
377    fn test_eval_is_memoized_no_recompute() {
378        struct CountingBackend {
379            inner: crate::cpu_kernels::CpuRefBackend,
380            calls: Arc<AtomicUsize>,
381        }
382
383        impl Backend for CountingBackend {
384            fn eval_node(
385                &self,
386                op: &OpKind,
387                inputs: &[NodeInput<'_>],
388                output_meta: &TensorMeta,
389            ) -> Result<Vec<f32>> {
390                self.calls.fetch_add(1, Ordering::Relaxed);
391                self.inner.eval_node(op, inputs, output_meta)
392            }
393        }
394
395        let calls = Arc::new(AtomicUsize::new(0));
396        let stream = Stream::new(Box::new(CountingBackend {
397            inner: crate::cpu_kernels::CpuRefBackend,
398            calls: Arc::clone(&calls),
399        }));
400
401        let a = stream.add_constant(
402            vec![1.0, 2.0],
403            TensorMeta {
404                shape: Shape::new(vec![2]),
405                dtype: DType::F32,
406            },
407        );
408        let b = stream.add_constant(
409            vec![3.0, 4.0],
410            TensorMeta {
411                shape: Shape::new(vec![2]),
412                dtype: DType::F32,
413            },
414        );
415
416        // Two op nodes: add then neg
417        let add = stream.add_op(
418            OpKind::Add,
419            smallvec::SmallVec::from_slice(&[a, b]),
420            TensorMeta {
421                shape: Shape::new(vec![2]),
422                dtype: DType::F32,
423            },
424        );
425        let out = stream.add_op(
426            OpKind::Neg,
427            smallvec::SmallVec::from_slice(&[add]),
428            TensorMeta {
429                shape: Shape::new(vec![2]),
430                dtype: DType::F32,
431            },
432        );
433
434        stream.eval(out).unwrap();
435        let after_first = calls.load(Ordering::Relaxed);
436        assert_eq!(after_first, 2);
437        assert_eq!(stream.get_buffer(out).unwrap(), vec![-4.0, -6.0]);
438
439        // Repeated eval should not call into the backend again.
440        stream.eval(out).unwrap();
441        let after_second = calls.load(Ordering::Relaxed);
442        assert_eq!(after_first, after_second);
443
444        // Evaluating already-materialized intermediates should also be a no-op.
445        stream.eval(add).unwrap();
446        let after_third = calls.load(Ordering::Relaxed);
447        assert_eq!(after_second, after_third);
448    }
449}