block_graph/
ops.rs

1				// TODO macro primitive implementations
2impl Abs for f32{// TODO implement operations for result types
3	fn abs(self)->Self::Output{f32::abs(self)}
4	type Output=f32;
5}
6impl Rank for f32{
7	fn dynamic_rank(&self)->usize{0}
8	fn type_rank()->usize{0}
9}
10impl Squeeze for Vec<f32>{
11	fn squeeze(self,dim:i32)->Self::Output{
12		if dim!=-1&&dim!=0{panic!("squeeze dim out of bounds")}
13		if self.len()!=1{panic!("cannot squeeze a dim whose size is not 1")}
14		self[0]
15	}
16	type Output=f32;
17}
18impl SquaredError for f32{
19	fn squared_error(self,rhs:f32)->Self::Output{
20		let d=self-rhs;
21		d*d
22	}
23	type Output=f32;
24}
25impl Unsqueeze for f32{
26	fn unsqueeze(self,dim:i32)->UnsqueezeScalar<f32>{
27		if dim==-1||dim==0{UnsqueezeScalar(self)}else{panic!("unsqueeze dim out of bounds")}
28	}
29	type Output=UnsqueezeScalar<f32>;
30}
31impl<T:Rank> Rank for Vec<T>{
32	fn dynamic_rank(&self)->usize{self.first().map(T::dynamic_rank).unwrap_or_else(T::type_rank)+1}
33	fn type_rank()->usize{T::type_rank()+1}
34}
35impl<T:Squeeze> Squeeze for Vec<Vec<T>> where Vec<T>:Squeeze<Output=T>+Rank{
36	fn squeeze(self,mut dim:i32)->Self::Output{
37		let rank=self.rank() as i32;
38
39		if !(-rank..rank).contains(&dim){panic!("squeeze dim out of bounds")}
40		if dim==0||dim==-rank{
41			if self.len()!=1{panic!("cannot squeeze a dim whose size is not 1")}
42			self.into_iter().next().unwrap()
43		}else{
44			if dim>0{dim-=1}
45			self.into_iter().map(|x|x.squeeze(dim)).collect()
46		}
47	}
48	type Output=Vec<T>;
49}
50impl<T:Unsqueeze<Output=U>,U> Stack for Vec<T> where Vec<U>:Cat<Output=Vec<T>>{
51	fn stack(self,dim:i32)->Self::Output{
52		let unsqueezed:Vec<U>=self.into_iter().map(|x|x.unsqueeze(dim)).collect();
53		unsqueezed.cat(dim)
54	}
55	type Output=Self;
56}
57impl<T:Unsqueeze> Unsqueeze for Vec<T> where T::Output:Into<Vec<T>>,Vec<T>:Rank{
58	fn unsqueeze(self,mut dim:i32)->Self::Output{
59		let rank=self.rank() as i32;
60
61		if !(-rank..rank+1).contains(&dim){panic!("unsqueeze dim out of bounds")}
62		if dim==0||dim==-rank{return vec![self]}else if dim>0{dim-=1}
63		self.into_iter().map(|x|x.unsqueeze(dim).into()).collect()
64	}
65	type Output=Vec<Vec<T>>;
66}
67impl<T> From<UnsqueezeScalar<T>> for Vec<T>{
68	fn from(value:UnsqueezeScalar<T>)->Vec<T>{vec![value.0]}
69}
70#[derive(Clone,Copy,Debug,Default,Eq,Hash,Ord,PartialEq,PartialOrd)]
71/// unsqueezed scalar that can be converted to vector type
72pub struct UnsqueezeScalar<T>(pub T);
73/// trait to represent the operation
74pub trait Abs{
75	/// macro convenience version of the primary method
76	fn _apply(self)->Self::Output where Self:Sized{self.abs()}
77	/// computes the operation
78	fn abs(self)->Self::Output;
79	/// the output type
80	type Output;
81}
82/// trait to represent the operation
83pub trait Cat{
84	/// macro convenience version of the primary method
85	fn _apply(self,dim:i32)->Self::Output where Self:Sized{self.cat(dim)}
86	/// concatenates the data along the given axis
87	fn cat(self,dim:i32)->Self::Output;
88	/// the output type
89	type Output;
90}
91// flatten
92pub trait Flatten<R>{
93	/// macro convenience version of the primary method
94	fn _apply(self,args:R)->Self::Output where Self:Sized{self.flatten(args)}
95	/// flattens
96	fn flatten(self,args:R)->Self::Output;
97	/// the output type
98	type Output;
99}
100/// get tensor rank
101pub trait Rank{
102	/// gets the rank
103	fn dynamic_rank(&self)->usize;
104	/// gets the rank
105	fn rank(&self)->usize{self.dynamic_rank()}
106	/// gets the rank at a type level. this may be some kind of default if there isn't a clear rank associated with the type
107	fn type_rank()->usize where Self:Sized;
108}
109// reshape
110pub trait Reshape<R>{
111	/// macro convenience version of the primary method
112	fn _apply(self,args:R)->Self::Output where Self:Sized{self.reshape(args)}
113	/// reshapes
114	fn reshape(self,args:R)->Self::Output;
115	/// the output type
116	type Output;
117}
118/// trait to represent the operation
119pub trait Squeeze{
120	/// macro convenience version of the primary method
121	fn _apply(self,dim:i32)->Self::Output where Self:Sized{self.squeeze(dim)}
122	/// computes the operation
123	fn squeeze(self,dim:i32)->Self::Output;
124	/// the output type
125	type Output;
126}
127/// trait to represent the operation
128pub trait SwapDims{
129	/// macro convenience version of the primary method
130	fn _apply(self,a:i32,b:i32)->Self::Output where Self:Sized{self.swap_dims(a,b)}
131	/// computes the operation
132	fn swap_dims(self,a:i32,b:i32)->Self::Output;
133	/// the output type
134	type Output;
135}
136/// trait to represent the operation
137pub trait SquaredError<R=Self>{
138	/// macro convenience version of the primary method
139	fn _apply(self,rhs:R)->Self::Output where Self:Sized{self.squared_error(rhs)}
140	/// computes the operation
141	fn squared_error(self,rhs:R)->Self::Output;
142	/// the output type
143	type Output;
144}
145/// trait to represent the operation
146pub trait Stack{
147	/// macro convenience version of the primary method
148	fn _apply(self,dim:i32)->Self::Output where Self:Sized{self.stack(dim)}
149	/// stacks the data along the given axis
150	fn stack(self,dim:i32)->Self::Output;
151	/// the output type
152	type Output;
153}
154/// trait to represent the operation
155pub trait Unsqueeze{
156	/// macro convenience version of the primary method
157	fn _apply(self,dim:i32)->Self::Output where Self:Sized{self.unsqueeze(dim)}
158	/// computes the operation
159	fn unsqueeze(self,dim:i32)->Self::Output;
160	/// the output type
161	type Output;
162}