block_graph/burn/
mod.rs

1impl Decompose for ClassificationLayer{
2	fn compose(_decomposition:Self::Decomposition)->Self{Self::default()}
3	fn decompose(self){}
4	fn decompose_cloned(&self){}
5	type Decomposition=();
6}
7impl Decompose for RegressionLayer{
8	fn compose(_decomposition:Self::Decomposition)->Self{Self::default()}
9	fn decompose(self){}
10	fn decompose_cloned(&self){}
11	type Decomposition=();
12}
13impl MetricsRenderer for DontRender{
14	fn update_train(&mut self,_state:MetricState){}
15	fn update_valid(&mut self,_state:MetricState){}
16	fn render_train(&mut self,_item:TrainingProgress){}
17	fn render_valid(&mut self,_item:TrainingProgress){}
18}
19impl Op for ClassificationLayer{
20	type Output=ClassificationOutput<NdArray>;
21}
22impl Op for RegressionLayer{
23	type Output=RegressionOutput<NdArray>;
24}
25impl<A:AI<X,LossOutput<B>>,B:Backend,X> AI<X,ClassificationOutput<B>> for Classification<A>{
26	fn forward(&self,input:X)->ClassificationOutput<B>{self.layer.forward(self.inner.forward(input))}
27	fn forward_mut(&mut self,input:X)->ClassificationOutput<B>{self.layer.forward(self.inner.forward_mut(input))}
28}
29impl<A:AI<X,LossOutput<B>>,B:Backend,X> AI<X,RegressionOutput<B>> for Regression<A>{
30	fn forward(&self,input:X)->RegressionOutput<B>{self.layer.forward(self.inner.forward(input))}
31	fn forward_mut(&mut self,input:X)->RegressionOutput<B>{self.layer.forward(self.inner.forward_mut(input))}
32}
33impl<A:AutodiffBackend<InnerBackend=B>,B:Backend,W:'static+Wrappable<B=A>,Y:'static+ItemLazy+Send+Sync,Z:'static+ItemLazy+Send+Sync> Wrapped<W> where <Self as AutodiffModule<A>>::InnerModule:ValidStep<(Value<B>,Value<B>),Z>,Self:TrainStep<(Value<A>,Value<A>),Y>,W::Decomposition:AutodiffModule<A>,W::With<B>:Decompose<Decomposition=<W::Decomposition as AutodiffModule<A>>::InnerModule>+Op<Output=Z>,W:Op<Output=Y>,Y::ItemSync:Adaptor<LossInput<NdArray>>,Z::ItemSync:Adaptor<LossInput<NdArray>>{
34	/// trains the model
35	pub fn train<I:'static+Clone+Debug+Into<(Value<A>,Value<A>)>+Send+Sync,J:'static+Clone+Debug+Into<(Value<B>,Value<B>)>+Send+Sync,O:Optimizer<Self,A>,S:LrScheduler,T:'static+Dataset<I>,V:'static+Dataset<J>>(self,config:&TrainConfig,optimizer:O,scheduler:S,train:T,valid:V)->Self where O::Record:'static,S::Record<A>:'static{
36		let batcher=BatchStacker;
37		let trainloader=DataLoaderBuilder::new(batcher).batch_size(config.batch_size).shuffle(random()).num_workers(config.workers).build(train);
38		let validloader=DataLoaderBuilder::new(batcher).batch_size(config.batch_size).shuffle(random()).num_workers(config.workers).build(valid);
39
40		create_folder(&config.artifact_directory).unwrap();
41		let builder=LearnerBuilder::new(&config.artifact_directory).metric_train_numeric(LossMetric::new()).metric_valid_numeric(LossMetric::new());
42		let builder=if config.checkpoints{builder.with_file_checkpointer(CompactRecorder::new())}else{builder};
43		let builder=if config.console_rendering{builder}else{builder.renderer(DontRender)};
44		let builder=builder.devices(vec![<W::B as Backend>::Device::default()]).num_epochs(config.epochs);
45		let builder=if config.summary{builder.summary()}else{builder};
46		let learner=builder.build(self,optimizer,scheduler);
47		learner.fit(trainloader,validloader)
48	}
49}
50impl<A:AutodiffBackend,W:AI<X,LossOutput<A>>+Wrappable<B=A>,X> TrainStep<X,ClassificationOutput<A>> for Wrapped<Classification<W>> where W::Decomposition:AutodiffModule<A>,W::With<A::InnerBackend>:Decompose<Decomposition=<W::Decomposition as AutodiffModule<A>>::InnerModule>{
51	fn step(&self,item:X)->TrainOutput<ClassificationOutput<A>>{
52		let output:ClassificationOutput<A>=self.forward(item);
53		TrainOutput::new(self,output.loss.backward(),output)
54	}
55}
56impl<A:AutodiffBackend,W:AI<X,LossOutput<A>>+Wrappable<B=A>,X> TrainStep<X,RegressionOutput<A>> for Wrapped<Regression<W>> where W::Decomposition:AutodiffModule<A>,W::With<A::InnerBackend>:Decompose<Decomposition=<W::Decomposition as AutodiffModule<A>>::InnerModule>{
57	fn step(&self,item:X)->TrainOutput<RegressionOutput<A>>{
58		let output:RegressionOutput<A>=self.forward(item);
59		TrainOutput::new(self,output.loss.backward(),output)
60	}
61}
62impl<A:AutodiffBackend,W:Wrappable<B=A>> AutodiffModule<A> for Wrapped<W> where W::Decomposition:AutodiffModule<A>,W::With<A::InnerBackend>:Decompose<Decomposition=<W::Decomposition as AutodiffModule<A>>::InnerModule>{
63	fn valid(&self)->Self::InnerModule{Wrapped::new(Decompose::compose(self.inner.decompose_cloned().valid()))}
64	type InnerModule=Wrapped<W::With<A::InnerBackend>>;
65}
66impl<A:Decompose> Decompose for Classification<A>{
67	fn compose(decomposition:Self::Decomposition)->Self{
68		Self{inner:A::compose(decomposition),layer:Default::default()}
69	}
70	fn decompose(self)->Self::Decomposition{self.inner.decompose()}
71	fn decompose_cloned(&self)->Self::Decomposition{self.inner.decompose_cloned()}
72	type Decomposition=A::Decomposition;
73}
74impl<A:Decompose> Decompose for Regression<A>{
75	fn compose(decomposition:Self::Decomposition)->Self{
76		Self{inner:A::compose(decomposition),layer:Default::default()}
77	}
78	fn decompose(self)->Self::Decomposition{self.inner.decompose()}
79	fn decompose_cloned(&self)->Self::Decomposition{self.inner.decompose_cloned()}
80	type Decomposition=A::Decomposition;
81}
82impl<A:Op<Output=Y>+Wrappable,Y> Op for Classification<A> where ClassificationLayer:AI<Y,ClassificationOutput<A::B>>{
83	type Output=ClassificationOutput<A::B>;
84}
85impl<A:Op<Output=Y>+Wrappable,Y> Op for Regression<A> where RegressionLayer:AI<Y,RegressionOutput<A::B>>{
86	type Output=RegressionOutput<A::B>;
87}
88impl<A:UnwrapInner> UnwrapInner for Classification<A>{
89	fn unwrap_inner(self)->Self::Inner{self.into_inner().unwrap_inner()}
90	type Inner=A::Inner;
91}
92impl<A:UnwrapInner> UnwrapInner for Regression<A>{
93	fn unwrap_inner(self)->Self::Inner{self.into_inner().unwrap_inner()}
94	type Inner=A::Inner;
95}
96impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>> Wrappable for (A,D){
97	type B=B;
98	type With<C:Backend>=(A::With<C>,D::With<C>);
99}
100impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>,E:Wrappable<B=B>> Wrappable for (A,D,E){
101	type B=B;
102	type With<C:Backend>=(A::With<C>,D::With<C>,E::With<C>);
103}
104impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>,E:Wrappable<B=B>,F:Wrappable<B=B>> Wrappable for (A,D,E,F){
105	type B=B;
106	type With<C:Backend>=(A::With<C>,D::With<C>,E::With<C>,F::With<C>);
107}
108impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>,E:Wrappable<B=B>,F:Wrappable<B=B>,G:Wrappable<B=B>> Wrappable for (A,D,E,F,G){
109	type B=B;
110	type With<C:Backend>=(A::With<C>,D::With<C>,E::With<C>,F::With<C>,G::With<C>);
111}
112impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>,E:Wrappable<B=B>,F:Wrappable<B=B>,G:Wrappable<B=B>,H:Wrappable<B=B>> Wrappable for (A,D,E,F,G,H){
113	type B=B;
114	type With<C:Backend>=(A::With<C>,D::With<C>,E::With<C>,F::With<C>,G::With<C>,H::With<C>);
115}
116impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>,E:Wrappable<B=B>,F:Wrappable<B=B>,G:Wrappable<B=B>,H:Wrappable<B=B>,I:Wrappable<B=B>> Wrappable for (A,D,E,F,G,H,I){
117	type B=B;
118	type With<C:Backend>=(A::With<C>,D::With<C>,E::With<C>,F::With<C>,G::With<C>,H::With<C>,I::With<C>);
119}
120impl<A:Wrappable<B=B>,B:Backend,D:Wrappable<B=B>,E:Wrappable<B=B>,F:Wrappable<B=B>,G:Wrappable<B=B>,H:Wrappable<B=B>,I:Wrappable<B=B>,J:Wrappable<B=B>> Wrappable for (A,D,E,F,G,H,I,J){
121	type B=B;
122	type With<C:Backend>=(A::With<C>,D::With<C>,E::With<C>,F::With<C>,G::With<C>,H::With<C>,I::With<C>,J::With<C>);
123}
124impl<A:Wrappable<B=B>,B:Backend,X:Wrappable<B=B>,Y:Wrappable<B=B>> Wrappable for SetType<A,X,Y>{
125	type B=B;
126	type With<C:Backend>=SetType<A::With<C>,X::With<C>,Y::With<C>>;
127}
128impl<A> Classification<A>{
129	/// creates from the inner value
130	pub fn from_inner(inner:A)->Self where Classification<A>:Op{
131		Self{inner,layer:Default::default()}
132	}
133	/// references the inner value
134	pub fn inner(&self)->&A{&self.inner}
135	/// references the inner value
136	pub fn inner_mut(&mut self)->&mut A{&mut self.inner}
137	/// converts into the inner value
138	pub fn into_inner(self)->A{self.inner}
139	/// replaces the inner value
140	pub fn with_inner<B>(&self,inner:B)->Classification<B> where Classification<B>:Op{Classification::from_inner(inner)}
141}
142impl<A> Regression<A>{
143	/// creates from the inner value
144	pub fn from_inner(inner:A)->Self where Regression<A>:Op{
145		Self{inner,layer:Default::default()}
146	}
147	/// references the inner value
148	pub fn inner(&self)->&A{&self.inner}
149	/// references the inner value
150	pub fn inner_mut(&mut self)->&mut A{&mut self.inner}
151	/// converts into the inner value
152	pub fn into_inner(self)->A{self.inner}
153	/// replaces the inner value
154	pub fn with_inner<B>(&self,inner:B)->Regression<B> where Regression<B>:Op{Regression::from_inner(inner)}
155}
156impl<B:Backend,C:Backend,K:BasicOps<B>+BasicOps<C>+TensorKind<B>+TensorKind<C>,const N:usize> ToBackend<C> for Tensor<B,N,K>{
157	fn to_backend_device(self,device:&C::Device)->Self::OnBackend{
158		let data=self.to_data();
159		Tensor::from_data(data,device)
160	}
161	type OnBackend=Tensor<C,N,K>;
162}
163impl<B:Backend,C:Backend> ToBackend<C> for Value<B>{
164	fn to_backend_device(self,device:&C::Device)->Self::OnBackend{
165		match self{Value::B1(x)=>x.to_backend_device(device).into(),Value::B2(x)=>x.to_backend_device(device).into(),Value::B3(x)=>x.to_backend_device(device).into(),Value::B4(x)=>x.to_backend_device(device).into(),Value::B5(x)=>x.to_backend_device(device).into(),Value::B6(x)=>x.to_backend_device(device).into(),Value::B7(x)=>x.to_backend_device(device).into(),Value::B8(x)=>x.to_backend_device(device).into(),Value::F1(x)=>x.to_backend_device(device).into(),Value::F2(x)=>x.to_backend_device(device).into(),Value::F3(x)=>x.to_backend_device(device).into(),Value::F4(x)=>x.to_backend_device(device).into(),Value::F5(x)=>x.to_backend_device(device).into(),Value::F6(x)=>x.to_backend_device(device).into(),Value::F7(x)=>x.to_backend_device(device).into(),Value::F8(x)=>x.to_backend_device(device).into(),Value::I1(x)=>x.to_backend_device(device).into(),Value::I2(x)=>x.to_backend_device(device).into(),Value::I3(x)=>x.to_backend_device(device).into(),Value::I4(x)=>x.to_backend_device(device).into(),Value::I5(x)=>x.to_backend_device(device).into(),Value::I6(x)=>x.to_backend_device(device).into(),Value::I7(x)=>x.to_backend_device(device).into(),Value::I8(x)=>x.to_backend_device(device).into(),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.to_backend_device(device)).collect()}
166	}
167	type OnBackend=Value<C>;
168}
169impl<B:Backend,E:Into<(Value<B>,Value<B>)>> Batcher<B,E,(Value<B>,Value<B>)> for BatchStacker{
170	fn batch(&self,items:Vec<E>,_device:&<B as Backend>::Device)->(Value<B>,Value<B>){
171		let items=items.into_iter().map(Into::into);
172		let (input,target):(Vec<Value<B>>,Vec<Value<B>>)=items.unzip();
173		let (input,target)=(Value::Multi(input),Value::Multi(target));
174
175		(input.zip().stack(0),target.zip().stack(0))
176	}
177}
178impl<B:Backend,W:AI<X,LossOutput<B>>+Wrappable<B=B>,X> ValidStep<X,ClassificationOutput<B>> for Wrapped<Classification<W>> where W::Decomposition:Module<B>{
179	fn step(&self,item:X)->ClassificationOutput<B>{self.forward(item)}
180}
181impl<B:Backend,W:AI<X,LossOutput<B>>+Wrappable<B=B>,X> ValidStep<X,RegressionOutput<B>> for Wrapped<Regression<W>> where W::Decomposition:Module<B>{
182	fn step(&self,item:X)->RegressionOutput<B>{self.forward(item)}
183}
184impl<B:Backend,W:Wrappable<B=B>> Module<B> for Wrapped<W> where W::Decomposition:Module<B>{
185	fn collect_devices(&self,devices:Vec<<B as Backend>::Device>)->Vec<<B as Backend>::Device>{self.inner.decompose_cloned().collect_devices(devices)}
186	fn devices(&self)->Vec<<B as Backend>::Device>{self.inner.decompose_cloned().devices()}
187	fn fork(self,device:&<B as Backend>::Device)->Self{Self::new(W::compose(self.inner.decompose().fork(device)))}
188	fn into_record(self)->Self::Record{self.inner.decompose().into_record()}
189	fn load_file<F:FileRecorder<B>,P:Into<PathBuf>>(self,filepath:P,recorder:&F,device:&<B as Backend>::Device)->Result<Self,RecorderError>{self.inner.decompose().load_file(filepath,recorder,device).map(|a|Self::new(W::compose(a)))}
190	fn load_record(self,record:Self::Record)->Self{Self::new(W::compose(self.inner.decompose().load_record(record)))}
191	fn map<Mapper:ModuleMapper<B>>(self,mapper:&mut Mapper)->Self{Self::new(W::compose(self.inner.decompose().map(mapper)))}
192	fn num_params(&self)->usize{self.inner.decompose_cloned().num_params()}
193	fn quantize_weights(self,quantizer:&mut Quantizer)->Self{Self::new(W::compose(self.inner.decompose().quantize_weights(quantizer)))}
194	fn save_file<F:FileRecorder<B>,P:Into<PathBuf>>(self,filepath:P,recorder:&F)->Result<(),RecorderError>{self.inner.decompose().save_file(filepath,recorder)}
195	fn to_device(self,device:&<B as Backend>::Device)->Self{Self::new(W::compose(self.inner.decompose().to_device(device)))}
196	fn visit<Visitor:ModuleVisitor<B>>(&self,visitor:&mut Visitor){self.inner.decompose_cloned().visit(visitor)}
197	type Record=<W::Decomposition as Module<B>>::Record;
198}
199impl<B:Backend,X:Into<Y>,Y> AI<X,Y> for Identity<B>{
200	fn forward(&self,input:X)->Y{input.into()}
201}
202impl<B:Backend> AI<LossOutput<B>,ClassificationOutput<B>> for ClassificationLayer{
203	fn forward(&self,lossoutput:LossOutput<B>)->ClassificationOutput<B>{//TODO make work for multi
204		let loss=match lossoutput.loss(){Value::F1(x)=>x,Value::F2(x)=>x.mean(),Value::F3(x)=>x.mean(),Value::F4(x)=>x.mean(),Value::F5(x)=>x.mean(),Value::F6(x)=>x.mean(),Value::F7(x)=>x.mean(),Value::F8(x)=>x.mean(),Value::Incompatible(e)=>panic!("{e}"),_=>panic!("cannot convert non floats to classification output")};
205		let output=match lossoutput.output(){Value::F1(x)=>x.unsqueeze(),Value::F2(x)=>x,Value::F3(x)=>x.flatten(0,1),Value::F4(x)=>x.flatten(0,2),Value::F5(x)=>x.flatten(0,3),Value::F6(x)=>x.flatten(0,4),Value::F7(x)=>x.flatten(0,5),Value::F8(x)=>x.flatten(0,6),Value::Incompatible(e)=>panic!("{e}"),_=>panic!("cannot convert non floats to classification output")};
206		let target=match lossoutput.target(){Value::I1(x)=>x,Value::I2(x)=>x.flatten(0,1),Value::I3(x)=>x.flatten(0,2),Value::I4(x)=>x.flatten(0,3),Value::I5(x)=>x.flatten(0,4),Value::I6(x)=>x.flatten(0,5),Value::I7(x)=>x.flatten(0,6),Value::I8(x)=>x.flatten(0,7),Value::Incompatible(e)=>panic!("{e}"),_=>panic!("cannot convert non floats to classification output")};
207		ClassificationOutput::new(loss,output,target)
208	}
209}
210impl<B:Backend> AI<LossOutput<B>,RegressionOutput<B>> for RegressionLayer{
211	fn forward(&self,lossoutput:LossOutput<B>)->RegressionOutput<B>{
212		let loss=match lossoutput.loss(){Value::F1(x)=>x,Value::F2(x)=>x.mean(),Value::F3(x)=>x.mean(),Value::F4(x)=>x.mean(),Value::F5(x)=>x.mean(),Value::F6(x)=>x.mean(),Value::F7(x)=>x.mean(),Value::F8(x)=>x.mean(),Value::Incompatible(e)=>panic!("{e}"),_=>panic!("cannot convert non floats to regression output")};
213		let output=match lossoutput.output(){Value::F1(x)=>x.unsqueeze(),Value::F2(x)=>x,Value::F3(x)=>x.flatten(0,1),Value::F4(x)=>x.flatten(0,2),Value::F5(x)=>x.flatten(0,3),Value::F6(x)=>x.flatten(0,4),Value::F7(x)=>x.flatten(0,5),Value::F8(x)=>x.flatten(0,6),Value::Incompatible(e)=>panic!("{e}"),_=>panic!("cannot convert non floats to regression output")};
214		let target=match lossoutput.target(){Value::F1(x)=>x.unsqueeze(),Value::F2(x)=>x,Value::F3(x)=>x.flatten(0,1),Value::F4(x)=>x.flatten(0,2),Value::F5(x)=>x.flatten(0,3),Value::F6(x)=>x.flatten(0,4),Value::F7(x)=>x.flatten(0,5),Value::F8(x)=>x.flatten(0,6),Value::Incompatible(e)=>panic!("{e}"),_=>panic!("cannot convert non floats to regression output")};
215		RegressionOutput::new(loss,output,target)
216	}
217}
218impl<B:Backend> Decompose for Identity<B>{
219	fn compose(_decomposition:Self::Decomposition)->Self{new()}
220	fn decompose(self){}
221	fn decompose_cloned(&self){}
222	type Decomposition=();
223}
224impl<B:Backend> Op for Identity<B>{
225	type Output=();
226}
227impl<B:Backend> Wrappable for Identity<B>{
228	type B=B;
229	type With<C:Backend>=Identity<C>;
230}
231impl<B:Backend> Wrappable for Layer<B>{
232	type B=B;
233	type With<C:Backend>=Layer<C>;
234}
235impl<B:Backend> Wrappable for LossOutput<B>{
236	type B=B;
237	type With<C:Backend>=LossOutput<C>;
238}
239impl<B:Backend> Wrappable for Value<B>{
240	type B=B;
241	type With<C:Backend>=Value<C>;
242}
243impl<C:Backend,W:ToBackend<C,OnBackend=W::With<C>>+Wrappable> ToBackend<C> for Wrapped<W>{
244	fn to_backend_device(self,device:&C::Device)->Self::OnBackend{
245		Wrapped{inner:self.inner.to_backend_device(device)}
246	}
247	type OnBackend=Wrapped<W::With<C>>;
248}
249impl<T:?Sized+Op> Shortcuts for T{}
250impl<W:AI<X,Y>+Wrappable,X,Y> AI<X,Y> for Wrapped<W>{
251	fn forward(&self,input:X)->Y{self.inner.forward(input)}
252	fn forward_mut(&mut self,input:X)->Y{self.inner.forward_mut(input)}
253}
254impl<W:Op+Wrappable> Op for Wrapped<W>{
255	type Output=W::Output;
256}
257impl<W:UnwrapInner+Wrappable> UnwrapInner for Wrapped<W>{
258	fn unwrap_inner(self)->Self::Inner{self.into_inner().unwrap_inner()}
259	type Inner=W::Inner;
260}
261impl<W:Wrappable> Decompose for Wrapped<W>{
262	fn compose(decomposition:Self::Decomposition)->Self{Self::new(W::compose(decomposition))}
263	fn decompose(self)->Self::Decomposition{self.inner.decompose()}
264	fn decompose_cloned(&self)->Self::Decomposition{self.inner.decompose_cloned()}
265	type Decomposition=W::Decomposition;
266}
267impl<W:Wrappable> Display for Wrapped<W>{
268    fn fmt(&self,f:&mut std::fmt::Formatter<'_>)->Result<(),std::fmt::Error>{write!(f,"todo")}
269}
270impl<W:Wrappable> From<W> for Wrapped<W>{
271	fn from(value:W)->Self{Self::new(value)}
272}
273impl<W:Wrappable> ModuleDisplay for Wrapped<W> where W::Decomposition:ModuleDisplay{
274	fn custom_content(&self,content:Content)->Option<Content>{self.inner.decompose_cloned().custom_content(content)}
275	fn custom_settings(&self)->Option<DisplaySettings>{self.inner.decompose_cloned().custom_settings()}
276	fn format(&self,passed_settings:DisplaySettings)->String{self.inner.decompose_cloned().format(passed_settings)}
277}
278impl<W:Wrappable> ModuleDisplayDefault for Wrapped<W> where W::Decomposition:ModuleDisplayDefault{
279	fn content(&self,content:Content)->Option<Content>{self.inner.decompose_cloned().content(content)}
280	fn num_params(&self)->usize{self.inner.decompose_cloned().num_params()}
281}
282impl<W:Wrappable> Wrappable for Abs<W>{
283	type B=W::B;
284	type With<C:Backend>=Abs<W::With<C>>;
285}
286impl<W:Wrappable> Wrappable for AccQ<W>{
287	type B=W::B;
288	type With<C:Backend>=AccQ<W::With<C>>;
289}
290impl<W:Wrappable> Wrappable for Cat<W>{
291	type B=W::B;
292	type With<C:Backend>=Cat<W::With<C>>;
293}
294impl<W:Wrappable> Wrappable for Classification<W>{
295	type B=W::B;
296	type With<C:Backend>=Classification<W::With<C>>;
297}
298impl<W:Wrappable> Wrappable for CrossEntropy<W>{
299	type B=W::B;
300	type With<C:Backend>=CrossEntropy<W::With<C>>;
301}
302impl<W:Wrappable> Wrappable for Duplicate<W>{
303	type B=W::B;
304	type With<C:Backend>=Duplicate<W::With<C>>;
305}
306impl<W:Wrappable> Wrappable for Graph<W>{
307	type B=W::B;
308	type With<C:Backend>=Graph<W::With<C>>;
309}
310impl<W:Wrappable> Wrappable for Inner<W>{
311	type B=W::B;
312	type With<C:Backend>=Inner<W::With<C>>;
313}
314impl<W:Wrappable> Wrappable for Mean<W>{
315	type B=W::B;
316	type With<C:Backend>=Mean<W::With<C>>;
317}
318impl<W:Wrappable> Wrappable for SquaredError<W>{
319	type B=W::B;
320	type With<C:Backend>=SquaredError<W::With<C>>;
321}
322impl<W:Wrappable> Wrappable for Map<W>{
323	type B=W::B;
324	type With<C:Backend>=Map<W::With<C>>;
325}
326impl<W:Wrappable> Wrappable for Regression<W>{
327	type B=W::B;
328	type With<C:Backend>=Regression<W::With<C>>;
329}
330impl<W:Wrappable> Wrappable for Sequential<W>{
331	type B=W::B;
332	type With<C:Backend>=Sequential<W::With<C>>;
333}
334impl<W:Wrappable> Wrappable for Choose<W>{
335	type B=W::B;
336	type With<C:Backend>=Choose<W::With<C>>;
337}
338impl<W:Wrappable> Wrappable for Unvec<W>{
339	type B=W::B;
340	type With<C:Backend>=Unvec<W::With<C>>;
341}
342impl<W:Wrappable> Wrappable for Zip<W>{
343	type B=W::B;
344	type With<C:Backend>=Zip<W::With<C>>;
345}
346impl<W:Wrappable> Wrapped<W>{
347	/// references the inner value
348	pub fn inner(&self)->&W{&self.inner}
349	/// references the inner value
350	pub fn inner_mut(&mut self)->&mut W{&mut self.inner}
351	/// unwraps the inner value
352	pub fn into_inner(self)->W{self.inner}
353	/// creates a new wrapped value
354	pub fn new(inner:W)->Self{
355		Self{inner}
356	}
357}
358#[cfg(test)]
359mod tests{
360	#[test]
361	fn learn_xor(){
362		type B=NdArray;
363		type A=Autodiff<B>;
364		let i0=Tensor::<A,1>::from_data(TensorData::new([0.0,0.0].to_vec(),[2]),&Default::default());
365		let i1=Tensor::<A,1>::from_data(TensorData::new([0.0,1.0].to_vec(),[2]),&Default::default());
366		let i2=Tensor::<A,1>::from_data(TensorData::new([1.0,0.0].to_vec(),[2]),&Default::default());
367		let i3=Tensor::<A,1>::from_data(TensorData::new([1.0,1.0].to_vec(),[2]),&Default::default());
368		let o0=Tensor::<A,1>::from_data(TensorData::new([0.0].to_vec(),[1]),&Default::default());
369		let o1=Tensor::<A,1>::from_data(TensorData::new([1.0].to_vec(),[1]),&Default::default());
370		let o2=Tensor::<A,1>::from_data(TensorData::new([1.0].to_vec(),[1]),&Default::default());
371		let o3=Tensor::<A,1>::from_data(TensorData::new([0.0].to_vec(),[1]),&Default::default());
372
373		let dataset:Vec<(Tensor<A,1>,Tensor<A,1>)>=[(i0,o0),(i1,o1),(i2,o2),(i3,o3)].into_iter().cycle().take(4000).collect();
374		let train=InMemDataset::new(dataset.clone().into_iter().map(|(i,o)|(Value::from(i),Value::from(o))).collect());
375		let valid=InMemDataset::new(dataset.into_iter().map(|(i,o)|(Value::from(i.valid()),Value::from(o.valid()))).collect());
376		let mut graph:Graph<Layer<A>>=Graph::new();
377		graph.connect("input","x").with_clear(true).with(Layer::linear(true,2,10,1.0));
378		graph.connect("x","y").with_clear(true).with(Layer::relu());
379		graph.connect("y","output").with_clear(true).with(Layer::linear(false,10,1,1.0));
380
381		let graph=Unvec(graph).wrap_inner().squared_error().set_type::<(Value<A>,Value<A>),LossOutput<A>>().regression().wrap();
382		let graph=graph.train(&TrainConfig::new().with_checkpoints(false),SgdConfig::new().init(),0.01,train,valid);
383		let graph=graph.valid().unwrap_inner();
384
385		let inputval=Value::from(Tensor::<B,2>::from_data(TensorData::new([0.0,0.0,0.0,1.0,1.0,0.0,1.0,1.0].to_vec(),[4,2]),&Default::default()));
386		let outputval=graph.forward(inputval);
387		if let Value::F2(o)=outputval{
388			let target=Tensor::<B,2>::from_data(TensorData::new([0.0,1.0,1.0,0.0].to_vec(),[4,1]),&Default::default());
389			let error=(target-o.clone()).abs().max();
390			println!("{}",o);
391			assert!(error.into_scalar()<0.1);
392		}else{
393			panic!("h");
394		}
395	}
396	use burn::{
397		backend::Autodiff,data::dataset::InMemDataset,optim::SgdConfig
398	};
399	use super::*;
400}
401mod layer;
402mod shape;
403mod value;
404/// helper function for applying operations that apply to a specific depth of multiple structure such that wrapping multiple appropriate inputs with a multi outputs the output of the function applied to all inputs. 0 depth for empty, 1 for single, 2+ for multi
405pub fn apply_depthwise<B:Backend,F:FnMut(Value<B>)->Value<B>>(depth:usize,mut op:F,value:Value<B>)->Value<B>{
406	fn inner<B:Backend,F:FnMut(Value<B>)->Value<B>>(depth:usize,op:&mut F,value:Value<B>)->(Value<B>,usize){
407		let mut height=1;
408		let value=if value.is_multi(){
409			let value:Value<B>=value.into_iter().map(|v|{
410				let (v,h)=inner(depth,op,v);
411				height=height.max(h);
412				v
413			}).collect();
414			if value.len()==0{height=0}else{height+=1}
415			value
416		}else{
417			value
418		};
419		(if depth==height{op(value)}else{value},height)
420	}
421	inner(depth,&mut op,value).0
422}
423/// starts the building of an ai structure in chained method style from an identity operation
424pub fn new<B:Backend>()->Identity<B>{
425	Identity{phantom:PhantomData}
426}
427#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
428/// batcher that stacks things
429pub struct BatchStacker;
430#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
431/// wrapper for converting loss to classification output
432pub struct Classification<A>{inner:A,layer:ClassificationLayer}
433#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
434/// layer for converting loss to classification output
435pub struct ClassificationLayer{seal:PhantomData<()>}
436#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
437/// metrics renderer implementation that doesn't actually do anything
438pub struct DontRender;
439#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
440/// identity version that knows what backend
441pub struct Identity<B:Backend>{phantom:PhantomData<B>}
442#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
443/// wrapper for converting loss to regression output
444pub struct Regression<A>{inner:A,layer:RegressionLayer}
445#[derive(Clone,Copy,Debug,Default,Deserialize,Serialize)]
446/// layer for converting loss to regression output
447pub struct RegressionLayer{seal:PhantomData<()>}
448#[derive(Config,Debug)]
449/// configuration for convenient training through the wrapper
450pub struct TrainConfig{
451	#[config(default="String::from(\".artifact\")")]
452	artifact_directory:String,
453	#[config(default="16")]
454	batch_size:usize,
455	#[config(default="true")]
456	checkpoints:bool,
457	#[config(default="false")]
458	console_rendering:bool,
459	#[config(default="10")]
460	epochs:usize,
461	#[config(default="false")]
462	summary:bool,
463	#[config(default="4")]
464	workers:usize
465}
466#[derive(Clone,Copy,Debug,Default)]
467/// wraps in a burn wrapper
468pub struct Wrapped<W:Wrappable>{inner:W}
469/// chained method shortcut trait
470pub trait Shortcuts{
471	/// wraps in a classification wrapper
472	fn classification(self)->Classification<Self> where Classification<Self>:Op,Self:Sized{Classification::from_inner(self)}
473	/// wraps in a regression wrapper
474	fn regression(self)->Regression<Self> where Regression<Self>:Op,Self:Sized{Regression::from_inner(self)}
475	/// wraps in a burn wrapper
476	fn wrap(self)->Wrapped<Self> where Self:Wrappable{Wrapped::new(self)}
477}
478/// trait for switching the backend of a module
479pub trait ToBackend<B:Backend>:Sized{
480	/// moves the module to the backend with the device
481	fn to_backend_device(self,device:&B::Device)->Self::OnBackend;
482	/// moves the module to the backend with the device
483	fn to_backend(self)->Self::OnBackend{self.to_backend_device(&Default::default())}
484	/// the type on the new backend
485	type OnBackend;
486}
487/// higher kinded type trait to allow rewrapping burn modules in different backends to implement some wrapper features
488pub trait Wrappable:Clone+Debug+Decompose+Send{
489	type B:Backend;
490	type With<C:Backend>:Wrappable<B=C,With<C>=Self::With<C>>+Wrappable<B=C,With<Self::B>=Self>;
491}
492pub use burn as lib;
493pub use layer::{Attention,AttentionConfig,AttentionMask,BiasConfig,Cache,Config,Layer,KQV,KQVConfig,PowerMaskInfo};
494pub use shape::{Kind,Reshape,Shape};
495pub use value::{LossOutput,Value};
496use burn::{
497	backend::NdArray,
498	data::{
499		dataset::Dataset,dataloader::{batcher::Batcher,DataLoaderBuilder}
500	},
501	lr_scheduler::LrScheduler,
502	module::{AutodiffModule,Content,DisplaySettings,ModuleDisplay,ModuleDisplayDefault,ModuleMapper,ModuleVisitor,Quantizer},
503	optim::Optimizer,
504	prelude::*,
505	record::{CompactRecorder,FileRecorder,RecorderError},
506	tensor::{BasicOps,TensorKind,backend::AutodiffBackend},
507	train::{
508		ClassificationOutput,LearnerBuilder,RegressionOutput,TrainOutput,TrainStep,ValidStep,metric::{Adaptor,ItemLazy,LossInput,LossMetric},renderer::{MetricState,MetricsRenderer,TrainingProgress}
509	}
510};
511use crate::{
512	AI,Decompose,Graph,Inner,Op,UnwrapInner,Unvec,
513	builtin::{
514		Duplicate,Map,Sequential,SetType,Zip,math::{Abs,Mean,SquaredError},reinforcement::AccQ,soft::{Choose,CrossEntropy},structural::Cat
515	},
516	ops::Stack as OpsStack
517};
518use rand::random;
519use serde::{Deserialize,Serialize};
520use std::{
521	fmt::{Debug,Display},fs::{create_dir_all as create_folder},marker::PhantomData,path::PathBuf
522};