concision_linear/impls/params/
impl_from.rs

1/*
2    Appellation: impl_from <module>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::{Biased, Features, NodeBase, Pair, ParamsBase, Unbiased};
6#[cfg(all(feature = "alloc", no_std))]
7use alloc::vec;
8use core::marker::PhantomData;
9use nd::prelude::*;
10use nd::{Data, DataMut, DataOwned, OwnedRepr, RawData, RemoveAxis};
11#[cfg(feature = "std")]
12use std::vec;
13
14impl<A, S, D, E> IntoIterator for ParamsBase<S, D, Biased>
15where
16    A: Clone,
17    D: Dimension<Smaller = E> + RemoveAxis,
18    S: Data<Elem = A>,
19    E: RemoveAxis,
20{
21    type Item = (Array<A, E>, Array<A, E::Smaller>);
22    type IntoIter = vec::IntoIter<Self::Item>;
23
24    fn into_iter(self) -> Self::IntoIter {
25        let axis = Axis(0);
26        self.weights()
27            .axis_iter(axis)
28            .zip(self.bias().axis_iter(axis))
29            .map(|(w, b)| (w.to_owned(), b.to_owned()))
30            .collect::<Vec<_>>()
31            .into_iter()
32    }
33}
34
35impl<A, S, D, E> IntoIterator for ParamsBase<S, D, Unbiased>
36where
37    A: Clone,
38    D: Dimension<Smaller = E> + RemoveAxis,
39    S: Data<Elem = A>,
40    E: RemoveAxis,
41{
42    type Item = Array<A, E>;
43    type IntoIter = vec::IntoIter<Self::Item>;
44
45    fn into_iter(self) -> Self::IntoIter {
46        self.weights()
47            .axis_iter(Axis(0))
48            .map(|w| w.to_owned())
49            .collect::<Vec<_>>()
50            .into_iter()
51    }
52}
53
54impl<A, S> FromIterator<(Array1<A>, Option<Array0<A>>)> for ParamsBase<S, Ix2>
55where
56    A: Clone + Default,
57    S: DataOwned<Elem = A> + DataMut,
58{
59    fn from_iter<I: IntoIterator<Item = (Array1<A>, Option<Array0<A>>)>>(nodes: I) -> Self {
60        let nodes = nodes.into_iter().collect::<Vec<_>>();
61        let mut iter = nodes.iter();
62        let node = iter.next().unwrap();
63        let shape = Features::new(node.0.len(), nodes.len());
64        let mut params = ParamsBase::new(shape);
65        params.set_node(0, node.clone());
66        for (i, node) in iter.into_iter().enumerate() {
67            params.set_node(i + 1, node.clone());
68        }
69        params
70    }
71}
72
73macro_rules! impl_from {
74    ($($bias:ty),*) => {
75        $(impl_from!(@impl $bias);)*
76
77    };
78    (@impl $b:ty) => {
79
80    };
81}
82
83impl_from!(ArrayBase<S, D::Smaller>);
84
85impl<A> From<(Array1<A>, A)> for ParamsBase<OwnedRepr<A>, Ix1, Biased>
86where
87    A: Clone,
88{
89    fn from((weights, bias): (Array1<A>, A)) -> Self {
90        let bias = ArrayBase::from_elem((), bias);
91        Self {
92            bias: Some(bias),
93            weight: weights,
94            _mode: PhantomData,
95        }
96    }
97}
98impl<A, K> From<(Array1<A>, Option<A>)> for ParamsBase<OwnedRepr<A>, Ix1, K>
99where
100    A: Clone,
101{
102    fn from((weights, bias): (Array1<A>, Option<A>)) -> Self {
103        Self {
104            bias: bias.map(|b| ArrayBase::from_elem((), b)),
105            weight: weights,
106            _mode: PhantomData,
107        }
108    }
109}
110
111impl<A, S, D, K> From<NodeBase<S, D, D::Smaller>> for ParamsBase<S, D, K>
112where
113    D: RemoveAxis,
114    S: RawData<Elem = A>,
115{
116    fn from((weights, bias): NodeBase<S, D, D::Smaller>) -> Self {
117        Self {
118            bias,
119            weight: weights,
120            _mode: PhantomData::<K>,
121        }
122    }
123}
124
125impl<A, S, D> From<Pair<ArrayBase<S, D>, ArrayBase<S, D::Smaller>>> for ParamsBase<S, D, Biased>
126where
127    D: RemoveAxis,
128    S: RawData<Elem = A>,
129{
130    fn from((weights, bias): Pair<ArrayBase<S, D>, ArrayBase<S, D::Smaller>>) -> Self {
131        Self {
132            bias: Some(bias),
133            weight: weights,
134            _mode: PhantomData::<Biased>,
135        }
136    }
137}