1use crate::prelude::{TensorExpr, TensorId};
6use crate::TensorBase;
7use nd::{Array, Dimension, IntoDimension, RawData, RawDataClone};
8use num::Float;
9use std::collections::HashMap;
10
11#[cfg(feature = "std")]
13pub fn hash_dim<D>(dim: impl IntoDimension<Dim = D>) -> u64
14where
15 D: Dimension,
16{
17 use std::hash::{DefaultHasher, Hash, Hasher};
18 let dim = dim.into_dimension();
19 let mut s = DefaultHasher::new();
20 for i in dim.slice() {
21 i.hash(&mut s);
22 }
23 s.finish()
24}
25
26pub fn linarr<A, D>(dim: impl IntoDimension<Dim = D>) -> Array<A, D>
27where
28 A: Float,
29 D: Dimension,
30{
31 let dim = dim.into_dimension();
32 let dview = dim.as_array_view();
33 let n = dview.product();
34 Array::linspace(A::zero(), A::from(n).unwrap() - A::one(), n)
35 .into_shape(dim)
36 .expect("linspace err")
37}
38
39pub(crate) fn walk<S>(
40 scope: TensorBase<S>,
41 nodes: Vec<TensorBase<S>>,
42 visited: &mut HashMap<TensorId, bool>,
43) -> (bool, Vec<TensorBase<S>>)
44where
45 S: RawData + RawDataClone,
46{
47 if let Some(&tg) = visited.get(&scope.id()) {
48 return (tg, nodes);
49 }
50 let mut track = false;
52 let mut nodes = if scope.is_variable() {
54 track = true;
56 nodes
57 } else if let Some(op) = scope.op() {
58 match op {
59 TensorExpr::Binary { lhs, rhs, .. } => {
60 let (tg, nodes) = walk(*lhs.clone(), nodes, visited);
61 track |= tg;
62 let (tg, nodes) = walk(*rhs.clone(), nodes, visited);
63 track |= tg;
64 nodes
65 }
66 TensorExpr::Unary { recv, .. } => {
67 let (tg, nodes) = walk(*recv.clone(), nodes, visited);
68 track |= tg;
69 nodes
70 }
71 _ => nodes,
72 }
73 } else {
74 nodes
75 };
76 visited.insert(scope.id(), track);
77 if track {
78 nodes.push(scope);
79 }
80 (track, nodes)
81}