ndtensor/
utils.rs

1/*
2    Appellation: utils <module>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::prelude::{TensorExpr, TensorId};
6use crate::TensorBase;
7use nd::{Array, Dimension, IntoDimension, RawData, RawDataClone};
8use num::Float;
9use std::collections::HashMap;
10
11/// Hashes a dimension using the [DefaultHasher].
12#[cfg(feature = "std")]
13pub fn hash_dim<D>(dim: impl IntoDimension<Dim = D>) -> u64
14where
15    D: Dimension,
16{
17    use std::hash::{DefaultHasher, Hash, Hasher};
18    let dim = dim.into_dimension();
19    let mut s = DefaultHasher::new();
20    for i in dim.slice() {
21        i.hash(&mut s);
22    }
23    s.finish()
24}
25
26pub fn linarr<A, D>(dim: impl IntoDimension<Dim = D>) -> Array<A, D>
27where
28    A: Float,
29    D: Dimension,
30{
31    let dim = dim.into_dimension();
32    let dview = dim.as_array_view();
33    let n = dview.product();
34    Array::linspace(A::zero(), A::from(n).unwrap() - A::one(), n)
35        .into_shape(dim)
36        .expect("linspace err")
37}
38
39pub(crate) fn walk<S>(
40    scope: TensorBase<S>,
41    nodes: Vec<TensorBase<S>>,
42    visited: &mut HashMap<TensorId, bool>,
43) -> (bool, Vec<TensorBase<S>>)
44where
45    S: RawData + RawDataClone,
46{
47    if let Some(&tg) = visited.get(&scope.id()) {
48        return (tg, nodes);
49    }
50    // track the gradient of the current node
51    let mut track = false;
52    // recursively call on the children nodes
53    let mut nodes = if scope.is_variable() {
54        // Do not call recursively on the "leaf" nodes.
55        track = true;
56        nodes
57    } else if let Some(op) = scope.op() {
58        match op {
59            TensorExpr::Binary { lhs, rhs, .. } => {
60                let (tg, nodes) = walk(*lhs.clone(), nodes, visited);
61                track |= tg;
62                let (tg, nodes) = walk(*rhs.clone(), nodes, visited);
63                track |= tg;
64                nodes
65            }
66            TensorExpr::Unary { recv, .. } => {
67                let (tg, nodes) = walk(*recv.clone(), nodes, visited);
68                track |= tg;
69                nodes
70            }
71            _ => nodes,
72        }
73    } else {
74        nodes
75    };
76    visited.insert(scope.id(), track);
77    if track {
78        nodes.push(scope);
79    }
80    (track, nodes)
81}