1use std::collections::HashMap;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::{Arc, LazyLock, Mutex};
10
11use crate::Result;
12use crate::graph::{Graph, Node, NodeId, OpKind, TensorMeta};
13use crate::types::{DType, Shape};
14
15pub struct NodeInput<'a> {
17 pub data: &'a [f32],
18 pub shape: &'a Shape,
19 pub dtype: DType,
20}
21
22pub trait Backend: Send + Sync {
28 fn eval_node(
30 &self,
31 op: &OpKind,
32 inputs: &[NodeInput<'_>],
33 output_meta: &TensorMeta,
34 ) -> Result<Vec<f32>>;
35}
36
37pub struct Stream {
42 graph: Mutex<Graph>,
43 backend: Box<dyn Backend>,
44 buffers: Mutex<HashMap<NodeId, Vec<f32>>>,
45 eval_calls: AtomicU64,
46}
47
48impl Stream {
49 pub fn new(backend: Box<dyn Backend>) -> Self {
51 Self {
52 graph: Mutex::new(Graph::new()),
53 backend,
54 buffers: Mutex::new(HashMap::new()),
55 eval_calls: AtomicU64::new(0),
56 }
57 }
58
59 pub fn add_constant(&self, data: Vec<f32>, meta: TensorMeta) -> NodeId {
61 let mut graph = self.graph.lock().unwrap();
62 let id = graph.intern_node(
63 OpKind::Constant,
64 smallvec::SmallVec::new(),
65 meta,
66 Some(&data),
67 );
68 let mut buffers = self.buffers.lock().unwrap();
69 buffers.entry(id).or_insert(data);
70 id
71 }
72
73 pub fn add_op(
75 &self,
76 op: OpKind,
77 inputs: smallvec::SmallVec<[NodeId; 2]>,
78 meta: TensorMeta,
79 ) -> NodeId {
80 let mut graph = self.graph.lock().unwrap();
81 graph.intern_node(op, inputs, meta, None)
82 }
83
84 pub fn eval(&self, output: NodeId) -> Result<()> {
86 if self.buffers.lock().unwrap().contains_key(&output) {
88 return Ok(());
89 }
90
91 let order = {
93 let graph = self.graph.lock().unwrap();
94 graph.topo_sort(&[output])
95 };
96
97 for &node_id in &order {
99 if self.buffers.lock().unwrap().contains_key(&node_id) {
100 continue;
101 }
102
103 let node: Node = {
105 let graph = self.graph.lock().unwrap();
106 graph
107 .get(node_id)
108 .cloned()
109 .ok_or_else(|| crate::MlxError::InvalidArgument("missing graph node".into()))?
110 };
111
112 let input_metas: Vec<TensorMeta> = {
114 let graph = self.graph.lock().unwrap();
115 node.inputs
116 .iter()
117 .map(|&id| {
118 graph
119 .get(id)
120 .expect("input node missing from graph")
121 .meta
122 .clone()
123 })
124 .collect()
125 };
126
127 let input_buffers: Vec<Vec<f32>> = {
129 let buffers = self.buffers.lock().unwrap();
130 node.inputs
131 .iter()
132 .map(|&id| {
133 buffers
134 .get(&id)
135 .expect("input node should be evaluated before dependents")
136 .clone()
137 })
138 .collect()
139 };
140
141 let inputs: Vec<NodeInput<'_>> = input_buffers
142 .iter()
143 .zip(input_metas.iter())
144 .map(|(data, meta)| NodeInput {
145 data: data.as_slice(),
146 shape: &meta.shape,
147 dtype: meta.dtype,
148 })
149 .collect();
150
151 self.eval_calls.fetch_add(1, Ordering::Relaxed);
152 let result = self.backend.eval_node(&node.op, &inputs, &node.meta)?;
153
154 self.buffers.lock().unwrap().insert(node_id, result);
156 }
157
158 Ok(())
159 }
160
161 pub fn get_buffer(&self, id: NodeId) -> Option<Vec<f32>> {
163 self.buffers.lock().unwrap().get(&id).cloned()
164 }
165
166 pub fn get_node(&self, id: NodeId) -> Option<Node> {
168 self.graph.lock().unwrap().get(id).cloned()
169 }
170
171 pub fn topo_sort(&self, outputs: &[NodeId]) -> Vec<NodeId> {
173 self.graph.lock().unwrap().topo_sort(outputs)
174 }
175
176 pub fn eval_count(&self) -> u64 {
178 self.eval_calls.load(Ordering::Relaxed)
179 }
180
181 pub fn cache_len(&self) -> usize {
183 self.buffers.lock().unwrap().len()
184 }
185}
186
187impl std::fmt::Debug for Stream {
188 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189 f.debug_struct("Stream").finish_non_exhaustive()
190 }
191}
192
193static DEFAULT_STREAM: LazyLock<Arc<Stream>> =
195 LazyLock::new(|| Arc::new(Stream::new(Box::new(crate::cpu_kernels::CpuRefBackend))));
196
197pub fn default_stream() -> Arc<Stream> {
199 Arc::clone(&DEFAULT_STREAM)
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use crate::graph::TensorMeta;
206 use std::sync::atomic::{AtomicUsize, Ordering};
207
208 #[test]
209 fn test_stream_constant() {
210 let stream = default_stream();
211 let id = stream.add_constant(
212 vec![1.0, 2.0, 3.0],
213 TensorMeta {
214 shape: Shape::new(vec![3]),
215 dtype: DType::F32,
216 },
217 );
218 stream.eval(id).unwrap();
219 assert_eq!(stream.get_buffer(id).unwrap(), vec![1.0, 2.0, 3.0]);
220 }
221
222 #[test]
223 fn test_stream_add_op() {
224 let stream = default_stream();
225 let a = stream.add_constant(
226 vec![1.0, 2.0],
227 TensorMeta {
228 shape: Shape::new(vec![2]),
229 dtype: DType::F32,
230 },
231 );
232 let b = stream.add_constant(
233 vec![3.0, 4.0],
234 TensorMeta {
235 shape: Shape::new(vec![2]),
236 dtype: DType::F32,
237 },
238 );
239 let c = stream.add_op(
240 OpKind::Add,
241 smallvec::SmallVec::from_slice(&[a, b]),
242 TensorMeta {
243 shape: Shape::new(vec![2]),
244 dtype: DType::F32,
245 },
246 );
247 stream.eval(c).unwrap();
248 assert_eq!(stream.get_buffer(c).unwrap(), vec![4.0, 6.0]);
249 }
250
251 fn fresh_stream() -> Stream {
252 Stream::new(Box::new(crate::cpu_kernels::CpuRefBackend))
253 }
254
255 fn meta2() -> TensorMeta {
256 TensorMeta {
257 shape: Shape::new(vec![2]),
258 dtype: DType::F32,
259 }
260 }
261
262 #[test]
263 fn test_memoization_repeated_eval() {
264 let s = fresh_stream();
265 let a = s.add_constant(vec![1.0, 2.0], meta2());
266 let b = s.add_constant(vec![3.0, 4.0], meta2());
267 let c = s.add_op(
268 OpKind::Add,
269 smallvec::SmallVec::from_slice(&[a, b]),
270 meta2(),
271 );
272
273 s.eval(c).unwrap();
274 let count_after_first = s.eval_count();
275 assert!(count_after_first > 0);
276
277 s.eval(c).unwrap();
278 assert_eq!(
279 s.eval_count(),
280 count_after_first,
281 "repeated eval should not recompute"
282 );
283 }
284
285 #[test]
286 fn test_memoization_shared_subgraph() {
287 let s = fresh_stream();
288 let a = s.add_constant(vec![1.0, 2.0], meta2());
289 let b = s.add_constant(vec![3.0, 4.0], meta2());
290 let c = s.add_op(
291 OpKind::Add,
292 smallvec::SmallVec::from_slice(&[a, b]),
293 meta2(),
294 );
295 let d = s.add_op(
296 OpKind::Mul,
297 smallvec::SmallVec::from_slice(&[c, c]),
298 meta2(),
299 );
300
301 s.eval(d).unwrap();
302 let count_after_d = s.eval_count();
303
304 s.eval(d).unwrap();
306 assert_eq!(
307 s.eval_count(),
308 count_after_d,
309 "repeated eval of d should not recompute"
310 );
311
312 s.eval(c).unwrap();
314 assert_eq!(
315 s.eval_count(),
316 count_after_d,
317 "c should already be cached from d's eval"
318 );
319 }
320
321 #[test]
322 fn test_memoization_diamond() {
323 let s = fresh_stream();
324 let a = s.add_constant(vec![1.0, 2.0], meta2());
325 let b = s.add_op(
327 OpKind::Add,
328 smallvec::SmallVec::from_slice(&[a, a]),
329 meta2(),
330 );
331 let c = s.add_op(
332 OpKind::Mul,
333 smallvec::SmallVec::from_slice(&[a, a]),
334 meta2(),
335 );
336 let d = s.add_op(
337 OpKind::Add,
338 smallvec::SmallVec::from_slice(&[b, c]),
339 meta2(),
340 );
341
342 s.eval(d).unwrap();
343 assert_eq!(
345 s.eval_count(),
346 3,
347 "diamond should require exactly 3 backend calls"
348 );
349
350 s.eval(d).unwrap();
351 assert_eq!(
352 s.eval_count(),
353 3,
354 "repeated eval should not increase call count"
355 );
356 }
357
358 #[test]
359 fn test_cache_len_updates() {
360 let s = fresh_stream();
361 assert_eq!(s.cache_len(), 0);
362
363 let a = s.add_constant(vec![1.0, 2.0], meta2());
364 let b = s.add_constant(vec![3.0, 4.0], meta2());
365 assert_eq!(s.cache_len(), 2);
366
367 let c = s.add_op(
368 OpKind::Add,
369 smallvec::SmallVec::from_slice(&[a, b]),
370 meta2(),
371 );
372 s.eval(c).unwrap();
373 assert_eq!(s.cache_len(), 3);
374 }
375
376 #[test]
377 fn test_eval_is_memoized_no_recompute() {
378 struct CountingBackend {
379 inner: crate::cpu_kernels::CpuRefBackend,
380 calls: Arc<AtomicUsize>,
381 }
382
383 impl Backend for CountingBackend {
384 fn eval_node(
385 &self,
386 op: &OpKind,
387 inputs: &[NodeInput<'_>],
388 output_meta: &TensorMeta,
389 ) -> Result<Vec<f32>> {
390 self.calls.fetch_add(1, Ordering::Relaxed);
391 self.inner.eval_node(op, inputs, output_meta)
392 }
393 }
394
395 let calls = Arc::new(AtomicUsize::new(0));
396 let stream = Stream::new(Box::new(CountingBackend {
397 inner: crate::cpu_kernels::CpuRefBackend,
398 calls: Arc::clone(&calls),
399 }));
400
401 let a = stream.add_constant(
402 vec![1.0, 2.0],
403 TensorMeta {
404 shape: Shape::new(vec![2]),
405 dtype: DType::F32,
406 },
407 );
408 let b = stream.add_constant(
409 vec![3.0, 4.0],
410 TensorMeta {
411 shape: Shape::new(vec![2]),
412 dtype: DType::F32,
413 },
414 );
415
416 let add = stream.add_op(
418 OpKind::Add,
419 smallvec::SmallVec::from_slice(&[a, b]),
420 TensorMeta {
421 shape: Shape::new(vec![2]),
422 dtype: DType::F32,
423 },
424 );
425 let out = stream.add_op(
426 OpKind::Neg,
427 smallvec::SmallVec::from_slice(&[add]),
428 TensorMeta {
429 shape: Shape::new(vec![2]),
430 dtype: DType::F32,
431 },
432 );
433
434 stream.eval(out).unwrap();
435 let after_first = calls.load(Ordering::Relaxed);
436 assert_eq!(after_first, 2);
437 assert_eq!(stream.get_buffer(out).unwrap(), vec![-4.0, -6.0]);
438
439 stream.eval(out).unwrap();
441 let after_second = calls.load(Ordering::Relaxed);
442 assert_eq!(after_first, after_second);
443
444 stream.eval(add).unwrap();
446 let after_third = calls.load(Ordering::Relaxed);
447 assert_eq!(after_second, after_third);
448 }
449}