block_graph/burn/
shape.rs

1impl AsRef<Self> for Shape{//TODO more reref stuff
2	fn as_ref(&self)->&Self{self}
3}
4impl Default for Reshape{
5	fn default()->Self{Self::Recursive(Vec::new())}
6}
7impl Default for Shape{
8	fn default()->Self{Self::Recursive(Vec::new())}
9}
10impl From<[usize;1]> for Reshape{
11	fn from(dims:[usize;1])->Self{R1(dims.map(|d|d as isize))}
12}
13impl From<[usize;2]> for Reshape{
14	fn from(dims:[usize;2])->Self{R2(dims.map(|d|d as isize))}
15}
16impl From<[usize;3]> for Reshape{
17	fn from(dims:[usize;3])->Self{R3(dims.map(|d|d as isize))}
18}
19impl From<[usize;4]> for Reshape{
20	fn from(dims:[usize;4])->Self{R4(dims.map(|d|d as isize))}
21}
22impl From<[usize;5]> for Reshape{
23	fn from(dims:[usize;5])->Self{R5(dims.map(|d|d as isize))}
24}
25impl From<[usize;6]> for Reshape{
26	fn from(dims:[usize;6])->Self{R6(dims.map(|d|d as isize))}
27}
28impl From<[usize;7]> for Reshape{
29	fn from(dims:[usize;7])->Self{R7(dims.map(|d|d as isize))}
30}
31impl From<[usize;8]> for Reshape{
32	fn from(dims:[usize;8])->Self{R8(dims.map(|d|d as isize))}
33}
34impl Reshape{
35	/// counts the recursive depth
36	pub fn depth(&self)->usize{
37		match self{
38			R1(_)=>1,
39			R2(_)=>1,
40			R3(_)=>1,
41			R4(_)=>1,
42			R5(_)=>1,
43			R6(_)=>1,
44			R7(_)=>1,
45			R8(_)=>1,
46			Reshape::Recursive(v)=>v.iter().map(Reshape::depth).max().unwrap_or(0)
47		}
48	}
49	/// converts to the eight dimensional array type by extending with ones. The original data will be placed according to the alignment. Multi and incompatible types will be all ones
50	pub fn to_array(self,alignment:Alignment)->[isize;8]{
51		let mut result=[1;8];
52		let slice=match &self{R1(x)=>x.as_slice(),R2(x)=>x.as_slice(),R3(x)=>x.as_slice(),R4(x)=>x.as_slice(),R5(x)=>x.as_slice(),R6(x)=>x.as_slice(),R7(x)=>x.as_slice(),R8(x)=>x.as_slice(),Reshape::Recursive(_r)=>return result};
53		let l=slice.len();
54		match alignment{Alignment::Center=>result[4-l/2..][..l].copy_from_slice(slice),Alignment::Left=>result[..l].copy_from_slice(slice),Alignment::Right=>result[8-l..].copy_from_slice(slice)}
55		result
56	}
57}
58impl Shape{
59	/// counts the number of components if possible. returns none if incompatible or if a non recursive multi shape of more than 0 tensors
60	pub fn count(&self)->Option<usize>{
61		match self{
62			Shape::Incompatible(_e)=>None,
63			Shape::Multi(n)=>if *n==0{Some(0)}else{None},
64			Shape::Recursive(v)=>{
65				let mut s=0;
66				for v in v{s+=v.count()?}
67				Some(s)
68			},
69			X1(x)=>Some(x.iter().product()),
70			X2(x)=>Some(x.iter().product()),
71			X3(x)=>Some(x.iter().product()),
72			X4(x)=>Some(x.iter().product()),
73			X5(x)=>Some(x.iter().product()),
74			X6(x)=>Some(x.iter().product()),
75			X7(x)=>Some(x.iter().product()),
76			X8(x)=>Some(x.iter().product())
77		}
78	}
79	/// converts to the eight dimensional array type by extending with ones. The original data will be placed according to the alignment. Multi and incompatible types will be all ones
80	pub fn to_array(self,alignment:Alignment)->[usize;8]{
81		let mut result=[1;8];
82		let slice=match &self{Shape::Incompatible(_e)=>return result,Shape::Multi(_v)=>return result,Shape::Recursive(_r)=>return result,X1(x)=>x.as_slice(),X2(x)=>x.as_slice(),X3(x)=>x.as_slice(),X4(x)=>x.as_slice(),X5(x)=>x.as_slice(),X6(x)=>x.as_slice(),X7(x)=>x.as_slice(),X8(x)=>x.as_slice()};
83		let l=slice.len();
84		match alignment{Alignment::Center=>result[4-l/2..][..l].copy_from_slice(slice),Alignment::Left=>result[..l].copy_from_slice(slice),Alignment::Right=>result[8-l..].copy_from_slice(slice)}
85		result
86	}
87}
88#[derive(Clone,Copy,Debug,Eq,PartialEq,Deserialize,Serialize)]
89/// enumerates kinds for values
90pub enum Kind{Bool,Float,Incompatible,Int,Multi}
91#[derive(Clone,Debug,Deserialize,Serialize)]
92/// value reshaping arguments
93pub enum Reshape{R1([isize;1]),R2([isize;2]),R3([isize;3]),R4([isize;4]),R5([isize;5]),R6([isize;6]),R7([isize;7]),R8([isize;8]),Recursive(Vec<Reshape>)}
94#[derive(Clone,Debug,Deserialize,Serialize)]// TODO eq that doesn't include the payload of incompatible
95/// tensor shapes for Value
96pub enum Shape{Incompatible(String),Multi(usize),Recursive(Vec<Shape>),X1([usize;1]),X2([usize;2]),X3([usize;3]),X4([usize;4]),X5([usize;5]),X6([usize;6]),X7([usize;7]),X8([usize;8])}
97use Reshape::{R1,R2,R3,R4,R5,R6,R7,R8};
98use Shape::{X1,X2,X3,X4,X5,X6,X7,X8};
99use crate::builtin::Alignment;
100use serde::{Deserialize,Serialize};