1cat_like!(CatLayer,Cat);
2cat_like!(SqueezeLayer,Squeeze);
3cat_like!(StackLayer,Stack);
4cat_like!(UnsqueezeLayer,Unsqueeze);
5flat_like!(FlattenLayer,Flatten);
6flat_like!(ReshapeLayer,Reshape);
7macro_rules! cat_like{
9 (@ai $layer:ident,$wrap:ident)=>{
10 impl<A:AI<X,Y>+Op<Output=Y>,X,Y,Z> AI<X,Z> for $wrap<A> where $layer:AI<Y,Z>{
11 fn forward(&self,input:X)->Z{self.layer.forward(self.inner.forward(input))}
12 fn forward_mut(&mut self,input:X)->Z{self.layer.forward_mut(self.inner.forward_mut(input))}
13 }
14 impl<X:crate::ops::$wrap,Y> AI<X,Y> for $layer where X::Output:Into<Y>{
15 fn forward(&self,input:X)->Y{input._apply(self.dim).into()}
16 }
17 };
18 (@declare $layer:ident,$wrap:ident)=>{
19 #[derive(Clone,Copy,Debug,Default,Deserialize,Eq,Hash,PartialEq,Serialize)]
20 pub struct $layer{dim:i32}
22 #[derive(Clone,Copy,Debug,Default,Deserialize,Eq,Hash,PartialEq,Serialize)]
23 pub struct $wrap<A>{inner:A,layer:$layer}
25 };
26 (@decompose $layer:ident,$wrap:ident)=>{
27 impl Decompose for $layer{
28 fn compose(dim:Self::Decomposition)->Self{
29 Self{dim}
30 }
31 fn decompose(self)->Self::Decomposition{self.dim}
32 fn decompose_cloned(&self)->Self::Decomposition{self.dim}
33 type Decomposition=i32;
34 }
35 impl<A:Decompose> Decompose for $wrap<A>{
36 fn compose((inner,layer):Self::Decomposition)->Self{
37 Self{inner:A::compose(inner),layer:$layer::compose(layer)}
38 }
39 fn decompose(self)->Self::Decomposition{(self.inner.decompose(),self.layer.decompose())}
40 fn decompose_cloned(&self)->Self::Decomposition{(self.inner.decompose_cloned(),self.layer.decompose_cloned())}
41 type Decomposition=(A::Decomposition,<$layer as Decompose>::Decomposition);
42 }
43 };
44 (@impl $layer:ident,$wrap:ident)=>{
45 impl $layer{
46 pub fn get_dim(&self)->i32{self.dim}
48 pub fn new(dim:i32)->Self{Self::default().with_dim(dim)}
50 pub fn set_dim(&mut self,dim:i32){self.dim=dim}
52 pub fn with_dim(mut self,dim:i32)->Self{
54 self.dim=dim;
55 self
56 }
57 }
58 impl<A:UnwrapInner> UnwrapInner for $wrap<A>{
59 fn unwrap_inner(self)->A::Inner{self.into_inner().unwrap_inner()}
60 type Inner=A::Inner;
61 }
62 impl<A> $wrap<A>{
63 pub fn get_dim(&self)->i32{self.layer.dim}
64 pub fn inner(&self)->&A{&self.inner}
66 pub fn inner_mut(&mut self)->&mut A{&mut self.inner}
68 pub fn into_inner(self)->A{self.inner}
70 pub fn new(dim:i32,inner:A)->Self where Self:Op{
72 Self{inner,layer:$layer::new(dim)}
73 }
74 pub fn set_dim(&mut self,dim:i32){self.layer.dim=dim}
76 pub fn with_dim(mut self,dim:i32)->Self{
78 self.layer.dim=dim;
79 self
80 }
81 pub fn with_inner<B>(self,inner:B)->$wrap<B> where $wrap<B>:Op{
83 $wrap{inner,layer:self.layer}
84 }
85 }
86 impl<M:AI<M::Output,M::Output>+Op> IntoSequence<M> for $layer where $layer:Into<M>{
87 fn into_sequence(self)->Sequential<Vec<M>>{vec![self.into()].sequential()}
88 }
89 };
90 (@op $layer:ident,$wrap:ident)=>{
91 impl Op for $layer{
92 type Output=Vec<f32>;
93 }
94 impl<A:Op<Output=Y>,Y:crate::ops::$wrap<Output=Z>,Z> Op for $wrap<A> where $layer:AI<Y,Z>{
95 type Output=Z;
96 }
97 };
98 ($layer:ident,$wrap:ident)=>{
99 cat_like!(@ai @declare @decompose @impl @op $layer,$wrap);
100 };
101 ($(@$command:tt)* $layer:ident,$wrap:ident)=>{
102 $(cat_like!(@$command $layer,$wrap);)*
103 };
104}
105macro_rules! flat_like{
107 (@ai $layer:ident,$wrap:ident)=>{
108 impl<A:AI<X,Y>+Op<Output=Y>,R:Clone,X,Y,Z> AI<X,Z> for $wrap<A,R> where $layer<R>:AI<Y,Z>{
109 fn forward(&self,input:X)->Z{self.layer.forward(self.inner.forward(input))}
110 fn forward_mut(&mut self,input:X)->Z{self.layer.forward_mut(self.inner.forward_mut(input))}
111 }
112 impl<X:crate::ops::$wrap<R>,R:Clone,Y> AI<X,Y> for $layer<R> where X::Output:Into<Y>{
113 fn forward(&self,input:X)->Y{input._apply(self.args.clone()).into()}
114 }
115 };
116 (@declare $layer:ident,$wrap:ident)=>{
117 #[derive(Clone,Copy,Debug,Default,Deserialize,Eq,Hash,PartialEq,Serialize)]
118 pub struct $layer<R:Clone>{args:R}
120 #[derive(Clone,Copy,Debug,Default,Deserialize,Eq,Hash,PartialEq,Serialize)]
121 pub struct $wrap<A,R:Clone>{inner:A,layer:$layer<R>}
123 };
124 (@decompose $layer:ident,$wrap:ident)=>{
125 impl<R:Clone+Decompose> Decompose for $layer<R>{
126 fn compose(args:Self::Decomposition)->Self{
127 Self{args:R::compose(args)}
128 }
129 fn decompose(self)->Self::Decomposition{self.args.decompose()}
130 fn decompose_cloned(&self)->Self::Decomposition{self.args.decompose_cloned()}
131 type Decomposition=R::Decomposition;
132 }
133 impl<A:Decompose,R:Clone+Decompose> Decompose for $wrap<A,R>{
134 fn compose((inner,layer):Self::Decomposition)->Self{
135 Self{inner:A::compose(inner),layer:$layer::compose(layer)}
136 }
137 fn decompose(self)->Self::Decomposition{(self.inner.decompose(),self.layer.decompose())}
138 fn decompose_cloned(&self)->Self::Decomposition{(self.inner.decompose_cloned(),self.layer.decompose_cloned())}
139 type Decomposition=(A::Decomposition,<$layer<R> as Decompose>::Decomposition);
140 }
141 };
142 (@impl $layer:ident,$wrap:ident)=>{
143 impl<R:Clone> $layer<R>{
144 pub fn args(&self)->&R{&self.args}
146 pub fn args_mut(&mut self)->&mut R{&mut self.args}
148 pub fn new(args:R)->Self{
150 Self{args}
151 }
152 pub fn with_args(mut self,args:R)->Self{
154 self.args=args;
155 self
156 }
157 }
158 impl<A:UnwrapInner,R:Clone> UnwrapInner for $wrap<A,R>{
159 fn unwrap_inner(self)->A::Inner{self.into_inner().unwrap_inner()}
160 type Inner=A::Inner;
161 }
162 impl<A,R:Clone> $wrap<A,R>{
163 pub fn args(&self)->&R{&self.layer.args}
165 pub fn args_mut(&mut self)->&mut R{&mut self.layer.args}
167 pub fn inner(&self)->&A{&self.inner}
169 pub fn inner_mut(&mut self)->&mut A{&mut self.inner}
171 pub fn into_inner(self)->A{self.inner}
173 pub fn new(args:R,inner:A)->Self where Self:Op{
175 Self{inner,layer:$layer::new(args)}
176 }
177 pub fn with_args(mut self,args:R)->Self{
179 self.layer.args=args;
180 self
181 }
182 pub fn with_inner<B>(self,inner:B)->$wrap<B,R> where $wrap<B,R>:Op{
184 $wrap{inner,layer:self.layer}
185 }
186 }
187 impl<M:AI<M::Output,M::Output>+Op,R:Clone> IntoSequence<M> for $layer<R> where $layer<R>:Into<M>{
188 fn into_sequence(self)->Sequential<Vec<M>>{vec![self.into()].sequential()}
189 }
190 };
191 (@op $layer:ident,$wrap:ident)=>{
192 impl<R:Clone> Op for $layer<R>{
193 type Output=Vec<f32>;
194 }
195 impl<A:Op<Output=Y>,R:Clone,Y:crate::ops::$wrap<R,Output=Z>,Z> Op for $wrap<A,R> where $layer<R>:AI<Y,Z>{
196 type Output=Z;
197 }
198 };
199 ($layer:ident,$wrap:ident)=>{
200 flat_like!(@ai @declare @decompose @impl @op $layer,$wrap);
201 };
202 ($(@$command:tt)* $layer:ident,$wrap:ident)=>{
203 $(flat_like!(@$command $layer,$wrap);)*
204 };
205}
206use {cat_like,flat_like};
207use crate::{AI,Decompose,IntoSequence,Op,UnwrapInner,builtin::Sequential};
208use serde::{Deserialize,Serialize};