1use std::cmp::max;
2use std::collections::HashMap;
3use std::convert::TryInto;
4use std::fmt::{Debug, Display, Formatter};
5use std::ops::Index;
6
7use decorum::Total;
8use itertools::{Itertools, zip_eq};
9use ndarray::{ArrayView, IxDyn};
10use rand::random;
11
12use crate::cpu::{OperationError, OperationResult, run_cpu_const_operation};
13use crate::dtype::{dispatch_dtensor, dispatch_dtype, DScalar, DTensor, DType, IntoDScalar, map_dscalar_pair, Tensor};
14use crate::optimizer::recurse::heap_recurse;
15use crate::shape;
16use crate::shape::{Shape, Size};
17use crate::wrap_debug::WrapDebug;
18
19#[derive(Clone)]
72pub struct Graph {
78 check: u32,
79 values: Vec<ValueInfo>,
80 back_map: HashMap<(Shape, DType, Operation), usize>,
81 new_values: Vec<Value>,
82 inputs: Vec<Value>,
83 outputs: Vec<Value>,
84}
85
86#[derive(Copy, Clone, Eq, PartialEq, Hash)]
88pub struct Value {
89 index: usize,
90 check: u32,
91}
92
93#[derive(Debug, Clone, Eq, PartialEq)]
95pub struct ValueInfo {
96 pub shape: Shape,
97 pub dtype: DType,
98 pub operation: Operation,
99 pub debug_id: String,
100 non_output_uses: usize,
101}
102
103#[derive(Debug, Clone, Eq, PartialEq, Hash)]
106pub enum Operation {
107 Input { index: usize },
109 Constant { tensor: WrapDebug<DTensor> },
111
112 View { input: Value },
115 Broadcast { input: Value },
117 Permute { input: Value, permutation: Vec<usize> },
119 Slice {
121 input: Value,
122 axis: usize,
123 range: SliceRange,
124 },
125 Flip { input: Value, axis: usize },
127
128 Gather { input: Value, axis: usize, indices: Value },
131
132 Concat { inputs: Vec<Value>, axis: usize },
134
135 Conv {
137 input: Value,
138 filter: Value,
139 details: ConvDetails,
140 },
141 MatMul { left: Value, right: Value },
144
145 Unary { input: Value, op: UnaryOp },
147 Binary { left: Value, right: Value, op: BinaryOp },
149
150 Softmax { input: Value, axis: usize },
152 Layernorm { input: Value, axis: usize, eps: Total<f32> },
154
155 Reduce {
157 input: Value,
158 axes: Vec<usize>,
159 op: ReduceOp,
160 },
161 }
163
164#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
165pub struct SliceRange {
166 pub start: usize,
167 pub end: usize,
168 pub step: usize,
169}
170
171#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
174pub enum UnaryOp {
175 Abs,
176 Neg,
177 Sin,
178 Cos,
179 Exp,
180 Log,
181 Sqrt,
182 Sigmoid,
183 Tanh,
184 Erf,
185 Mish,
186 Softplus,
187
188 ValueCast(DType),
191 BitCast(DType),
195}
196
197#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
198pub enum BinaryOp {
199 Add,
200 Sub,
201 Mul,
202 Div,
203 Min,
204 Max,
205 Pow,
206}
207
208#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
209pub enum ReduceOp {
210 Sum,
211 Mean,
214 Prod,
215 Max,
216 Min,
217}
218
219impl Operation {
220 pub fn inputs(&self) -> Vec<Value> {
221 match self {
222 Operation::Input { index: _ } => vec![],
223 Operation::Constant { tensor: _ } => vec![],
224 &Operation::View { input } => vec![input],
225 &Operation::Broadcast { input } => vec![input],
226 &Operation::Permute { input, permutation: _ } => vec![input],
227 &Operation::Slice {
228 input,
229 axis: _,
230 range: _,
231 } => vec![input],
232 &Operation::Flip { input, axis: _ } => vec![input],
233 &Operation::Gather {
234 input,
235 axis: _,
236 indices,
237 } => vec![input, indices],
238 Operation::Concat { inputs, axis: _ } => inputs.clone(),
239 &Operation::Conv {
240 input,
241 filter,
242 details: _,
243 } => vec![input, filter],
244 &Operation::MatMul { left, right } => vec![left, right],
245 &Operation::Unary { input, op: _ } => vec![input],
246 &Operation::Binary { left, right, op: _ } => vec![left, right],
247 &Operation::Softmax { input, axis: _ } => vec![input],
248 &Operation::Layernorm { input, axis: _, eps: _ } => vec![input],
249 &Operation::Reduce { input, axes: _, op: _ } => vec![input],
250 }
251 }
252
253 pub(crate) fn clone_map_inputs(&self, mut f: impl FnMut(Value) -> Value) -> Operation {
254 match self {
255 &Operation::Input { index } => Operation::Input { index },
256 &Operation::Constant { ref tensor } => Operation::Constant { tensor: tensor.clone() },
257 &Operation::View { input } => Operation::View { input: f(input) },
258 &Operation::Broadcast { input } => Operation::Broadcast { input: f(input) },
259 &Operation::Permute { input, ref permutation } => Operation::Permute {
260 input: f(input),
261 permutation: permutation.clone(),
262 },
263 &Operation::Slice { input, axis, range } => Operation::Slice {
264 input: f(input),
265 axis,
266 range,
267 },
268 &Operation::Flip { input, axis } => Operation::Flip { input: f(input), axis },
269 &Operation::Gather { input, axis, indices } => Operation::Gather {
270 input: f(input),
271 axis,
272 indices: f(indices),
273 },
274 &Operation::Concat { ref inputs, axis } => Operation::Concat {
275 inputs: inputs.iter().copied().map(f).collect(),
276 axis,
277 },
278 &Operation::Conv {
279 input,
280 filter,
281 details: conv_shape,
282 } => Operation::Conv {
283 input: f(input),
284 filter: f(filter),
285 details: conv_shape,
286 },
287 &Operation::MatMul { left, right } => Operation::MatMul {
288 left: f(left),
289 right: f(right),
290 },
291 &Operation::Unary { input, op } => Operation::Unary { input: f(input), op },
292 &Operation::Binary { left, right, op } => Operation::Binary {
293 left: f(left),
294 right: f(right),
295 op,
296 },
297 &Operation::Softmax { input, axis } => Operation::Softmax { input: f(input), axis },
298 &Operation::Layernorm { input, axis, eps } => Operation::Layernorm {
299 input: f(input),
300 axis,
301 eps,
302 },
303 &Operation::Reduce { input, ref axes, op } => Operation::Reduce {
304 input: f(input),
305 axes: axes.clone(),
306 op,
307 },
308 }
309 }
310}
311
312#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
313pub struct ConvDetails {
314 pub dtype: DType,
315 pub batch_size: Size,
316
317 pub input_channels: usize,
318 pub output_channels: usize,
319
320 pub input_h: usize,
321 pub input_w: usize,
322 pub kernel_h: usize,
323 pub kernel_w: usize,
324 pub stride_y: usize,
325 pub stride_x: usize,
326 pub padding_y: usize,
327 pub padding_x: usize,
328 pub output_h: usize,
329 pub output_w: usize,
330}
331
332impl ConvDetails {
333 pub fn input_shape(&self) -> Shape {
334 shape![self.batch_size, self.input_channels, self.input_h, self.input_w]
335 }
336
337 pub fn output_shape(&self) -> Shape {
338 shape![self.batch_size, self.output_channels, self.output_h, self.output_w]
339 }
340
341 pub fn keeps_spatial_shape(&self) -> bool {
342 (self.input_h == self.output_h) && (self.input_w == self.output_w)
343 }
344
345 pub fn has_stride(&self) -> bool {
346 self.stride_y != 1 || self.stride_x != 1
347 }
348
349 pub fn kernel_shape(&self) -> [usize; 4] {
350 [self.output_channels, self.input_channels, self.kernel_h, self.kernel_w]
351 }
352}
353
354impl Index<Value> for Graph {
355 type Output = ValueInfo;
356
357 fn index(&self, value: Value) -> &Self::Output {
358 self.check_contains(value);
359 &self.values[value.index]
360 }
361}
362
363impl Graph {
364 pub fn new() -> Self {
365 Graph {
366 check: random(),
367 values: vec![],
368 back_map: HashMap::new(),
369 new_values: vec![],
370 inputs: vec![],
371 outputs: vec![],
372 }
373 }
374
375 fn check_contains(&self, value: Value) {
376 assert_eq!(
377 value.check, self.check,
378 "Value {:?} does not belong to this graph",
379 value
380 );
381 assert!(value.index < self.values.len());
382 }
383
384 pub fn shape_dtype(&self, value: Value) -> (&Shape, DType) {
385 let info = &self[value];
386 (&info.shape, info.dtype)
387 }
388
389 pub fn values(&self) -> impl Iterator<Item = Value> {
392 let check = self.check;
393 (0..self.values.len()).map(move |index| Value { index, check })
394 }
395
396 pub fn inputs(&self) -> &[Value] {
397 &self.inputs
398 }
399
400 pub fn input_shapes(&self) -> Vec<Shape> {
401 self.inputs().iter().map(|&v| self[v].shape.clone()).collect()
402 }
403
404 pub fn outputs(&self) -> &[Value] {
405 &self.outputs
406 }
407
408 pub fn output_shapes(&self) -> Vec<Shape> {
409 self.outputs().iter().map(|&v| self[v].shape.clone()).collect()
410 }
411
412 pub fn outputs_mut(&mut self) -> &mut Vec<Value> {
413 &mut self.outputs
414 }
415
416 pub fn is_hidden(&self, value: Value) -> bool {
417 self.check_contains(value);
418 !self.inputs.contains(&value) && !self.outputs.contains(&value)
419 }
420
421 pub fn is_hidden_with_uses(&self, value: Value, users: usize) -> bool {
422 self.is_hidden(value) && self[value].non_output_uses == users
423 }
424
425 pub fn is_const(&self, value: Value) -> bool {
426 let operation = &self[value].operation;
427 match *operation {
428 Operation::Input { .. } => false,
429 Operation::Constant { .. } => true,
430 _ => operation.inputs().into_iter().all(|input| self.is_const(input)),
431 }
432 }
433
434 pub fn as_const(&self, value: Value) -> Option<DTensor> {
436 let mut cache: HashMap<Value, OperationResult> = HashMap::new();
442
443 let f_cached = |curr| {
444 let mut missing_arg = None;
445
446 let res = run_cpu_const_operation(&self[curr], |arg| {
447 match cache.get(&arg) {
448 Some(Ok(tensor)) => Ok(tensor.clone()),
450 Some(&Err(err)) => Err(err),
451 None => {
453 missing_arg = Some(arg);
454 Err(OperationError::MissingOperand)
456 }
457 }
458 });
459
460 if let Some(missing_arg) = missing_arg {
462 assert_eq!(res, Err(OperationError::MissingOperand));
463 return Err(missing_arg);
464 }
465
466 let prev = cache.insert(curr, res.clone());
467 assert!(prev.is_none());
468
469 Ok(res)
470 };
471
472 let res = heap_recurse(value, f_cached);
473 res.ok()
474 }
475
476 pub fn is_const_filled_with(&self, value: Value, expected: DScalar) -> bool {
478 self.as_single_const(value).map_or(false, |actual| expected == actual)
479 }
480
481 pub fn is_const_zero(&self, value: Value) -> bool {
482 self.is_const_filled_with(value, self[value].dtype.specials().zero)
483 }
484
485 pub fn is_const_one(&self, value: Value) -> bool {
486 self.is_const_filled_with(value, self[value].dtype.specials().one)
487 }
488
489 pub fn as_single_const(&self, value: Value) -> Option<DScalar> {
491 let info = &self[value];
492
493 match info.operation {
494 Operation::Input { .. } => None,
495 Operation::Constant { tensor: WrapDebug(ref tensor) } => dispatch_dtensor!(tensor, |_T, _f, tensor| {
496 let &e = tensor.iter().next()?;
497 tensor.iter().all(|&d| d == e).then(|| e.to_dscalar())
498 }),
499 Operation::View { input } => self.as_single_const(input),
500 Operation::Broadcast { input } => self.as_single_const(input),
501 Operation::Permute { input, permutation: _ } => self.as_single_const(input),
502 Operation::Slice {
503 input,
504 axis: _,
505 range: _,
506 } => self.as_single_const(input),
507 Operation::Flip { input, axis: _ } => self.as_single_const(input),
508 Operation::Gather {
509 input,
510 axis: _,
511 indices: _,
512 } => self.as_single_const(input),
513 Operation::Concat { ref inputs, axis: _ } => {
514 let f = self.as_single_const(*inputs.first()?)?;
515 inputs.iter().all(|&x| self.is_const_filled_with(x, f)).then(|| f)
516 }
517 Operation::Unary { input, op } => Some(op.map(self.as_single_const(input)?)),
518 Operation::Binary { left, right, op } => {
519 Some(op.map(self.as_single_const(left)?, self.as_single_const(right)?))
520 }
521 Operation::Conv { .. }
522 | Operation::MatMul { .. }
523 | Operation::Softmax { .. }
524 | Operation::Layernorm { .. }
525 | Operation::Reduce { .. } => None,
526 }
527 }
528
529 pub fn take_new_values(&mut self) -> Vec<Value> {
531 std::mem::take(&mut self.new_values)
532 }
533
534 #[must_use]
535 pub(crate) fn push(&mut self, shape: Shape, dtype: DType, operation: Operation) -> Value {
536 let check = self.check;
539 let key = (shape.clone(), dtype, operation.clone());
540
541 match self.back_map.get(&key) {
542 Some(&index) => {
543 Value { index, check }
545 }
546 None => {
547 for input in operation.inputs() {
550 self.check_contains(input);
551 self.values[input.index].non_output_uses += 1;
552 }
553
554 let info = ValueInfo {
556 shape,
557 dtype,
558 operation,
559 non_output_uses: 0,
560 debug_id: String::new(),
561 };
562
563 let index = self.values.len();
564 self.values.push(info);
565
566 let value = Value { index, check };
567 self.new_values.push(value);
568
569 self.back_map.insert(key, index);
570
571 value
572 }
573 }
574 }
575
576 pub fn set_debug_id(&mut self, value: Value, id: String) {
579 self.check_contains(value);
580 self.values[value.index].debug_id = id;
581 }
582
583 #[must_use]
585 pub fn input(&mut self, shape: Shape, dtype: DType) -> Value {
586 let index = self.inputs.len();
587 let value = self.push(shape, dtype, Operation::Input { index });
588 self.inputs.push(value);
589 value
590 }
591
592 #[must_use]
593 pub fn constant_tensor(&mut self, tensor: DTensor) -> Value {
594 let shape = Shape::fixed(tensor.shape());
595 self.push(shape, tensor.dtype(), Operation::Constant { tensor: WrapDebug(tensor) })
596 }
597
598 #[must_use]
599 pub fn constant<T: IntoDScalar>(&mut self, shape: Shape, data: Vec<T>) -> Value {
600 let linear = T::vec_to_dtensor(data);
601 let shape = shape.unwrap_fixed("constant shape");
602 let tensor = linear.reshape(shape.dims.as_slice());
603 self.constant_tensor(tensor)
604 }
605
606 #[must_use]
607 pub fn scalar_dyn(&mut self, value: DScalar) -> Value {
608 self.constant_tensor(value.to_tensor())
609 }
610
611 #[must_use]
612 pub fn scalar<T: IntoDScalar>(&mut self, value: T) -> Value {
613 self.scalar_dyn(value.to_dscalar())
614 }
615
616 #[must_use]
618 pub fn view(&mut self, input: Value, new_shape: Shape) -> Value {
619 let (input_shape, dtype) = self.shape_dtype(input);
620 if &new_shape == input_shape {
621 return input;
622 }
623
624 assert_eq!(
625 input_shape.size(),
626 new_shape.size(),
627 "New shape {:?} must have the same size as old shape {:?}",
628 new_shape,
629 input_shape,
630 );
631
632 let inner_input = if let &Operation::View { input: inner_input } = &self[input].operation {
634 inner_input
635 } else {
636 input
637 };
638
639 self.push(new_shape, dtype, Operation::View { input: inner_input })
640 }
641
642 #[must_use]
645 pub fn broadcast(&mut self, input: Value, new_shape: Shape) -> Value {
646 let (input_shape, dtype) = self.shape_dtype(input);
647 let input_shape = input_shape.clone();
648
649 assert!(
650 input_shape.rank() <= new_shape.rank(),
651 "Cannot broadcast to a lower rank shape (from {:?} to {:?})",
652 input_shape,
653 new_shape
654 );
655
656 let view_shape = Shape::ones(new_shape.rank() - input_shape.rank()).concat(&input_shape);
658 let curr = self.view(input, view_shape.clone());
659
660 for (&v, &n) in zip_eq(&view_shape.dims, &new_shape.dims) {
662 assert!(
663 v == n || v == Size::ONE,
664 "Cannot broadcast from {:?} to {:?} because of axis ({}, {})",
665 input_shape,
666 new_shape,
667 v,
668 n
669 );
670 }
671
672 if view_shape == new_shape {
674 return curr;
675 }
676
677 self.push(new_shape, dtype, Operation::Broadcast { input: curr })
679 }
680
681 pub fn repeat_unary(&mut self, input: Value, axis: usize, count: Size) -> Value {
682 let (input_shape, dtype) = self.shape_dtype(input);
683
684 assert_eq!(
685 input_shape[axis],
686 Size::ONE,
687 "Input shape {} does not have dim 1 for axis {}",
688 input_shape,
689 axis
690 );
691
692 if count == Size::ONE {
695 return input;
696 }
697
698 let new_shape = input_shape.replace(axis, shape![count]);
699 self.push(new_shape, dtype, Operation::Broadcast { input })
700 }
701
702 #[must_use]
705 pub fn flatten(&mut self, input: Value, start_axis: usize) -> Value {
706 let old_shape = &self[input].shape;
707 assert!(
708 start_axis <= old_shape.rank(),
709 "Flatten start axis {} out of bounds for {}",
710 start_axis,
711 old_shape,
712 );
713
714 let kept_dims = &old_shape.dims[..start_axis];
715 let flat_size = old_shape.dims[start_axis..].iter().copied().product();
716 let new_shape = Shape::new([kept_dims, &[flat_size]].concat());
717
718 self.view(input, new_shape)
719 }
720
721 #[must_use]
723 pub fn permute(&mut self, input: Value, permutation: Vec<usize>) -> Value {
724 let input_info = &self[input];
725 let input_shape = &input_info.shape;
726
727 assert_eq!(
728 permutation.len(),
729 input_shape.rank(),
730 "Permutation rank must match input shape, got {:?} and {:?}",
731 permutation,
732 input_shape
733 );
734 assert!(
735 permutation.iter().all_unique(),
736 "Permutation cannot contain repeated axis, got {:?}",
737 permutation
738 );
739 assert!(
740 permutation.iter().all(|&i| i < input_shape.rank()),
741 "Permutation axis out of bounds, got {:?}",
742 permutation
743 );
744
745 let (inner_input, full_permutation) = if let &Operation::Permute {
747 input: inner_input,
748 permutation: ref inner_permutation,
749 } = &self[input].operation
750 {
751 let combined = permutation.iter().map(|&i| inner_permutation[i]).collect();
752 (inner_input, combined)
753 } else {
754 (input, permutation)
755 };
756
757 let inner_input_shape = &self[inner_input].shape;
758 let result_dims = full_permutation.iter().map(|&i| inner_input_shape[i]).collect_vec();
759 let result_shape = Shape::new(result_dims);
760
761 self.push(
762 result_shape,
763 input_info.dtype,
764 Operation::Permute {
765 input: inner_input,
766 permutation: full_permutation,
767 },
768 )
769 }
770
771 #[must_use]
773 pub fn slice(&mut self, input: Value, axis: usize, range: SliceRange) -> Value {
774 let input_info = &self[input];
775 let input_shape = &input_info.shape;
776
777 input_shape.assert_has_axis(axis);
778
779 let input_size = input_shape.dims[axis].unwrap_fixed("Slice axis length");
780 range.assert_in_bounds(input_size);
781 let new_size = (range.end - range.start) / range.step;
782
783 if range == SliceRange::new(0, input_size, 1) {
785 return input;
786 }
787
788 let new_shape = input_shape.replace(axis, shape![new_size]);
789 self.push(new_shape, input_info.dtype, Operation::Slice { input, axis, range })
790 }
791
792 #[must_use]
795 pub fn index(&mut self, input: Value, axis: usize, index: usize) -> Value {
796 let new_shape = self[input].shape.replace(axis, shape![]);
797 let sliced = self.slice(input, axis, SliceRange::single(index));
798 self.view(sliced, new_shape)
799 }
800
801 pub fn flip(&mut self, input: Value, axis: usize) -> Value {
803 let input_info = &self[input];
804 let input_shape = input_info.shape.clone();
805
806 input_shape.assert_has_axis(axis);
807
808 self.push(input_shape, input_info.dtype, Operation::Flip { input, axis })
809 }
810
811 pub fn repeat(&mut self, input: Value, axis: usize, count: Size) -> Value {
817 self.repeat_impl(input, axis, count, false)
818 }
819
820 pub fn repeat_interleave(&mut self, input: Value, axis: usize, count: Size) -> Value {
826 self.repeat_impl(input, axis, count, true)
827 }
828
829 fn repeat_impl(&mut self, input: Value, axis: usize, count: Size, inner: bool) -> Value {
830 let input_shape = self[input].shape.clone();
831 input_shape.assert_has_axis(axis);
832
833 if input_shape[axis] == Size::ONE {
836 return self.repeat_unary(input, axis, count);
837 }
838
839 let new_size = input_shape[axis] * count;
840 let dummy_axis = if inner { axis + 1 } else { axis };
841
842 let extra = self.view(input, input_shape.insert(dummy_axis, Size::ONE));
844 let broad = self.repeat_unary(extra, dummy_axis, count);
845 let result = self.view(broad, input_shape.replace(axis, shape![new_size]));
846
847 result
848 }
849
850 #[must_use]
854 pub fn gather(&mut self, input: Value, axis: usize, indices: Value) -> Value {
855 let (input_shape, dtype) = self.shape_dtype(input);
856 let (indices_shape, indices_dtype) = self.shape_dtype(indices);
857
858 input_shape.assert_has_axis(axis);
859 assert!(
860 indices_dtype.is_int(),
861 "Indices must be integers, got {:?}",
862 indices_dtype
863 );
864
865 let result_shape = input_shape.replace(axis, indices_shape.clone());
866 let result_shape_flat = input_shape.replace(axis, shape![indices_shape.size()]);
867
868 let flat_indices = self.flatten(indices, 0);
870 let flat_size = self[flat_indices].shape.unwrap_1();
871
872 let result_flat = if let Some(index) = self.as_single_const(indices) {
873 let index: usize = index.unwrap_int().unwrap().try_into().unwrap();
875
876 let result_flat_single = self.slice(input, axis, SliceRange::single(index));
877 let result_flat = self.repeat(result_flat_single, axis, flat_size);
878
879 assert_eq!(self[result_flat].shape, result_shape_flat);
880 result_flat
881 } else {
882 self.push(
884 result_shape_flat,
885 dtype,
886 Operation::Gather {
887 input,
888 axis,
889 indices: flat_indices,
890 },
891 )
892 };
893
894 let result = self.view(result_flat, result_shape);
895 result
896 }
897
898 #[must_use]
901 pub fn concat(
902 &mut self,
903 inputs: Vec<Value>,
904 axis: usize,
905 base_shape: Option<Shape>,
906 dtype: Option<DType>,
907 ) -> Value {
908 let base_shape = base_shape.unwrap_or_else(|| {
911 assert!(
912 !inputs.is_empty(),
913 "Cannot infer concatenation shape without any inputs"
914 );
915 self[inputs[0]].shape.replace(axis, shape![0])
916 });
917 let dtype = dtype.unwrap_or_else(|| {
918 assert!(
919 !inputs.is_empty(),
920 "Cannot infer concatenation dtype without any inputs"
921 );
922 self[inputs[0]].dtype
923 });
924
925 let size_along_axis = inputs
926 .iter()
927 .map(|&v| {
928 assert_eq!(
929 self[v].shape.replace(axis, shape![0]),
930 base_shape,
931 "All concatenated values must match base shape on non-concatenated axes"
932 );
933 assert_eq!(self[v].dtype, dtype, "All concatenated values must have the same dtype");
934 self[v].shape.dims[axis]
935 })
936 .sum::<Option<Size>>()
937 .unwrap_or_else(|| {
938 let input_shapes = inputs.iter().map(|&v| &self[v].shape).collect_vec();
939 panic!("Could not add all concatenation sizes: {:?}", input_shapes);
940 });
941
942 let result_shape = base_shape.replace(axis, shape![size_along_axis]);
943
944 let mut inputs = inputs;
946 inputs.retain(|&x| self[x].shape.size() != Size::ZERO);
947
948 if inputs.len() == 1 {
950 return inputs[0];
951 }
952
953 self.push(result_shape, dtype, Operation::Concat { inputs, axis })
954 }
955
956 pub fn pad(&mut self, input: Value, pad_amount: &[(usize, usize)], pad_value: Value) -> Value {
959 let (input_shape, dtype) = self.shape_dtype(input);
960 let (pad_value_shape, pad_value_dtype) = self.shape_dtype(pad_value);
961
962 assert_eq!(input_shape.rank(), pad_amount.len(), "Padding length must match input rank");
963 assert_eq!(dtype, pad_value_dtype, "Padding value dtype must match input dtype");
964 assert_eq!(pad_value_shape, &Shape::SCALAR, "Padding value must be scalar");
965
966 pad_amount.iter().enumerate().fold(input, |curr, (i, &(before, after))| {
968 let curr_shape = self[curr].shape.clone();
969 let before = self.broadcast(pad_value, curr_shape.replace(i, shape![before]));
970 let after = self.broadcast(pad_value, curr_shape.replace(i, shape![after]));
971 self.concat(vec![before, curr, after], i, None, None)
972 })
973 }
974
975 #[must_use]
977 pub fn conv(
978 &mut self,
979 input: Value,
980 filter: Value,
981 stride_y: usize,
982 stride_x: usize,
983 padding_y: usize,
984 padding_x: usize,
985 ) -> Value {
986 let (input_shape, input_dtype) = self.shape_dtype(input);
987 let (filter_shape, filter_dtype) = self.shape_dtype(filter);
988 assert_eq!(
989 input_dtype, filter_dtype,
990 "Convolution input and filter must have the same dtype"
991 );
992 let dtype = input_dtype;
993
994 let [batch_size, in_c, in_h, in_w]: [Size; 4] = input_shape
995 .dims
996 .as_slice()
997 .try_into()
998 .expect("Convolution input must have rank 4");
999 let [out_c, in_c_check, k_h, k_w]: [Size; 4] = filter_shape
1000 .dims
1001 .as_slice()
1002 .try_into()
1003 .expect("Convolution filter must have rank 4");
1004
1005 let input_channels = in_c.unwrap_fixed("Conv input channels");
1007 let input_h = in_h.unwrap_fixed("Conv input height");
1008 let input_w = in_w.unwrap_fixed("Conv input width");
1009 let output_channels = out_c.unwrap_fixed("Conv output channels");
1010 let in_c_check = in_c_check.unwrap_fixed("Filter input channels");
1011 let kernel_h = k_h.unwrap_fixed("Conv kernel height");
1012 let kernel_w = k_w.unwrap_fixed("Conv kernel width");
1013
1014 assert_eq!(1, kernel_h % 2, "Kernel height must be odd, got {}", kernel_h);
1015 assert_eq!(1, kernel_w % 2, "Kernel width must be odd, got {}", kernel_w);
1016
1017 assert_eq!(input_channels, in_c_check, "Input channel mismatch");
1018
1019 let padded_input_h = input_h + 2 * padding_y;
1020 let padded_input_w = input_w + 2 * padding_x;
1021 assert!(
1022 padded_input_h >= kernel_h && padded_input_w >= kernel_w,
1023 "Kernel must fit inside of padded input"
1024 );
1025
1026 let output_h = (padded_input_h - (kernel_h - 1) - 1) / stride_y + 1;
1028 let output_w = (padded_input_w - (kernel_w - 1) - 1) / stride_x + 1;
1029 let output_shape = shape![batch_size, output_channels, output_h, output_w];
1030
1031 let details = ConvDetails {
1032 dtype,
1033 batch_size,
1034 input_channels,
1035 output_channels,
1036 input_h,
1037 input_w,
1038 kernel_h,
1039 kernel_w,
1040 stride_y,
1041 stride_x,
1042 padding_y,
1043 padding_x,
1044 output_h,
1045 output_w,
1046 };
1047 self.push(output_shape, input_dtype, Operation::Conv { input, details, filter })
1048 }
1049
1050 #[must_use]
1053 pub fn linear(&mut self, input: Value, weight: Value) -> Value {
1054 let weight_transposed = self.permute(weight, vec![1, 0]);
1055 self.mat_mul(input, weight_transposed)
1056 }
1057
1058 #[must_use]
1063 pub fn mat_mul(&mut self, left: Value, right: Value) -> Value {
1064 let left_shape = &self[left].shape;
1065 let right_shape = &self[right].shape;
1066
1067 assert!(
1068 left_shape.rank() >= 2 && right_shape.rank() >= 2,
1069 "Matmul operands must have rank >= 2, got shapes {} and {}",
1070 left_shape,
1071 right_shape
1072 );
1073
1074 let (left_head, left_tail) = left_shape.split(left_shape.rank() - 2);
1075 let (right_head, right_tail) = right_shape.split(right_shape.rank() - 2);
1076
1077 let [m, n0] = left_tail.unwrap_2();
1079 let [n1, p] = right_tail.unwrap_2();
1080 assert_eq!(
1081 n0, n1,
1082 "Inner matmul dimension must match, got shapes {} and {}",
1083 left_shape, right_shape
1084 );
1085 let result_tail = shape![m, p];
1086
1087 let result_head = broadcast_shape_symmetric(&left_head, &right_head);
1089 let batch_size = result_head.size();
1090 let left_broadcast = self.broadcast(left, result_head.clone().concat(&left_tail));
1091 let right_broadcast = self.broadcast(right, result_head.clone().concat(&right_tail));
1092
1093 let left_flat = self.view(left_broadcast, left_tail.insert(0, batch_size));
1095 let right_flat = self.view(right_broadcast, right_tail.insert(0, batch_size));
1096 let result_flat = self.batched_mat_mul(left_flat, right_flat);
1097
1098 let result = self.view(result_flat, result_head.concat(&result_tail));
1100 result
1101 }
1102
1103 #[must_use]
1106 pub fn batched_mat_mul(&mut self, left: Value, right: Value) -> Value {
1107 let (left_shape, left_dtype) = self.shape_dtype(left);
1108 let (right_shape, right_dtype) = self.shape_dtype(right);
1109 assert_eq!(left_dtype, right_dtype, "Matmul operands must have same dtype");
1110
1111 let [b0, m, n0] = left_shape.unwrap_3();
1112 let [b1, n1, p] = right_shape.unwrap_3();
1113
1114 assert!(
1115 b0 == b1 && n0 == n1,
1116 "Batched matmul dimension mismatch, got shapes {} and {}",
1117 left_shape,
1118 right_shape
1119 );
1120
1121 let result_shape = shape![b0, m, p];
1122 self.push(result_shape, left_dtype, Operation::MatMul { left, right })
1123 }
1124
1125 #[must_use]
1126 pub fn softmax(&mut self, input: Value, axis: usize) -> Value {
1127 let (input_shape, input_dtype) = self.shape_dtype(input);
1128 assert_eq!(input_dtype, DType::F32, "Softmax input must be f32");
1129 input_shape.assert_has_axis(axis);
1130
1131 let new_shape = input_shape.clone();
1132 self.push(new_shape, input_dtype, Operation::Softmax { input, axis })
1133 }
1134
1135 #[must_use]
1136 pub fn layernorm(&mut self, input: Value, axis: usize, eps: f32) -> Value {
1137 let (input_shape, input_dtype) = self.shape_dtype(input);
1138 assert_eq!(input_dtype, DType::F32, "Softmax input must be f32");
1139 input_shape.assert_has_axis(axis);
1140
1141 let new_shape = input_shape.clone();
1142 self.push(
1143 new_shape,
1144 input_dtype,
1145 Operation::Layernorm {
1146 input,
1147 axis,
1148 eps: Total::from(eps),
1149 },
1150 )
1151 }
1152
1153 #[must_use]
1156 pub fn reduce(&mut self, input: Value, axes: Vec<usize>, op: ReduceOp) -> Value {
1157 let (input_shape, dtype) = self.shape_dtype(input);
1158
1159 for &axis in &axes {
1161 input_shape.assert_has_axis(axis);
1162 }
1163 match op {
1164 ReduceOp::Mean => assert_eq!(dtype, DType::F32, "Softmax input must be f32"),
1165 ReduceOp::Sum | ReduceOp::Prod | ReduceOp::Max | ReduceOp::Min => {}
1166 }
1167
1168 if axes.is_empty() {
1170 return input;
1171 }
1172
1173 let new_shape = input_shape.replace_all(&axes, shape![]);
1174 self.push(new_shape, dtype, Operation::Reduce { input, axes, op })
1175 }
1176
1177 #[must_use]
1179 pub fn sigmoid(&mut self, input: Value) -> Value {
1180 self.unary(UnaryOp::Sigmoid, input)
1181 }
1182
1183 #[must_use]
1185 pub fn relu(&mut self, input: Value) -> Value {
1186 let (_, dtype) = self.shape_dtype(input);
1187 let specials = dtype.specials();
1188 self.clamp_dyn(input, specials.zero, specials.max)
1189 }
1190
1191 #[must_use]
1193 pub fn clamp_dyn(&mut self, input: Value, min: DScalar, max: DScalar) -> Value {
1194 let (_, dtype) = self.shape_dtype(input);
1195 assert!(
1196 dtype == min.dtype() && dtype == max.dtype(),
1197 "Clamp bounds must match value type, got min={:?} and max={:?} for {:?}",
1198 min,
1199 max,
1200 dtype
1201 );
1202
1203 let mut curr = input;
1206 let specials = dtype.specials();
1207
1208 if max != specials.max {
1209 let max_value = self.scalar_dyn(max);
1210 curr = self.binary(BinaryOp::Min, curr, max_value);
1211 }
1212
1213 if min != specials.min {
1214 let min_value = self.scalar_dyn(min);
1215 curr = self.binary(BinaryOp::Max, curr, min_value);
1216 }
1217
1218 curr
1219 }
1220
1221 #[must_use]
1222 pub fn clamp<T: IntoDScalar>(&mut self, input: Value, min: T, max: T) -> Value {
1223 self.clamp_dyn(input, min.to_dscalar(), max.to_dscalar())
1224 }
1225
1226 #[must_use]
1227 pub fn add(&mut self, left: Value, right: Value) -> Value {
1228 self.binary(BinaryOp::Add, left, right)
1229 }
1230
1231 #[must_use]
1232 pub fn sub(&mut self, left: Value, right: Value) -> Value {
1233 self.binary(BinaryOp::Sub, left, right)
1234 }
1235
1236 #[must_use]
1237 pub fn mul(&mut self, left: Value, right: Value) -> Value {
1238 self.binary(BinaryOp::Mul, left, right)
1239 }
1240
1241 #[must_use]
1242 pub fn pow(&mut self, left: Value, right: Value) -> Value {
1243 self.binary(BinaryOp::Pow, left, right)
1244 }
1245
1246 #[must_use]
1248 pub fn unary(&mut self, op: UnaryOp, mut input: Value) -> Value {
1249 let (shape, input_dtype) = self.shape_dtype(input);
1250
1251 let output_dtype = match op.output_dtype(input_dtype) {
1252 Some(d) => d,
1253 None => panic!("Operation {:?} not supported on dtype {:?}", op, input_dtype),
1254 };
1255
1256 if let UnaryOp::ValueCast(_) | UnaryOp::BitCast(_) = op {
1258 if output_dtype == input_dtype {
1259 return input;
1260 }
1261 }
1262
1263 if let UnaryOp::BitCast(_) = op {
1267 while let &Operation::Unary {
1268 op: UnaryOp::BitCast(_),
1269 input: inner,
1270 } = &self[input].operation
1271 {
1272 input = inner;
1273 }
1274 }
1275
1276 self.push(shape.clone(), output_dtype, Operation::Unary { op, input })
1277 }
1278
1279 #[must_use]
1282 pub fn binary(&mut self, op: BinaryOp, left: Value, right: Value) -> Value {
1283 let (left_shape, left_dtype) = self.shape_dtype(left);
1286 let (right_shape, right_dtype) = self.shape_dtype(right);
1287
1288 let result_shape = broadcast_shape_symmetric(left_shape, right_shape);
1289 assert_eq!(
1290 left_dtype, right_dtype,
1291 "Binary operation {:?} requires matching dtypes, got {:?} and {:?}",
1292 op, left_dtype, right_dtype
1293 );
1294 let dtype = left_dtype;
1295
1296 let skip = match op {
1298 BinaryOp::Sub | BinaryOp::Add => self.is_const_zero(right),
1299 BinaryOp::Mul | BinaryOp::Div | BinaryOp::Pow => self.is_const_one(right),
1300 BinaryOp::Min => self.is_const_filled_with(right, dtype.specials().max),
1301 BinaryOp::Max => self.is_const_filled_with(right, dtype.specials().min),
1302 };
1303 if skip {
1306 return left;
1307 }
1308
1309 let left = self.broadcast(left, result_shape.clone());
1310 let right = self.broadcast(right, result_shape.clone());
1311
1312 self.push(result_shape, dtype, Operation::Binary { left, right, op })
1313 }
1314
1315 #[must_use]
1319 pub fn call(&mut self, graph: &Graph, inputs: &[Value]) -> Vec<Value> {
1320 assert_eq!(inputs.len(), graph.inputs.len(), "Wrong number of inputs");
1322 for (&input, &graph_input) in zip_eq(inputs, &graph.inputs) {
1323 assert_eq!(self[input].shape, graph[graph_input].shape, "Wrong input shape");
1324 }
1325
1326 let mut map = HashMap::new();
1327
1328 for graph_value in graph.values() {
1330 let graph_info = &graph[graph_value];
1331
1332 let shape = graph_info.shape.clone();
1333 let graph_operation = &graph_info.operation;
1334
1335 let value = if let &Operation::Input { index } = graph_operation {
1336 inputs[index]
1337 } else {
1338 let operation = graph_info.operation.clone_map_inputs(|p| *map.get(&p).unwrap());
1339 self.push(shape, graph_info.dtype, operation)
1340 };
1341
1342 map.insert(graph_value, value);
1343 }
1344
1345 graph
1347 .outputs()
1348 .iter()
1349 .map(|graph_value| *map.get(graph_value).unwrap())
1350 .collect_vec()
1351 }
1352
1353 pub fn output(&mut self, value: Value) {
1355 self.outputs.push(value);
1356 }
1357
1358 pub fn output_all(&mut self, values: &[Value]) {
1360 for &value in values {
1361 self.output(value)
1362 }
1363 }
1364
1365 pub fn extract_subgraph(&self, value: Value, depth: u32) -> Graph {
1369 fn extract_impl(
1370 graph: &Graph,
1371 sub: &mut Graph,
1372 map: &mut HashMap<Value, Value>,
1373 old: Value,
1374 depth: u32,
1375 ) -> Value {
1376 if let Some(&new) = map.get(&old) {
1378 return new;
1379 }
1380
1381 let &ValueInfo {
1382 ref shape,
1383 dtype,
1384 operation: ref old_op,
1385 ref debug_id,
1386 non_output_uses: _,
1387 } = &graph[old];
1388
1389 let new = if depth == 0 {
1390 sub.input(shape.clone(), dtype)
1392 } else {
1393 let new_op = old_op.clone_map_inputs(|p| extract_impl(graph, sub, map, p, depth - 1));
1395 sub.push(shape.clone(), dtype, new_op)
1396 };
1397
1398 sub.set_debug_id(new, debug_id.clone());
1399 let prev = map.insert(old, new);
1400 assert_eq!(prev, None);
1401
1402 new
1403 }
1404
1405 let mut sub = Graph::new();
1406 let mut map = HashMap::new();
1407
1408 let new = extract_impl(self, &mut sub, &mut map, value, depth);
1409 sub.output(new);
1410
1411 sub
1412 }
1413
1414 pub fn dummy_zero_inputs(&self, batch_size: usize) -> Vec<DTensor> {
1417 self.inputs()
1419 .iter()
1420 .map(|&v| {
1421 let dtype = self[v].dtype;
1422 dispatch_dtype!(dtype, |_T, _fs, ft| ft(Tensor::zeros(
1423 self[v].shape.eval(batch_size).dims
1424 )))
1425 })
1426 .collect_vec()
1427 }
1428}
1429
1430pub fn broadcast_shape_symmetric(left: &Shape, right: &Shape) -> Shape {
1432 let rank = max(left.rank(), right.rank());
1433
1434 let left = Shape::ones(rank - left.rank()).concat(&left);
1436 let right = Shape::ones(rank - right.rank()).concat(&right);
1437
1438 let result = zip_eq(&left.dims, &right.dims)
1440 .map(|(&l, &r)| match (l, r) {
1441 (Size::ONE, other) | (other, Size::ONE) => other,
1442 (any, other) if any == other => any,
1443 _ => panic!("Cannot broadcast {} and {} in shapes {} and {}", l, r, left, right),
1444 })
1445 .collect_vec();
1446
1447 Shape::new(result)
1448}
1449
1450pub fn broadcast_tensors_symmetric<'l, 'r, L, R>(
1451 left: &'l Tensor<L>,
1452 right: &'r Tensor<R>,
1453) -> (ArrayView<'l, L, IxDyn>, ArrayView<'r, R, IxDyn>) {
1454 let result_shape = broadcast_shape_symmetric(&Shape::fixed(left.shape()), &Shape::fixed(right.shape()));
1455 let result_shape = result_shape.as_fixed().unwrap().dims;
1456
1457 let left = left.broadcast(result_shape.clone()).unwrap();
1458 let right = right.broadcast(result_shape).unwrap();
1459
1460 (left, right)
1461}
1462
1463impl Debug for Graph {
1464 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1465 f.debug_struct("Graph")
1466 .field("inputs", &self.inputs().iter().map(|&v| &self[v].shape).collect_vec())
1467 .field("outputs", &self.outputs().iter().map(|&v| &self[v].shape).collect_vec())
1468 .finish_non_exhaustive()
1469 }
1470}
1471
1472impl Display for Graph {
1474 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1475 let Graph {
1476 check,
1477 values,
1478 back_map: _,
1479 new_values: _,
1480 inputs,
1481 outputs,
1482 } = self;
1483
1484 writeln!(f, "Graph {{")?;
1485 writeln!(f, " check: {},", self.check)?;
1486
1487 let input_shapes = self.inputs().iter().map(|&v| &self[v].shape).collect_vec();
1488 let output_shapes = self.outputs().iter().map(|&v| &self[v].shape).collect_vec();
1489 writeln!(f, " input_shapes: {:?},", input_shapes)?;
1490 writeln!(f, " output_shapes: {:?},", output_shapes)?;
1491 writeln!(f, " inputs: {:?},", inputs)?;
1492 writeln!(f, " outputs: {:?},", outputs)?;
1493
1494 writeln!(f, " values: [")?;
1495 for (i, info) in values.iter().enumerate() {
1496 writeln!(
1497 f,
1498 " {:?} = {:?},",
1499 Value {
1500 index: i,
1501 check: *check,
1502 },
1503 info
1504 )?;
1505 }
1506 writeln!(f, " ],")?;
1507
1508 writeln!(f, "}}")?;
1509
1510 Ok(())
1511 }
1512}
1513
1514impl Value {
1515 pub fn index(self) -> usize {
1516 self.index
1517 }
1518}
1519
1520impl Debug for Value {
1521 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1522 let Value { index, check } = self;
1523 if f.alternate() {
1524 write!(f, "Value {{ index: {}, check: {} }}", index, check)
1525 } else {
1526 write!(f, "Value({})", index)
1527 }
1528 }
1529}
1530
1531impl Display for SliceRange {
1532 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1533 if self.step == 1 {
1534 write!(f, "{}:{}", self.start, self.end)
1535 } else {
1536 write!(f, "{}:{}:{}", self.start, self.end, self.step)
1537 }
1538 }
1539}
1540
1541impl From<std::ops::Range<usize>> for SliceRange {
1542 fn from(range: std::ops::Range<usize>) -> Self {
1543 let std::ops::Range { start, end } = range;
1544 SliceRange::simple(start, end)
1545 }
1546}
1547
1548impl SliceRange {
1550 pub fn new(start: usize, end: usize, step: usize) -> Self {
1551 let result = Self { start, end, step };
1552 result.assert_valid();
1553 result
1554 }
1555
1556 pub fn simple(start: usize, end: usize) -> Self {
1557 Self::new(start, end, 1)
1558 }
1559
1560 pub fn single(index: usize) -> Self {
1561 Self::new(index, index + 1, 1)
1562 }
1563
1564 pub fn empty() -> Self {
1565 Self::new(0, 0, 1)
1566 }
1567
1568 pub fn assert_valid(self) {
1569 assert!(
1570 self.end >= self.start,
1571 "Invalid range {:?}: bounds cannot be decreasing",
1572 self,
1573 );
1574
1575 assert_ne!(self.step, 0, "Invalid range {:?}: step cannot be 0", self);
1576
1577 assert_eq!(
1578 (self.end - self.start) % self.step,
1579 0,
1580 "Invalid range {:?}: bounds must differ by a multiple of step",
1581 self
1582 );
1583 }
1584
1585 pub fn assert_in_bounds(self, size: usize) {
1586 self.assert_valid();
1587
1588 assert!(
1589 self.start == self.end || (self.start < size && self.end - (self.step - 1) <= size),
1590 "{:?} out of bounds for axis of size {}",
1591 self,
1592 size
1593 )
1594 }
1595}
1596
1597impl UnaryOp {
1598 pub const ALL: &'static [Self] = &[
1599 UnaryOp::Abs,
1600 UnaryOp::Neg,
1601 UnaryOp::Sin,
1602 UnaryOp::Cos,
1603 UnaryOp::Exp,
1604 UnaryOp::Log,
1605 UnaryOp::Sqrt,
1606 UnaryOp::Sigmoid,
1607 UnaryOp::Tanh,
1608 UnaryOp::Erf,
1609 UnaryOp::Mish,
1610 UnaryOp::Softplus,
1611 ];
1612
1613 pub fn output_dtype(self, x: DType) -> Option<DType> {
1614 match self {
1615 UnaryOp::Abs | UnaryOp::Neg => {
1616 if x.is_signed() {
1617 Some(x)
1618 } else {
1619 None
1620 }
1621 }
1622 UnaryOp::Sin
1623 | UnaryOp::Cos
1624 | UnaryOp::Exp
1625 | UnaryOp::Log
1626 | UnaryOp::Sqrt
1627 | UnaryOp::Sigmoid
1628 | UnaryOp::Tanh
1629 | UnaryOp::Erf
1630 | UnaryOp::Mish
1631 | UnaryOp::Softplus => {
1632 if x.is_float() {
1633 Some(x)
1634 } else {
1635 None
1636 }
1637 }
1638 UnaryOp::ValueCast(y) => Some(y),
1639 UnaryOp::BitCast(y) => {
1640 if x.size() == y.size() {
1641 Some(y)
1642 } else {
1643 None
1644 }
1645 }
1646 }
1647 }
1648
1649 pub fn map(self, x: DScalar) -> DScalar {
1650 macro_rules! map_float {
1651 ($x:expr, |$inner:ident| $result:expr) => {{
1652 use $crate::dtype::{DScalar, T32, T64};
1653 match $x {
1654 DScalar::F32(T32($inner)) => DScalar::f32($result),
1655 DScalar::F64(T64($inner)) => DScalar::f64($result),
1656 _ => unreachable!("Invalid dtype of {:?} for float operation {:?}", $x, self),
1657 }
1658 }};
1659 }
1660 let y = match self {
1661 UnaryOp::Abs => {
1662 assert!(x.dtype().is_signed(), "Cannot take abs of unsigned scalar");
1663 match x {
1664 DScalar::F32(x) => DScalar::f32(x.abs()),
1665 DScalar::F64(x) => DScalar::f64(x.abs()),
1666 DScalar::I8(x) => DScalar::I8(x.abs()),
1667 DScalar::I16(x) => DScalar::I16(x.abs()),
1668 DScalar::I32(x) => DScalar::I32(x.abs()),
1669 DScalar::I64(x) => DScalar::I64(x.abs()),
1670 DScalar::U8(_) | DScalar::U16(_) | DScalar::U32(_) | DScalar::U64(_) | DScalar::Bool(_) => {
1671 unreachable!()
1672 }
1673 }
1674 }
1675 UnaryOp::Neg => {
1676 assert!(x.dtype().is_signed(), "Cannot negate unsigned scalar");
1677 match x {
1678 DScalar::F32(x) => DScalar::f32(-*x),
1679 DScalar::F64(x) => DScalar::f64(-*x),
1680 DScalar::I8(x) => DScalar::I8(-x),
1681 DScalar::I16(x) => DScalar::I16(-x),
1682 DScalar::I32(x) => DScalar::I32(-x),
1683 DScalar::I64(x) => DScalar::I64(-x),
1684 DScalar::U8(_) | DScalar::U16(_) | DScalar::U32(_) | DScalar::U64(_) | DScalar::Bool(_) => {
1685 unreachable!()
1686 }
1687 }
1688 }
1689 UnaryOp::Sin => map_float!(x, |x| x.sin()),
1690 UnaryOp::Cos => map_float!(x, |x| x.cos()),
1691 UnaryOp::Exp => map_float!(x, |x| x.exp()),
1692 UnaryOp::Log => map_float!(x, |x| x.ln()),
1693 UnaryOp::Sqrt => map_float!(x, |x| x.sqrt()),
1694 UnaryOp::Sigmoid => map_float!(x, |x| 1.0 / (1.0 + (-x).exp())),
1695 UnaryOp::Tanh => map_float!(x, |x| x.tanh()),
1696 UnaryOp::Erf => map_float!(x, |x| erf(x as f64) as _),
1697 UnaryOp::Mish => map_float!(x, |x| x * (x.exp().ln_1p().tanh())),
1698 UnaryOp::Softplus => map_float!(x, |x| (-x.abs()).exp().ln_1p() + x.max(0.0)),
1699 UnaryOp::ValueCast(to) => x.value_cast(to),
1700 UnaryOp::BitCast(to) => x.bit_cast(to).unwrap(),
1701 };
1702
1703 debug_assert_eq!(self.output_dtype(x.dtype()), Some(y.dtype()));
1704 y
1705 }
1706}
1707
1708impl BinaryOp {
1709 pub const ALL: &'static [Self] = &[
1710 BinaryOp::Add,
1711 BinaryOp::Sub,
1712 BinaryOp::Mul,
1713 BinaryOp::Div,
1714 BinaryOp::Pow,
1715 BinaryOp::Min,
1716 BinaryOp::Max,
1717 ];
1718
1719 pub fn map(self, left: DScalar, right: DScalar) -> DScalar {
1720 match self {
1721 BinaryOp::Add => map_dscalar_pair!(left, right, |left, right| left + right),
1722 BinaryOp::Sub => map_dscalar_pair!(left, right, |left, right| left - right),
1723 BinaryOp::Mul => map_dscalar_pair!(left, right, |left, right| left * right),
1724 BinaryOp::Div => map_dscalar_pair!(left, right, |left, right| left / right),
1725 BinaryOp::Pow => DScalar::f32(left.unwrap_f32().unwrap().powf(right.unwrap_f32().unwrap())),
1727 BinaryOp::Min => map_dscalar_pair!(left, right, |left, right| left.min(right)),
1728 BinaryOp::Max => map_dscalar_pair!(left, right, |left, right| left.max(right)),
1729 }
1730 }
1731
1732 pub fn map_t<T: IntoDScalar>(self, left: T, right: T) -> T {
1733 T::from_dscalar(self.map(left.to_dscalar(), right.to_dscalar())).unwrap()
1734 }
1735}
1736
1737impl ReduceOp {
1738 pub const ALL: &'static [Self] = &[
1739 ReduceOp::Sum,
1740 ReduceOp::Mean,
1741 ReduceOp::Prod,
1742 ReduceOp::Min,
1743 ReduceOp::Max,
1744 ];
1745
1746 pub fn identity(self, dtype: DType) -> DScalar {
1747 let specials = dtype.specials();
1748 match self {
1749 ReduceOp::Sum | ReduceOp::Mean => specials.zero,
1750 ReduceOp::Prod => specials.one,
1751 ReduceOp::Min => specials.max,
1752 ReduceOp::Max => specials.min,
1753 }
1754 }
1755
1756 pub fn identity_t<T: IntoDScalar>(self) -> T {
1757 T::from_dscalar(self.identity(T::DTYPE)).unwrap()
1758 }
1759
1760 pub fn operation(self) -> (BinaryOp, bool) {
1761 match self {
1762 ReduceOp::Sum => (BinaryOp::Add, false),
1763 ReduceOp::Mean => (BinaryOp::Add, true),
1764 ReduceOp::Prod => (BinaryOp::Mul, false),
1765 ReduceOp::Min => (BinaryOp::Min, false),
1766 ReduceOp::Max => (BinaryOp::Max, false),
1767 }
1768 }
1769
1770 pub fn reduce_t<T: IntoDScalar>(self, seq: impl IntoIterator<Item = T>) -> T {
1771 let (op, is_mean) = self.operation();
1772
1773 let mut count = 0;
1774 let total = seq.into_iter().fold(self.identity_t(), |acc, x| {
1775 count += 1;
1776 op.map_t(acc, x)
1777 });
1778
1779 if is_mean {
1780 let total = total.to_dscalar().unwrap_f32().unwrap();
1782 T::from_dscalar(DScalar::f32(total / count as f32)).unwrap()
1783 } else {
1784 total
1785 }
1786 }
1787}
1788
1789pub fn erf(x: f64) -> f64 {
1792 let sign = x.signum();
1794 let x_abs = x.abs();
1795
1796 const A: &[f64] = &[
1797 1.0,
1798 0.0705230784,
1799 0.0422820123,
1800 0.0092705272,
1801 0.0001520143,
1802 0.0002765672,
1803 0.0000430638,
1804 ];
1805
1806 let d: f64 = A
1807 .iter()
1808 .copied()
1809 .enumerate()
1810 .map(|(i, a)| a * x_abs.powi(i as i32))
1811 .sum();
1812 let y_abs = 1.0 - 1.0 / d.powi(16);
1813
1814 sign * y_abs
1815}