1impl Decompose for ClassificationLayer{
2 fn compose(_decomposition:Self::Decomposition)->Self{Self::default()}
3 fn decompose(self){}
4 fn decompose_cloned(&self){}
5 type Decomposition=();
6}
7impl Decompose for RegressionLayer{
8 fn compose(_decomposition:Self::Decomposition)->Self{Self::default()}
9 fn decompose(self){}
10 fn decompose_cloned(&self){}
11 type Decomposition=();
12}
13impl MetricsRenderer for DontRender{
14 fn manual_close(&mut self){}
15}
16impl MetricsRendererEvaluation for DontRender{
17 fn render_test(&mut self,_item:EvaluationProgress){}
18 fn update_test(&mut self,_name:EvaluationName,_state:MetricState){}
19}
20impl MetricsRendererTraining for DontRender{
21 fn render_train(&mut self,_item:TrainingProgress){}
22 fn render_valid(&mut self,_item:TrainingProgress){}
23 fn update_train(&mut self,_state:MetricState){}
24 fn update_valid(&mut self,_state:MetricState){}
25}
26impl Op for ClassificationLayer{
27 type Output=ClassificationOutput<NdArray>;
28}
29impl Op for RegressionLayer{
30 type Output=RegressionOutput<NdArray>;
31}
32impl<A:AI<X,LossOutput<B>>,B:Backend,X> AI<X,ClassificationOutput<B>> for Classification<A>{
33 fn forward(&self,input:X)->ClassificationOutput<B>{self.layer.forward(self.inner.forward(input))}
34 fn forward_mut(&mut self,input:X)->ClassificationOutput<B>{self.layer.forward(self.inner.forward_mut(input))}
35}
36impl<A:AI<X,LossOutput<B>>,B:Backend,X> AI<X,RegressionOutput<B>> for Regression<A>{
37 fn forward(&self,input:X)->RegressionOutput<B>{self.layer.forward(self.inner.forward(input))}
38 fn forward_mut(&mut self,input:X)->RegressionOutput<B>{self.layer.forward(self.inner.forward_mut(input))}
39}
40impl<A:AutodiffBackend<InnerBackend=B>,B:Backend,W:'static+Wrappable<B=A>,Y:'static+ItemLazy+Send+Sync,Z:'static+ItemLazy+Send+Sync> Wrapped<W> where <Self as AutodiffModule<A>>::InnerModule:ValidStep<(Value<B>,Value<B>),Z>,Self:TrainStep<(Value<A>,Value<A>),Y>,W::Decomposition:AutodiffModule<A>,W::With<B>:Decompose<Decomposition=<W::Decomposition as AutodiffModule<A>>::InnerModule>+Op<Output=Z>,W:Op<Output=Y>,Y::ItemSync:Adaptor<LossInput<NdArray>>,Z::ItemSync:Adaptor<LossInput<NdArray>>{
41 pub fn train<I:'static+Clone+Debug+Into<(Value<A>,Value<A>)>+Send+Sync,J:'static+Clone+Debug+Into<(Value<B>,Value<B>)>+Send+Sync,O:'static+Optimizer<Self,A>,S:'static+LrScheduler,T:'static+Dataset<I>,V:'static+Dataset<J>>(self,config:&TrainConfig,optimizer:O,scheduler:S,train:T,valid:V)->Wrapped<W::With<B>> where O::Record:'static,S::Record<A>:'static{
43 let batcher=BatchStacker;
44 let trainloader=DataLoaderBuilder::new(batcher).batch_size(config.batch_size).shuffle(random()).num_workers(config.workers).build(train);
45 let validloader=DataLoaderBuilder::new(batcher).batch_size(config.batch_size).shuffle(random()).num_workers(config.workers).build(valid);
46
47 create_folder(&config.artifact_directory).unwrap();
48 let builder=LearnerBuilder::new(&config.artifact_directory).metric_train_numeric(LossMetric::new()).metric_valid_numeric(LossMetric::new());
49 let builder=if config.checkpoints{builder.with_file_checkpointer(CompactRecorder::new())}else{builder};
50 let builder=if config.console_rendering{builder}else{builder.renderer(DontRender)};
51 let builder=builder.learning_strategy(LearningStrategy::SingleDevice(<W::B as Backend>::Device::default())).num_epochs(config.epochs);
52 let builder=if config.summary{builder.summary()}else{builder};
53 let learner=builder.build(self,optimizer,scheduler);
54 learner.fit(trainloader,validloader).model
55 }
56}
57impl<A:AutodiffBackend,W:AI<X,LossOutput<A>>+Wrappable<B=A>,X> TrainStep<X,ClassificationOutput<A>> for Wrapped<Classification<W>> where W::Decomposition:AutodiffModule<A>,W::With<A::InnerBackend>:Decompose<Decomposition=<W::Decomposition as AutodiffModule<A>>::InnerModule>{
58 fn step(&self,item:X)->TrainOutput<ClassificationOutput<A>>{
59 let output:ClassificationOutput<A>=self.forward(item);
60 TrainOutput::new(self,output.loss.backward(),output)
61 }
62}
63impl<A:AutodiffBackend,W:AI<X,LossOutput<A>>+Wrappable<B=A>,X> TrainStep<X,RegressionOutput<A>> for Wrapped<Regression<W>> where W::Decomposition:AutodiffModule<A>,W::With<A::InnerBackend>:Decompose<Decomposition=<W::Decomposition as AutodiffModule<A>>::InnerModule>{
64 fn step(&self,item:X)->TrainOutput<RegressionOutput<A>>{
65 let output:RegressionOutput<A>=self.forward(item);
66 TrainOutput::new(self,output.loss.backward(),output)
67 }
68}
69impl<A:AutodiffBackend,W:Wrappable<B=A>> AutodiffModule<A> for Wrapped<W> where W::Decomposition:AutodiffModule<A>,W::With<A::InnerBackend>:Decompose<Decomposition=<W::Decomposition as AutodiffModule<A>>::InnerModule>{
70 fn valid(&self)->Self::InnerModule{Wrapped::new(Decompose::compose(self.inner.decompose_cloned().valid()))}
71 type InnerModule=Wrapped<W::With<A::InnerBackend>>;
72}
73impl<A:Decompose> Decompose for Classification<A>{
74 fn compose(decomposition:Self::Decomposition)->Self{
75 Self{inner:A::compose(decomposition),layer:Default::default()}
76 }
77 fn decompose(self)->Self::Decomposition{self.inner.decompose()}
78 fn decompose_cloned(&self)->Self::Decomposition{self.inner.decompose_cloned()}
79 type Decomposition=A::Decomposition;
80}
81impl<A:Decompose> Decompose for Regression<A>{
82 fn compose(decomposition:Self::Decomposition)->Self{
83 Self{inner:A::compose(decomposition),layer:Default::default()}
84 }
85 fn decompose(self)->Self::Decomposition{self.inner.decompose()}
86 fn decompose_cloned(&self)->Self::Decomposition{self.inner.decompose_cloned()}
87 type Decomposition=A::Decomposition;
88}
89impl<A:Op<Output=Y>+Wrappable,Y> Op for Classification<A> where ClassificationLayer:AI<Y,ClassificationOutput<A::B>>{
90 type Output=ClassificationOutput<A::B>;
91}
92impl<A:Op<Output=Y>+Wrappable,Y> Op for Regression<A> where RegressionLayer:AI<Y,RegressionOutput<A::B>>{
93 type Output=RegressionOutput<A::B>;
94}
95impl<A:UnwrapInner> UnwrapInner for Classification<A>{
96 fn unwrap_inner(self)->Self::Inner{self.into_inner().unwrap_inner()}
97 type Inner=A::Inner;
98}
99impl<A:UnwrapInner> UnwrapInner for Regression<A>{
100 fn unwrap_inner(self)->Self::Inner{self.into_inner().unwrap_inner()}
101 type Inner=A::Inner;
102}
103impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>> Wrappable for (A,D){
104 type B=B;
105 type With<C:Backend>=(A::With<C>,D::With<C>);
106}
107impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>,E:Wrappable<B=B>> Wrappable for (A,D,E){
108 type B=B;
109 type With<C:Backend>=(A::With<C>,D::With<C>,E::With<C>);
110}
111impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>,E:Wrappable<B=B>,F:Wrappable<B=B>> Wrappable for (A,D,E,F){
112 type B=B;
113 type With<C:Backend>=(A::With<C>,D::With<C>,E::With<C>,F::With<C>);
114}
115impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>,E:Wrappable<B=B>,F:Wrappable<B=B>,G:Wrappable<B=B>> Wrappable for (A,D,E,F,G){
116 type B=B;
117 type With<C:Backend>=(A::With<C>,D::With<C>,E::With<C>,F::With<C>,G::With<C>);
118}
119impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>,E:Wrappable<B=B>,F:Wrappable<B=B>,G:Wrappable<B=B>,H:Wrappable<B=B>> Wrappable for (A,D,E,F,G,H){
120 type B=B;
121 type With<C:Backend>=(A::With<C>,D::With<C>,E::With<C>,F::With<C>,G::With<C>,H::With<C>);
122}
123impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>,E:Wrappable<B=B>,F:Wrappable<B=B>,G:Wrappable<B=B>,H:Wrappable<B=B>,I:Wrappable<B=B>> Wrappable for (A,D,E,F,G,H,I){
124 type B=B;
125 type With<C:Backend>=(A::With<C>,D::With<C>,E::With<C>,F::With<C>,G::With<C>,H::With<C>,I::With<C>);
126}
127impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>,E:Wrappable<B=B>,F:Wrappable<B=B>,G:Wrappable<B=B>,H:Wrappable<B=B>,I:Wrappable<B=B>,J:Wrappable<B=B>> Wrappable for (A,D,E,F,G,H,I,J){
128 type B=B;
129 type With<C:Backend>=(A::With<C>,D::With<C>,E::With<C>,F::With<C>,G::With<C>,H::With<C>,I::With<C>,J::With<C>);
130}
131impl<A:Wrappable<B=B>,B:Backend,X:Wrappable<B=B>,Y:Wrappable<B=B>> Wrappable for SetType<A,X,Y>{
132 type B=B;
133 type With<C:Backend>=SetType<A::With<C>,X::With<C>,Y::With<C>>;
134}
135impl<A> Classification<A>{
136 pub fn from_inner(inner:A)->Self where Classification<A>:Op{
138 Self{inner,layer:Default::default()}
139 }
140 pub fn inner(&self)->&A{&self.inner}
142 pub fn inner_mut(&mut self)->&mut A{&mut self.inner}
144 pub fn into_inner(self)->A{self.inner}
146 pub fn with_inner<B>(&self,inner:B)->Classification<B> where Classification<B>:Op{Classification::from_inner(inner)}
148}
149impl<A> Regression<A>{
150 pub fn from_inner(inner:A)->Self where Regression<A>:Op{
152 Self{inner,layer:Default::default()}
153 }
154 pub fn inner(&self)->&A{&self.inner}
156 pub fn inner_mut(&mut self)->&mut A{&mut self.inner}
158 pub fn into_inner(self)->A{self.inner}
160 pub fn with_inner<B>(&self,inner:B)->Regression<B> where Regression<B>:Op{Regression::from_inner(inner)}
162}
163impl<B:Backend,C:Backend,K:BasicOps<B>+BasicOps<C>+TensorKind<B>+TensorKind<C>,const N:usize> ToBackend<C> for Tensor<B,N,K>{
164 fn to_backend_device(self,device:&C::Device)->Self::OnBackend{
165 let data=self.to_data();
166 Tensor::from_data(data,device)
167 }
168 type OnBackend=Tensor<C,N,K>;
169}
170impl<B:Backend,C:Backend> ToBackend<C> for Value<B>{
171 fn to_backend_device(self,device:&C::Device)->Self::OnBackend{
172 match self{Value::B1(x)=>x.to_backend_device(device).into(),Value::B2(x)=>x.to_backend_device(device).into(),Value::B3(x)=>x.to_backend_device(device).into(),Value::B4(x)=>x.to_backend_device(device).into(),Value::B5(x)=>x.to_backend_device(device).into(),Value::B6(x)=>x.to_backend_device(device).into(),Value::B7(x)=>x.to_backend_device(device).into(),Value::B8(x)=>x.to_backend_device(device).into(),Value::F1(x)=>x.to_backend_device(device).into(),Value::F2(x)=>x.to_backend_device(device).into(),Value::F3(x)=>x.to_backend_device(device).into(),Value::F4(x)=>x.to_backend_device(device).into(),Value::F5(x)=>x.to_backend_device(device).into(),Value::F6(x)=>x.to_backend_device(device).into(),Value::F7(x)=>x.to_backend_device(device).into(),Value::F8(x)=>x.to_backend_device(device).into(),Value::I1(x)=>x.to_backend_device(device).into(),Value::I2(x)=>x.to_backend_device(device).into(),Value::I3(x)=>x.to_backend_device(device).into(),Value::I4(x)=>x.to_backend_device(device).into(),Value::I5(x)=>x.to_backend_device(device).into(),Value::I6(x)=>x.to_backend_device(device).into(),Value::I7(x)=>x.to_backend_device(device).into(),Value::I8(x)=>x.to_backend_device(device).into(),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.to_backend_device(device)).collect()}
173 }
174 type OnBackend=Value<C>;
175}
176impl<B:Backend,E:Into<(Value<B>,Value<B>)>> Batcher<B,E,(Value<B>,Value<B>)> for BatchStacker{
177 fn batch(&self,items:Vec<E>,_device:&<B as Backend>::Device)->(Value<B>,Value<B>){
178 let items=items.into_iter().map(Into::into);
179 let (input,target):(Vec<Value<B>>,Vec<Value<B>>)=items.unzip();
180 let (input,target)=(Value::Multi(input),Value::Multi(target));
181
182 (input.zip().stack(0),target.zip().stack(0))
183 }
184}
185impl<B:Backend,W:AI<X,LossOutput<B>>+Wrappable<B=B>,X> ValidStep<X,ClassificationOutput<B>> for Wrapped<Classification<W>> where W::Decomposition:Module<B>{
186 fn step(&self,item:X)->ClassificationOutput<B>{self.forward(item)}
187}
188impl<B:Backend,W:AI<X,LossOutput<B>>+Wrappable<B=B>,X> ValidStep<X,RegressionOutput<B>> for Wrapped<Regression<W>> where W::Decomposition:Module<B>{
189 fn step(&self,item:X)->RegressionOutput<B>{self.forward(item)}
190}
191impl<B:Backend,W:Wrappable<B=B>> Module<B> for Wrapped<W> where W::Decomposition:Module<B>{
192 fn collect_devices(&self,devices:Vec<<B as Backend>::Device>)->Vec<<B as Backend>::Device>{self.inner.decompose_cloned().collect_devices(devices)}
193 fn devices(&self)->Vec<<B as Backend>::Device>{self.inner.decompose_cloned().devices()}
194 fn fork(self,device:&<B as Backend>::Device)->Self{Self::new(W::compose(self.inner.decompose().fork(device)))}
195 fn into_record(self)->Self::Record{self.inner.decompose().into_record()}
196 fn load_file<F:FileRecorder<B>,P:Into<PathBuf>>(self,filepath:P,recorder:&F,device:&<B as Backend>::Device)->Result<Self,RecorderError>{self.inner.decompose().load_file(filepath,recorder,device).map(|a|Self::new(W::compose(a)))}
197 fn load_record(self,record:Self::Record)->Self{Self::new(W::compose(self.inner.decompose().load_record(record)))}
198 fn map<Mapper:ModuleMapper<B>>(self,mapper:&mut Mapper)->Self{Self::new(W::compose(self.inner.decompose().map(mapper)))}
199 fn num_params(&self)->usize{self.inner.decompose_cloned().num_params()}
200 fn quantize_weights(self,quantizer:&mut Quantizer)->Self{Self::new(W::compose(self.inner.decompose().quantize_weights(quantizer)))}
201 fn save_file<F:FileRecorder<B>,P:Into<PathBuf>>(self,filepath:P,recorder:&F)->Result<(),RecorderError>{self.inner.decompose().save_file(filepath,recorder)}
202 fn to_device(self,device:&<B as Backend>::Device)->Self{Self::new(W::compose(self.inner.decompose().to_device(device)))}
203 fn visit<Visitor:ModuleVisitor<B>>(&self,visitor:&mut Visitor){self.inner.decompose_cloned().visit(visitor)}
204 type Record=<W::Decomposition as Module<B>>::Record;
205}
206impl<B:Backend,X:Into<Y>,Y> AI<X,Y> for Identity<B>{
207 fn forward(&self,input:X)->Y{input.into()}
208}
209impl<B:Backend> AI<LossOutput<B>,ClassificationOutput<B>> for ClassificationLayer{
210 fn forward(&self,lossoutput:LossOutput<B>)->ClassificationOutput<B>{let loss=match lossoutput.loss(){Value::F1(x)=>x,Value::F2(x)=>x.mean(),Value::F3(x)=>x.mean(),Value::F4(x)=>x.mean(),Value::F5(x)=>x.mean(),Value::F6(x)=>x.mean(),Value::F7(x)=>x.mean(),Value::F8(x)=>x.mean(),Value::Incompatible(e)=>panic!("{e}"),_=>panic!("cannot convert non floats to classification output")};
212 let output=match lossoutput.output(){Value::F1(x)=>x.unsqueeze(),Value::F2(x)=>x,Value::F3(x)=>x.flatten(0,1),Value::F4(x)=>x.flatten(0,2),Value::F5(x)=>x.flatten(0,3),Value::F6(x)=>x.flatten(0,4),Value::F7(x)=>x.flatten(0,5),Value::F8(x)=>x.flatten(0,6),Value::Incompatible(e)=>panic!("{e}"),_=>panic!("cannot convert non floats to classification output")};
213 let target=match lossoutput.target(){Value::I1(x)=>x,Value::I2(x)=>x.flatten(0,1),Value::I3(x)=>x.flatten(0,2),Value::I4(x)=>x.flatten(0,3),Value::I5(x)=>x.flatten(0,4),Value::I6(x)=>x.flatten(0,5),Value::I7(x)=>x.flatten(0,6),Value::I8(x)=>x.flatten(0,7),Value::Incompatible(e)=>panic!("{e}"),_=>panic!("cannot convert non floats to classification output")};
214 ClassificationOutput::new(loss,output,target)
215 }
216}
217impl<B:Backend> AI<LossOutput<B>,RegressionOutput<B>> for RegressionLayer{
218 fn forward(&self,lossoutput:LossOutput<B>)->RegressionOutput<B>{
219 let loss=match lossoutput.loss(){Value::F1(x)=>x,Value::F2(x)=>x.mean(),Value::F3(x)=>x.mean(),Value::F4(x)=>x.mean(),Value::F5(x)=>x.mean(),Value::F6(x)=>x.mean(),Value::F7(x)=>x.mean(),Value::F8(x)=>x.mean(),Value::Incompatible(e)=>panic!("{e}"),_=>panic!("cannot convert non floats to regression output")};
220 let output=match lossoutput.output(){Value::F1(x)=>x.unsqueeze(),Value::F2(x)=>x,Value::F3(x)=>x.flatten(0,1),Value::F4(x)=>x.flatten(0,2),Value::F5(x)=>x.flatten(0,3),Value::F6(x)=>x.flatten(0,4),Value::F7(x)=>x.flatten(0,5),Value::F8(x)=>x.flatten(0,6),Value::Incompatible(e)=>panic!("{e}"),_=>panic!("cannot convert non floats to regression output")};
221 let target=match lossoutput.target(){Value::F1(x)=>x.unsqueeze(),Value::F2(x)=>x,Value::F3(x)=>x.flatten(0,1),Value::F4(x)=>x.flatten(0,2),Value::F5(x)=>x.flatten(0,3),Value::F6(x)=>x.flatten(0,4),Value::F7(x)=>x.flatten(0,5),Value::F8(x)=>x.flatten(0,6),Value::Incompatible(e)=>panic!("{e}"),_=>panic!("cannot convert non floats to regression output")};
222 RegressionOutput::new(loss,output,target)
223 }
224}
225impl<B:Backend> Decompose for Identity<B>{
226 fn compose(_decomposition:Self::Decomposition)->Self{new()}
227 fn decompose(self){}
228 fn decompose_cloned(&self){}
229 type Decomposition=();
230}
231impl<B:Backend> Op for Identity<B>{
232 type Output=();
233}
234impl<B:Backend> Wrappable for Identity<B>{
235 type B=B;
236 type With<C:Backend>=Identity<C>;
237}
238impl<B:Backend> Wrappable for Layer<B>{
239 type B=B;
240 type With<C:Backend>=Layer<C>;
241}
242impl<B:Backend> Wrappable for LossOutput<B>{
243 type B=B;
244 type With<C:Backend>=LossOutput<C>;
245}
246impl<B:Backend> Wrappable for Value<B>{
247 type B=B;
248 type With<C:Backend>=Value<C>;
249}
250impl<C:Backend,W:ToBackend<C,OnBackend=W::With<C>>+Wrappable> ToBackend<C> for Wrapped<W>{
251 fn to_backend_device(self,device:&C::Device)->Self::OnBackend{
252 Wrapped{inner:self.inner.to_backend_device(device)}
253 }
254 type OnBackend=Wrapped<W::With<C>>;
255}
256impl<T:?Sized+Op> Shortcuts for T{}
257impl<W:AI<X,Y>+Wrappable,X,Y> AI<X,Y> for Wrapped<W>{
258 fn forward(&self,input:X)->Y{self.inner.forward(input)}
259 fn forward_mut(&mut self,input:X)->Y{self.inner.forward_mut(input)}
260}
261impl<W:Op+Wrappable> Op for Wrapped<W>{
262 type Output=W::Output;
263}
264impl<W:UnwrapInner+Wrappable> UnwrapInner for Wrapped<W>{
265 fn unwrap_inner(self)->Self::Inner{self.into_inner().unwrap_inner()}
266 type Inner=W::Inner;
267}
268impl<W:Wrappable> Decompose for Wrapped<W>{
269 fn compose(decomposition:Self::Decomposition)->Self{Self::new(W::compose(decomposition))}
270 fn decompose(self)->Self::Decomposition{self.inner.decompose()}
271 fn decompose_cloned(&self)->Self::Decomposition{self.inner.decompose_cloned()}
272 type Decomposition=W::Decomposition;
273}
274impl<W:Wrappable> Display for Wrapped<W>{
275 fn fmt(&self,f:&mut std::fmt::Formatter<'_>)->Result<(),std::fmt::Error>{write!(f,"todo")}
276}
277impl<W:Wrappable> From<W> for Wrapped<W>{
278 fn from(value:W)->Self{Self::new(value)}
279}
280impl<W:Wrappable> ModuleDisplay for Wrapped<W> where W::Decomposition:ModuleDisplay{
281 fn custom_content(&self,content:Content)->Option<Content>{self.inner.decompose_cloned().custom_content(content)}
282 fn custom_settings(&self)->Option<DisplaySettings>{self.inner.decompose_cloned().custom_settings()}
283 fn format(&self,passed_settings:DisplaySettings)->String{self.inner.decompose_cloned().format(passed_settings)}
284}
285impl<W:Wrappable> ModuleDisplayDefault for Wrapped<W> where W::Decomposition:ModuleDisplayDefault{
286 fn content(&self,content:Content)->Option<Content>{self.inner.decompose_cloned().content(content)}
287 fn num_params(&self)->usize{self.inner.decompose_cloned().num_params()}
288}
289impl<W:Wrappable> Wrappable for Abs<W>{
290 type B=W::B;
291 type With<C:Backend>=Abs<W::With<C>>;
292}
293impl<W:Wrappable> Wrappable for AccQ<W>{
294 type B=W::B;
295 type With<C:Backend>=AccQ<W::With<C>>;
296}
297impl<W:Wrappable> Wrappable for Cat<W>{
298 type B=W::B;
299 type With<C:Backend>=Cat<W::With<C>>;
300}
301impl<W:Wrappable> Wrappable for Classification<W>{
302 type B=W::B;
303 type With<C:Backend>=Classification<W::With<C>>;
304}
305impl<W:Wrappable> Wrappable for CrossEntropy<W>{
306 type B=W::B;
307 type With<C:Backend>=CrossEntropy<W::With<C>>;
308}
309impl<W:Wrappable> Wrappable for Duplicate<W>{
310 type B=W::B;
311 type With<C:Backend>=Duplicate<W::With<C>>;
312}
313impl<W:Wrappable> Wrappable for Graph<W>{
314 type B=W::B;
315 type With<C:Backend>=Graph<W::With<C>>;
316}
317impl<W:Wrappable> Wrappable for Inner<W>{
318 type B=W::B;
319 type With<C:Backend>=Inner<W::With<C>>;
320}
321impl<W:Wrappable> Wrappable for Mean<W>{
322 type B=W::B;
323 type With<C:Backend>=Mean<W::With<C>>;
324}
325impl<W:Wrappable> Wrappable for SquaredError<W>{
326 type B=W::B;
327 type With<C:Backend>=SquaredError<W::With<C>>;
328}
329impl<W:Wrappable> Wrappable for Map<W>{
330 type B=W::B;
331 type With<C:Backend>=Map<W::With<C>>;
332}
333impl<W:Wrappable> Wrappable for Regression<W>{
334 type B=W::B;
335 type With<C:Backend>=Regression<W::With<C>>;
336}
337impl<W:Wrappable> Wrappable for Sequential<W>{
338 type B=W::B;
339 type With<C:Backend>=Sequential<W::With<C>>;
340}
341impl<W:Wrappable> Wrappable for Choose<W>{
342 type B=W::B;
343 type With<C:Backend>=Choose<W::With<C>>;
344}
345impl<W:Wrappable> Wrappable for Unvec<W>{
346 type B=W::B;
347 type With<C:Backend>=Unvec<W::With<C>>;
348}
349impl<W:Wrappable> Wrappable for Zip<W>{
350 type B=W::B;
351 type With<C:Backend>=Zip<W::With<C>>;
352}
353impl<W:Wrappable> Wrapped<W>{
354 pub fn inner(&self)->&W{&self.inner}
356 pub fn inner_mut(&mut self)->&mut W{&mut self.inner}
358 pub fn into_inner(self)->W{self.inner}
360 pub fn new(inner:W)->Self{
362 Self{inner}
363 }
364}
365#[cfg(test)]
366mod tests{
367 #[test]
368 fn learn_xor(){
369 type B=NdArray;
370 type A=Autodiff<B>;
371 let i0=Tensor::<A,1>::from_data(TensorData::new([0.0,0.0].to_vec(),[2]),&Default::default());
372 let i1=Tensor::<A,1>::from_data(TensorData::new([0.0,1.0].to_vec(),[2]),&Default::default());
373 let i2=Tensor::<A,1>::from_data(TensorData::new([1.0,0.0].to_vec(),[2]),&Default::default());
374 let i3=Tensor::<A,1>::from_data(TensorData::new([1.0,1.0].to_vec(),[2]),&Default::default());
375 let o0=Tensor::<A,1>::from_data(TensorData::new([0.0].to_vec(),[1]),&Default::default());
376 let o1=Tensor::<A,1>::from_data(TensorData::new([1.0].to_vec(),[1]),&Default::default());
377 let o2=Tensor::<A,1>::from_data(TensorData::new([1.0].to_vec(),[1]),&Default::default());
378 let o3=Tensor::<A,1>::from_data(TensorData::new([0.0].to_vec(),[1]),&Default::default());
379
380 let dataset:Vec<(Tensor<A,1>,Tensor<A,1>)>=[(i0,o0),(i1,o1),(i2,o2),(i3,o3)].into_iter().cycle().take(4000).collect();
381 let train=InMemDataset::new(dataset.clone().into_iter().map(|(i,o)|(Value::from(i),Value::from(o))).collect());
382 let valid=InMemDataset::new(dataset.into_iter().map(|(i,o)|(Value::from(i.valid()),Value::from(o.valid()))).collect());
383 let mut graph:Graph<Layer<A>>=Graph::new();
384 graph.connect("input","x").with_clear(true).with(Layer::linear(true,2,10,1.0));
385 graph.connect("x","y").with_clear(true).with(Layer::relu());
386 graph.connect("y","output").with_clear(true).with(Layer::linear(false,10,1,1.0));
387
388 let graph=Unvec(graph).wrap_inner().squared_error().set_type::<(Value<A>,Value<A>),LossOutput<A>>().regression().wrap();
389 let graph=graph.train(&TrainConfig::new().with_checkpoints(false),SgdConfig::new().init(),0.01,train,valid);
390 let graph=graph.unwrap_inner();
391
392 let inputval=Value::from(Tensor::<B,2>::from_data(TensorData::new([0.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0].to_vec(),[4,2]),&Default::default()));
393 let outputval=graph.forward(inputval);
394 if let Value::F2(o)=outputval{
395 let target=Tensor::<B,2>::from_data(TensorData::new([0.0,1.0,1.0,0.0].to_vec(),[4,1]),&Default::default());
396 let error=(target-o.clone()).abs().max();
397 println!("{}",o);
398 assert!(error.into_scalar()<0.1);
399 }else{
400 panic!("h");
401 }
402 }
403 use burn::{
404 backend::Autodiff,data::dataset::InMemDataset,optim::SgdConfig
405 };
406 use super::*;
407}
408mod layer;
409mod shape;
410mod value;
411pub fn apply_depthwise<B:Backend,F:FnMut(Value<B>)->Value<B>>(depth:usize,mut op:F,value:Value<B>)->Value<B>{
413 fn inner<B:Backend,F:FnMut(Value<B>)->Value<B>>(depth:usize,op:&mut F,value:Value<B>)->(Value<B>,usize){
414 let mut height=1;
415 let value=if value.is_multi(){
416 let value:Value<B>=value.into_iter().map(|v|{
417 let (v,h)=inner(depth,op,v);
418 height=height.max(h);
419 v
420 }).collect();
421 if value.len()==0{height=0}else{height+=1}
422 value
423 }else{
424 value
425 };
426 (if depth==height{op(value)}else{value},height)
427 }
428 inner(depth,&mut op,value).0
429}
430pub fn new<B:Backend>()->Identity<B>{
432 Identity{phantom:PhantomData}
433}
434#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
435pub struct BatchStacker;
437#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
438pub struct Classification<A>{inner:A,layer:ClassificationLayer}
440#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
441pub struct ClassificationLayer{seal:PhantomData<()>}
443#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
444pub struct DontRender;
446#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
447pub struct Identity<B:Backend>{phantom:PhantomData<B>}
449#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
450pub struct Regression<A>{inner:A,layer:RegressionLayer}
452#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
453pub struct RegressionLayer{seal:PhantomData<()>}
455#[derive(Config,Debug)]
456pub struct TrainConfig{
458 #[config(default="String::from(\".artifact\")")]
459 artifact_directory:String,
460 #[config(default="16")]
461 batch_size:usize,
462 #[config(default="true")]
463 checkpoints:bool,
464 #[config(default="false")]
465 console_rendering:bool,
466 #[config(default="10")]
467 epochs:usize,
468 #[config(default="false")]
469 summary:bool,
470 #[config(default="4")]
471 workers:usize
472}
473#[derive(Clone,Copy,Debug,Default)]
474pub struct Wrapped<W:Wrappable>{inner:W}
476pub trait Shortcuts{
478 fn classification(self)->Classification<Self> where Classification<Self>:Op,Self:Sized{Classification::from_inner(self)}
480 fn regression(self)->Regression<Self> where Regression<Self>:Op,Self:Sized{Regression::from_inner(self)}
482 fn wrap(self)->Wrapped<Self> where Self:Wrappable{Wrapped::new(self)}
484}
485pub trait ToBackend<B:Backend>:Sized{
487 fn to_backend_device(self,device:&B::Device)->Self::OnBackend;
489 fn to_backend(self)->Self::OnBackend{self.to_backend_device(&Default::default())}
491 type OnBackend;
493}
494pub trait Wrappable:Clone+Debug+Decompose+Send{
496 type B:Backend;
497 type With<C:Backend>:Wrappable<B=C,With<C>=Self::With<C>>+Wrappable<B=C,With<Self::B>=Self>;
498}
499pub use burn as lib;
500pub use layer::{Attention,AttentionConfig,AttentionMask,BiasConfig,Cache,Config,Layer,KQV,KQVConfig,PowerMaskInfo};
501pub use shape::{Kind,Reshape,Shape};
502pub use value::{LossOutput,Value};
503use burn::{
504 backend::NdArray,
505 data::{
506 dataset::Dataset,dataloader::{batcher::Batcher,DataLoaderBuilder}
507 },
508 lr_scheduler::LrScheduler,
509 module::{AutodiffModule,Content,DisplaySettings,ModuleDisplay,ModuleDisplayDefault,ModuleMapper,ModuleVisitor,Quantizer},
510 optim::Optimizer,
511 prelude::*,
512 record::{CompactRecorder,FileRecorder,RecorderError},
513 tensor::{BasicOps,TensorKind,backend::AutodiffBackend},
514 train::{
515 ClassificationOutput,LearningStrategy,LearnerBuilder,RegressionOutput,TrainOutput,TrainStep,ValidStep,metric::{Adaptor,ItemLazy,LossInput,LossMetric},renderer::{EvaluationName,EvaluationProgress,MetricsRenderer,MetricState,MetricsRendererEvaluation,MetricsRendererTraining,TrainingProgress}
516 }
517};
518use crate::{
519 AI,Decompose,Graph,Inner,Op,UnwrapInner,Unvec,
520 builtin::{
521 Duplicate,Map,Sequential,SetType,Zip,math::{Abs,Mean,SquaredError},reinforcement::AccQ,soft::{Choose,CrossEntropy},structural::Cat
522 },
523 ops::Stack as OpsStack
524};
525use rand::random;
526use serde::{Deserialize,Serialize};
527use std::{
528 fmt::{Debug,Display},fs::{create_dir_all as create_folder},marker::PhantomData,path::PathBuf
529};