block_graph/burn/
mod.rs

1impl Decompose for ClassificationLayer{
2	fn compose(_decomposition:Self::Decomposition)->Self{Self::default()}
3	fn decompose(self){}
4	fn decompose_cloned(&self){}
5	type Decomposition=();
6}
7impl Decompose for RegressionLayer{
8	fn compose(_decomposition:Self::Decomposition)->Self{Self::default()}
9	fn decompose(self){}
10	fn decompose_cloned(&self){}
11	type Decomposition=();
12}
13impl MetricsRenderer for DontRender{
14	fn manual_close(&mut self){}
15}
16impl MetricsRendererEvaluation for DontRender{
17    fn render_test(&mut self,_item:EvaluationProgress){}
18    fn update_test(&mut self,_name:EvaluationName,_state:MetricState){}
19}
20impl MetricsRendererTraining for DontRender{
21    fn render_train(&mut self,_item:TrainingProgress){}
22    fn render_valid(&mut self,_item:TrainingProgress){}
23	fn update_train(&mut self,_state:MetricState){}
24    fn update_valid(&mut self,_state:MetricState){}
25}
26impl Op for ClassificationLayer{
27	type Output=ClassificationOutput<NdArray>;
28}
29impl Op for RegressionLayer{
30	type Output=RegressionOutput<NdArray>;
31}
32impl<A:AI<X,LossOutput<B>>,B:Backend,X> AI<X,ClassificationOutput<B>> for Classification<A>{
33	fn forward(&self,input:X)->ClassificationOutput<B>{self.layer.forward(self.inner.forward(input))}
34	fn forward_mut(&mut self,input:X)->ClassificationOutput<B>{self.layer.forward(self.inner.forward_mut(input))}
35}
36impl<A:AI<X,LossOutput<B>>,B:Backend,X> AI<X,RegressionOutput<B>> for Regression<A>{
37	fn forward(&self,input:X)->RegressionOutput<B>{self.layer.forward(self.inner.forward(input))}
38	fn forward_mut(&mut self,input:X)->RegressionOutput<B>{self.layer.forward(self.inner.forward_mut(input))}
39}
40impl<A:AutodiffBackend<InnerBackend=B>,B:Backend,W:'static+Wrappable<B=A>,Y:'static+ItemLazy+Send+Sync,Z:'static+ItemLazy+Send+Sync> Wrapped<W> where <Self as AutodiffModule<A>>::InnerModule:ValidStep<(Value<B>,Value<B>),Z>,Self:TrainStep<(Value<A>,Value<A>),Y>,W::Decomposition:AutodiffModule<A>,W::With<B>:Decompose<Decomposition=<W::Decomposition as AutodiffModule<A>>::InnerModule>+Op<Output=Z>,W:Op<Output=Y>,Y::ItemSync:Adaptor<LossInput<NdArray>>,Z::ItemSync:Adaptor<LossInput<NdArray>>{
41	/// trains the model
42	pub fn train<I:'static+Clone+Debug+Into<(Value<A>,Value<A>)>+Send+Sync,J:'static+Clone+Debug+Into<(Value<B>,Value<B>)>+Send+Sync,O:'static+Optimizer<Self,A>,S:'static+LrScheduler,T:'static+Dataset<I>,V:'static+Dataset<J>>(self,config:&TrainConfig,optimizer:O,scheduler:S,train:T,valid:V)->Wrapped<W::With<B>> where O::Record:'static,S::Record<A>:'static{
43		let batcher=BatchStacker;
44		let trainloader=DataLoaderBuilder::new(batcher).batch_size(config.batch_size).shuffle(random()).num_workers(config.workers).build(train);
45		let validloader=DataLoaderBuilder::new(batcher).batch_size(config.batch_size).shuffle(random()).num_workers(config.workers).build(valid);
46
47		create_folder(&config.artifact_directory).unwrap();
48		let builder=LearnerBuilder::new(&config.artifact_directory).metric_train_numeric(LossMetric::new()).metric_valid_numeric(LossMetric::new());
49		let builder=if config.checkpoints{builder.with_file_checkpointer(CompactRecorder::new())}else{builder};
50		let builder=if config.console_rendering{builder}else{builder.renderer(DontRender)};
51		let builder=builder.learning_strategy(LearningStrategy::SingleDevice(<W::B as Backend>::Device::default())).num_epochs(config.epochs);
52		let builder=if config.summary{builder.summary()}else{builder};
53		let learner=builder.build(self,optimizer,scheduler);
54		learner.fit(trainloader,validloader).model
55	}
56}
57impl<A:AutodiffBackend,W:AI<X,LossOutput<A>>+Wrappable<B=A>,X> TrainStep<X,ClassificationOutput<A>> for Wrapped<Classification<W>> where W::Decomposition:AutodiffModule<A>,W::With<A::InnerBackend>:Decompose<Decomposition=<W::Decomposition as AutodiffModule<A>>::InnerModule>{
58	fn step(&self,item:X)->TrainOutput<ClassificationOutput<A>>{
59		let output:ClassificationOutput<A>=self.forward(item);
60		TrainOutput::new(self,output.loss.backward(),output)
61	}
62}
63impl<A:AutodiffBackend,W:AI<X,LossOutput<A>>+Wrappable<B=A>,X> TrainStep<X,RegressionOutput<A>> for Wrapped<Regression<W>> where W::Decomposition:AutodiffModule<A>,W::With<A::InnerBackend>:Decompose<Decomposition=<W::Decomposition as AutodiffModule<A>>::InnerModule>{
64	fn step(&self,item:X)->TrainOutput<RegressionOutput<A>>{
65		let output:RegressionOutput<A>=self.forward(item);
66		TrainOutput::new(self,output.loss.backward(),output)
67	}
68}
69impl<A:AutodiffBackend,W:Wrappable<B=A>> AutodiffModule<A> for Wrapped<W> where W::Decomposition:AutodiffModule<A>,W::With<A::InnerBackend>:Decompose<Decomposition=<W::Decomposition as AutodiffModule<A>>::InnerModule>{
70	fn valid(&self)->Self::InnerModule{Wrapped::new(Decompose::compose(self.inner.decompose_cloned().valid()))}
71	type InnerModule=Wrapped<W::With<A::InnerBackend>>;
72}
73impl<A:Decompose> Decompose for Classification<A>{
74	fn compose(decomposition:Self::Decomposition)->Self{
75		Self{inner:A::compose(decomposition),layer:Default::default()}
76	}
77	fn decompose(self)->Self::Decomposition{self.inner.decompose()}
78	fn decompose_cloned(&self)->Self::Decomposition{self.inner.decompose_cloned()}
79	type Decomposition=A::Decomposition;
80}
81impl<A:Decompose> Decompose for Regression<A>{
82	fn compose(decomposition:Self::Decomposition)->Self{
83		Self{inner:A::compose(decomposition),layer:Default::default()}
84	}
85	fn decompose(self)->Self::Decomposition{self.inner.decompose()}
86	fn decompose_cloned(&self)->Self::Decomposition{self.inner.decompose_cloned()}
87	type Decomposition=A::Decomposition;
88}
89impl<A:Op<Output=Y>+Wrappable,Y> Op for Classification<A> where ClassificationLayer:AI<Y,ClassificationOutput<A::B>>{
90	type Output=ClassificationOutput<A::B>;
91}
92impl<A:Op<Output=Y>+Wrappable,Y> Op for Regression<A> where RegressionLayer:AI<Y,RegressionOutput<A::B>>{
93	type Output=RegressionOutput<A::B>;
94}
95impl<A:UnwrapInner> UnwrapInner for Classification<A>{
96	fn unwrap_inner(self)->Self::Inner{self.into_inner().unwrap_inner()}
97	type Inner=A::Inner;
98}
99impl<A:UnwrapInner> UnwrapInner for Regression<A>{
100	fn unwrap_inner(self)->Self::Inner{self.into_inner().unwrap_inner()}
101	type Inner=A::Inner;
102}
103impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>> Wrappable for (A,D){
104	type B=B;
105	type With<C:Backend>=(A::With<C>,D::With<C>);
106}
107impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>,E:Wrappable<B=B>> Wrappable for (A,D,E){
108	type B=B;
109	type With<C:Backend>=(A::With<C>,D::With<C>,E::With<C>);
110}
111impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>,E:Wrappable<B=B>,F:Wrappable<B=B>> Wrappable for (A,D,E,F){
112	type B=B;
113	type With<C:Backend>=(A::With<C>,D::With<C>,E::With<C>,F::With<C>);
114}
115impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>,E:Wrappable<B=B>,F:Wrappable<B=B>,G:Wrappable<B=B>> Wrappable for (A,D,E,F,G){
116	type B=B;
117	type With<C:Backend>=(A::With<C>,D::With<C>,E::With<C>,F::With<C>,G::With<C>);
118}
119impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>,E:Wrappable<B=B>,F:Wrappable<B=B>,G:Wrappable<B=B>,H:Wrappable<B=B>> Wrappable for (A,D,E,F,G,H){
120	type B=B;
121	type With<C:Backend>=(A::With<C>,D::With<C>,E::With<C>,F::With<C>,G::With<C>,H::With<C>);
122}
123impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>,E:Wrappable<B=B>,F:Wrappable<B=B>,G:Wrappable<B=B>,H:Wrappable<B=B>,I:Wrappable<B=B>> Wrappable for (A,D,E,F,G,H,I){
124	type B=B;
125	type With<C:Backend>=(A::With<C>,D::With<C>,E::With<C>,F::With<C>,G::With<C>,H::With<C>,I::With<C>);
126}
127impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>,E:Wrappable<B=B>,F:Wrappable<B=B>,G:Wrappable<B=B>,H:Wrappable<B=B>,I:Wrappable<B=B>,J:Wrappable<B=B>> Wrappable for (A,D,E,F,G,H,I,J){
128	type B=B;
129	type With<C:Backend>=(A::With<C>,D::With<C>,E::With<C>,F::With<C>,G::With<C>,H::With<C>,I::With<C>,J::With<C>);
130}
131impl<A:Wrappable<B=B>,B:Backend,X:Wrappable<B=B>,Y:Wrappable<B=B>> Wrappable for SetType<A,X,Y>{
132	type B=B;
133	type With<C:Backend>=SetType<A::With<C>,X::With<C>,Y::With<C>>;
134}
135impl<A> Classification<A>{
136	/// creates from the inner value
137	pub fn from_inner(inner:A)->Self where Classification<A>:Op{
138		Self{inner,layer:Default::default()}
139	}
140	/// references the inner value
141	pub fn inner(&self)->&A{&self.inner}
142	/// references the inner value
143	pub fn inner_mut(&mut self)->&mut A{&mut self.inner}
144	/// converts into the inner value
145	pub fn into_inner(self)->A{self.inner}
146	/// replaces the inner value
147	pub fn with_inner<B>(&self,inner:B)->Classification<B> where Classification<B>:Op{Classification::from_inner(inner)}
148}
149impl<A> Regression<A>{
150	/// creates from the inner value
151	pub fn from_inner(inner:A)->Self where Regression<A>:Op{
152		Self{inner,layer:Default::default()}
153	}
154	/// references the inner value
155	pub fn inner(&self)->&A{&self.inner}
156	/// references the inner value
157	pub fn inner_mut(&mut self)->&mut A{&mut self.inner}
158	/// converts into the inner value
159	pub fn into_inner(self)->A{self.inner}
160	/// replaces the inner value
161	pub fn with_inner<B>(&self,inner:B)->Regression<B> where Regression<B>:Op{Regression::from_inner(inner)}
162}
163impl<B:Backend,C:Backend,K:BasicOps<B>+BasicOps<C>+TensorKind<B>+TensorKind<C>,const N:usize> ToBackend<C> for Tensor<B,N,K>{
164	fn to_backend_device(self,device:&C::Device)->Self::OnBackend{
165		let data=self.to_data();
166		Tensor::from_data(data,device)
167	}
168	type OnBackend=Tensor<C,N,K>;
169}
170impl<B:Backend,C:Backend> ToBackend<C> for Value<B>{
171	fn to_backend_device(self,device:&C::Device)->Self::OnBackend{
172		match self{Value::B1(x)=>x.to_backend_device(device).into(),Value::B2(x)=>x.to_backend_device(device).into(),Value::B3(x)=>x.to_backend_device(device).into(),Value::B4(x)=>x.to_backend_device(device).into(),Value::B5(x)=>x.to_backend_device(device).into(),Value::B6(x)=>x.to_backend_device(device).into(),Value::B7(x)=>x.to_backend_device(device).into(),Value::B8(x)=>x.to_backend_device(device).into(),Value::F1(x)=>x.to_backend_device(device).into(),Value::F2(x)=>x.to_backend_device(device).into(),Value::F3(x)=>x.to_backend_device(device).into(),Value::F4(x)=>x.to_backend_device(device).into(),Value::F5(x)=>x.to_backend_device(device).into(),Value::F6(x)=>x.to_backend_device(device).into(),Value::F7(x)=>x.to_backend_device(device).into(),Value::F8(x)=>x.to_backend_device(device).into(),Value::I1(x)=>x.to_backend_device(device).into(),Value::I2(x)=>x.to_backend_device(device).into(),Value::I3(x)=>x.to_backend_device(device).into(),Value::I4(x)=>x.to_backend_device(device).into(),Value::I5(x)=>x.to_backend_device(device).into(),Value::I6(x)=>x.to_backend_device(device).into(),Value::I7(x)=>x.to_backend_device(device).into(),Value::I8(x)=>x.to_backend_device(device).into(),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.to_backend_device(device)).collect()}
173	}
174	type OnBackend=Value<C>;
175}
176impl<B:Backend,E:Into<(Value<B>,Value<B>)>> Batcher<B,E,(Value<B>,Value<B>)> for BatchStacker{
177	fn batch(&self,items:Vec<E>,_device:&<B as Backend>::Device)->(Value<B>,Value<B>){
178		let items=items.into_iter().map(Into::into);
179		let (input,target):(Vec<Value<B>>,Vec<Value<B>>)=items.unzip();
180		let (input,target)=(Value::Multi(input),Value::Multi(target));
181
182		(input.zip().stack(0),target.zip().stack(0))
183	}
184}
185impl<B:Backend,W:AI<X,LossOutput<B>>+Wrappable<B=B>,X> ValidStep<X,ClassificationOutput<B>> for Wrapped<Classification<W>> where W::Decomposition:Module<B>{
186	fn step(&self,item:X)->ClassificationOutput<B>{self.forward(item)}
187}
188impl<B:Backend,W:AI<X,LossOutput<B>>+Wrappable<B=B>,X> ValidStep<X,RegressionOutput<B>> for Wrapped<Regression<W>> where W::Decomposition:Module<B>{
189	fn step(&self,item:X)->RegressionOutput<B>{self.forward(item)}
190}
191impl<B:Backend,W:Wrappable<B=B>> Module<B> for Wrapped<W> where W::Decomposition:Module<B>{
192	fn collect_devices(&self,devices:Vec<<B as Backend>::Device>)->Vec<<B as Backend>::Device>{self.inner.decompose_cloned().collect_devices(devices)}
193	fn devices(&self)->Vec<<B as Backend>::Device>{self.inner.decompose_cloned().devices()}
194	fn fork(self,device:&<B as Backend>::Device)->Self{Self::new(W::compose(self.inner.decompose().fork(device)))}
195	fn into_record(self)->Self::Record{self.inner.decompose().into_record()}
196	fn load_file<F:FileRecorder<B>,P:Into<PathBuf>>(self,filepath:P,recorder:&F,device:&<B as Backend>::Device)->Result<Self,RecorderError>{self.inner.decompose().load_file(filepath,recorder,device).map(|a|Self::new(W::compose(a)))}
197	fn load_record(self,record:Self::Record)->Self{Self::new(W::compose(self.inner.decompose().load_record(record)))}
198	fn map<Mapper:ModuleMapper<B>>(self,mapper:&mut Mapper)->Self{Self::new(W::compose(self.inner.decompose().map(mapper)))}
199	fn num_params(&self)->usize{self.inner.decompose_cloned().num_params()}
200	fn quantize_weights(self,quantizer:&mut Quantizer)->Self{Self::new(W::compose(self.inner.decompose().quantize_weights(quantizer)))}
201	fn save_file<F:FileRecorder<B>,P:Into<PathBuf>>(self,filepath:P,recorder:&F)->Result<(),RecorderError>{self.inner.decompose().save_file(filepath,recorder)}
202	fn to_device(self,device:&<B as Backend>::Device)->Self{Self::new(W::compose(self.inner.decompose().to_device(device)))}
203	fn visit<Visitor:ModuleVisitor<B>>(&self,visitor:&mut Visitor){self.inner.decompose_cloned().visit(visitor)}
204	type Record=<W::Decomposition as Module<B>>::Record;
205}
206impl<B:Backend,X:Into<Y>,Y> AI<X,Y> for Identity<B>{
207	fn forward(&self,input:X)->Y{input.into()}
208}
209impl<B:Backend> AI<LossOutput<B>,ClassificationOutput<B>> for ClassificationLayer{
210	fn forward(&self,lossoutput:LossOutput<B>)->ClassificationOutput<B>{//TODO make work for multi
211		let loss=match lossoutput.loss(){Value::F1(x)=>x,Value::F2(x)=>x.mean(),Value::F3(x)=>x.mean(),Value::F4(x)=>x.mean(),Value::F5(x)=>x.mean(),Value::F6(x)=>x.mean(),Value::F7(x)=>x.mean(),Value::F8(x)=>x.mean(),Value::Incompatible(e)=>panic!("{e}"),_=>panic!("cannot convert non floats to classification output")};
212		let output=match lossoutput.output(){Value::F1(x)=>x.unsqueeze(),Value::F2(x)=>x,Value::F3(x)=>x.flatten(0,1),Value::F4(x)=>x.flatten(0,2),Value::F5(x)=>x.flatten(0,3),Value::F6(x)=>x.flatten(0,4),Value::F7(x)=>x.flatten(0,5),Value::F8(x)=>x.flatten(0,6),Value::Incompatible(e)=>panic!("{e}"),_=>panic!("cannot convert non floats to classification output")};
213		let target=match lossoutput.target(){Value::I1(x)=>x,Value::I2(x)=>x.flatten(0,1),Value::I3(x)=>x.flatten(0,2),Value::I4(x)=>x.flatten(0,3),Value::I5(x)=>x.flatten(0,4),Value::I6(x)=>x.flatten(0,5),Value::I7(x)=>x.flatten(0,6),Value::I8(x)=>x.flatten(0,7),Value::Incompatible(e)=>panic!("{e}"),_=>panic!("cannot convert non floats to classification output")};
214		ClassificationOutput::new(loss,output,target)
215	}
216}
217impl<B:Backend> AI<LossOutput<B>,RegressionOutput<B>> for RegressionLayer{
218	fn forward(&self,lossoutput:LossOutput<B>)->RegressionOutput<B>{
219		let loss=match lossoutput.loss(){Value::F1(x)=>x,Value::F2(x)=>x.mean(),Value::F3(x)=>x.mean(),Value::F4(x)=>x.mean(),Value::F5(x)=>x.mean(),Value::F6(x)=>x.mean(),Value::F7(x)=>x.mean(),Value::F8(x)=>x.mean(),Value::Incompatible(e)=>panic!("{e}"),_=>panic!("cannot convert non floats to regression output")};
220		let output=match lossoutput.output(){Value::F1(x)=>x.unsqueeze(),Value::F2(x)=>x,Value::F3(x)=>x.flatten(0,1),Value::F4(x)=>x.flatten(0,2),Value::F5(x)=>x.flatten(0,3),Value::F6(x)=>x.flatten(0,4),Value::F7(x)=>x.flatten(0,5),Value::F8(x)=>x.flatten(0,6),Value::Incompatible(e)=>panic!("{e}"),_=>panic!("cannot convert non floats to regression output")};
221		let target=match lossoutput.target(){Value::F1(x)=>x.unsqueeze(),Value::F2(x)=>x,Value::F3(x)=>x.flatten(0,1),Value::F4(x)=>x.flatten(0,2),Value::F5(x)=>x.flatten(0,3),Value::F6(x)=>x.flatten(0,4),Value::F7(x)=>x.flatten(0,5),Value::F8(x)=>x.flatten(0,6),Value::Incompatible(e)=>panic!("{e}"),_=>panic!("cannot convert non floats to regression output")};
222		RegressionOutput::new(loss,output,target)
223	}
224}
225impl<B:Backend> Decompose for Identity<B>{
226	fn compose(_decomposition:Self::Decomposition)->Self{new()}
227	fn decompose(self){}
228	fn decompose_cloned(&self){}
229	type Decomposition=();
230}
231impl<B:Backend> Op for Identity<B>{
232	type Output=();
233}
234impl<B:Backend> Wrappable for Identity<B>{
235	type B=B;
236	type With<C:Backend>=Identity<C>;
237}
238impl<B:Backend> Wrappable for Layer<B>{
239	type B=B;
240	type With<C:Backend>=Layer<C>;
241}
242impl<B:Backend> Wrappable for LossOutput<B>{
243	type B=B;
244	type With<C:Backend>=LossOutput<C>;
245}
246impl<B:Backend> Wrappable for Value<B>{
247	type B=B;
248	type With<C:Backend>=Value<C>;
249}
250impl<C:Backend,W:ToBackend<C,OnBackend=W::With<C>>+Wrappable> ToBackend<C> for Wrapped<W>{
251	fn to_backend_device(self,device:&C::Device)->Self::OnBackend{
252		Wrapped{inner:self.inner.to_backend_device(device)}
253	}
254	type OnBackend=Wrapped<W::With<C>>;
255}
256impl<T:?Sized+Op> Shortcuts for T{}
257impl<W:AI<X,Y>+Wrappable,X,Y> AI<X,Y> for Wrapped<W>{
258	fn forward(&self,input:X)->Y{self.inner.forward(input)}
259	fn forward_mut(&mut self,input:X)->Y{self.inner.forward_mut(input)}
260}
261impl<W:Op+Wrappable> Op for Wrapped<W>{
262	type Output=W::Output;
263}
264impl<W:UnwrapInner+Wrappable> UnwrapInner for Wrapped<W>{
265	fn unwrap_inner(self)->Self::Inner{self.into_inner().unwrap_inner()}
266	type Inner=W::Inner;
267}
268impl<W:Wrappable> Decompose for Wrapped<W>{
269	fn compose(decomposition:Self::Decomposition)->Self{Self::new(W::compose(decomposition))}
270	fn decompose(self)->Self::Decomposition{self.inner.decompose()}
271	fn decompose_cloned(&self)->Self::Decomposition{self.inner.decompose_cloned()}
272	type Decomposition=W::Decomposition;
273}
274impl<W:Wrappable> Display for Wrapped<W>{
275    fn fmt(&self,f:&mut std::fmt::Formatter<'_>)->Result<(),std::fmt::Error>{write!(f,"todo")}
276}
277impl<W:Wrappable> From<W> for Wrapped<W>{
278	fn from(value:W)->Self{Self::new(value)}
279}
280impl<W:Wrappable> ModuleDisplay for Wrapped<W> where W::Decomposition:ModuleDisplay{
281	fn custom_content(&self,content:Content)->Option<Content>{self.inner.decompose_cloned().custom_content(content)}
282	fn custom_settings(&self)->Option<DisplaySettings>{self.inner.decompose_cloned().custom_settings()}
283	fn format(&self,passed_settings:DisplaySettings)->String{self.inner.decompose_cloned().format(passed_settings)}
284}
285impl<W:Wrappable> ModuleDisplayDefault for Wrapped<W> where W::Decomposition:ModuleDisplayDefault{
286	fn content(&self,content:Content)->Option<Content>{self.inner.decompose_cloned().content(content)}
287	fn num_params(&self)->usize{self.inner.decompose_cloned().num_params()}
288}
289impl<W:Wrappable> Wrappable for Abs<W>{
290	type B=W::B;
291	type With<C:Backend>=Abs<W::With<C>>;
292}
293impl<W:Wrappable> Wrappable for AccQ<W>{
294	type B=W::B;
295	type With<C:Backend>=AccQ<W::With<C>>;
296}
297impl<W:Wrappable> Wrappable for Cat<W>{
298	type B=W::B;
299	type With<C:Backend>=Cat<W::With<C>>;
300}
301impl<W:Wrappable> Wrappable for Classification<W>{
302	type B=W::B;
303	type With<C:Backend>=Classification<W::With<C>>;
304}
305impl<W:Wrappable> Wrappable for CrossEntropy<W>{
306	type B=W::B;
307	type With<C:Backend>=CrossEntropy<W::With<C>>;
308}
309impl<W:Wrappable> Wrappable for Duplicate<W>{
310	type B=W::B;
311	type With<C:Backend>=Duplicate<W::With<C>>;
312}
313impl<W:Wrappable> Wrappable for Graph<W>{
314	type B=W::B;
315	type With<C:Backend>=Graph<W::With<C>>;
316}
317impl<W:Wrappable> Wrappable for Inner<W>{
318	type B=W::B;
319	type With<C:Backend>=Inner<W::With<C>>;
320}
321impl<W:Wrappable> Wrappable for Mean<W>{
322	type B=W::B;
323	type With<C:Backend>=Mean<W::With<C>>;
324}
325impl<W:Wrappable> Wrappable for SquaredError<W>{
326	type B=W::B;
327	type With<C:Backend>=SquaredError<W::With<C>>;
328}
329impl<W:Wrappable> Wrappable for Map<W>{
330	type B=W::B;
331	type With<C:Backend>=Map<W::With<C>>;
332}
333impl<W:Wrappable> Wrappable for Regression<W>{
334	type B=W::B;
335	type With<C:Backend>=Regression<W::With<C>>;
336}
337impl<W:Wrappable> Wrappable for Sequential<W>{
338	type B=W::B;
339	type With<C:Backend>=Sequential<W::With<C>>;
340}
341impl<W:Wrappable> Wrappable for Choose<W>{
342	type B=W::B;
343	type With<C:Backend>=Choose<W::With<C>>;
344}
345impl<W:Wrappable> Wrappable for Unvec<W>{
346	type B=W::B;
347	type With<C:Backend>=Unvec<W::With<C>>;
348}
349impl<W:Wrappable> Wrappable for Zip<W>{
350	type B=W::B;
351	type With<C:Backend>=Zip<W::With<C>>;
352}
353impl<W:Wrappable> Wrapped<W>{
354	/// references the inner value
355	pub fn inner(&self)->&W{&self.inner}
356	/// references the inner value
357	pub fn inner_mut(&mut self)->&mut W{&mut self.inner}
358	/// unwraps the inner value
359	pub fn into_inner(self)->W{self.inner}
360	/// creates a new wrapped value
361	pub fn new(inner:W)->Self{
362		Self{inner}
363	}
364}
365#[cfg(test)]
366mod tests{
367	#[test]
368	fn learn_xor(){
369		type B=NdArray;
370		type A=Autodiff<B>;
371		let i0=Tensor::<A,1>::from_data(TensorData::new([0.0,0.0].to_vec(),[2]),&Default::default());
372		let i1=Tensor::<A,1>::from_data(TensorData::new([0.0,1.0].to_vec(),[2]),&Default::default());
373		let i2=Tensor::<A,1>::from_data(TensorData::new([1.0,0.0].to_vec(),[2]),&Default::default());
374		let i3=Tensor::<A,1>::from_data(TensorData::new([1.0,1.0].to_vec(),[2]),&Default::default());
375		let o0=Tensor::<A,1>::from_data(TensorData::new([0.0].to_vec(),[1]),&Default::default());
376		let o1=Tensor::<A,1>::from_data(TensorData::new([1.0].to_vec(),[1]),&Default::default());
377		let o2=Tensor::<A,1>::from_data(TensorData::new([1.0].to_vec(),[1]),&Default::default());
378		let o3=Tensor::<A,1>::from_data(TensorData::new([0.0].to_vec(),[1]),&Default::default());
379
380		let dataset:Vec<(Tensor<A,1>,Tensor<A,1>)>=[(i0,o0),(i1,o1),(i2,o2),(i3,o3)].into_iter().cycle().take(4000).collect();
381		let train=InMemDataset::new(dataset.clone().into_iter().map(|(i,o)|(Value::from(i),Value::from(o))).collect());
382		let valid=InMemDataset::new(dataset.into_iter().map(|(i,o)|(Value::from(i.valid()),Value::from(o.valid()))).collect());
383		let mut graph:Graph<Layer<A>>=Graph::new();
384		graph.connect("input","x").with_clear(true).with(Layer::linear(true,2,10,1.0));
385		graph.connect("x","y").with_clear(true).with(Layer::relu());
386		graph.connect("y","output").with_clear(true).with(Layer::linear(false,10,1,1.0));
387
388		let graph=Unvec(graph).wrap_inner().squared_error().set_type::<(Value<A>,Value<A>),LossOutput<A>>().regression().wrap();
389		let graph=graph.train(&TrainConfig::new().with_checkpoints(false),SgdConfig::new().init(),0.01,train,valid);
390		let graph=graph.unwrap_inner();
391
392		let inputval=Value::from(Tensor::<B,2>::from_data(TensorData::new([0.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0].to_vec(),[4,2]),&Default::default()));
393		let outputval=graph.forward(inputval);
394		if let Value::F2(o)=outputval{
395			let target=Tensor::<B,2>::from_data(TensorData::new([0.0,1.0,1.0,0.0].to_vec(),[4,1]),&Default::default());
396			let error=(target-o.clone()).abs().max();
397			println!("{}",o);
398			assert!(error.into_scalar()<0.1);
399		}else{
400			panic!("h");
401		}
402	}
403	use burn::{
404		backend::Autodiff,data::dataset::InMemDataset,optim::SgdConfig
405	};
406	use super::*;
407}
408mod layer;
409mod shape;
410mod value;
411/// helper function for applying operations that apply to a specific depth of multiple structure such that wrapping multiple appropriate inputs with a multi outputs the output of the function applied to all inputs. 0 depth for empty, 1 for single, 2+ for multi
412pub fn apply_depthwise<B:Backend,F:FnMut(Value<B>)->Value<B>>(depth:usize,mut op:F,value:Value<B>)->Value<B>{
413	fn inner<B:Backend,F:FnMut(Value<B>)->Value<B>>(depth:usize,op:&mut F,value:Value<B>)->(Value<B>,usize){
414		let mut height=1;
415		let value=if value.is_multi(){
416			let value:Value<B>=value.into_iter().map(|v|{
417				let (v,h)=inner(depth,op,v);
418				height=height.max(h);
419				v
420			}).collect();
421			if value.len()==0{height=0}else{height+=1}
422			value
423		}else{
424			value
425		};
426		(if depth==height{op(value)}else{value},height)
427	}
428	inner(depth,&mut op,value).0
429}
430/// starts the building of an ai structure in chained method style from an identity operation
431pub fn new<B:Backend>()->Identity<B>{
432	Identity{phantom:PhantomData}
433}
434#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
435/// batcher that stacks things
436pub struct BatchStacker;
437#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
438/// wrapper for converting loss to classification output
439pub struct Classification<A>{inner:A,layer:ClassificationLayer}
440#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
441/// layer for converting loss to classification output
442pub struct ClassificationLayer{seal:PhantomData<()>}
443#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
444/// metrics renderer implementation that doesn't actually do anything
445pub struct DontRender;
446#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
447/// identity version that knows what backend
448pub struct Identity<B:Backend>{phantom:PhantomData<B>}
449#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
450/// wrapper for converting loss to regression output
451pub struct Regression<A>{inner:A,layer:RegressionLayer}
452#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
453/// layer for converting loss to regression output
454pub struct RegressionLayer{seal:PhantomData<()>}
455#[derive(Config,Debug)]
456/// configuration for convenient training through the wrapper
457pub struct TrainConfig{
458	#[config(default="String::from(\".artifact\")")]
459	artifact_directory:String,
460	#[config(default="16")]
461	batch_size:usize,
462	#[config(default="true")]
463	checkpoints:bool,
464	#[config(default="false")]
465	console_rendering:bool,
466	#[config(default="10")]
467	epochs:usize,
468	#[config(default="false")]
469	summary:bool,
470	#[config(default="4")]
471	workers:usize
472}
473#[derive(Clone,Copy,Debug,Default)]
474/// wraps in a burn wrapper
475pub struct Wrapped<W:Wrappable>{inner:W}
476/// chained method shortcut trait
477pub trait Shortcuts{
478	/// wraps in a classification wrapper
479	fn classification(self)->Classification<Self> where Classification<Self>:Op,Self:Sized{Classification::from_inner(self)}
480	/// wraps in a regression wrapper
481	fn regression(self)->Regression<Self> where Regression<Self>:Op,Self:Sized{Regression::from_inner(self)}
482	/// wraps in a burn wrapper
483	fn wrap(self)->Wrapped<Self> where Self:Wrappable{Wrapped::new(self)}
484}
485/// trait for switching the backend of a module
486pub trait ToBackend<B:Backend>:Sized{
487	/// moves the module to the backend with the device
488	fn to_backend_device(self,device:&B::Device)->Self::OnBackend;
489	/// moves the module to the backend with the device
490	fn to_backend(self)->Self::OnBackend{self.to_backend_device(&Default::default())}
491	/// the type on the new backend
492	type OnBackend;
493}
494/// higher kinded type trait to allow rewrapping burn modules in different backends to implement some wrapper features
495pub trait Wrappable:Clone+Debug+Decompose+Send{
496	type B:Backend;
497	type With<C:Backend>:Wrappable<B=C,With<C>=Self::With<C>>+Wrappable<B=C,With<Self::B>=Self>;
498}
499pub use burn as lib;
500pub use layer::{Attention,AttentionConfig,AttentionMask,BiasConfig,Cache,Config,Layer,KQV,KQVConfig,PowerMaskInfo};
501pub use shape::{Kind,Reshape,Shape};
502pub use value::{LossOutput,Value};
503use burn::{
504	backend::NdArray,
505	data::{
506		dataset::Dataset,dataloader::{batcher::Batcher,DataLoaderBuilder}
507	},
508	lr_scheduler::LrScheduler,
509	module::{AutodiffModule,Content,DisplaySettings,ModuleDisplay,ModuleDisplayDefault,ModuleMapper,ModuleVisitor,Quantizer},
510	optim::Optimizer,
511	prelude::*,
512	record::{CompactRecorder,FileRecorder,RecorderError},
513	tensor::{BasicOps,TensorKind,backend::AutodiffBackend},
514	train::{
515		ClassificationOutput,LearningStrategy,LearnerBuilder,RegressionOutput,TrainOutput,TrainStep,ValidStep,metric::{Adaptor,ItemLazy,LossInput,LossMetric},renderer::{EvaluationName,EvaluationProgress,MetricsRenderer,MetricState,MetricsRendererEvaluation,MetricsRendererTraining,TrainingProgress}
516	}
517};
518use crate::{
519	AI,Decompose,Graph,Inner,Op,UnwrapInner,Unvec,
520	builtin::{
521		Duplicate,Map,Sequential,SetType,Zip,math::{Abs,Mean,SquaredError},reinforcement::AccQ,soft::{Choose,CrossEntropy},structural::Cat
522	},
523	ops::Stack as OpsStack
524};
525use rand::random;
526use serde::{Deserialize,Serialize};
527use std::{
528	fmt::{Debug,Display},fs::{create_dir_all as create_folder},marker::PhantomData,path::PathBuf
529};