block_graph/burn/
shape.rs

1impl AsRef<Self> for Shape{//TODO more reref stuff
2	fn as_ref(&self)->&Self{self}
3}
4impl Default for Reshape{
5	fn default()->Self{Self::Recursive(Vec::new())}
6}
7impl Default for Shape{
8	fn default()->Self{Self::Recursive(Vec::new())}
9}
10impl From<[isize;1]> for Reshape{
11	fn from(dims:[isize;1])->Self{R1(dims)}
12}
13impl From<[isize;2]> for Reshape{
14	fn from(dims:[isize;2])->Self{R2(dims)}
15}
16impl From<[isize;3]> for Reshape{
17	fn from(dims:[isize;3])->Self{R3(dims)}
18}
19impl From<[isize;4]> for Reshape{
20	fn from(dims:[isize;4])->Self{R4(dims)}
21}
22impl From<[isize;5]> for Reshape{
23	fn from(dims:[isize;5])->Self{R5(dims)}
24}
25impl From<[isize;6]> for Reshape{
26	fn from(dims:[isize;6])->Self{R6(dims)}
27}
28impl From<[isize;7]> for Reshape{
29	fn from(dims:[isize;7])->Self{R7(dims)}
30}
31impl From<[isize;8]> for Reshape{
32	fn from(dims:[isize;8])->Self{R8(dims)}
33}
34impl From<[usize;1]> for Reshape{
35	fn from(dims:[usize;1])->Self{R1(dims.map(|d|d as isize))}
36}
37impl From<[usize;2]> for Reshape{
38	fn from(dims:[usize;2])->Self{R2(dims.map(|d|d as isize))}
39}
40impl From<[usize;3]> for Reshape{
41	fn from(dims:[usize;3])->Self{R3(dims.map(|d|d as isize))}
42}
43impl From<[usize;4]> for Reshape{
44	fn from(dims:[usize;4])->Self{R4(dims.map(|d|d as isize))}
45}
46impl From<[usize;5]> for Reshape{
47	fn from(dims:[usize;5])->Self{R5(dims.map(|d|d as isize))}
48}
49impl From<[usize;6]> for Reshape{
50	fn from(dims:[usize;6])->Self{R6(dims.map(|d|d as isize))}
51}
52impl From<[usize;7]> for Reshape{
53	fn from(dims:[usize;7])->Self{R7(dims.map(|d|d as isize))}
54}
55impl From<[usize;8]> for Reshape{
56	fn from(dims:[usize;8])->Self{R8(dims.map(|d|d as isize))}
57}
58impl Reshape{
59	/// counts the recursive depth
60	pub fn depth(&self)->usize{
61		match self{
62			R1(_)=>1,
63			R2(_)=>1,
64			R3(_)=>1,
65			R4(_)=>1,
66			R5(_)=>1,
67			R6(_)=>1,
68			R7(_)=>1,
69			R8(_)=>1,
70			Reshape::Recursive(v)=>v.iter().map(Reshape::depth).max().unwrap_or(0)
71		}
72	}
73	/// converts to the eight dimensional array type by extending with ones. The original data will be placed according to the alignment. Multi and incompatible types will be all ones
74	pub fn to_array(self,alignment:Alignment)->[isize;8]{
75		let mut result=[1;8];
76		let slice=match &self{R1(x)=>x.as_slice(),R2(x)=>x.as_slice(),R3(x)=>x.as_slice(),R4(x)=>x.as_slice(),R5(x)=>x.as_slice(),R6(x)=>x.as_slice(),R7(x)=>x.as_slice(),R8(x)=>x.as_slice(),Reshape::Recursive(_r)=>return result};
77		let l=slice.len();
78		match alignment{Alignment::Center=>result[4-l/2..][..l].copy_from_slice(slice),Alignment::Left=>result[..l].copy_from_slice(slice),Alignment::Right=>result[8-l..].copy_from_slice(slice)}
79		result
80	}
81}
82impl Shape{
83	/// counts the number of components if possible. returns none if incompatible or if a non recursive multi shape of more than 0 tensors
84	pub fn count(&self)->Option<usize>{
85		match self{
86			Shape::Incompatible(_e)=>None,
87			Shape::Multi(n)=>if *n==0{Some(0)}else{None},
88			Shape::Recursive(v)=>{
89				let mut s=0;
90				for v in v{s+=v.count()?}
91				Some(s)
92			},
93			X1(x)=>Some(x.iter().product()),
94			X2(x)=>Some(x.iter().product()),
95			X3(x)=>Some(x.iter().product()),
96			X4(x)=>Some(x.iter().product()),
97			X5(x)=>Some(x.iter().product()),
98			X6(x)=>Some(x.iter().product()),
99			X7(x)=>Some(x.iter().product()),
100			X8(x)=>Some(x.iter().product())
101		}
102	}
103	/// converts to the eight dimensional array type by extending with ones. The original data will be placed according to the alignment. Multi and incompatible types will be all ones
104	pub fn to_array(self,alignment:Alignment)->[usize;8]{
105		let mut result=[1;8];
106		let slice=match &self{Shape::Incompatible(_e)=>return result,Shape::Multi(_v)=>return result,Shape::Recursive(_r)=>return result,X1(x)=>x.as_slice(),X2(x)=>x.as_slice(),X3(x)=>x.as_slice(),X4(x)=>x.as_slice(),X5(x)=>x.as_slice(),X6(x)=>x.as_slice(),X7(x)=>x.as_slice(),X8(x)=>x.as_slice()};
107		let l=slice.len();
108		match alignment{Alignment::Center=>result[4-l/2..][..l].copy_from_slice(slice),Alignment::Left=>result[..l].copy_from_slice(slice),Alignment::Right=>result[8-l..].copy_from_slice(slice)}
109		result
110	}
111}
112#[derive(Clone,Copy,Debug,Eq,PartialEq,Deserialize,Serialize)]
113/// enumerates kinds for values
114pub enum Kind{Bool,Float,Incompatible,Int,Multi}
115#[derive(Clone,Debug,Deserialize,Serialize)]
116/// value reshaping arguments
117pub enum Reshape{R1([isize;1]),R2([isize;2]),R3([isize;3]),R4([isize;4]),R5([isize;5]),R6([isize;6]),R7([isize;7]),R8([isize;8]),Recursive(Vec<Reshape>)}
118#[derive(Clone,Debug,Deserialize,Serialize)]// TODO eq that doesn't include the payload of incompatible
119/// tensor shapes for Value
120pub enum Shape{Incompatible(String),Multi(usize),Recursive(Vec<Shape>),X1([usize;1]),X2([usize;2]),X3([usize;3]),X4([usize;4]),X5([usize;5]),X6([usize;6]),X7([usize;7]),X8([usize;8])}
121use Reshape::{R1,R2,R3,R4,R5,R6,R7,R8};
122use Shape::{X1,X2,X3,X4,X5,X6,X7,X8};
123use crate::builtin::Alignment;
124use serde::{Deserialize,Serialize};