kn_graph/
graph.rs

1use std::cmp::max;
2use std::collections::HashMap;
3use std::convert::TryInto;
4use std::fmt::{Debug, Display, Formatter};
5use std::ops::Index;
6
7use decorum::Total;
8use itertools::{Itertools, zip_eq};
9use ndarray::{ArrayView, IxDyn};
10use rand::random;
11
12use crate::cpu::{OperationError, OperationResult, run_cpu_const_operation};
13use crate::dtype::{dispatch_dtensor, dispatch_dtype, DScalar, DTensor, DType, IntoDScalar, map_dscalar_pair, Tensor};
14use crate::optimizer::recurse::heap_recurse;
15use crate::shape;
16use crate::shape::{Shape, Size};
17use crate::wrap_debug::WrapDebug;
18
19/// The core graph datastructure.
20///
21/// This is a Directed Acyclic Graph (DAG) with values and their creating operations as nodes,
22/// and input operands as edges. The data structure is append-only, values cannot be removed
23/// and so will never become invalid.
24///
25/// This type implements `Index<Value>` trait, so you can use `graph[value]` to get information about the given value.
26///
27/// ```
28/// # use kn_graph::dtype::DType;
29/// use kn_graph::graph::*;
30/// # use kn_graph::shape;
31/// # use kn_graph::shape::*;
32/// // create a new graph
33/// let mut graph = Graph::new();
34///
35/// // define the inputs
36/// let x = graph.input(shape![Size::BATCH, 4, 8, 8], DType::F32);
37///
38/// // define constants
39/// let w_data = vec![0.5; 4 * 4 * 3 * 3];
40/// let w = graph.constant::<f32>(shape![4, 4, 3, 3], w_data);
41/// let b_data = vec![0.5; 4];
42/// let b = graph.constant::<f32>(shape![4, 1, 1], b_data);
43///
44/// // build operation graph
45/// let y0 = graph.conv(x, w, 1, 1, 1, 1);
46/// let y = graph.add(y0, b);
47///
48/// graph.output(y);
49///
50/// println!("{}", graph);
51/// ```
52/// Results in the following output:
53/// ```text
54/// Graph {
55///   check: 1504812640,
56///   input_shapes: [Shape(B x 4 x 8 x 8)],
57///   output_shapes: [Shape(B x 4 x 8 x 8)],
58///   inputs: [Value(0)],
59///   outputs: [Value(6)],
60///   values: [
61///     Value(0) = ValueInfo { shape: Shape(B x 4 x 8 x 8), operation: Input { index: 0 }, debug_id: "", non_output_uses: 1 },
62///     Value(1) = ValueInfo { shape: Shape(4 x 4 x 3 x 3), operation: Constant { data: [..; 144] }, debug_id: "", non_output_uses: 1 },
63///     Value(2) = ValueInfo { shape: Shape(4 x 1 x 1), operation: Constant { data: [0.5, 0.5, 0.5, 0.5] }, debug_id: "", non_output_uses: 1 },
64///     Value(3) = ValueInfo { shape: Shape(B x 4 x 8 x 8), operation: Conv { input: Value(0), filter: Value(1), details: ConvDetails { batch_size: Size(B), input_channels: 4, output_channels: 4, input_h: 8, input_w: 8, kernel_h: 3, kernel_w: 3, stride_y: 1, stride_x: 1, padding_y: 1, padding_x: 1, output_h: 8, output_w: 8 } }, debug_id: "", non_output_uses: 1 },
65///     Value(4) = ValueInfo { shape: Shape(1 x 4 x 1 x 1), operation: View { input: Value(2) }, debug_id: "", non_output_uses: 1 },
66///     Value(5) = ValueInfo { shape: Shape(B x 4 x 8 x 8), operation: Broadcast { input: Value(4) }, debug_id: "", non_output_uses: 1 },
67///     Value(6) = ValueInfo { shape: Shape(B x 4 x 8 x 8), operation: Binary { left: Value(3), right: Value(5), op: Add }, debug_id: "", non_output_uses: 0 },
68///   ],
69/// }
70/// ```
71#[derive(Clone)]
72// TODO override clone manually, replace check value
73// TODO think about two builder categories:
74//     * things that map directly to an operation, with all the type and shape checking
75//     * things that do optimizations, extra broadcasting, ...
76//   alternatively do extra checking in `self.push`?
77pub struct Graph {
78    check: u32,
79    values: Vec<ValueInfo>,
80    back_map: HashMap<(Shape, DType, Operation), usize>,
81    new_values: Vec<Value>,
82    inputs: Vec<Value>,
83    outputs: Vec<Value>,
84}
85
86/// A value in a [Graph].
87#[derive(Copy, Clone, Eq, PartialEq, Hash)]
88pub struct Value {
89    index: usize,
90    check: u32,
91}
92
93/// Information about a [Value], most importantly its shape and creating operation.
94#[derive(Debug, Clone, Eq, PartialEq)]
95pub struct ValueInfo {
96    pub shape: Shape,
97    pub dtype: DType,
98    pub operation: Operation,
99    pub debug_id: String,
100    non_output_uses: usize,
101}
102
103/// The core set of graph operations.
104/// Some attempt was made to keep operations orthogonal but flexible, so they can be composed easily.
105#[derive(Debug, Clone, Eq, PartialEq, Hash)]
106pub enum Operation {
107    /// A runtime-variable input.
108    Input { index: usize },
109    /// A constant built into the network.
110    Constant { tensor: WrapDebug<DTensor> },
111
112    //TODO maybe fuse a bunch of these operations into a single "Restride" operation?
113    /// View a value as a different shape.
114    View { input: Value },
115    /// Repeat along all axes with size 1 that don't match the output shape.
116    Broadcast { input: Value },
117    /// Change the order of axis in the shape.
118    Permute { input: Value, permutation: Vec<usize> },
119    /// Slice along the given `axis` with range `start..end`.
120    Slice {
121        input: Value,
122        axis: usize,
123        range: SliceRange,
124    },
125    /// Flip the given axis.
126    Flip { input: Value, axis: usize },
127
128    /// Gather values from `input` at the indices in `index` on the given `axis`.
129    /// `indices` is a rank-1 tensor.
130    Gather { input: Value, axis: usize, indices: Value },
131
132    /// Concatenate values along an axis.
133    Concat { inputs: Vec<Value>, axis: usize },
134
135    /// 2D convolution.
136    Conv {
137        input: Value,
138        filter: Value,
139        details: ConvDetails,
140    },
141    /// (Batched) Matrix multiply.
142    /// If left has shape `[b, p, q]` and right has shape `[b, q, r]` the result has shape `[b, p, r]`.
143    MatMul { left: Value, right: Value },
144
145    /// Elementwise unary operation.
146    Unary { input: Value, op: UnaryOp },
147    /// Elementwise binary operation. Both operands must have the same shape.
148    Binary { left: Value, right: Value, op: BinaryOp },
149
150    /// Softmax along `axis`.
151    Softmax { input: Value, axis: usize },
152    /// Layernorm along `axis`.
153    Layernorm { input: Value, axis: usize, eps: Total<f32> },
154
155    /// Reduce along the given `axes` using `op`. The `axes` are removed from the shape.
156    Reduce {
157        input: Value,
158        axes: Vec<usize>,
159        op: ReduceOp,
160    },
161    // TODO "select"/"where" operation
162}
163
164#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
165pub struct SliceRange {
166    pub start: usize,
167    pub end: usize,
168    pub step: usize,
169}
170
171// TODO consider removing the compound operations (sigmoid, mish)
172//   alternatively check if either the CPU or CUDA implementations are faster/more accurate
173#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
174pub enum UnaryOp {
175    Abs,
176    Neg,
177    Sin,
178    Cos,
179    Exp,
180    Log,
181    Sqrt,
182    Sigmoid,
183    Tanh,
184    Erf,
185    Mish,
186    Softplus,
187
188    /// Cast to a different type.
189    /// When possible the value is preserved or at least approximated.
190    ValueCast(DType),
191    /// Cast to a different type.
192    /// The bit pattern is kept, so the value is not necessarily preserved.
193    /// The type before and after the cast must have the same size.
194    BitCast(DType),
195}
196
197#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
198pub enum BinaryOp {
199    Add,
200    Sub,
201    Mul,
202    Div,
203    Min,
204    Max,
205    Pow,
206}
207
208#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
209pub enum ReduceOp {
210    Sum,
211    // TODO remove mean and rely on operator fusion instead
212    //   definitely do this, it's getting pretty ugly in the planner
213    Mean,
214    Prod,
215    Max,
216    Min,
217}
218
219impl Operation {
220    pub fn inputs(&self) -> Vec<Value> {
221        match self {
222            Operation::Input { index: _ } => vec![],
223            Operation::Constant { tensor: _ } => vec![],
224            &Operation::View { input } => vec![input],
225            &Operation::Broadcast { input } => vec![input],
226            &Operation::Permute { input, permutation: _ } => vec![input],
227            &Operation::Slice {
228                input,
229                axis: _,
230                range: _,
231            } => vec![input],
232            &Operation::Flip { input, axis: _ } => vec![input],
233            &Operation::Gather {
234                input,
235                axis: _,
236                indices,
237            } => vec![input, indices],
238            Operation::Concat { inputs, axis: _ } => inputs.clone(),
239            &Operation::Conv {
240                input,
241                filter,
242                details: _,
243            } => vec![input, filter],
244            &Operation::MatMul { left, right } => vec![left, right],
245            &Operation::Unary { input, op: _ } => vec![input],
246            &Operation::Binary { left, right, op: _ } => vec![left, right],
247            &Operation::Softmax { input, axis: _ } => vec![input],
248            &Operation::Layernorm { input, axis: _, eps: _ } => vec![input],
249            &Operation::Reduce { input, axes: _, op: _ } => vec![input],
250        }
251    }
252
253    pub(crate) fn clone_map_inputs(&self, mut f: impl FnMut(Value) -> Value) -> Operation {
254        match self {
255            &Operation::Input { index } => Operation::Input { index },
256            &Operation::Constant { ref tensor } => Operation::Constant { tensor: tensor.clone() },
257            &Operation::View { input } => Operation::View { input: f(input) },
258            &Operation::Broadcast { input } => Operation::Broadcast { input: f(input) },
259            &Operation::Permute { input, ref permutation } => Operation::Permute {
260                input: f(input),
261                permutation: permutation.clone(),
262            },
263            &Operation::Slice { input, axis, range } => Operation::Slice {
264                input: f(input),
265                axis,
266                range,
267            },
268            &Operation::Flip { input, axis } => Operation::Flip { input: f(input), axis },
269            &Operation::Gather { input, axis, indices } => Operation::Gather {
270                input: f(input),
271                axis,
272                indices: f(indices),
273            },
274            &Operation::Concat { ref inputs, axis } => Operation::Concat {
275                inputs: inputs.iter().copied().map(f).collect(),
276                axis,
277            },
278            &Operation::Conv {
279                input,
280                filter,
281                details: conv_shape,
282            } => Operation::Conv {
283                input: f(input),
284                filter: f(filter),
285                details: conv_shape,
286            },
287            &Operation::MatMul { left, right } => Operation::MatMul {
288                left: f(left),
289                right: f(right),
290            },
291            &Operation::Unary { input, op } => Operation::Unary { input: f(input), op },
292            &Operation::Binary { left, right, op } => Operation::Binary {
293                left: f(left),
294                right: f(right),
295                op,
296            },
297            &Operation::Softmax { input, axis } => Operation::Softmax { input: f(input), axis },
298            &Operation::Layernorm { input, axis, eps } => Operation::Layernorm {
299                input: f(input),
300                axis,
301                eps,
302            },
303            &Operation::Reduce { input, ref axes, op } => Operation::Reduce {
304                input: f(input),
305                axes: axes.clone(),
306                op,
307            },
308        }
309    }
310}
311
312#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
313pub struct ConvDetails {
314    pub dtype: DType,
315    pub batch_size: Size,
316
317    pub input_channels: usize,
318    pub output_channels: usize,
319
320    pub input_h: usize,
321    pub input_w: usize,
322    pub kernel_h: usize,
323    pub kernel_w: usize,
324    pub stride_y: usize,
325    pub stride_x: usize,
326    pub padding_y: usize,
327    pub padding_x: usize,
328    pub output_h: usize,
329    pub output_w: usize,
330}
331
332impl ConvDetails {
333    pub fn input_shape(&self) -> Shape {
334        shape![self.batch_size, self.input_channels, self.input_h, self.input_w]
335    }
336
337    pub fn output_shape(&self) -> Shape {
338        shape![self.batch_size, self.output_channels, self.output_h, self.output_w]
339    }
340
341    pub fn keeps_spatial_shape(&self) -> bool {
342        (self.input_h == self.output_h) && (self.input_w == self.output_w)
343    }
344
345    pub fn has_stride(&self) -> bool {
346        self.stride_y != 1 || self.stride_x != 1
347    }
348
349    pub fn kernel_shape(&self) -> [usize; 4] {
350        [self.output_channels, self.input_channels, self.kernel_h, self.kernel_w]
351    }
352}
353
354impl Index<Value> for Graph {
355    type Output = ValueInfo;
356
357    fn index(&self, value: Value) -> &Self::Output {
358        self.check_contains(value);
359        &self.values[value.index]
360    }
361}
362
363impl Graph {
364    pub fn new() -> Self {
365        Graph {
366            check: random(),
367            values: vec![],
368            back_map: HashMap::new(),
369            new_values: vec![],
370            inputs: vec![],
371            outputs: vec![],
372        }
373    }
374
375    fn check_contains(&self, value: Value) {
376        assert_eq!(
377            value.check, self.check,
378            "Value {:?} does not belong to this graph",
379            value
380        );
381        assert!(value.index < self.values.len());
382    }
383
384    pub fn shape_dtype(&self, value: Value) -> (&Shape, DType) {
385        let info = &self[value];
386        (&info.shape, info.dtype)
387    }
388
389    /// Iterate over the values in this graph, in topological order,
390    /// which means that nodes will only be visited after all of their inputs have been visited.
391    pub fn values(&self) -> impl Iterator<Item = Value> {
392        let check = self.check;
393        (0..self.values.len()).map(move |index| Value { index, check })
394    }
395
396    pub fn inputs(&self) -> &[Value] {
397        &self.inputs
398    }
399
400    pub fn input_shapes(&self) -> Vec<Shape> {
401        self.inputs().iter().map(|&v| self[v].shape.clone()).collect()
402    }
403
404    pub fn outputs(&self) -> &[Value] {
405        &self.outputs
406    }
407
408    pub fn output_shapes(&self) -> Vec<Shape> {
409        self.outputs().iter().map(|&v| self[v].shape.clone()).collect()
410    }
411
412    pub fn outputs_mut(&mut self) -> &mut Vec<Value> {
413        &mut self.outputs
414    }
415
416    pub fn is_hidden(&self, value: Value) -> bool {
417        self.check_contains(value);
418        !self.inputs.contains(&value) && !self.outputs.contains(&value)
419    }
420
421    pub fn is_hidden_with_uses(&self, value: Value, users: usize) -> bool {
422        self.is_hidden(value) && self[value].non_output_uses == users
423    }
424
425    pub fn is_const(&self, value: Value) -> bool {
426        let operation = &self[value].operation;
427        match *operation {
428            Operation::Input { .. } => false,
429            Operation::Constant { .. } => true,
430            _ => operation.inputs().into_iter().all(|input| self.is_const(input)),
431        }
432    }
433
434    /// Try to evaluate `value` as a constant.
435    pub fn as_const(&self, value: Value) -> Option<DTensor> {
436        // we have to use heap_recurse to avoid stack overflows
437
438        // TODO always immediately evaluate all possible values instead?
439        // TODO store this cache in the graph permanently?
440        //   this will all take a bunch of memory and time :(
441        let mut cache: HashMap<Value, OperationResult> = HashMap::new();
442
443        let f_cached = |curr| {
444            let mut missing_arg = None;
445
446            let res = run_cpu_const_operation(&self[curr], |arg| {
447                match cache.get(&arg) {
448                    // already evaluated
449                    Some(Ok(tensor)) => Ok(tensor.clone()),
450                    Some(&Err(err)) => Err(err),
451                    // not evaluated yet, bubble back to the top
452                    None => {
453                        missing_arg = Some(arg);
454                        //   the exact error used here doesn't matter
455                        Err(OperationError::MissingOperand)
456                    }
457                }
458            });
459
460            // continue bubbling
461            if let Some(missing_arg) = missing_arg {
462                assert_eq!(res, Err(OperationError::MissingOperand));
463                return Err(missing_arg);
464            }
465
466            let prev = cache.insert(curr, res.clone());
467            assert!(prev.is_none());
468
469            Ok(res)
470        };
471
472        let res = heap_recurse(value, f_cached);
473        res.ok()
474    }
475
476    /// Returns whether `value` is effectively a constant with every element equal to `expected`.
477    pub fn is_const_filled_with(&self, value: Value, expected: DScalar) -> bool {
478        self.as_single_const(value).map_or(false, |actual| expected == actual)
479    }
480
481    pub fn is_const_zero(&self, value: Value) -> bool {
482        self.is_const_filled_with(value, self[value].dtype.specials().zero)
483    }
484
485    pub fn is_const_one(&self, value: Value) -> bool {
486        self.is_const_filled_with(value, self[value].dtype.specials().one)
487    }
488
489    /// Returns `Some(f)` if `value` is effectively a constant with every element equal to `f`.
490    pub fn as_single_const(&self, value: Value) -> Option<DScalar> {
491        let info = &self[value];
492
493        match info.operation {
494            Operation::Input { .. } => None,
495            Operation::Constant { tensor: WrapDebug(ref tensor) } => dispatch_dtensor!(tensor, |_T, _f, tensor| {
496                let &e = tensor.iter().next()?;
497                tensor.iter().all(|&d| d == e).then(|| e.to_dscalar())
498            }),
499            Operation::View { input } => self.as_single_const(input),
500            Operation::Broadcast { input } => self.as_single_const(input),
501            Operation::Permute { input, permutation: _ } => self.as_single_const(input),
502            Operation::Slice {
503                input,
504                axis: _,
505                range: _,
506            } => self.as_single_const(input),
507            Operation::Flip { input, axis: _ } => self.as_single_const(input),
508            Operation::Gather {
509                input,
510                axis: _,
511                indices: _,
512            } => self.as_single_const(input),
513            Operation::Concat { ref inputs, axis: _ } => {
514                let f = self.as_single_const(*inputs.first()?)?;
515                inputs.iter().all(|&x| self.is_const_filled_with(x, f)).then(|| f)
516            }
517            Operation::Unary { input, op } => Some(op.map(self.as_single_const(input)?)),
518            Operation::Binary { left, right, op } => {
519                Some(op.map(self.as_single_const(left)?, self.as_single_const(right)?))
520            }
521            Operation::Conv { .. }
522            | Operation::MatMul { .. }
523            | Operation::Softmax { .. }
524            | Operation::Layernorm { .. }
525            | Operation::Reduce { .. } => None,
526        }
527    }
528
529    /// Return all newly crated values since the last call to `take_new_values`.
530    pub fn take_new_values(&mut self) -> Vec<Value> {
531        std::mem::take(&mut self.new_values)
532    }
533
534    #[must_use]
535    pub(crate) fn push(&mut self, shape: Shape, dtype: DType, operation: Operation) -> Value {
536        // TODO replace const computations, especially for simple ops like unary and binary?
537
538        let check = self.check;
539        let key = (shape.clone(), dtype, operation.clone());
540
541        match self.back_map.get(&key) {
542            Some(&index) => {
543                // found duplicate, reuse existing value
544                Value { index, check }
545            }
546            None => {
547                // no duplicate found
548                // check validness
549                for input in operation.inputs() {
550                    self.check_contains(input);
551                    self.values[input.index].non_output_uses += 1;
552                }
553
554                // push new value
555                let info = ValueInfo {
556                    shape,
557                    dtype,
558                    operation,
559                    non_output_uses: 0,
560                    debug_id: String::new(),
561                };
562
563                let index = self.values.len();
564                self.values.push(info);
565
566                let value = Value { index, check };
567                self.new_values.push(value);
568
569                self.back_map.insert(key, index);
570
571                value
572            }
573        }
574    }
575
576    /// Equivalent to `self[value].debug_id = id`,
577    /// but that would not work since there is intentionally no implementation of `IndexMut` for `Graph`.
578    pub fn set_debug_id(&mut self, value: Value, id: String) {
579        self.check_contains(value);
580        self.values[value.index].debug_id = id;
581    }
582
583    /// Declare a new input value.
584    #[must_use]
585    pub fn input(&mut self, shape: Shape, dtype: DType) -> Value {
586        let index = self.inputs.len();
587        let value = self.push(shape, dtype, Operation::Input { index });
588        self.inputs.push(value);
589        value
590    }
591
592    #[must_use]
593    pub fn constant_tensor(&mut self, tensor: DTensor) -> Value {
594        let shape = Shape::fixed(tensor.shape());
595        self.push(shape, tensor.dtype(), Operation::Constant { tensor: WrapDebug(tensor) })
596    }
597
598    #[must_use]
599    pub fn constant<T: IntoDScalar>(&mut self, shape: Shape, data: Vec<T>) -> Value {
600        let linear = T::vec_to_dtensor(data);
601        let shape = shape.unwrap_fixed("constant shape");
602        let tensor = linear.reshape(shape.dims.as_slice());
603        self.constant_tensor(tensor)
604    }
605
606    #[must_use]
607    pub fn scalar_dyn(&mut self, value: DScalar) -> Value {
608        self.constant_tensor(value.to_tensor())
609    }
610
611    #[must_use]
612    pub fn scalar<T: IntoDScalar>(&mut self, value: T) -> Value {
613        self.scalar_dyn(value.to_dscalar())
614    }
615
616    /// View an existing value as a new shape.
617    #[must_use]
618    pub fn view(&mut self, input: Value, new_shape: Shape) -> Value {
619        let (input_shape, dtype) = self.shape_dtype(input);
620        if &new_shape == input_shape {
621            return input;
622        }
623
624        assert_eq!(
625            input_shape.size(),
626            new_shape.size(),
627            "New shape {:?} must have the same size as old shape {:?}",
628            new_shape,
629            input_shape,
630        );
631
632        // only keep the last view operation
633        let inner_input = if let &Operation::View { input: inner_input } = &self[input].operation {
634            inner_input
635        } else {
636            input
637        };
638
639        self.push(new_shape, dtype, Operation::View { input: inner_input })
640    }
641
642    /// Broadcast the `input` towards `new_shape`.
643    /// Additional unit axes are are inserted at the front and unit axes are repeated as necessary.
644    #[must_use]
645    pub fn broadcast(&mut self, input: Value, new_shape: Shape) -> Value {
646        let (input_shape, dtype) = self.shape_dtype(input);
647        let input_shape = input_shape.clone();
648
649        assert!(
650            input_shape.rank() <= new_shape.rank(),
651            "Cannot broadcast to a lower rank shape (from {:?} to {:?})",
652            input_shape,
653            new_shape
654        );
655
656        // pad with 1 axes
657        let view_shape = Shape::ones(new_shape.rank() - input_shape.rank()).concat(&input_shape);
658        let curr = self.view(input, view_shape.clone());
659
660        // check that broadcasting is valid)
661        for (&v, &n) in zip_eq(&view_shape.dims, &new_shape.dims) {
662            assert!(
663                v == n || v == Size::ONE,
664                "Cannot broadcast from {:?} to {:?} because of axis ({}, {})",
665                input_shape,
666                new_shape,
667                v,
668                n
669            );
670        }
671
672        // don't need to actually broadcast
673        if view_shape == new_shape {
674            return curr;
675        }
676
677        // do the actual broadcast
678        self.push(new_shape, dtype, Operation::Broadcast { input: curr })
679    }
680
681    pub fn repeat_unary(&mut self, input: Value, axis: usize, count: Size) -> Value {
682        let (input_shape, dtype) = self.shape_dtype(input);
683
684        assert_eq!(
685            input_shape[axis],
686            Size::ONE,
687            "Input shape {} does not have dim 1 for axis {}",
688            input_shape,
689            axis
690        );
691
692        // TODO fuse consecutive broadcast operations, maybe even view/broadcast/view if the axes are independent
693        // skip broadcast operation
694        if count == Size::ONE {
695            return input;
696        }
697
698        let new_shape = input_shape.replace(axis, shape![count]);
699        self.push(new_shape, dtype, Operation::Broadcast { input })
700    }
701
702    /// View a value with a flattened shape.
703    /// All axis starting from `start_axis` inclusive are flattened into a single axis.
704    #[must_use]
705    pub fn flatten(&mut self, input: Value, start_axis: usize) -> Value {
706        let old_shape = &self[input].shape;
707        assert!(
708            start_axis <= old_shape.rank(),
709            "Flatten start axis {} out of bounds for {}",
710            start_axis,
711            old_shape,
712        );
713
714        let kept_dims = &old_shape.dims[..start_axis];
715        let flat_size = old_shape.dims[start_axis..].iter().copied().product();
716        let new_shape = Shape::new([kept_dims, &[flat_size]].concat());
717
718        self.view(input, new_shape)
719    }
720
721    /// Change the order of axis in the shape.
722    #[must_use]
723    pub fn permute(&mut self, input: Value, permutation: Vec<usize>) -> Value {
724        let input_info = &self[input];
725        let input_shape = &input_info.shape;
726
727        assert_eq!(
728            permutation.len(),
729            input_shape.rank(),
730            "Permutation rank must match input shape, got {:?} and {:?}",
731            permutation,
732            input_shape
733        );
734        assert!(
735            permutation.iter().all_unique(),
736            "Permutation cannot contain repeated axis, got {:?}",
737            permutation
738        );
739        assert!(
740            permutation.iter().all(|&i| i < input_shape.rank()),
741            "Permutation axis out of bounds, got {:?}",
742            permutation
743        );
744
745        // fuse consecutive permute operations
746        let (inner_input, full_permutation) = if let &Operation::Permute {
747            input: inner_input,
748            permutation: ref inner_permutation,
749        } = &self[input].operation
750        {
751            let combined = permutation.iter().map(|&i| inner_permutation[i]).collect();
752            (inner_input, combined)
753        } else {
754            (input, permutation)
755        };
756
757        let inner_input_shape = &self[inner_input].shape;
758        let result_dims = full_permutation.iter().map(|&i| inner_input_shape[i]).collect_vec();
759        let result_shape = Shape::new(result_dims);
760
761        self.push(
762            result_shape,
763            input_info.dtype,
764            Operation::Permute {
765                input: inner_input,
766                permutation: full_permutation,
767            },
768        )
769    }
770
771    /// Slice a value along an axis.
772    #[must_use]
773    pub fn slice(&mut self, input: Value, axis: usize, range: SliceRange) -> Value {
774        let input_info = &self[input];
775        let input_shape = &input_info.shape;
776
777        input_shape.assert_has_axis(axis);
778
779        let input_size = input_shape.dims[axis].unwrap_fixed("Slice axis length");
780        range.assert_in_bounds(input_size);
781        let new_size = (range.end - range.start) / range.step;
782
783        // skip trivial slice
784        if range == SliceRange::new(0, input_size, 1) {
785            return input;
786        }
787
788        let new_shape = input_shape.replace(axis, shape![new_size]);
789        self.push(new_shape, input_info.dtype, Operation::Slice { input, axis, range })
790    }
791
792    /// Index along a given axis.
793    /// Similar to slice with a 1-sized interval except that the the resulting value doesn't have the extra axis.
794    #[must_use]
795    pub fn index(&mut self, input: Value, axis: usize, index: usize) -> Value {
796        let new_shape = self[input].shape.replace(axis, shape![]);
797        let sliced = self.slice(input, axis, SliceRange::single(index));
798        self.view(sliced, new_shape)
799    }
800
801    /// Flip the given `axis`.
802    pub fn flip(&mut self, input: Value, axis: usize) -> Value {
803        let input_info = &self[input];
804        let input_shape = input_info.shape.clone();
805
806        input_shape.assert_has_axis(axis);
807
808        self.push(input_shape, input_info.dtype, Operation::Flip { input, axis })
809    }
810
811    /// Repeat `input` along a given `axis`, `count` times.
812    ///
813    /// This starts by emitting the entire tensor before repeating elements,
814    /// similar to `torch.repeat` or `numpy.tile`.
815    /// See also [repeat_interleave](Self::repeat_interleave).
816    pub fn repeat(&mut self, input: Value, axis: usize, count: Size) -> Value {
817        self.repeat_impl(input, axis, count, false)
818    }
819
820    /// Repeat elements of `input` along a given `axis`, `count` times.
821    ///
822    /// This starts by repeat each element before going to the next one,
823    /// similar to `torch.repeat_interleave` or `numpy.repeat`.
824    /// See also [repeat](Self::repeat).
825    pub fn repeat_interleave(&mut self, input: Value, axis: usize, count: Size) -> Value {
826        self.repeat_impl(input, axis, count, true)
827    }
828
829    fn repeat_impl(&mut self, input: Value, axis: usize, count: Size, inner: bool) -> Value {
830        let input_shape = self[input].shape.clone();
831        input_shape.assert_has_axis(axis);
832
833        // do simpler repeat operation instead
834        // TODO would this not fuse away automatically?
835        if input_shape[axis] == Size::ONE {
836            return self.repeat_unary(input, axis, count);
837        }
838
839        let new_size = input_shape[axis] * count;
840        let dummy_axis = if inner { axis + 1 } else { axis };
841
842        // insert dummy axis, repeat dummy axis, flatten into main axis
843        let extra = self.view(input, input_shape.insert(dummy_axis, Size::ONE));
844        let broad = self.repeat_unary(extra, dummy_axis, count);
845        let result = self.view(broad, input_shape.replace(axis, shape![new_size]));
846
847        result
848    }
849
850    /// Index `input` along the given `axis` with indices given by `indices`.
851    ///
852    /// The `output` shape is the `input` shape with `axis` replaced by the shape of `indices`.
853    #[must_use]
854    pub fn gather(&mut self, input: Value, axis: usize, indices: Value) -> Value {
855        let (input_shape, dtype) = self.shape_dtype(input);
856        let (indices_shape, indices_dtype) = self.shape_dtype(indices);
857
858        input_shape.assert_has_axis(axis);
859        assert!(
860            indices_dtype.is_int(),
861            "Indices must be integers, got {:?}",
862            indices_dtype
863        );
864
865        let result_shape = input_shape.replace(axis, indices_shape.clone());
866        let result_shape_flat = input_shape.replace(axis, shape![indices_shape.size()]);
867
868        // we support arbitrary rank indices here, but the actual operation does not
869        let flat_indices = self.flatten(indices, 0);
870        let flat_size = self[flat_indices].shape.unwrap_1();
871
872        let result_flat = if let Some(index) = self.as_single_const(indices) {
873            // replace gather with simpler slice + repeat operators
874            let index: usize = index.unwrap_int().unwrap().try_into().unwrap();
875
876            let result_flat_single = self.slice(input, axis, SliceRange::single(index));
877            let result_flat = self.repeat(result_flat_single, axis, flat_size);
878
879            assert_eq!(self[result_flat].shape, result_shape_flat);
880            result_flat
881        } else {
882            // do a full gather operation
883            self.push(
884                result_shape_flat,
885                dtype,
886                Operation::Gather {
887                    input,
888                    axis,
889                    indices: flat_indices,
890                },
891            )
892        };
893
894        let result = self.view(result_flat, result_shape);
895        result
896    }
897
898    /// Concatenate `inputs` along `axis`.
899    /// `base_shape` can be provided to allow the result shape to be inferred in case `inputs` is empty.
900    #[must_use]
901    pub fn concat(
902        &mut self,
903        inputs: Vec<Value>,
904        axis: usize,
905        base_shape: Option<Shape>,
906        dtype: Option<DType>,
907    ) -> Value {
908        // TODO skip entire operation if the output is empty (generalize this to all operations?)
909
910        let base_shape = base_shape.unwrap_or_else(|| {
911            assert!(
912                !inputs.is_empty(),
913                "Cannot infer concatenation shape without any inputs"
914            );
915            self[inputs[0]].shape.replace(axis, shape![0])
916        });
917        let dtype = dtype.unwrap_or_else(|| {
918            assert!(
919                !inputs.is_empty(),
920                "Cannot infer concatenation dtype without any inputs"
921            );
922            self[inputs[0]].dtype
923        });
924
925        let size_along_axis = inputs
926            .iter()
927            .map(|&v| {
928                assert_eq!(
929                    self[v].shape.replace(axis, shape![0]),
930                    base_shape,
931                    "All concatenated values must match base shape on non-concatenated axes"
932                );
933                assert_eq!(self[v].dtype, dtype, "All concatenated values must have the same dtype");
934                self[v].shape.dims[axis]
935            })
936            .sum::<Option<Size>>()
937            .unwrap_or_else(|| {
938                let input_shapes = inputs.iter().map(|&v| &self[v].shape).collect_vec();
939                panic!("Could not add all concatenation sizes: {:?}", input_shapes);
940            });
941
942        let result_shape = base_shape.replace(axis, shape![size_along_axis]);
943
944        // drop empty inputs
945        let mut inputs = inputs;
946        inputs.retain(|&x| self[x].shape.size() != Size::ZERO);
947
948        // skip operation if there is only a single non-empty input
949        if inputs.len() == 1 {
950            return inputs[0];
951        }
952
953        self.push(result_shape, dtype, Operation::Concat { inputs, axis })
954    }
955
956    /// Pad `input` with the given `padding` along each axis using the given `value`.
957    /// Each padding is a pair of `(before, after)` values.
958    pub fn pad(&mut self, input: Value, pad_amount: &[(usize, usize)], pad_value: Value) -> Value {
959        let (input_shape, dtype) = self.shape_dtype(input);
960        let (pad_value_shape, pad_value_dtype) = self.shape_dtype(pad_value);
961
962        assert_eq!(input_shape.rank(), pad_amount.len(), "Padding length must match input rank");
963        assert_eq!(dtype, pad_value_dtype, "Padding value dtype must match input dtype");
964        assert_eq!(pad_value_shape, &Shape::SCALAR, "Padding value must be scalar");
965
966        // implemented using a bunch of concatenations
967        pad_amount.iter().enumerate().fold(input, |curr, (i, &(before, after))| {
968            let curr_shape = self[curr].shape.clone();
969            let before = self.broadcast(pad_value, curr_shape.replace(i, shape![before]));
970            let after = self.broadcast(pad_value, curr_shape.replace(i, shape![after]));
971            self.concat(vec![before, curr, after], i, None, None)
972        })
973    }
974
975    /// Apply 2D convolution.
976    #[must_use]
977    pub fn conv(
978        &mut self,
979        input: Value,
980        filter: Value,
981        stride_y: usize,
982        stride_x: usize,
983        padding_y: usize,
984        padding_x: usize,
985    ) -> Value {
986        let (input_shape, input_dtype) = self.shape_dtype(input);
987        let (filter_shape, filter_dtype) = self.shape_dtype(filter);
988        assert_eq!(
989            input_dtype, filter_dtype,
990            "Convolution input and filter must have the same dtype"
991        );
992        let dtype = input_dtype;
993
994        let [batch_size, in_c, in_h, in_w]: [Size; 4] = input_shape
995            .dims
996            .as_slice()
997            .try_into()
998            .expect("Convolution input must have rank 4");
999        let [out_c, in_c_check, k_h, k_w]: [Size; 4] = filter_shape
1000            .dims
1001            .as_slice()
1002            .try_into()
1003            .expect("Convolution filter must have rank 4");
1004
1005        // almost everything must be fixed, except for the batch size n
1006        let input_channels = in_c.unwrap_fixed("Conv input channels");
1007        let input_h = in_h.unwrap_fixed("Conv input height");
1008        let input_w = in_w.unwrap_fixed("Conv input width");
1009        let output_channels = out_c.unwrap_fixed("Conv output channels");
1010        let in_c_check = in_c_check.unwrap_fixed("Filter input channels");
1011        let kernel_h = k_h.unwrap_fixed("Conv kernel height");
1012        let kernel_w = k_w.unwrap_fixed("Conv kernel width");
1013
1014        assert_eq!(1, kernel_h % 2, "Kernel height must be odd, got {}", kernel_h);
1015        assert_eq!(1, kernel_w % 2, "Kernel width must be odd, got {}", kernel_w);
1016
1017        assert_eq!(input_channels, in_c_check, "Input channel mismatch");
1018
1019        let padded_input_h = input_h + 2 * padding_y;
1020        let padded_input_w = input_w + 2 * padding_x;
1021        assert!(
1022            padded_input_h >= kernel_h && padded_input_w >= kernel_w,
1023            "Kernel must fit inside of padded input"
1024        );
1025
1026        // operations are ordered to avoid underflow
1027        let output_h = (padded_input_h - (kernel_h - 1) - 1) / stride_y + 1;
1028        let output_w = (padded_input_w - (kernel_w - 1) - 1) / stride_x + 1;
1029        let output_shape = shape![batch_size, output_channels, output_h, output_w];
1030
1031        let details = ConvDetails {
1032            dtype,
1033            batch_size,
1034            input_channels,
1035            output_channels,
1036            input_h,
1037            input_w,
1038            kernel_h,
1039            kernel_w,
1040            stride_y,
1041            stride_x,
1042            padding_y,
1043            padding_x,
1044            output_h,
1045            output_w,
1046        };
1047        self.push(output_shape, input_dtype, Operation::Conv { input, details, filter })
1048    }
1049
1050    /// Apply a linear transformation.
1051    /// Input shape `[b, Ci]` and weight shape `[Co, Ci]` result in an output with shape `[b, Co]`.
1052    #[must_use]
1053    pub fn linear(&mut self, input: Value, weight: Value) -> Value {
1054        let weight_transposed = self.permute(weight, vec![1, 0]);
1055        self.mat_mul(input, weight_transposed)
1056    }
1057
1058    /// General matrix multiply, with broadcasting.
1059    ///
1060    /// * The last two axes should have shapes `[n, p]` and `[p, m]` and will result in an output shape `[n, m]`
1061    /// * The preceding axes are broadcast together and reappear in the output as-is.
1062    #[must_use]
1063    pub fn mat_mul(&mut self, left: Value, right: Value) -> Value {
1064        let left_shape = &self[left].shape;
1065        let right_shape = &self[right].shape;
1066
1067        assert!(
1068            left_shape.rank() >= 2 && right_shape.rank() >= 2,
1069            "Matmul operands must have rank >= 2, got shapes {} and {}",
1070            left_shape,
1071            right_shape
1072        );
1073
1074        let (left_head, left_tail) = left_shape.split(left_shape.rank() - 2);
1075        let (right_head, right_tail) = right_shape.split(right_shape.rank() - 2);
1076
1077        // check tails match
1078        let [m, n0] = left_tail.unwrap_2();
1079        let [n1, p] = right_tail.unwrap_2();
1080        assert_eq!(
1081            n0, n1,
1082            "Inner matmul dimension must match, got shapes {} and {}",
1083            left_shape, right_shape
1084        );
1085        let result_tail = shape![m, p];
1086
1087        // broadcast heads
1088        let result_head = broadcast_shape_symmetric(&left_head, &right_head);
1089        let batch_size = result_head.size();
1090        let left_broadcast = self.broadcast(left, result_head.clone().concat(&left_tail));
1091        let right_broadcast = self.broadcast(right, result_head.clone().concat(&right_tail));
1092
1093        // flatten for bmm
1094        let left_flat = self.view(left_broadcast, left_tail.insert(0, batch_size));
1095        let right_flat = self.view(right_broadcast, right_tail.insert(0, batch_size));
1096        let result_flat = self.batched_mat_mul(left_flat, right_flat);
1097
1098        // unflatten into final shape
1099        let result = self.view(result_flat, result_head.concat(&result_tail));
1100        result
1101    }
1102
1103    /// Batched matrix multiply, without any automatic broadcasting.
1104    /// Inputs must have shapes `[b, m, n]`, `[b, n, p]` and the result has shape `[b, m, p]`.
1105    #[must_use]
1106    pub fn batched_mat_mul(&mut self, left: Value, right: Value) -> Value {
1107        let (left_shape, left_dtype) = self.shape_dtype(left);
1108        let (right_shape, right_dtype) = self.shape_dtype(right);
1109        assert_eq!(left_dtype, right_dtype, "Matmul operands must have same dtype");
1110
1111        let [b0, m, n0] = left_shape.unwrap_3();
1112        let [b1, n1, p] = right_shape.unwrap_3();
1113
1114        assert!(
1115            b0 == b1 && n0 == n1,
1116            "Batched matmul dimension mismatch, got shapes {} and {}",
1117            left_shape,
1118            right_shape
1119        );
1120
1121        let result_shape = shape![b0, m, p];
1122        self.push(result_shape, left_dtype, Operation::MatMul { left, right })
1123    }
1124
1125    #[must_use]
1126    pub fn softmax(&mut self, input: Value, axis: usize) -> Value {
1127        let (input_shape, input_dtype) = self.shape_dtype(input);
1128        assert_eq!(input_dtype, DType::F32, "Softmax input must be f32");
1129        input_shape.assert_has_axis(axis);
1130
1131        let new_shape = input_shape.clone();
1132        self.push(new_shape, input_dtype, Operation::Softmax { input, axis })
1133    }
1134
1135    #[must_use]
1136    pub fn layernorm(&mut self, input: Value, axis: usize, eps: f32) -> Value {
1137        let (input_shape, input_dtype) = self.shape_dtype(input);
1138        assert_eq!(input_dtype, DType::F32, "Softmax input must be f32");
1139        input_shape.assert_has_axis(axis);
1140
1141        let new_shape = input_shape.clone();
1142        self.push(
1143            new_shape,
1144            input_dtype,
1145            Operation::Layernorm {
1146                input,
1147                axis,
1148                eps: Total::from(eps),
1149            },
1150        )
1151    }
1152
1153    /// Reduce `input` along the given `axes`.
1154    /// The result shape is the same as the input shape but without the reduces axes.
1155    #[must_use]
1156    pub fn reduce(&mut self, input: Value, axes: Vec<usize>, op: ReduceOp) -> Value {
1157        let (input_shape, dtype) = self.shape_dtype(input);
1158
1159        // check shape and dtype
1160        for &axis in &axes {
1161            input_shape.assert_has_axis(axis);
1162        }
1163        match op {
1164            ReduceOp::Mean => assert_eq!(dtype, DType::F32, "Softmax input must be f32"),
1165            ReduceOp::Sum | ReduceOp::Prod | ReduceOp::Max | ReduceOp::Min => {}
1166        }
1167
1168        // skip reduction
1169        if axes.is_empty() {
1170            return input;
1171        }
1172
1173        let new_shape = input_shape.replace_all(&axes, shape![]);
1174        self.push(new_shape, dtype, Operation::Reduce { input, axes, op })
1175    }
1176
1177    /// Elementwise sigmoid.
1178    #[must_use]
1179    pub fn sigmoid(&mut self, input: Value) -> Value {
1180        self.unary(UnaryOp::Sigmoid, input)
1181    }
1182
1183    /// Elementwise relu.
1184    #[must_use]
1185    pub fn relu(&mut self, input: Value) -> Value {
1186        let (_, dtype) = self.shape_dtype(input);
1187        let specials = dtype.specials();
1188        self.clamp_dyn(input, specials.zero, specials.max)
1189    }
1190
1191    /// Elementwise clamp.
1192    #[must_use]
1193    pub fn clamp_dyn(&mut self, input: Value, min: DScalar, max: DScalar) -> Value {
1194        let (_, dtype) = self.shape_dtype(input);
1195        assert!(
1196            dtype == min.dtype() && dtype == max.dtype(),
1197            "Clamp bounds must match value type, got min={:?} and max={:?} for {:?}",
1198            min,
1199            max,
1200            dtype
1201        );
1202
1203        // careful, min/max are intentionally flipped to yield MAX(MIN(x, max), min)
1204        // these checks are redundant with the checks in binary, but we can skip constant allocation
1205        let mut curr = input;
1206        let specials = dtype.specials();
1207
1208        if max != specials.max {
1209            let max_value = self.scalar_dyn(max);
1210            curr = self.binary(BinaryOp::Min, curr, max_value);
1211        }
1212
1213        if min != specials.min {
1214            let min_value = self.scalar_dyn(min);
1215            curr = self.binary(BinaryOp::Max, curr, min_value);
1216        }
1217
1218        curr
1219    }
1220
1221    #[must_use]
1222    pub fn clamp<T: IntoDScalar>(&mut self, input: Value, min: T, max: T) -> Value {
1223        self.clamp_dyn(input, min.to_dscalar(), max.to_dscalar())
1224    }
1225
1226    #[must_use]
1227    pub fn add(&mut self, left: Value, right: Value) -> Value {
1228        self.binary(BinaryOp::Add, left, right)
1229    }
1230
1231    #[must_use]
1232    pub fn sub(&mut self, left: Value, right: Value) -> Value {
1233        self.binary(BinaryOp::Sub, left, right)
1234    }
1235
1236    #[must_use]
1237    pub fn mul(&mut self, left: Value, right: Value) -> Value {
1238        self.binary(BinaryOp::Mul, left, right)
1239    }
1240
1241    #[must_use]
1242    pub fn pow(&mut self, left: Value, right: Value) -> Value {
1243        self.binary(BinaryOp::Pow, left, right)
1244    }
1245
1246    // Elementwise binary operation.
1247    #[must_use]
1248    pub fn unary(&mut self, op: UnaryOp, mut input: Value) -> Value {
1249        let (shape, input_dtype) = self.shape_dtype(input);
1250
1251        let output_dtype = match op.output_dtype(input_dtype) {
1252            Some(d) => d,
1253            None => panic!("Operation {:?} not supported on dtype {:?}", op, input_dtype),
1254        };
1255
1256        // skip cast to same type
1257        if let UnaryOp::ValueCast(_) | UnaryOp::BitCast(_) = op {
1258            if output_dtype == input_dtype {
1259                return input;
1260            }
1261        }
1262
1263        // skip to innermost bitcast value
1264        // TODO skip to innermost value for exact value casts, eg. for successive truncating int casts
1265        //    but be careful, this is tricky stuff!
1266        if let UnaryOp::BitCast(_) = op {
1267            while let &Operation::Unary {
1268                op: UnaryOp::BitCast(_),
1269                input: inner,
1270            } = &self[input].operation
1271            {
1272                input = inner;
1273            }
1274        }
1275
1276        self.push(shape.clone(), output_dtype, Operation::Unary { op, input })
1277    }
1278
1279    /// Compute elementwise binary operation.
1280    /// Both inputs must have the same rank (or right must have rank 0), the right shape is broadcasted to the left shape.
1281    #[must_use]
1282    pub fn binary(&mut self, op: BinaryOp, left: Value, right: Value) -> Value {
1283        // TODO move constants to the right hand side for binary operations add/mul/min/max
1284        //   also think about other normalizations!
1285        let (left_shape, left_dtype) = self.shape_dtype(left);
1286        let (right_shape, right_dtype) = self.shape_dtype(right);
1287
1288        let result_shape = broadcast_shape_symmetric(left_shape, right_shape);
1289        assert_eq!(
1290            left_dtype, right_dtype,
1291            "Binary operation {:?} requires matching dtypes, got {:?} and {:?}",
1292            op, left_dtype, right_dtype
1293        );
1294        let dtype = left_dtype;
1295
1296        // TODO expand this skipping to be symmetric (and to do const eval if both are known and small?)
1297        let skip = match op {
1298            BinaryOp::Sub | BinaryOp::Add => self.is_const_zero(right),
1299            BinaryOp::Mul | BinaryOp::Div | BinaryOp::Pow => self.is_const_one(right),
1300            BinaryOp::Min => self.is_const_filled_with(right, dtype.specials().max),
1301            BinaryOp::Max => self.is_const_filled_with(right, dtype.specials().min),
1302        };
1303        // TODO only skip after shape checking
1304        //   also check other functions
1305        if skip {
1306            return left;
1307        }
1308
1309        let left = self.broadcast(left, result_shape.clone());
1310        let right = self.broadcast(right, result_shape.clone());
1311
1312        self.push(result_shape, dtype, Operation::Binary { left, right, op })
1313    }
1314
1315    /// Computes the operations described by `graph` on the given inputs.
1316    ///
1317    /// This can be used to cleanly compose multiple graphs together.
1318    #[must_use]
1319    pub fn call(&mut self, graph: &Graph, inputs: &[Value]) -> Vec<Value> {
1320        // check inputs
1321        assert_eq!(inputs.len(), graph.inputs.len(), "Wrong number of inputs");
1322        for (&input, &graph_input) in zip_eq(inputs, &graph.inputs) {
1323            assert_eq!(self[input].shape, graph[graph_input].shape, "Wrong input shape");
1324        }
1325
1326        let mut map = HashMap::new();
1327
1328        // map operations
1329        for graph_value in graph.values() {
1330            let graph_info = &graph[graph_value];
1331
1332            let shape = graph_info.shape.clone();
1333            let graph_operation = &graph_info.operation;
1334
1335            let value = if let &Operation::Input { index } = graph_operation {
1336                inputs[index]
1337            } else {
1338                let operation = graph_info.operation.clone_map_inputs(|p| *map.get(&p).unwrap());
1339                self.push(shape, graph_info.dtype, operation)
1340            };
1341
1342            map.insert(graph_value, value);
1343        }
1344
1345        // map outputs
1346        graph
1347            .outputs()
1348            .iter()
1349            .map(|graph_value| *map.get(graph_value).unwrap())
1350            .collect_vec()
1351    }
1352
1353    /// Register an existing value as an output
1354    pub fn output(&mut self, value: Value) {
1355        self.outputs.push(value);
1356    }
1357
1358    /// Register multiple values as output at once, in order.
1359    pub fn output_all(&mut self, values: &[Value]) {
1360        for &value in values {
1361            self.output(value)
1362        }
1363    }
1364
1365    // TODO variant that extracts the entire subgraph up to a given set of values, keeping the exact same inputs?
1366    /// Extract a small subgraph consisting of all values that go into `value`, up to a given `depth`.
1367    /// Values that exceed the depth are added as inputs.
1368    pub fn extract_subgraph(&self, value: Value, depth: u32) -> Graph {
1369        fn extract_impl(
1370            graph: &Graph,
1371            sub: &mut Graph,
1372            map: &mut HashMap<Value, Value>,
1373            old: Value,
1374            depth: u32,
1375        ) -> Value {
1376            // luckily we don't have to worry about cycles
1377            if let Some(&new) = map.get(&old) {
1378                return new;
1379            }
1380
1381            let &ValueInfo {
1382                ref shape,
1383                dtype,
1384                operation: ref old_op,
1385                ref debug_id,
1386                non_output_uses: _,
1387            } = &graph[old];
1388
1389            let new = if depth == 0 {
1390                // insert input
1391                sub.input(shape.clone(), dtype)
1392            } else {
1393                // insert operation and map operands
1394                let new_op = old_op.clone_map_inputs(|p| extract_impl(graph, sub, map, p, depth - 1));
1395                sub.push(shape.clone(), dtype, new_op)
1396            };
1397
1398            sub.set_debug_id(new, debug_id.clone());
1399            let prev = map.insert(old, new);
1400            assert_eq!(prev, None);
1401
1402            new
1403        }
1404
1405        let mut sub = Graph::new();
1406        let mut map = HashMap::new();
1407
1408        let new = extract_impl(self, &mut sub, &mut map, value, depth);
1409        sub.output(new);
1410
1411        sub
1412    }
1413
1414    /// Generate a set of dummy inputs that have the right shapes and dtypes and are all fully zero.
1415    /// This can be useful for some quick testing.
1416    pub fn dummy_zero_inputs(&self, batch_size: usize) -> Vec<DTensor> {
1417        // TODO add add a random version? ofc both can break gather operations, but that's acceptable
1418        self.inputs()
1419            .iter()
1420            .map(|&v| {
1421                let dtype = self[v].dtype;
1422                dispatch_dtype!(dtype, |_T, _fs, ft| ft(Tensor::zeros(
1423                    self[v].shape.eval(batch_size).dims
1424                )))
1425            })
1426            .collect_vec()
1427    }
1428}
1429
1430/// This corresponds to [_multidimensional broadcasting_ in the ONNX spec](https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md#multidirectional-broadcasting).
1431pub fn broadcast_shape_symmetric(left: &Shape, right: &Shape) -> Shape {
1432    let rank = max(left.rank(), right.rank());
1433
1434    // pad with leading 1 axes
1435    let left = Shape::ones(rank - left.rank()).concat(&left);
1436    let right = Shape::ones(rank - right.rank()).concat(&right);
1437
1438    // decide the matching axes for both
1439    let result = zip_eq(&left.dims, &right.dims)
1440        .map(|(&l, &r)| match (l, r) {
1441            (Size::ONE, other) | (other, Size::ONE) => other,
1442            (any, other) if any == other => any,
1443            _ => panic!("Cannot broadcast {} and {} in shapes {} and {}", l, r, left, right),
1444        })
1445        .collect_vec();
1446
1447    Shape::new(result)
1448}
1449
1450pub fn broadcast_tensors_symmetric<'l, 'r, L, R>(
1451    left: &'l Tensor<L>,
1452    right: &'r Tensor<R>,
1453) -> (ArrayView<'l, L, IxDyn>, ArrayView<'r, R, IxDyn>) {
1454    let result_shape = broadcast_shape_symmetric(&Shape::fixed(left.shape()), &Shape::fixed(right.shape()));
1455    let result_shape = result_shape.as_fixed().unwrap().dims;
1456
1457    let left = left.broadcast(result_shape.clone()).unwrap();
1458    let right = right.broadcast(result_shape).unwrap();
1459
1460    (left, right)
1461}
1462
1463impl Debug for Graph {
1464    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1465        f.debug_struct("Graph")
1466            .field("inputs", &self.inputs().iter().map(|&v| &self[v].shape).collect_vec())
1467            .field("outputs", &self.outputs().iter().map(|&v| &self[v].shape).collect_vec())
1468            .finish_non_exhaustive()
1469    }
1470}
1471
1472// TODO output nicer table with debug_id near the front
1473impl Display for Graph {
1474    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1475        let Graph {
1476            check,
1477            values,
1478            back_map: _,
1479            new_values: _,
1480            inputs,
1481            outputs,
1482        } = self;
1483
1484        writeln!(f, "Graph {{")?;
1485        writeln!(f, "  check: {},", self.check)?;
1486
1487        let input_shapes = self.inputs().iter().map(|&v| &self[v].shape).collect_vec();
1488        let output_shapes = self.outputs().iter().map(|&v| &self[v].shape).collect_vec();
1489        writeln!(f, "  input_shapes: {:?},", input_shapes)?;
1490        writeln!(f, "  output_shapes: {:?},", output_shapes)?;
1491        writeln!(f, "  inputs: {:?},", inputs)?;
1492        writeln!(f, "  outputs: {:?},", outputs)?;
1493
1494        writeln!(f, "  values: [")?;
1495        for (i, info) in values.iter().enumerate() {
1496            writeln!(
1497                f,
1498                "    {:?} = {:?},",
1499                Value {
1500                    index: i,
1501                    check: *check,
1502                },
1503                info
1504            )?;
1505        }
1506        writeln!(f, "  ],")?;
1507
1508        writeln!(f, "}}")?;
1509
1510        Ok(())
1511    }
1512}
1513
1514impl Value {
1515    pub fn index(self) -> usize {
1516        self.index
1517    }
1518}
1519
1520impl Debug for Value {
1521    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1522        let Value { index, check } = self;
1523        if f.alternate() {
1524            write!(f, "Value {{ index: {}, check: {} }}", index, check)
1525        } else {
1526            write!(f, "Value({})", index)
1527        }
1528    }
1529}
1530
1531impl Display for SliceRange {
1532    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1533        if self.step == 1 {
1534            write!(f, "{}:{}", self.start, self.end)
1535        } else {
1536            write!(f, "{}:{}:{}", self.start, self.end, self.step)
1537        }
1538    }
1539}
1540
1541impl From<std::ops::Range<usize>> for SliceRange {
1542    fn from(range: std::ops::Range<usize>) -> Self {
1543        let std::ops::Range { start, end } = range;
1544        SliceRange::simple(start, end)
1545    }
1546}
1547
1548// TODO switch to u64? no reason to stay stuck at u32 randomly
1549impl SliceRange {
1550    pub fn new(start: usize, end: usize, step: usize) -> Self {
1551        let result = Self { start, end, step };
1552        result.assert_valid();
1553        result
1554    }
1555
1556    pub fn simple(start: usize, end: usize) -> Self {
1557        Self::new(start, end, 1)
1558    }
1559
1560    pub fn single(index: usize) -> Self {
1561        Self::new(index, index + 1, 1)
1562    }
1563
1564    pub fn empty() -> Self {
1565        Self::new(0, 0, 1)
1566    }
1567
1568    pub fn assert_valid(self) {
1569        assert!(
1570            self.end >= self.start,
1571            "Invalid range {:?}: bounds cannot be decreasing",
1572            self,
1573        );
1574
1575        assert_ne!(self.step, 0, "Invalid range {:?}: step cannot be 0", self);
1576
1577        assert_eq!(
1578            (self.end - self.start) % self.step,
1579            0,
1580            "Invalid range {:?}: bounds must differ by a multiple of step",
1581            self
1582        );
1583    }
1584
1585    pub fn assert_in_bounds(self, size: usize) {
1586        self.assert_valid();
1587
1588        assert!(
1589            self.start == self.end || (self.start < size && self.end - (self.step - 1) <= size),
1590            "{:?} out of bounds for axis of size {}",
1591            self,
1592            size
1593        )
1594    }
1595}
1596
1597impl UnaryOp {
1598    pub const ALL: &'static [Self] = &[
1599        UnaryOp::Abs,
1600        UnaryOp::Neg,
1601        UnaryOp::Sin,
1602        UnaryOp::Cos,
1603        UnaryOp::Exp,
1604        UnaryOp::Log,
1605        UnaryOp::Sqrt,
1606        UnaryOp::Sigmoid,
1607        UnaryOp::Tanh,
1608        UnaryOp::Erf,
1609        UnaryOp::Mish,
1610        UnaryOp::Softplus,
1611    ];
1612
1613    pub fn output_dtype(self, x: DType) -> Option<DType> {
1614        match self {
1615            UnaryOp::Abs | UnaryOp::Neg => {
1616                if x.is_signed() {
1617                    Some(x)
1618                } else {
1619                    None
1620                }
1621            }
1622            UnaryOp::Sin
1623            | UnaryOp::Cos
1624            | UnaryOp::Exp
1625            | UnaryOp::Log
1626            | UnaryOp::Sqrt
1627            | UnaryOp::Sigmoid
1628            | UnaryOp::Tanh
1629            | UnaryOp::Erf
1630            | UnaryOp::Mish
1631            | UnaryOp::Softplus => {
1632                if x.is_float() {
1633                    Some(x)
1634                } else {
1635                    None
1636                }
1637            }
1638            UnaryOp::ValueCast(y) => Some(y),
1639            UnaryOp::BitCast(y) => {
1640                if x.size() == y.size() {
1641                    Some(y)
1642                } else {
1643                    None
1644                }
1645            }
1646        }
1647    }
1648
1649    pub fn map(self, x: DScalar) -> DScalar {
1650        macro_rules! map_float {
1651            ($x:expr, |$inner:ident| $result:expr) => {{
1652                use $crate::dtype::{DScalar, T32, T64};
1653                match $x {
1654                    DScalar::F32(T32($inner)) => DScalar::f32($result),
1655                    DScalar::F64(T64($inner)) => DScalar::f64($result),
1656                    _ => unreachable!("Invalid dtype of {:?} for float operation {:?}", $x, self),
1657                }
1658            }};
1659        }
1660        let y = match self {
1661            UnaryOp::Abs => {
1662                assert!(x.dtype().is_signed(), "Cannot take abs of unsigned scalar");
1663                match x {
1664                    DScalar::F32(x) => DScalar::f32(x.abs()),
1665                    DScalar::F64(x) => DScalar::f64(x.abs()),
1666                    DScalar::I8(x) => DScalar::I8(x.abs()),
1667                    DScalar::I16(x) => DScalar::I16(x.abs()),
1668                    DScalar::I32(x) => DScalar::I32(x.abs()),
1669                    DScalar::I64(x) => DScalar::I64(x.abs()),
1670                    DScalar::U8(_) | DScalar::U16(_) | DScalar::U32(_) | DScalar::U64(_) | DScalar::Bool(_) => {
1671                        unreachable!()
1672                    }
1673                }
1674            }
1675            UnaryOp::Neg => {
1676                assert!(x.dtype().is_signed(), "Cannot negate unsigned scalar");
1677                match x {
1678                    DScalar::F32(x) => DScalar::f32(-*x),
1679                    DScalar::F64(x) => DScalar::f64(-*x),
1680                    DScalar::I8(x) => DScalar::I8(-x),
1681                    DScalar::I16(x) => DScalar::I16(-x),
1682                    DScalar::I32(x) => DScalar::I32(-x),
1683                    DScalar::I64(x) => DScalar::I64(-x),
1684                    DScalar::U8(_) | DScalar::U16(_) | DScalar::U32(_) | DScalar::U64(_) | DScalar::Bool(_) => {
1685                        unreachable!()
1686                    }
1687                }
1688            }
1689            UnaryOp::Sin => map_float!(x, |x| x.sin()),
1690            UnaryOp::Cos => map_float!(x, |x| x.cos()),
1691            UnaryOp::Exp => map_float!(x, |x| x.exp()),
1692            UnaryOp::Log => map_float!(x, |x| x.ln()),
1693            UnaryOp::Sqrt => map_float!(x, |x| x.sqrt()),
1694            UnaryOp::Sigmoid => map_float!(x, |x| 1.0 / (1.0 + (-x).exp())),
1695            UnaryOp::Tanh => map_float!(x, |x| x.tanh()),
1696            UnaryOp::Erf => map_float!(x, |x| erf(x as f64) as _),
1697            UnaryOp::Mish => map_float!(x, |x| x * (x.exp().ln_1p().tanh())),
1698            UnaryOp::Softplus => map_float!(x, |x| (-x.abs()).exp().ln_1p() + x.max(0.0)),
1699            UnaryOp::ValueCast(to) => x.value_cast(to),
1700            UnaryOp::BitCast(to) => x.bit_cast(to).unwrap(),
1701        };
1702
1703        debug_assert_eq!(self.output_dtype(x.dtype()), Some(y.dtype()));
1704        y
1705    }
1706}
1707
1708impl BinaryOp {
1709    pub const ALL: &'static [Self] = &[
1710        BinaryOp::Add,
1711        BinaryOp::Sub,
1712        BinaryOp::Mul,
1713        BinaryOp::Div,
1714        BinaryOp::Pow,
1715        BinaryOp::Min,
1716        BinaryOp::Max,
1717    ];
1718
1719    pub fn map(self, left: DScalar, right: DScalar) -> DScalar {
1720        match self {
1721            BinaryOp::Add => map_dscalar_pair!(left, right, |left, right| left + right),
1722            BinaryOp::Sub => map_dscalar_pair!(left, right, |left, right| left - right),
1723            BinaryOp::Mul => map_dscalar_pair!(left, right, |left, right| left * right),
1724            BinaryOp::Div => map_dscalar_pair!(left, right, |left, right| left / right),
1725            // TODO support all types (including a mix) for pow?
1726            BinaryOp::Pow => DScalar::f32(left.unwrap_f32().unwrap().powf(right.unwrap_f32().unwrap())),
1727            BinaryOp::Min => map_dscalar_pair!(left, right, |left, right| left.min(right)),
1728            BinaryOp::Max => map_dscalar_pair!(left, right, |left, right| left.max(right)),
1729        }
1730    }
1731
1732    pub fn map_t<T: IntoDScalar>(self, left: T, right: T) -> T {
1733        T::from_dscalar(self.map(left.to_dscalar(), right.to_dscalar())).unwrap()
1734    }
1735}
1736
1737impl ReduceOp {
1738    pub const ALL: &'static [Self] = &[
1739        ReduceOp::Sum,
1740        ReduceOp::Mean,
1741        ReduceOp::Prod,
1742        ReduceOp::Min,
1743        ReduceOp::Max,
1744    ];
1745
1746    pub fn identity(self, dtype: DType) -> DScalar {
1747        let specials = dtype.specials();
1748        match self {
1749            ReduceOp::Sum | ReduceOp::Mean => specials.zero,
1750            ReduceOp::Prod => specials.one,
1751            ReduceOp::Min => specials.max,
1752            ReduceOp::Max => specials.min,
1753        }
1754    }
1755
1756    pub fn identity_t<T: IntoDScalar>(self) -> T {
1757        T::from_dscalar(self.identity(T::DTYPE)).unwrap()
1758    }
1759
1760    pub fn operation(self) -> (BinaryOp, bool) {
1761        match self {
1762            ReduceOp::Sum => (BinaryOp::Add, false),
1763            ReduceOp::Mean => (BinaryOp::Add, true),
1764            ReduceOp::Prod => (BinaryOp::Mul, false),
1765            ReduceOp::Min => (BinaryOp::Min, false),
1766            ReduceOp::Max => (BinaryOp::Max, false),
1767        }
1768    }
1769
1770    pub fn reduce_t<T: IntoDScalar>(self, seq: impl IntoIterator<Item = T>) -> T {
1771        let (op, is_mean) = self.operation();
1772
1773        let mut count = 0;
1774        let total = seq.into_iter().fold(self.identity_t(), |acc, x| {
1775            count += 1;
1776            op.map_t(acc, x)
1777        });
1778
1779        if is_mean {
1780            // TODO what to do here for non-float types?
1781            let total = total.to_dscalar().unwrap_f32().unwrap();
1782            T::from_dscalar(DScalar::f32(total / count as f32)).unwrap()
1783        } else {
1784            total
1785        }
1786    }
1787}
1788
1789/// Formula and coefficients from <https://en.wikipedia.org/wiki/Error_function#Numerical_approximations>
1790/// (Abramowitz and Stegun), Max error `3e-7`.
1791pub fn erf(x: f64) -> f64 {
1792    // TODO find something that's even better for f64?
1793    let sign = x.signum();
1794    let x_abs = x.abs();
1795
1796    const A: &[f64] = &[
1797        1.0,
1798        0.0705230784,
1799        0.0422820123,
1800        0.0092705272,
1801        0.0001520143,
1802        0.0002765672,
1803        0.0000430638,
1804    ];
1805
1806    let d: f64 = A
1807        .iter()
1808        .copied()
1809        .enumerate()
1810        .map(|(i, a)| a * x_abs.powi(i as i32))
1811        .sum();
1812    let y_abs = 1.0 - 1.0 / d.powi(16);
1813
1814    sign * y_abs
1815}