1impl Decompose for ClassificationLayer{
2 fn compose(_decomposition:Self::Decomposition)->Self{Self::default()}
3 fn decompose(self){}
4 fn decompose_cloned(&self){}
5 type Decomposition=();
6}
7impl Decompose for RegressionLayer{
8 fn compose(_decomposition:Self::Decomposition)->Self{Self::default()}
9 fn decompose(self){}
10 fn decompose_cloned(&self){}
11 type Decomposition=();
12}
13impl MetricsRenderer for DontRender{
14 fn update_train(&mut self,_state:MetricState){}
15 fn update_valid(&mut self,_state:MetricState){}
16 fn render_train(&mut self,_item:TrainingProgress){}
17 fn render_valid(&mut self,_item:TrainingProgress){}
18}
19impl Op for ClassificationLayer{
20 type Output=ClassificationOutput<NdArray>;
21}
22impl Op for RegressionLayer{
23 type Output=RegressionOutput<NdArray>;
24}
25impl<A:AI<X,LossOutput<B>>,B:Backend,X> AI<X,ClassificationOutput<B>> for Classification<A>{
26 fn forward(&self,input:X)->ClassificationOutput<B>{self.layer.forward(self.inner.forward(input))}
27 fn forward_mut(&mut self,input:X)->ClassificationOutput<B>{self.layer.forward(self.inner.forward_mut(input))}
28}
29impl<A:AI<X,LossOutput<B>>,B:Backend,X> AI<X,RegressionOutput<B>> for Regression<A>{
30 fn forward(&self,input:X)->RegressionOutput<B>{self.layer.forward(self.inner.forward(input))}
31 fn forward_mut(&mut self,input:X)->RegressionOutput<B>{self.layer.forward(self.inner.forward_mut(input))}
32}
33impl<A:AutodiffBackend<InnerBackend=B>,B:Backend,W:'static+Wrappable<B=A>,Y:'static+ItemLazy+Send+Sync,Z:'static+ItemLazy+Send+Sync> Wrapped<W> where <Self as AutodiffModule<A>>::InnerModule:ValidStep<(Value<B>,Value<B>),Z>,Self:TrainStep<(Value<A>,Value<A>),Y>,W::Decomposition:AutodiffModule<A>,W::With<B>:Decompose<Decomposition=<W::Decomposition as AutodiffModule<A>>::InnerModule>+Op<Output=Z>,W:Op<Output=Y>,Y::ItemSync:Adaptor<LossInput<NdArray>>,Z::ItemSync:Adaptor<LossInput<NdArray>>{
34 pub fn train<I:'static+Clone+Debug+Into<(Value<A>,Value<A>)>+Send+Sync,J:'static+Clone+Debug+Into<(Value<B>,Value<B>)>+Send+Sync,O:Optimizer<Self,A>,S:LrScheduler,T:'static+Dataset<I>,V:'static+Dataset<J>>(self,config:&TrainConfig,optimizer:O,scheduler:S,train:T,valid:V)->Self where O::Record:'static,S::Record<A>:'static{
36 let batcher=BatchStacker;
37 let trainloader=DataLoaderBuilder::new(batcher).batch_size(config.batch_size).shuffle(random()).num_workers(config.workers).build(train);
38 let validloader=DataLoaderBuilder::new(batcher).batch_size(config.batch_size).shuffle(random()).num_workers(config.workers).build(valid);
39
40 create_folder(&config.artifact_directory).unwrap();
41 let builder=LearnerBuilder::new(&config.artifact_directory).metric_train_numeric(LossMetric::new()).metric_valid_numeric(LossMetric::new());
42 let builder=if config.checkpoints{builder.with_file_checkpointer(CompactRecorder::new())}else{builder};
43 let builder=if config.console_rendering{builder}else{builder.renderer(DontRender)};
44 let builder=builder.devices(vec![<W::B as Backend>::Device::default()]).num_epochs(config.epochs);
45 let builder=if config.summary{builder.summary()}else{builder};
46 let learner=builder.build(self,optimizer,scheduler);
47 learner.fit(trainloader,validloader)
48 }
49}
50impl<A:AutodiffBackend,W:AI<X,LossOutput<A>>+Wrappable<B=A>,X> TrainStep<X,ClassificationOutput<A>> for Wrapped<Classification<W>> where W::Decomposition:AutodiffModule<A>,W::With<A::InnerBackend>:Decompose<Decomposition=<W::Decomposition as AutodiffModule<A>>::InnerModule>{
51 fn step(&self,item:X)->TrainOutput<ClassificationOutput<A>>{
52 let output:ClassificationOutput<A>=self.forward(item);
53 TrainOutput::new(self,output.loss.backward(),output)
54 }
55}
56impl<A:AutodiffBackend,W:AI<X,LossOutput<A>>+Wrappable<B=A>,X> TrainStep<X,RegressionOutput<A>> for Wrapped<Regression<W>> where W::Decomposition:AutodiffModule<A>,W::With<A::InnerBackend>:Decompose<Decomposition=<W::Decomposition as AutodiffModule<A>>::InnerModule>{
57 fn step(&self,item:X)->TrainOutput<RegressionOutput<A>>{
58 let output:RegressionOutput<A>=self.forward(item);
59 TrainOutput::new(self,output.loss.backward(),output)
60 }
61}
62impl<A:AutodiffBackend,W:Wrappable<B=A>> AutodiffModule<A> for Wrapped<W> where W::Decomposition:AutodiffModule<A>,W::With<A::InnerBackend>:Decompose<Decomposition=<W::Decomposition as AutodiffModule<A>>::InnerModule>{
63 fn valid(&self)->Self::InnerModule{Wrapped::new(Decompose::compose(self.inner.decompose_cloned().valid()))}
64 type InnerModule=Wrapped<W::With<A::InnerBackend>>;
65}
66impl<A:Decompose> Decompose for Classification<A>{
67 fn compose(decomposition:Self::Decomposition)->Self{
68 Self{inner:A::compose(decomposition),layer:Default::default()}
69 }
70 fn decompose(self)->Self::Decomposition{self.inner.decompose()}
71 fn decompose_cloned(&self)->Self::Decomposition{self.inner.decompose_cloned()}
72 type Decomposition=A::Decomposition;
73}
74impl<A:Decompose> Decompose for Regression<A>{
75 fn compose(decomposition:Self::Decomposition)->Self{
76 Self{inner:A::compose(decomposition),layer:Default::default()}
77 }
78 fn decompose(self)->Self::Decomposition{self.inner.decompose()}
79 fn decompose_cloned(&self)->Self::Decomposition{self.inner.decompose_cloned()}
80 type Decomposition=A::Decomposition;
81}
82impl<A:Op<Output=Y>+Wrappable,Y> Op for Classification<A> where ClassificationLayer:AI<Y,ClassificationOutput<A::B>>{
83 type Output=ClassificationOutput<A::B>;
84}
85impl<A:Op<Output=Y>+Wrappable,Y> Op for Regression<A> where RegressionLayer:AI<Y,RegressionOutput<A::B>>{
86 type Output=RegressionOutput<A::B>;
87}
88impl<A:UnwrapInner> UnwrapInner for Classification<A>{
89 fn unwrap_inner(self)->Self::Inner{self.into_inner().unwrap_inner()}
90 type Inner=A::Inner;
91}
92impl<A:UnwrapInner> UnwrapInner for Regression<A>{
93 fn unwrap_inner(self)->Self::Inner{self.into_inner().unwrap_inner()}
94 type Inner=A::Inner;
95}
96impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>> Wrappable for (A,D){
97 type B=B;
98 type With<C:Backend>=(A::With<C>,D::With<C>);
99}
100impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>,E:Wrappable<B=B>> Wrappable for (A,D,E){
101 type B=B;
102 type With<C:Backend>=(A::With<C>,D::With<C>,E::With<C>);
103}
104impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>,E:Wrappable<B=B>,F:Wrappable<B=B>> Wrappable for (A,D,E,F){
105 type B=B;
106 type With<C:Backend>=(A::With<C>,D::With<C>,E::With<C>,F::With<C>);
107}
108impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>,E:Wrappable<B=B>,F:Wrappable<B=B>,G:Wrappable<B=B>> Wrappable for (A,D,E,F,G){
109 type B=B;
110 type With<C:Backend>=(A::With<C>,D::With<C>,E::With<C>,F::With<C>,G::With<C>);
111}
112impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>,E:Wrappable<B=B>,F:Wrappable<B=B>,G:Wrappable<B=B>,H:Wrappable<B=B>> Wrappable for (A,D,E,F,G,H){
113 type B=B;
114 type With<C:Backend>=(A::With<C>,D::With<C>,E::With<C>,F::With<C>,G::With<C>,H::With<C>);
115}
116impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>,E:Wrappable<B=B>,F:Wrappable<B=B>,G:Wrappable<B=B>,H:Wrappable<B=B>,I:Wrappable<B=B>> Wrappable for (A,D,E,F,G,H,I){
117 type B=B;
118 type With<C:Backend>=(A::With<C>,D::With<C>,E::With<C>,F::With<C>,G::With<C>,H::With<C>,I::With<C>);
119}
120impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>,E:Wrappable<B=B>,F:Wrappable<B=B>,G:Wrappable<B=B>,H:Wrappable<B=B>,I:Wrappable<B=B>,J:Wrappable<B=B>> Wrappable for (A,D,E,F,G,H,I,J){
121 type B=B;
122 type With<C:Backend>=(A::With<C>,D::With<C>,E::With<C>,F::With<C>,G::With<C>,H::With<C>,I::With<C>,J::With<C>);
123}
124impl<A:Wrappable<B=B>,B:Backend,X:Wrappable<B=B>,Y:Wrappable<B=B>> Wrappable for SetType<A,X,Y>{
125 type B=B;
126 type With<C:Backend>=SetType<A::With<C>,X::With<C>,Y::With<C>>;
127}
128impl<A> Classification<A>{
129 pub fn from_inner(inner:A)->Self where Classification<A>:Op{
131 Self{inner,layer:Default::default()}
132 }
133 pub fn inner(&self)->&A{&self.inner}
135 pub fn inner_mut(&mut self)->&mut A{&mut self.inner}
137 pub fn into_inner(self)->A{self.inner}
139 pub fn with_inner<B>(&self,inner:B)->Classification<B> where Classification<B>:Op{Classification::from_inner(inner)}
141}
142impl<A> Regression<A>{
143 pub fn from_inner(inner:A)->Self where Regression<A>:Op{
145 Self{inner,layer:Default::default()}
146 }
147 pub fn inner(&self)->&A{&self.inner}
149 pub fn inner_mut(&mut self)->&mut A{&mut self.inner}
151 pub fn into_inner(self)->A{self.inner}
153 pub fn with_inner<B>(&self,inner:B)->Regression<B> where Regression<B>:Op{Regression::from_inner(inner)}
155}
156impl<B:Backend,C:Backend,K:BasicOps<B>+BasicOps<C>+TensorKind<B>+TensorKind<C>,const N:usize> ToBackend<C> for Tensor<B,N,K>{
157 fn to_backend_device(self,device:&C::Device)->Self::OnBackend{
158 let data=self.to_data();
159 Tensor::from_data(data,device)
160 }
161 type OnBackend=Tensor<C,N,K>;
162}
163impl<B:Backend,C:Backend> ToBackend<C> for Value<B>{
164 fn to_backend_device(self,device:&C::Device)->Self::OnBackend{
165 match self{Value::B1(x)=>x.to_backend_device(device).into(),Value::B2(x)=>x.to_backend_device(device).into(),Value::B3(x)=>x.to_backend_device(device).into(),Value::B4(x)=>x.to_backend_device(device).into(),Value::B5(x)=>x.to_backend_device(device).into(),Value::B6(x)=>x.to_backend_device(device).into(),Value::B7(x)=>x.to_backend_device(device).into(),Value::B8(x)=>x.to_backend_device(device).into(),Value::F1(x)=>x.to_backend_device(device).into(),Value::F2(x)=>x.to_backend_device(device).into(),Value::F3(x)=>x.to_backend_device(device).into(),Value::F4(x)=>x.to_backend_device(device).into(),Value::F5(x)=>x.to_backend_device(device).into(),Value::F6(x)=>x.to_backend_device(device).into(),Value::F7(x)=>x.to_backend_device(device).into(),Value::F8(x)=>x.to_backend_device(device).into(),Value::I1(x)=>x.to_backend_device(device).into(),Value::I2(x)=>x.to_backend_device(device).into(),Value::I3(x)=>x.to_backend_device(device).into(),Value::I4(x)=>x.to_backend_device(device).into(),Value::I5(x)=>x.to_backend_device(device).into(),Value::I6(x)=>x.to_backend_device(device).into(),Value::I7(x)=>x.to_backend_device(device).into(),Value::I8(x)=>x.to_backend_device(device).into(),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.to_backend_device(device)).collect()}
166 }
167 type OnBackend=Value<C>;
168}
169impl<B:Backend,E:Into<(Value<B>,Value<B>)>> Batcher<B,E,(Value<B>,Value<B>)> for BatchStacker{
170 fn batch(&self,items:Vec<E>,_device:&<B as Backend>::Device)->(Value<B>,Value<B>){
171 let items=items.into_iter().map(Into::into);
172 let (input,target):(Vec<Value<B>>,Vec<Value<B>>)=items.unzip();
173 let (input,target)=(Value::Multi(input),Value::Multi(target));
174
175 (input.zip().stack(0),target.zip().stack(0))
176 }
177}
178impl<B:Backend,W:AI<X,LossOutput<B>>+Wrappable<B=B>,X> ValidStep<X,ClassificationOutput<B>> for Wrapped<Classification<W>> where W::Decomposition:Module<B>{
179 fn step(&self,item:X)->ClassificationOutput<B>{self.forward(item)}
180}
181impl<B:Backend,W:AI<X,LossOutput<B>>+Wrappable<B=B>,X> ValidStep<X,RegressionOutput<B>> for Wrapped<Regression<W>> where W::Decomposition:Module<B>{
182 fn step(&self,item:X)->RegressionOutput<B>{self.forward(item)}
183}
184impl<B:Backend,W:Wrappable<B=B>> Module<B> for Wrapped<W> where W::Decomposition:Module<B>{
185 fn collect_devices(&self,devices:Vec<<B as Backend>::Device>)->Vec<<B as Backend>::Device>{self.inner.decompose_cloned().collect_devices(devices)}
186 fn devices(&self)->Vec<<B as Backend>::Device>{self.inner.decompose_cloned().devices()}
187 fn fork(self,device:&<B as Backend>::Device)->Self{Self::new(W::compose(self.inner.decompose().fork(device)))}
188 fn into_record(self)->Self::Record{self.inner.decompose().into_record()}
189 fn load_file<F:FileRecorder<B>,P:Into<PathBuf>>(self,filepath:P,recorder:&F,device:&<B as Backend>::Device)->Result<Self,RecorderError>{self.inner.decompose().load_file(filepath,recorder,device).map(|a|Self::new(W::compose(a)))}
190 fn load_record(self,record:Self::Record)->Self{Self::new(W::compose(self.inner.decompose().load_record(record)))}
191 fn map<Mapper:ModuleMapper<B>>(self,mapper:&mut Mapper)->Self{Self::new(W::compose(self.inner.decompose().map(mapper)))}
192 fn num_params(&self)->usize{self.inner.decompose_cloned().num_params()}
193 fn quantize_weights(self,quantizer:&mut Quantizer)->Self{Self::new(W::compose(self.inner.decompose().quantize_weights(quantizer)))}
194 fn save_file<F:FileRecorder<B>,P:Into<PathBuf>>(self,filepath:P,recorder:&F)->Result<(),RecorderError>{self.inner.decompose().save_file(filepath,recorder)}
195 fn to_device(self,device:&<B as Backend>::Device)->Self{Self::new(W::compose(self.inner.decompose().to_device(device)))}
196 fn visit<Visitor:ModuleVisitor<B>>(&self,visitor:&mut Visitor){self.inner.decompose_cloned().visit(visitor)}
197 type Record=<W::Decomposition as Module<B>>::Record;
198}
199impl<B:Backend,X:Into<Y>,Y> AI<X,Y> for Identity<B>{
200 fn forward(&self,input:X)->Y{input.into()}
201}
202impl<B:Backend> AI<LossOutput<B>,ClassificationOutput<B>> for ClassificationLayer{
203 fn forward(&self,lossoutput:LossOutput<B>)->ClassificationOutput<B>{let loss=match lossoutput.loss(){Value::F1(x)=>x,Value::F2(x)=>x.mean(),Value::F3(x)=>x.mean(),Value::F4(x)=>x.mean(),Value::F5(x)=>x.mean(),Value::F6(x)=>x.mean(),Value::F7(x)=>x.mean(),Value::F8(x)=>x.mean(),Value::Incompatible(e)=>panic!("{e}"),_=>panic!("cannot convert non floats to classification output")};
205 let output=match lossoutput.output(){Value::F1(x)=>x.unsqueeze(),Value::F2(x)=>x,Value::F3(x)=>x.flatten(0,1),Value::F4(x)=>x.flatten(0,2),Value::F5(x)=>x.flatten(0,3),Value::F6(x)=>x.flatten(0,4),Value::F7(x)=>x.flatten(0,5),Value::F8(x)=>x.flatten(0,6),Value::Incompatible(e)=>panic!("{e}"),_=>panic!("cannot convert non floats to classification output")};
206 let target=match lossoutput.target(){Value::I1(x)=>x,Value::I2(x)=>x.flatten(0,1),Value::I3(x)=>x.flatten(0,2),Value::I4(x)=>x.flatten(0,3),Value::I5(x)=>x.flatten(0,4),Value::I6(x)=>x.flatten(0,5),Value::I7(x)=>x.flatten(0,6),Value::I8(x)=>x.flatten(0,7),Value::Incompatible(e)=>panic!("{e}"),_=>panic!("cannot convert non floats to classification output")};
207 ClassificationOutput::new(loss,output,target)
208 }
209}
210impl<B:Backend> AI<LossOutput<B>,RegressionOutput<B>> for RegressionLayer{
211 fn forward(&self,lossoutput:LossOutput<B>)->RegressionOutput<B>{
212 let loss=match lossoutput.loss(){Value::F1(x)=>x,Value::F2(x)=>x.mean(),Value::F3(x)=>x.mean(),Value::F4(x)=>x.mean(),Value::F5(x)=>x.mean(),Value::F6(x)=>x.mean(),Value::F7(x)=>x.mean(),Value::F8(x)=>x.mean(),Value::Incompatible(e)=>panic!("{e}"),_=>panic!("cannot convert non floats to regression output")};
213 let output=match lossoutput.output(){Value::F1(x)=>x.unsqueeze(),Value::F2(x)=>x,Value::F3(x)=>x.flatten(0,1),Value::F4(x)=>x.flatten(0,2),Value::F5(x)=>x.flatten(0,3),Value::F6(x)=>x.flatten(0,4),Value::F7(x)=>x.flatten(0,5),Value::F8(x)=>x.flatten(0,6),Value::Incompatible(e)=>panic!("{e}"),_=>panic!("cannot convert non floats to regression output")};
214 let target=match lossoutput.target(){Value::F1(x)=>x.unsqueeze(),Value::F2(x)=>x,Value::F3(x)=>x.flatten(0,1),Value::F4(x)=>x.flatten(0,2),Value::F5(x)=>x.flatten(0,3),Value::F6(x)=>x.flatten(0,4),Value::F7(x)=>x.flatten(0,5),Value::F8(x)=>x.flatten(0,6),Value::Incompatible(e)=>panic!("{e}"),_=>panic!("cannot convert non floats to regression output")};
215 RegressionOutput::new(loss,output,target)
216 }
217}
218impl<B:Backend> Decompose for Identity<B>{
219 fn compose(_decomposition:Self::Decomposition)->Self{new()}
220 fn decompose(self){}
221 fn decompose_cloned(&self){}
222 type Decomposition=();
223}
224impl<B:Backend> Op for Identity<B>{
225 type Output=();
226}
227impl<B:Backend> Wrappable for Identity<B>{
228 type B=B;
229 type With<C:Backend>=Identity<C>;
230}
231impl<B:Backend> Wrappable for Layer<B>{
232 type B=B;
233 type With<C:Backend>=Layer<C>;
234}
235impl<B:Backend> Wrappable for LossOutput<B>{
236 type B=B;
237 type With<C:Backend>=LossOutput<C>;
238}
239impl<B:Backend> Wrappable for Value<B>{
240 type B=B;
241 type With<C:Backend>=Value<C>;
242}
243impl<C:Backend,W:ToBackend<C,OnBackend=W::With<C>>+Wrappable> ToBackend<C> for Wrapped<W>{
244 fn to_backend_device(self,device:&C::Device)->Self::OnBackend{
245 Wrapped{inner:self.inner.to_backend_device(device)}
246 }
247 type OnBackend=Wrapped<W::With<C>>;
248}
249impl<T:?Sized+Op> Shortcuts for T{}
250impl<W:AI<X,Y>+Wrappable,X,Y> AI<X,Y> for Wrapped<W>{
251 fn forward(&self,input:X)->Y{self.inner.forward(input)}
252 fn forward_mut(&mut self,input:X)->Y{self.inner.forward_mut(input)}
253}
254impl<W:Op+Wrappable> Op for Wrapped<W>{
255 type Output=W::Output;
256}
257impl<W:UnwrapInner+Wrappable> UnwrapInner for Wrapped<W>{
258 fn unwrap_inner(self)->Self::Inner{self.into_inner().unwrap_inner()}
259 type Inner=W::Inner;
260}
261impl<W:Wrappable> Decompose for Wrapped<W>{
262 fn compose(decomposition:Self::Decomposition)->Self{Self::new(W::compose(decomposition))}
263 fn decompose(self)->Self::Decomposition{self.inner.decompose()}
264 fn decompose_cloned(&self)->Self::Decomposition{self.inner.decompose_cloned()}
265 type Decomposition=W::Decomposition;
266}
267impl<W:Wrappable> Display for Wrapped<W>{
268 fn fmt(&self,f:&mut std::fmt::Formatter<'_>)->Result<(),std::fmt::Error>{write!(f,"todo")}
269}
270impl<W:Wrappable> From<W> for Wrapped<W>{
271 fn from(value:W)->Self{Self::new(value)}
272}
273impl<W:Wrappable> ModuleDisplay for Wrapped<W> where W::Decomposition:ModuleDisplay{
274 fn custom_content(&self,content:Content)->Option<Content>{self.inner.decompose_cloned().custom_content(content)}
275 fn custom_settings(&self)->Option<DisplaySettings>{self.inner.decompose_cloned().custom_settings()}
276 fn format(&self,passed_settings:DisplaySettings)->String{self.inner.decompose_cloned().format(passed_settings)}
277}
278impl<W:Wrappable> ModuleDisplayDefault for Wrapped<W> where W::Decomposition:ModuleDisplayDefault{
279 fn content(&self,content:Content)->Option<Content>{self.inner.decompose_cloned().content(content)}
280 fn num_params(&self)->usize{self.inner.decompose_cloned().num_params()}
281}
282impl<W:Wrappable> Wrappable for Abs<W>{
283 type B=W::B;
284 type With<C:Backend>=Abs<W::With<C>>;
285}
286impl<W:Wrappable> Wrappable for AccQ<W>{
287 type B=W::B;
288 type With<C:Backend>=AccQ<W::With<C>>;
289}
290impl<W:Wrappable> Wrappable for Cat<W>{
291 type B=W::B;
292 type With<C:Backend>=Cat<W::With<C>>;
293}
294impl<W:Wrappable> Wrappable for Classification<W>{
295 type B=W::B;
296 type With<C:Backend>=Classification<W::With<C>>;
297}
298impl<W:Wrappable> Wrappable for CrossEntropy<W>{
299 type B=W::B;
300 type With<C:Backend>=CrossEntropy<W::With<C>>;
301}
302impl<W:Wrappable> Wrappable for Duplicate<W>{
303 type B=W::B;
304 type With<C:Backend>=Duplicate<W::With<C>>;
305}
306impl<W:Wrappable> Wrappable for Graph<W>{
307 type B=W::B;
308 type With<C:Backend>=Graph<W::With<C>>;
309}
310impl<W:Wrappable> Wrappable for Inner<W>{
311 type B=W::B;
312 type With<C:Backend>=Inner<W::With<C>>;
313}
314impl<W:Wrappable> Wrappable for Mean<W>{
315 type B=W::B;
316 type With<C:Backend>=Mean<W::With<C>>;
317}
318impl<W:Wrappable> Wrappable for SquaredError<W>{
319 type B=W::B;
320 type With<C:Backend>=SquaredError<W::With<C>>;
321}
322impl<W:Wrappable> Wrappable for Map<W>{
323 type B=W::B;
324 type With<C:Backend>=Map<W::With<C>>;
325}
326impl<W:Wrappable> Wrappable for Regression<W>{
327 type B=W::B;
328 type With<C:Backend>=Regression<W::With<C>>;
329}
330impl<W:Wrappable> Wrappable for Sequential<W>{
331 type B=W::B;
332 type With<C:Backend>=Sequential<W::With<C>>;
333}
334impl<W:Wrappable> Wrappable for Choose<W>{
335 type B=W::B;
336 type With<C:Backend>=Choose<W::With<C>>;
337}
338impl<W:Wrappable> Wrappable for Unvec<W>{
339 type B=W::B;
340 type With<C:Backend>=Unvec<W::With<C>>;
341}
342impl<W:Wrappable> Wrappable for Zip<W>{
343 type B=W::B;
344 type With<C:Backend>=Zip<W::With<C>>;
345}
346impl<W:Wrappable> Wrapped<W>{
347 pub fn inner(&self)->&W{&self.inner}
349 pub fn inner_mut(&mut self)->&mut W{&mut self.inner}
351 pub fn into_inner(self)->W{self.inner}
353 pub fn new(inner:W)->Self{
355 Self{inner}
356 }
357}
358#[cfg(test)]
359mod tests{
360 #[test]
361 fn learn_xor(){
362 type B=NdArray;
363 type A=Autodiff<B>;
364 let i0=Tensor::<A,1>::from_data(TensorData::new([0.0,0.0].to_vec(),[2]),&Default::default());
365 let i1=Tensor::<A,1>::from_data(TensorData::new([0.0,1.0].to_vec(),[2]),&Default::default());
366 let i2=Tensor::<A,1>::from_data(TensorData::new([1.0,0.0].to_vec(),[2]),&Default::default());
367 let i3=Tensor::<A,1>::from_data(TensorData::new([1.0,1.0].to_vec(),[2]),&Default::default());
368 let o0=Tensor::<A,1>::from_data(TensorData::new([0.0].to_vec(),[1]),&Default::default());
369 let o1=Tensor::<A,1>::from_data(TensorData::new([1.0].to_vec(),[1]),&Default::default());
370 let o2=Tensor::<A,1>::from_data(TensorData::new([1.0].to_vec(),[1]),&Default::default());
371 let o3=Tensor::<A,1>::from_data(TensorData::new([0.0].to_vec(),[1]),&Default::default());
372
373 let dataset:Vec<(Tensor<A,1>,Tensor<A,1>)>=[(i0,o0),(i1,o1),(i2,o2),(i3,o3)].into_iter().cycle().take(4000).collect();
374 let train=InMemDataset::new(dataset.clone().into_iter().map(|(i,o)|(Value::from(i),Value::from(o))).collect());
375 let valid=InMemDataset::new(dataset.into_iter().map(|(i,o)|(Value::from(i.valid()),Value::from(o.valid()))).collect());
376 let mut graph:Graph<Layer<A>>=Graph::new();
377 graph.connect("input","x").with_clear(true).with(Layer::linear(true,2,10,1.0));
378 graph.connect("x","y").with_clear(true).with(Layer::relu());
379 graph.connect("y","output").with_clear(true).with(Layer::linear(false,10,1,1.0));
380
381 let graph=Unvec(graph).wrap_inner().squared_error().set_type::<(Value<A>,Value<A>),LossOutput<A>>().regression().wrap();
382 let graph=graph.train(&TrainConfig::new().with_checkpoints(false),SgdConfig::new().init(),0.01,train,valid);
383 let graph=graph.valid().unwrap_inner();
384
385 let inputval=Value::from(Tensor::<B,2>::from_data(TensorData::new([0.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0].to_vec(),[4,2]),&Default::default()));
386 let outputval=graph.forward(inputval);
387 if let Value::F2(o)=outputval{
388 let target=Tensor::<B,2>::from_data(TensorData::new([0.0,1.0,1.0,0.0].to_vec(),[4,1]),&Default::default());
389 let error=(target-o.clone()).abs().max();
390 println!("{}",o);
391 assert!(error.into_scalar()<0.1);
392 }else{
393 panic!("h");
394 }
395 }
396 use burn::{
397 backend::Autodiff,data::dataset::InMemDataset,optim::SgdConfig
398 };
399 use super::*;
400}
401mod layer;
402mod shape;
403mod value;
404pub fn apply_depthwise<B:Backend,F:FnMut(Value<B>)->Value<B>>(depth:usize,mut op:F,value:Value<B>)->Value<B>{
406 fn inner<B:Backend,F:FnMut(Value<B>)->Value<B>>(depth:usize,op:&mut F,value:Value<B>)->(Value<B>,usize){
407 let mut height=1;
408 let value=if value.is_multi(){
409 let value:Value<B>=value.into_iter().map(|v|{
410 let (v,h)=inner(depth,op,v);
411 height=height.max(h);
412 v
413 }).collect();
414 if value.len()==0{height=0}else{height+=1}
415 value
416 }else{
417 value
418 };
419 (if depth==height{op(value)}else{value},height)
420 }
421 inner(depth,&mut op,value).0
422}
423pub fn new<B:Backend>()->Identity<B>{
425 Identity{phantom:PhantomData}
426}
427#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
428pub struct BatchStacker;
430#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
431pub struct Classification<A>{inner:A,layer:ClassificationLayer}
433#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
434pub struct ClassificationLayer{seal:PhantomData<()>}
436#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
437pub struct DontRender;
439#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
440pub struct Identity<B:Backend>{phantom:PhantomData<B>}
442#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
443pub struct Regression<A>{inner:A,layer:RegressionLayer}
445#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
446pub struct RegressionLayer{seal:PhantomData<()>}
448#[derive(Config,Debug)]
449pub struct TrainConfig{
451 #[config(default="String::from(\".artifact\")")]
452 artifact_directory:String,
453 #[config(default="16")]
454 batch_size:usize,
455 #[config(default="true")]
456 checkpoints:bool,
457 #[config(default="false")]
458 console_rendering:bool,
459 #[config(default="10")]
460 epochs:usize,
461 #[config(default="false")]
462 summary:bool,
463 #[config(default="4")]
464 workers:usize
465}
466#[derive(Clone,Copy,Debug,Default)]
467pub struct Wrapped<W:Wrappable>{inner:W}
469pub trait Shortcuts{
471 fn classification(self)->Classification<Self> where Classification<Self>:Op,Self:Sized{Classification::from_inner(self)}
473 fn regression(self)->Regression<Self> where Regression<Self>:Op,Self:Sized{Regression::from_inner(self)}
475 fn wrap(self)->Wrapped<Self> where Self:Wrappable{Wrapped::new(self)}
477}
478pub trait ToBackend<B:Backend>:Sized{
480 fn to_backend_device(self,device:&B::Device)->Self::OnBackend;
482 fn to_backend(self)->Self::OnBackend{self.to_backend_device(&Default::default())}
484 type OnBackend;
486}
487pub trait Wrappable:Clone+Debug+Decompose+Send{
489 type B:Backend;
490 type With<C:Backend>:Wrappable<B=C,With<C>=Self::With<C>>+Wrappable<B=C,With<Self::B>=Self>;
491}
492pub use burn as lib;
493pub use layer::{Attention,AttentionConfig,AttentionMask,BiasConfig,Cache,Config,Layer,KQV,KQVConfig,PowerMaskInfo};
494pub use shape::{Kind,Reshape,Shape};
495pub use value::{LossOutput,Value};
496use burn::{
497 backend::NdArray,
498 data::{
499 dataset::Dataset,dataloader::{batcher::Batcher,DataLoaderBuilder}
500 },
501 lr_scheduler::LrScheduler,
502 module::{AutodiffModule,Content,DisplaySettings,ModuleDisplay,ModuleDisplayDefault,ModuleMapper,ModuleVisitor,Quantizer},
503 optim::Optimizer,
504 prelude::*,
505 record::{CompactRecorder,FileRecorder,RecorderError},
506 tensor::{BasicOps,TensorKind,backend::AutodiffBackend},
507 train::{
508 ClassificationOutput,LearnerBuilder,RegressionOutput,TrainOutput,TrainStep,ValidStep,metric::{Adaptor,ItemLazy,LossInput,LossMetric},renderer::{MetricState,MetricsRenderer,TrainingProgress}
509 }
510};
511use crate::{
512 AI,Decompose,Graph,Inner,Op,UnwrapInner,Unvec,
513 builtin::{
514 Duplicate,Map,Sequential,SetType,Zip,math::{Abs,Mean,SquaredError},reinforcement::AccQ,soft::{Choose,CrossEntropy},structural::Cat
515 },
516 ops::Stack as OpsStack
517};
518use rand::random;
519use serde::{Deserialize,Serialize};
520use std::{
521 fmt::{Debug,Display},fs::{create_dir_all as create_folder},marker::PhantomData,path::PathBuf
522};