1bicop_num!(Add,add,add_scalar);
2bicop_num!(Div,div,div_scalar);
3bicop_num!(Mul,mul,mul_scalar);
4bicop_num!(Rem,rem,remainder_scalar);
5bicop_num!(Sub,sub,sub_scalar);
6fn broadcast_multi<B:Backend,F:FnMut(Value<B>,Value<B>)->Value<B>>(u:Vec<Value<B>>,v:Vec<Value<B>>,mut f:F)->Value<B>{
7 if u.len()==1{
8 u.into_iter().cycle().zip(v).map(|(x,y)|f(x,y)).collect()
9 }else if v.len()==1{
10 u.into_iter().zip(v.into_iter().cycle()).map(|(x,y)|f(x,y)).collect()
11 }else if u.len()==v.len(){
12 u.into_iter().zip(v).map(|(x,y)|f(x,y)).collect()
13 }else{
14 "mismatched lengths".into()
15 }
16}
17fn hard_choose_burn_1<B:Backend,const N:usize>(dim:i32,distribution:Tensor<B,N>)->u32{
18 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
19 let distribution=if dim==N-1{distribution}else{distribution.movedim(dim,N-1)}.into_data();
20 let sum=distribution.iter().fold(0.0,|acc:f32,weight:f32|acc+weight);
21
22 distribution.iter().scan(random::<f32>()*sum,|choice:&mut f32,weight:f32|Some(*choice-=weight).filter(|_|*choice>=0.0)).count() as u32
23}
24fn hard_choose_burn_multi<B:Backend,const N:usize>(dim:i32,distribution:Tensor<B,N>)->Vec<u32>{
25 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
26
27 let chunk=distribution.dims()[dim];
28 let distribution=if dim==N-1{distribution}else{distribution.movedim(dim,N-1)}.into_data().to_vec().unwrap();
29
30 distribution.chunks_exact(chunk).map(|d|{
31 let sum=d.iter().fold(0.0,|acc:f32,weight:&f32|acc+weight);
32 d.iter().scan(random::<f32>()*sum,|choice:&mut f32,weight:&f32|Some(*choice-=weight).filter(|_|*choice>=0.0)).count() as u32
33 }).collect()
34}
35fn hard_choose_burn_tensor<B:Backend,const N:usize>(dim:i32,distribution:Tensor<B,N>)->Tensor<B,N,Int>{let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
37 let device=distribution.device();
38 let mut dims=distribution.dims();
39
40 dims[N-1]=1;
41 let r:Tensor<B,N,Int>=Tensor::from_data(TensorData::new(hard_choose_burn_multi(dim as i32,distribution),dims),&device);
42
43 r.movedim(N-1,dim)
44}
45fn slice_slice<B:Backend,K:BasicOps<B>+TensorKind<B>,const N:usize>(ranges:&[Range<usize>],tensor:Tensor<B,N,K>)->Tensor<B,N,K>{
46 let mut n=0;
47 let mut acc=||{
48 let a=n;
49 n+=1;
50 a
51 };
52
53 match ranges.len(){0=>tensor,1=>tensor.slice([0;1].map(|_|ranges[acc()].clone())),2=>tensor.slice([0;2].map(|_|ranges[acc()].clone())),3=>tensor.slice([0;3].map(|_|ranges[acc()].clone())),4=>tensor.slice([0;4].map(|_|ranges[acc()].clone())),5=>tensor.slice([0;5].map(|_|ranges[acc()].clone())),6=>tensor.slice([0;6].map(|_|ranges[acc()].clone())),7=>tensor.slice([0;7].map(|_|ranges[acc()].clone())),8=>tensor.slice([0;8].map(|_|ranges[acc()].clone())),_=>panic!("too many ranges for current max 8 dims")}
54}
55fn soft_choose_burn_1<B:Backend,const N:usize>(dim:i32,logits:Tensor<B,N>,temperature:f32)->u32{
56 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
57 let logits=if dim==N-1{logits}else{logits.movedim(dim,N-1)};
58 let distribution=softmax(logits/temperature,N-1).into_data();
59 distribution.iter().scan(random(),|choice:&mut f32,weight:f32|Some(*choice-=weight).filter(|_|*choice>=0.0)).count() as u32
60}
61fn soft_choose_burn_multi<B:Backend,const N:usize>(dim:i32,logits:Tensor<B,N>,temperature:f32)->Vec<u32>{
62 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
63 let logits=if dim==N-1{logits}else{logits.movedim(dim,N-1)};
64 let chunk=logits.dims()[N-1];
65 let distribution=softmax(logits/temperature,N-1).into_data().to_vec().unwrap();
66 distribution.chunks_exact(chunk).map(|d|d.iter().scan(random(),|choice:&mut f32,weight:&f32|Some(*choice-=weight).filter(|_|*choice>=0.0)).count() as u32).collect()
67}
68fn soft_choose_burn_tensor<B:Backend,const N:usize>(dim:i32,logits:Tensor<B,N>,temperature:f32)->Tensor<B,N,Int>{let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
70 let device=logits.device();
71 let mut dims=logits.dims();
72
73 dims[N-1]=1;
74 let r:Tensor<B,N,Int>=Tensor::from_data(TensorData::new(soft_choose_burn_multi(dim as i32,logits,temperature),dims),&device);
75 r.movedim(N-1,dim)
76}
77impl AsRef<Self> for Shape{fn as_ref(&self)->&Self{self}
79}
80impl Shape{
81 pub fn count(&self)->Option<usize>{
83 match self{
84 Shape::Incompatible(_e)=>None,
85 Shape::Multi(n)=>if *n==0{Some(0)}else{None},
86 Shape::Recursive(v)=>{
87 let mut s=0;
88 for v in v{s+=v.count()?}
89 Some(s)
90 },
91 X1(x)=>Some(x.iter().product()),
92 X2(x)=>Some(x.iter().product()),
93 X3(x)=>Some(x.iter().product()),
94 X4(x)=>Some(x.iter().product()),
95 X5(x)=>Some(x.iter().product()),
96 X6(x)=>Some(x.iter().product()),
97 X7(x)=>Some(x.iter().product()),
98 X8(x)=>Some(x.iter().product())
99 }
100 }
101 pub fn to_array(self,alignment:Alignment)->[usize;8]{
103 let mut result=[1;8];
104 let slice=match &self{Shape::Incompatible(_e)=>return result,Shape::Multi(_v)=>return result,Shape::Recursive(_r)=>return result,X1(x)=>x.as_slice(),X2(x)=>x.as_slice(),X3(x)=>x.as_slice(),X4(x)=>x.as_slice(),X5(x)=>x.as_slice(),X6(x)=>x.as_slice(),X7(x)=>x.as_slice(),X8(x)=>x.as_slice()};
105 let l=slice.len();
106 match alignment{Alignment::Center=>result[4-l/2..][..l].copy_from_slice(slice),Alignment::Left=>result[..l].copy_from_slice(slice),Alignment::Right=>result[8-l..].copy_from_slice(slice)}
107 result
108 }
109}
110impl<'a,B:Backend> Deserialize<'a> for Value<B>{
111 fn deserialize<D:Deserializer<'a>>(deserializer:D)->Result<Self,D::Error>{ValueData::deserialize(deserializer).map(Into::into)}
112}
113impl<A:AutodiffBackend> AutodiffModule<A> for Value<A>{
114 fn valid(&self)->Self::InnerModule{
115 match self{B1(x)=>B1(x.valid()),B2(x)=>B2(x.valid()),B3(x)=>B3(x.valid()),B4(x)=>B4(x.valid()),B5(x)=>B5(x.valid()),B6(x)=>B6(x.valid()),B7(x)=>B7(x.valid()),B8(x)=>B8(x.valid()),F1(x)=>F1(x.valid()),F2(x)=>F2(x.valid()),F3(x)=>F3(x.valid()),F4(x)=>F4(x.valid()),F5(x)=>F5(x.valid()),F6(x)=>F6(x.valid()),F7(x)=>F7(x.valid()),F8(x)=>F8(x.valid()),I1(x)=>I1(x.valid()),I2(x)=>I2(x.valid()),I3(x)=>I3(x.valid()),I4(x)=>I4(x.valid()),I5(x)=>I5(x.valid()),I6(x)=>I6(x.valid()),I7(x)=>I7(x.valid()),I8(x)=>I8(x.valid()),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.iter().map(|x|x.valid()).collect()}
116 }
117 type InnerModule=Value<A::InnerBackend>;
118}
119impl<A:Into<Value<B>>,B:Backend> FromIterator<A> for Value<B>{
120 fn from_iter<I:IntoIterator<Item=A>>(iter:I)->Self{Value::Multi(iter.into_iter().map(Into::into).collect())}
121}
122impl<B:Backend,K:'static+TensorKind<B>,const N:usize> From<Result<Tensor<B,N,K>,String>> for Value<B>{
123 fn from(value:Result<Tensor<B,N,K>,String>)->Self{
124 match value{Err(e)=>e.into(),Ok(t)=>t.into()}
125 }
126}
127impl<B:Backend,K:'static+TensorKind<B>,const N:usize> From<Tensor<B,N,K>> for Value<B>{
128 fn from(value:Tensor<B,N,K>)->Self{
129 let kind=TypeId::of::<K>();
130 let kind=if kind==TypeId::of::<Bool>(){Kind::Bool}else if kind==TypeId::of::<Float>(){Kind::Float}else if kind==TypeId::of::<Int>(){Kind::Int}else{return "only bool, float, and int tensors with dimensions 1-8 are currently supported".into()};
131
132 let v=unsafe{
133 match (N,kind){(1,Kind::Bool)=>B1(mem::transmute_copy(&value)),(2,Kind::Bool)=>B2(mem::transmute_copy(&value)),(3,Kind::Bool)=>B3(mem::transmute_copy(&value)),(4,Kind::Bool)=>B4(mem::transmute_copy(&value)),(5,Kind::Bool)=>B5(mem::transmute_copy(&value)),(6,Kind::Bool)=>B6(mem::transmute_copy(&value)),(7,Kind::Bool)=>B7(mem::transmute_copy(&value)),(8,Kind::Bool)=>B8(mem::transmute_copy(&value)),(1,Kind::Float)=>F1(mem::transmute_copy(&value)),(2,Kind::Float)=>F2(mem::transmute_copy(&value)),(3,Kind::Float)=>F3(mem::transmute_copy(&value)),(4,Kind::Float)=>F4(mem::transmute_copy(&value)),(5,Kind::Float)=>F5(mem::transmute_copy(&value)),(6,Kind::Float)=>F6(mem::transmute_copy(&value)),(7,Kind::Float)=>F7(mem::transmute_copy(&value)),(8,Kind::Float)=>F8(mem::transmute_copy(&value)),(1,Kind::Int)=>I1(mem::transmute_copy(&value)),(2,Kind::Int)=>I2(mem::transmute_copy(&value)),(3,Kind::Int)=>I3(mem::transmute_copy(&value)),(4,Kind::Int)=>I4(mem::transmute_copy(&value)),(5,Kind::Int)=>I5(mem::transmute_copy(&value)),(6,Kind::Int)=>I6(mem::transmute_copy(&value)),(7,Kind::Int)=>I7(mem::transmute_copy(&value)),(8,Kind::Int)=>I8(mem::transmute_copy(&value)),_=>return "only bool, float, and int tensors with dimensions 1-8 are currently supported".into()}
134 };
135 mem::forget(value);
136 v
137 }
138}
139impl<B:Backend,K:'static+TensorKind<B>,const N:usize> TryFrom<Value<B>> for Tensor<B,N,K>{
140 fn try_from(value:Value<B>)->Result<Self,Self::Error>{
141 let kind=TypeId::of::<K>();
142 let kind=if kind==TypeId::of::<Bool>(){Kind::Bool}else if kind==TypeId::of::<Float>(){Kind::Float}else if kind==TypeId::of::<Int>(){Kind::Int}else{return Err(value)};
143
144 if Some(N)!=value.rank()||kind!=value.kind(){return Err(value)}
145 let r=unsafe{
146 match &value{B1(x)=>mem::transmute_copy(x),B2(x)=>mem::transmute_copy(x),B3(x)=>mem::transmute_copy(x),B4(x)=>mem::transmute_copy(x),B5(x)=>mem::transmute_copy(x),B6(x)=>mem::transmute_copy(x),B7(x)=>mem::transmute_copy(x),B8(x)=>mem::transmute_copy(x),F1(x)=>mem::transmute_copy(x),F2(x)=>mem::transmute_copy(x),F3(x)=>mem::transmute_copy(x),F4(x)=>mem::transmute_copy(x),F5(x)=>mem::transmute_copy(x),F6(x)=>mem::transmute_copy(x),F7(x)=>mem::transmute_copy(x),F8(x)=>mem::transmute_copy(x),I1(x)=>mem::transmute_copy(x),I2(x)=>mem::transmute_copy(x),I3(x)=>mem::transmute_copy(x),I4(x)=>mem::transmute_copy(x),I5(x)=>mem::transmute_copy(x),I6(x)=>mem::transmute_copy(x),I7(x)=>mem::transmute_copy(x),I8(x)=>mem::transmute_copy(x),_=>panic!("internal error")}
147 };
148 mem::forget(value);
149 Ok(r)
150 }
151 type Error=Value<B>;
152}
153impl<B:Backend,S:?Sized+AsRef<str>> From<&S> for Value<B>{
154 fn from(value:&S)->Self{Self::Incompatible(value.as_ref().to_string())}
155}
156impl<B:Backend,const D:usize> AI<Value<B>,Value<B>> for BatchNorm<B,D>{
157 fn forward(&self,input:Value<B>)->Value<B>{
158 fn f<B:Backend,const D:usize,const E:usize,const F:usize>(norm:&BatchNorm<B,D>,x:Tensor<B,E>)->Value<B>{
159 let norm:BatchNorm<B,F>=BatchNorm{beta:norm.beta.clone(),epsilon:norm.epsilon.clone(),gamma:norm.gamma.clone(),momentum:norm.momentum.clone(),running_mean:norm.running_mean.clone(),running_var:norm.running_var.clone()};
160 norm.forward(x).into()
161 }
162 match input.float(){
163 F1(x)=>AI::forward(self,F1(x).unsqueeze().unsqueeze()).squeeze().squeeze(),
164 F2(x)=>AI::forward(self,F2(x).unsqueeze()).squeeze(),
165 F3(x)=>f::<B,D,3,1>(self,x),
166 F4(x)=>f::<B,D,4,2>(self,x),
167 F5(x)=>f::<B,D,5,3>(self,x),
168 F6(x)=>f::<B,D,6,4>(self,x),
169 F7(x)=>f::<B,D,7,5>(self,x),
170 F8(x)=>f::<B,D,8,6>(self,x),
171 Value::Incompatible(e)=>e.into(),
172 Value::Multi(v)=>v.into_iter().map(|x|AI::forward(self,x)).collect(),
173 _=>panic!("internal error")
174 }
175 }
176}
177impl<B:Backend> AI<(Value<B>,Value<B>),Vec<f32>> for CrossEntropyLayer{
178 fn forward(&self,input:(Value<B>,Value<B>))->Vec<f32>{
179 let output:Value<B>=self.forward(input);
180 output.into_float_vec()
181 }
182}
183impl<B:Backend> AI<(Value<B>,Value<B>),LossOutput<B>> for CrossEntropyLayer{
184 fn forward(&self,(output,target):(Value<B>,Value<B>))->LossOutput<B>{
185 let loss=self.forward((output.clone(),target.clone()));
186 LossOutput::new(loss,output,target)
187 }
188}
189impl<B:Backend> AI<(Value<B>,Value<B>),Value<B>> for CrossEntropyLayer{fn forward(&self,(output,target):(Value<B>,Value<B>))->Value<B>{
191 fn ff<B:Backend,const N:usize>(dim:i32,y:Tensor<B,N>,t:Tensor<B,N>,temperature:f32)->Result<Tensor<B,N>,String>{
192 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
193 let (ydims,tdims)=(y.dims(),t.dims());
194 if ydims==tdims{
195 let logy=if temperature.is_nan(){y.log()}else{log_softmax(y/temperature,dim)};
196 Ok(logy*t.neg())
197 }else{
198 Err(format!("incompatible shapes to cross entropy. ydims: {ydims:?} tdims: {tdims:?}"))
199 }
200 }
201 fn fi<B:Backend,const N:usize,const K:usize>(dim:i32,y:Tensor<B,N>,t:Tensor<B,K,Int>,temperature:f32)->Result<Tensor<B,K>,String>{
202 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
203 let (ydims,tdims)=(y.dims(),t.dims());
204 if ydims.iter().enumerate().filter_map(|(n,y)|(n!=dim).then_some(y)).eq(tdims.iter()){
205 let logy=if temperature.is_nan(){y.log()}else{log_softmax(y/temperature,dim)};
206 Ok(logy.gather(dim,t.unsqueeze_dim(dim)).neg().squeeze(dim))
207 }else{
208 Err(format!("incompatible shapes to cross entropy along dimension {dim}. ydims: {ydims:?} tdims: {tdims:?}"))
209 }
210 }
211 let (dim,temp)=(self.get_dim(),self.get_temperature());
212
213 match match (output,target){
214 (F1(y),F1(t))=>ff(dim,y,t,temp).map(Into::into),
215 (F2(y),F2(t))=>ff(dim,y,t,temp).map(Into::into),
216 (F3(y),F3(t))=>ff(dim,y,t,temp).map(Into::into),
217 (F4(y),F4(t))=>ff(dim,y,t,temp).map(Into::into),
218 (F5(y),F5(t))=>ff(dim,y,t,temp).map(Into::into),
219 (F6(y),F6(t))=>ff(dim,y,t,temp).map(Into::into),
220 (F7(y),F7(t))=>ff(dim,y,t,temp).map(Into::into),
221 (F8(y),F8(t))=>ff(dim,y,t,temp).map(Into::into),
222 (F1(y),I1(t))=>fi(dim,y.unsqueeze::<2>(),t,temp).map(Into::into),
223 (F2(y),I1(t))=>fi(dim,y,t,temp).map(Into::into),
224 (F3(y),I2(t))=>fi(dim,y,t,temp).map(Into::into),
225 (F4(y),I3(t))=>fi(dim,y,t,temp).map(Into::into),
226 (F5(y),I4(t))=>fi(dim,y,t,temp).map(Into::into),
227 (F6(y),I5(t))=>fi(dim,y,t,temp).map(Into::into),
228 (F7(y),I6(t))=>fi(dim,y,t,temp).map(Into::into),
229 (F7(y),I7(t))=>fi(dim,y,t,temp).map(Into::into),
230 (Value::Incompatible(y),_)=>Err(y),
231 (_,Value::Incompatible(t))=>Err(t),(Value::Multi(y),Value::Multi(t))=>if y.len()==t.len(){Ok(Value::Multi(y.into_iter().zip(t).map(|x|self.forward_typed::<_,Value<B>>(x)).collect()))}else{Err("mismatched lengths".into())},
233 _=>Err("incompatible".into())
234 }{
235 Err(e)=>Value::Incompatible(e),Ok(x)=>x
236 }
237 }
238}
239impl<B:Backend> AI<(Value<B>,Value<B>),Value<B>> for CrossEntropyLoss<B>{
240 fn forward(&self,(output,target):(Value<B>,Value<B>))->Value<B>{
241 let mut op=().fix_type::<Value<B>>().cross_entropy(1.0);
242 if !self.logits{op.set_temperature(f32::NAN)}
243 op.forward((output,target))
244 }
245}
246impl<B:Backend> AI<(Value<B>,Value<B>),LossOutput<B>> for SquaredErrorLayer{
247 fn forward(&self,(output,target):(Value<B>,Value<B>))->LossOutput<B>{
248 let loss=self.forward((output.clone(),target.clone()));
249 LossOutput::new(loss,output,target)
250 }
251}
252impl<B:Backend> AI<(Value<B>,Value<B>),Value<B>> for SquaredErrorLayer{
253 fn forward(&self,(output,target):(Value<B>,Value<B>))->Value<B>{
254 fn f<B:Backend,const N:usize>(y:Tensor<B,N>,t:Tensor<B,N>)->Value<B>{
255 if y.dims()==t.dims(){MseLoss.forward_no_reduction(y,t).into()}else{"compatible inputs for squared error are float tensors of the same shape".into()}
256 }
257 match (output.float(),target.float()){(F1(y),F1(t))=>f(y,t),(F2(y),F2(t))=>f(y,t),(F3(y),F3(t))=>f(y,t),(F4(y),F4(t))=>f(y,t),(F5(y),F5(t))=>f(y,t),(F6(y),F6(t))=>f(y,t),(F7(y),F7(t))=>f(y,t),(F8(y),F8(t))=>f(y,t),(Value::Incompatible(y),_)=>y.into(),(_,Value::Incompatible(t))=>t.into(),(Value::Multi(y),t)=>broadcast_multi(y,t.into_multi(),|y,t|self.forward((y,t))),(y,Value::Multi(t))=>broadcast_multi(y.into_multi(),t,|y,t|self.forward((y,t))),_=>"compatible inputs for squared error are float tensors of the same shape".into()}
258 }
259}
260impl<B:Backend> AI<(Value<B>,Value<B>),Vec<f32>> for SquaredErrorLayer{
261 fn forward(&self,(output,target):(Value<B>,Value<B>))->Vec<f32>{
262 let error:Value<B>=self.forward((output,target));
263 error.into_float_vec()
264 }
265}
266impl<B:Backend> AI<(Value<B>,Value<B>),f32> for SquaredErrorLayer{
267 fn forward(&self,(output,target):(Value<B>,Value<B>))->f32{().fix_type::<Value<B>>().squared_error().mean().forward((output,target))}
268}
269impl<B:Backend> AI<Value<B>,Tensor<B,1>> for MeanLayer{
270 fn forward(&self,input:Value<B>)->Tensor<B,1>{
271 fn avg<B:Backend,const N:usize>(x:Tensor<B,N>)->Tensor<B,1>{x.mean()}
272 let l=input.len();
273
274 if l==0{return Tensor::from_data(TensorData::new(vec![f32::NAN],[1]),&Default::default())}
275 match input.float(){F1(x)=>avg(x),F2(x)=>avg(x),F3(x)=>avg(x),F4(x)=>avg(x),F5(x)=>avg(x),F6(x)=>avg(x),F7(x)=>avg(x),F8(x)=>avg(x),Value::Incompatible(e)=>panic!("Could not reduce to a scalar due to incompatibility: {e}"),Value::Multi(v)=>v.into_iter().map(|x|self.forward(x)).reduce(|x:Tensor<B,1>,y:Tensor<B,1>|x+y).unwrap()/l as f32,_=>panic!("internal error")}
276 }
277}
278impl<B:Backend> AI<Value<B>,Tensor<B,1>> for SumLayer{
279 fn forward(&self,input:Value<B>)->Tensor<B,1>{
280 fn sum<B:Backend,const N:usize>(x:Tensor<B,N>)->Tensor<B,1>{x.sum()}
281 let l=input.len();
282
283 if l==0{return Tensor::from_data(TensorData::new(vec![f32::NAN],[1]),&Default::default())}
284 match input.float(){F1(x)=>sum(x),F2(x)=>sum(x),F3(x)=>sum(x),F4(x)=>sum(x),F5(x)=>sum(x),F6(x)=>sum(x),F7(x)=>sum(x),F8(x)=>sum(x),Value::Incompatible(e)=>panic!("Could not reduce to a scalar due to incompatibility: {e}"),Value::Multi(v)=>v.into_iter().map(|x|self.forward(x)).reduce(|x:Tensor<B,1>,y:Tensor<B,1>|x+y).unwrap(),_=>panic!("internal error")}
285 }
286}
287impl<B:Backend> AI<Value<B>,Value<B>> for Conv2d<B>{
288 fn forward(&self,input:Value<B>)->Value<B>{
289 fn f<B:Backend,const N:usize>(input:Tensor<B,N>,layer:&Conv2d<B>)->Value<B>{let mut dims=input.dims();
291 let n:usize=dims.iter().product();
292
293 let c=if N<3{1}else{dims[N-3]};
294 let h=if N<2{1}else{dims[N-2]};
295 let w=dims[N-1];
296
297 let b=n/(c*h*w);
298 let output=layer.forward(input.reshape([b,c,h,w]));
299
300 let [_b,c,h,w]=output.dims();
301
302 dims[N-1]=w;
303 if N<3&&c!=1{return F3(output.reshape([c,h,w]))}else if N>=3{dims[N-3]=c}
304 if N<2&&h!=1{return F2(output.reshape([h,w]))}else if N>=2{dims[N-2]=h}
305 output.reshape(dims).into()
306 }
307 let l=self;
308
309 match input.float(){F1(x)=>f(x,l),F2(x)=>f(x,l),F3(x)=>f(x,l),F4(x)=>f(x,l),F5(x)=>f(x,l),F6(x)=>f(x,l),F7(x)=>f(x,l),F8(x)=>f(x,l),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|AI::forward(self,x)).collect(),_=>panic!("internal error")}
310 }
311}
312impl<B:Backend> AI<Value<B>,Value<B>> for CrossEntropyLayer{
313 fn forward(&self,input:Value<B>)->Value<B>{
314 match input{
315 Value::Incompatible(e)=>e.into(),
316 Value::Multi(v)=>if v.len()==2{
317 let [output,target]=v.try_into().unwrap();
318 self.forward((output,target))
319 }else{
320 v.into_iter().map(|x|self.forward(x)).collect()
321 },
322 _=>"cross entropy inputs must be in pairs".into()
323 }
324 }
325}
326impl<B:Backend> AI<Value<B>,Value<B>> for CrossEntropyLoss<B>{
327 fn forward(&self,input:Value<B>)->Value<B>{
328 let mut op=CrossEntropyLayer::new(1.0);
329 if !self.logits{op.set_temperature(f32::NAN)}
330 op.forward(input)
331 }
332}
333impl<B:Backend> AI<Value<B>,Value<B>> for MeanLayer{
334 fn forward(&self,input:Value<B>)->Value<B>{
335 fn avg<B:Backend,const N:usize,const K:usize>(d:i32,x:Tensor<B,N>)->Tensor<B,K>{
336 let d=if d<0{N-((-d) as usize)}else{d as usize};
337 x.mean_dim(d).squeeze(d)
338 }
339 let l=input.len();
340
341 if l==0{return input}
342 match self.get_reduction_mode(){
343 ReductionMode::Component=>F1(self.forward(input)),
344 ReductionMode::Dim(d)=>{
345 if let Some(r)=input.rank(){
346 if d>=r as i32||d<(-(r as i32)){return format!("rank {r} is too low to cat along dimension {d}").into()}
347 }
348 match input.float(){F1(x)=>F1(x.mean()),F2(x)=>F1(avg(d,x)),F3(x)=>F2(avg(d,x)),F4(x)=>F3(avg(d,x)),F5(x)=>F4(avg(d,x)),F6(x)=>F5(avg(d,x)),F7(x)=>F6(avg(d,x)),F8(x)=>F7(avg(d,x)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|self.forward(x)).reduce(|x:Value<B>,y:Value<B>|x+y).unwrap()/l as f32,_=>panic!("internal error")}
349 },
350 ReductionMode::Tensor=>match input.float(){Value::Multi(v)=>v.into_iter().reduce(|x,y|x+y).unwrap()/l as f32,x=>x}
351 }
352 }
353}
354impl<B:Backend> AI<Value<B>,Value<B>> for MseLoss{
355 fn forward(&self,input:Value<B>)->Value<B>{SquaredErrorLayer::new().forward(input)}
356}
357impl<B:Backend> AI<Value<B>,Value<B>> for SquaredErrorLayer{
358 fn forward(&self,input:Value<B>)->Value<B>{
359 match input{
360 Value::Incompatible(e)=>e.into(),
361 Value::Multi(v)=>if v.len()==2{
362 let [output,target]=v.try_into().unwrap();
363 self.forward((output,target))
364 }else{
365 v.into_iter().map(|x|self.forward(x)).collect()
366 },
367 _=>"squared error inputs must be in pairs".into()
368 }
369 }
370}
371impl<B:Backend> AI<Value<B>,Value<B>> for SumLayer{
372 fn forward(&self,input:Value<B>)->Value<B>{
373 fn sum<B:Backend,const N:usize,const K:usize>(d:i32,x:Tensor<B,N>)->Tensor<B,K>{
374 let d=if d<0{N-((-d) as usize)}else{d as usize};
375 x.mean_dim(d).squeeze(d)
376 }
377 let l=input.len();
378
379 if l==0{return input}
380 match self.get_reduction_mode(){
381 ReductionMode::Component=>F1(self.forward(input)),
382 ReductionMode::Dim(d)=>{
383 if let Some(r)=input.rank(){
384 if d>=r as i32||d<(-(r as i32)){return format!("rank {r} is too low to cat along dimension {d}").into()}
385 }
386 match input.float(){F1(x)=>F1(x.sum()),F2(x)=>F1(sum(d,x)),F3(x)=>F2(sum(d,x)),F4(x)=>F3(sum(d,x)),F5(x)=>F4(sum(d,x)),F6(x)=>F5(sum(d,x)),F7(x)=>F6(sum(d,x)),F8(x)=>F7(sum(d,x)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|self.forward(x)).reduce(|x:Value<B>,y:Value<B>|x+y).unwrap(),_=>panic!("internal error")}
387 },
388 ReductionMode::Tensor=>match input.float(){Value::Multi(v)=>v.into_iter().reduce(|x,y|x+y).unwrap(),x=>x}
389 }
390 }
391}
392impl<B:Backend> AI<Value<B>,f32> for MeanLayer{
393 fn forward(&self,input:Value<B>)->f32{
394 let y:Tensor<B,1>=self.forward(input);
395 y.into_scalar().to_f32()
396 }
397}
398impl<B:Backend> AI<Value<B>,f32> for SumLayer{
399 fn forward(&self,input:Value<B>)->f32{
400 let y:Tensor<B,1>=self.forward(input);
401 y.into_scalar().to_f32()
402 }
403}
404impl<B:Backend> AI<Value<B>,Value<B>> for AccQLayer{
405 fn forward(&self,input:Value<B>)->Value<B>{
406 fn acc_q<B:Backend,const N:usize>(dim:i32,gamma:f32,i:Tensor<B,N>)->Tensor<B,N>{
407 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
408 let mut q=i.split(1,dim);
409 q.iter_mut().rev().fold(None,|future,present|{
410 if let Some(f)=future{*present=f*gamma+present.clone()}
411 Some(present.clone())
412 });
413 Tensor::cat(q,dim)
414 }
415 let (dim,gamma)=(self.get_dim(),self.get_gamma());
416
417 match input.float(){F1(x)=>F1(acc_q(dim,gamma,x)),F2(x)=>F2(acc_q(dim,gamma,x)),F3(x)=>F3(acc_q(dim,gamma,x)),F4(x)=>F4(acc_q(dim,gamma,x)),F5(x)=>F5(acc_q(dim,gamma,x)),F6(x)=>F6(acc_q(dim,gamma,x)),F7(x)=>F7(acc_q(dim,gamma,x)),F8(x)=>F8(acc_q(dim,gamma,x)),Value::Incompatible(x)=>x.into(),Value::Multi(x)=>Value::Multi(x.into_iter().map(|x|self.forward(x)).collect()),_=>panic!("unexpected non float value")}
418 }
419}
420impl<B:Backend> AI<Value<B>,u32> for ChooseLayer{
421 fn forward(&self,input:Value<B>)->u32{
422 let (dim,temperature)=(self.get_dim(),self.get_temperature());
423
424 match input.float(){
425 F1(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
426 F2(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
427 F3(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
428 F4(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
429 F5(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
430 F6(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
431 F7(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
432 F8(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
433 Value::Incompatible(e)=>panic!("Could not create scalar due to incompatibility: {e}"),
434 Value::Multi(v)=>if v.len()==1{self.forward(v.into_iter().next().unwrap())}else{panic!("Cannot soft choose one scalar from multiple values")},
435 _=>panic!("internal error")
436 }
437 }
438}
439impl<B:Backend> AI<Value<B>,Vec<u32>> for ChooseLayer{
440 fn forward(&self,input:Value<B>)->Vec<u32>{
441 let (dim,temperature)=(self.get_dim(),self.get_temperature());
442
443 match input.float(){
444 F1(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
445 F2(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
446 F3(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
447 F4(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
448 F5(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
449 F6(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
450 F7(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
451 F8(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
452 Value::Incompatible(e)=>panic!("Could not create vector due to incompatibility: {e}"),
453 Value::Multi(v)=>v.into_iter().flat_map(|x|self.forward_typed::<_,Vec<u32>>(x)).collect(),
454 _=>panic!("internal error")
455 }
456 }
457}
458impl<B:Backend> AI<Value<B>,Value<B>> for Dropout{
459 fn forward(&self,input:Value<B>)->Value<B>{
460 match input.float(){F1(x)=>F1(self.forward(x)),F2(x)=>F2(self.forward(x)),F3(x)=>F3(self.forward(x)),F4(x)=>F4(self.forward(x)),F5(x)=>F5(self.forward(x)),F6(x)=>F6(self.forward(x)),F7(x)=>F7(self.forward(x)),F8(x)=>F8(self.forward(x)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(|x|AI::forward(self,x)).collect()),_=>panic!("internal error")}
461 }
462}
463impl<B:Backend> AI<Value<B>,Value<B>> for Embedding<B>{
464 fn forward(&self,input:Value<B>)->Value<B>{
465 fn apply_embed<B:Backend,const N:usize,const K:usize>(this:&Embedding<B>,x:Tensor<B,N,Int>)->Tensor<B,K>{
466 let dims=x.dims();
467 let [batch,seq]=[dims[0],dims.iter().skip(1).product()];
468 let x=x.reshape([batch,seq]);
469 let y=this.forward(x);
470 let embed=y.dims().last().copied().unwrap();
471 let mut ydims=[0;K];
472 ydims[..N].copy_from_slice(&dims);
473 ydims[N]=embed;
474 y.reshape(ydims)
475 }
476 fn apply_linear<B:Backend,const N:usize>(this:&Embedding<B>,x:Tensor<B,N>)->Tensor<B,N>{
477 Linear{bias:None,weight:this.weight.clone()}.forward(x)
478 }
479 match input{F1(x)=>apply_linear(self,x).into(),F2(x)=>apply_linear(self,x).into(),F3(x)=>apply_linear(self,x).into(),F4(x)=>apply_linear(self,x).into(),F5(x)=>apply_linear(self,x).into(),F6(x)=>apply_linear(self,x).into(),F7(x)=>apply_linear(self,x).into(),F8(x)=>apply_linear(self,x).into(),I1(x)=>apply_embed::<B,1,2>(self,x).into(),I2(x)=>apply_embed::<B,2,3>(self,x).into(),I3(x)=>apply_embed::<B,3,4>(self,x).into(),I4(x)=>apply_embed::<B,4,5>(self,x).into(),I5(x)=>apply_embed::<B,5,6>(self,x).into(),I6(x)=>apply_embed::<B,6,7>(self,x).into(),I7(x)=>apply_embed::<B,7,8>(self,x).into(),I8(_x)=>"embedding output would exceed maximum supported rank".into(),Value::Incompatible(x)=>x.into(),Value::Multi(x)=>x.into_iter().map(|x|AI::forward(self,x)).collect::<Vec<_>>().into(),_=>"embedding is only available for float or int inputs".into()}
480 }
481}
482impl<B:Backend> AI<Value<B>,Value<B>> for LayerNorm<B>{
483 fn forward(&self,input:Value<B>)->Value<B>{
484 fn f<B:Backend,const N:usize>(input:Tensor<B,N>,layer:&LayerNorm<B>)->Value<B>{
485 let b=layer.beta.dims();
486 let g=layer.gamma.dims();
487 let i=input.dims();
488
489 if b!=g{return format!("malformed layer norm. beta dims: {b:?}. gamma dims: {g:?}.").into()}
490 if b.last()!=i.last(){return format!("layer norm for dimension {b:?} is not compatible with input dimensions {i:?}. The last dimension must match the norm dimension.").into()}
491 layer.forward(input).into()
492 }
493 let l=self;
494
495 match input.float(){F1(x)=>f(x,l),F2(x)=>f(x,l),F3(x)=>f(x,l),F4(x)=>f(x,l),F5(x)=>f(x,l),F6(x)=>f(x,l),F7(x)=>f(x,l),F8(x)=>f(x,l),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(|x|AI::forward(self,x)).collect()),_=>panic!("internal error")}
496 }
497}
498impl<B:Backend> AI<Value<B>,Value<B>> for Linear<B>{
499 fn forward(&self,input:Value<B>)->Value<B>{
500 match input.float(){F1(x)=>F1(self.forward(x)),F2(x)=>F2(self.forward(x)),F3(x)=>F3(self.forward(x)),F4(x)=>F4(self.forward(x)),F5(x)=>F5(self.forward(x)),F6(x)=>F6(self.forward(x)),F7(x)=>F7(self.forward(x)),F8(x)=>F8(self.forward(x)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(|x|AI::forward(self,x)).collect()),_=>panic!("internal error")}
501 }
502}
503impl<B:Backend> AI<Value<B>,Value<B>> for Relu{
504 fn forward(&self,input:Value<B>)->Value<B>{
505 match input.float(){F1(x)=>F1(self.forward(x)),F2(x)=>F2(self.forward(x)),F3(x)=>F3(self.forward(x)),F4(x)=>F4(self.forward(x)),F5(x)=>F5(self.forward(x)),F6(x)=>F6(self.forward(x)),F7(x)=>F7(self.forward(x)),F8(x)=>F8(self.forward(x)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(|x|AI::forward(self,x)).collect()),_=>panic!("internal error")}
506 }
507}
508impl<B:Backend> AI<Value<B>,Value<B>> for SoftmaxLayer{
509 fn forward(&self,input:Value<B>)->Value<B>{
510 fn f<B:Backend,const N:usize>(dim:i32,temperature:f32,x:Tensor<B,N>)->Tensor<B,N>{
511 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
512 softmax(x/temperature,dim)
513 }
514 let (dim,temperature)=(self.get_dim(),self.get_temperature());
515
516 match input.float(){F1(x)=>F1(f(dim,temperature,x)),F2(x)=>F2(f(dim,temperature,x)),F3(x)=>F3(f(dim,temperature,x)),F4(x)=>F4(f(dim,temperature,x)),F5(x)=>F5(f(dim,temperature,x)),F6(x)=>F6(f(dim,temperature,x)),F7(x)=>F7(f(dim,temperature,x)),F8(x)=>F8(f(dim,temperature,x)),Value::Incompatible(x)=>x.into(),Value::Multi(x)=>x.into_iter().map(|x|self.forward(x)).collect(),_=>panic!("unexpected non float value")}
517 }
518}
519impl<B:Backend> AI<Value<B>,Value<B>> for ChooseLayer{
520 fn forward(&self,input:Value<B>)->Value<B>{let (dim,temperature)=(self.get_dim(),self.get_temperature());
522 let d=if dim<0{input.rank().unwrap_or(8)-((-dim) as usize)}else{dim as usize};
523
524 match input.float(){
525 F1(x)=>I1(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}),
526 F2(x)=>I1(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze(d)),
527 F3(x)=>I2(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze(d)),
528 F4(x)=>I3(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze(d)),
529 F5(x)=>I4(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze(d)),
530 F6(x)=>I5(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze(d)),
531 F7(x)=>I6(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze(d)),
532 F8(x)=>I7(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze(d)),
533 Value::Incompatible(e)=>e.into(),
534 Value::Multi(v)=>Value::Multi(v.into_iter().map(|v|self.forward_typed::<_,Value<B>>(v)).collect()),
535 _=>panic!("internal error")}
536 }
537}
538impl<B:Backend> AI<Value<B>,Value<B>> for MaxPool2d{
539 fn forward(&self,input:Value<B>)->Value<B>{
540 fn f<B:Backend,const N:usize>(pool:&MaxPool2d,x:Tensor<B,N>)->Value<B>{
541 match N{
542 0=>panic!("internal error"),
543 1=>f::<B,2>(pool,x.unsqueeze()).squeeze(),
544 2=>f::<B,3>(pool,x.unsqueeze()).squeeze(),
545 3=>f::<B,4>(pool,x.unsqueeze()).squeeze(),
546 4=>pool.forward(Value::from(x).unwrap_f4()).into(),
547 _=>{
548 let mut dims=x.dims();
549
550 let [channels,h,w]=[dims[N-3],dims[N-2],dims[N-1]];
551 let big:usize=dims.iter().take(N-3).product();
552 let y=x.reshape([big,channels,h,w]);
553
554 dims[N-3..].copy_from_slice(&y.dims()[1..]);
555
556 let y=pool.forward(y);
557 y.reshape(dims).into()
558 }
559 }
560 }
561 match input.float(){
562 F1(x)=>f(self,x),
563 F2(x)=>f(self,x),
564 F3(x)=>f(self,x),
565 F4(x)=>f(self,x),
566 F5(x)=>f(self,x),
567 F6(x)=>f(self,x),
568 F7(x)=>f(self,x),
569 F8(x)=>f(self,x),
570 Value::Incompatible(e)=>e.into(),
571 Value::Multi(v)=>v.into_iter().map(|x|AI::forward(self,x)).collect(),
572 _=>panic!("Internal error")
573 }
574 }
575}
576impl<B:Backend> AI<Value<B>,Value<B>> for RotaryEncoding<B>{
577 fn forward(&self,input:Value<B>)->Value<B>{AI::forward(self,(input,0)).0}
578}
579impl<B:Backend> AI<(Value<B>,usize),(Value<B>,usize)> for RotaryEncoding<B>{
580 fn forward(&self,(input,offset):(Value<B>,usize))->(Value<B>,usize){
581 fn apply<B:Backend,const D:usize>(a:&RotaryEncoding<B>,input:Tensor<B,D>,offset:usize)->Value<B>{
582 assert!(D>=2);
583 const MAX_KERNEL:usize=65535; let device=input.device();
585 let freq=&a.freq_complex;
586 let shape=input.shape();
587
588 let (context,key)=(shape.dims[D-2],shape.dims[D-1]);
589 let [distance,head,_2]=freq.dims();
590
591 if context>distance{return "context length must not exceed rotary distance".into()}
592 if key%head!=0{return "input dimension must be a multiple of head".into()}
593 let count=shape.num_elements();
594 let big=count/(context*key);
595 let heads=key/head;
596 let group=count/head; let input=input.reshape([big,context,heads,head]).swap_dims(1,2).reshape([big*heads,context,head/2,2]);
598 let sign=Tensor::<B,2>::from_floats([[1.0,0.0,0.0,1.0],[0.0,-1.0,1.0,0.0]],&device).unsqueeze();
599
600 let chunks=input.chunk(group.div_ceil(MAX_KERNEL),0).into_iter().map(|x|{
601 let smaller=x.dims()[0];
602 let x=x.matmul(sign.clone()).reshape([smaller,context,head,2])*freq.clone().slice([offset..context+offset]).unsqueeze();
603 x.sum_dim(3)
604 }).collect();
605 Tensor::cat(chunks,0).reshape([big,heads,context,head]).swap_dims(1,2).reshape::<D,_>(shape).into()
606 }
607
608 (match input.float(){F1(x)=>apply(self,x.unsqueeze::<2>(),offset).squeeze(),F2(x)=>apply(self,x,offset),F3(x)=>apply(self,x,offset),F4(x)=>apply(self,x,offset),F5(x)=>apply(self,x,offset),F6(x)=>apply(self,x,offset),F7(x)=>apply(self,x,offset),F8(x)=>apply(self,x,offset),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|AI::forward(self,(x,offset)).0).collect(),_=>panic!("internal error")},offset)
609 }
610}
611impl<B:Backend> AI<Value<B>,Value<B>> for Tanh{
612 fn forward(&self,input:Value<B>)->Value<B>{
613 match input.float(){F1(x)=>F1(self.forward(x)),F2(x)=>F2(self.forward(x)),F3(x)=>F3(self.forward(x)),F4(x)=>F4(self.forward(x)),F5(x)=>F5(self.forward(x)),F6(x)=>F6(self.forward(x)),F7(x)=>F7(self.forward(x)),F8(x)=>F8(self.forward(x)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(|x|AI::forward(self,x)).collect()),_=>panic!("internal error")}
614 }
615}
616impl<B:Backend> Abs for Value<B>{
617 fn abs(self)->Self::Output{
618 match self{B1(x)=>B1(x),B2(x)=>B2(x),B3(x)=>B3(x),B4(x)=>B4(x),B5(x)=>B5(x),B6(x)=>B6(x),B7(x)=>B7(x),B8(x)=>B8(x),F1(x)=>F1(x.abs()),F2(x)=>F2(x.abs()),F3(x)=>F3(x.abs()),F4(x)=>F4(x.abs()),F5(x)=>F5(x.abs()),F6(x)=>F6(x.abs()),F7(x)=>F7(x.abs()),F8(x)=>F8(x.abs()),I1(x)=>I1(x.abs()),I2(x)=>I2(x.abs()),I3(x)=>I3(x.abs()),I4(x)=>I4(x.abs()),I5(x)=>I5(x.abs()),I6(x)=>I6(x.abs()),I7(x)=>I7(x.abs()),I8(x)=>I8(x.abs()),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(Value::abs).collect()}
619 }
620 type Output=Value<B>;
621}
622impl<B:Backend> AsRef<Self> for Value<B>{
623 fn as_ref(&self)->&Self{self}
624}
625impl<B:Backend> Cat for Value<B>{
626 fn cat(self,d:i32)->Self{
628 fn f<B:Backend,I:Iterator<Item=Tensor<B,N,K>>,K:BasicOps<B>+TensorKind<B>,const N:usize>(d:i32,x0:Tensor<B,N,K>,tensors:I)->Value<B> where Tensor<B,N,K>:Into<Value<B>>{
629 if d>=N as i32||d<(-(N as i32)){return format!("rank {N} is too low to cat along dimension {d}").into()}
630 let d=if d<0{N-((-d) as usize)}else{d as usize};
631 let shape=x0.dims();
632 let tensors:Vec<Tensor<B,N,K>>=once(x0).chain(tensors).collect();
633
634 if let Err(e)=tensors.iter().try_for_each(|x|{
635 let mut xshape=x.dims();
636 xshape[d]=shape[d];
637 if shape==xshape{Ok(())}else{Err("mismatched shapes {shape:?}, {xshape:?}")}
638 }){
639 return e.into()
640 }
641
642 Tensor::cat(tensors,d).into()
643 }
644 let v=if let Value::Multi(v)=self{v}else{return self};
645
646 if let Some(n)=v.iter().position(Value::is_incompatible){return v.into_iter().nth(n).unwrap()}
647 if v.iter().all(Value::is_multi){return v.into_iter().map(|x|x.cat(d)).collect()}
648 if v.iter().any(Value::is_multi){return "cannot mix single and multi values in a cat operation".into()}
649 let variant=mem::discriminant(&v[0]);
650
651 if v.iter().any(|x|mem::discriminant(x)!=variant){return "cannot mix variants in a cat operation".into()}
652 let mut v=v.into_iter();
653
654 match v.next().unwrap(){B1(x0)=>f(d,x0,v.map(Value::unwrap_b1)),B2(x0)=>f(d,x0,v.map(Value::unwrap_b2)),B3(x0)=>f(d,x0,v.map(Value::unwrap_b3)),B4(x0)=>f(d,x0,v.map(Value::unwrap_b4)),B5(x0)=>f(d,x0,v.map(Value::unwrap_b5)),B6(x0)=>f(d,x0,v.map(Value::unwrap_b6)),B7(x0)=>f(d,x0,v.map(Value::unwrap_b7)),B8(x0)=>f(d,x0,v.map(Value::unwrap_b8)),F1(x0)=>f(d,x0,v.map(Value::unwrap_f1)),F2(x0)=>f(d,x0,v.map(Value::unwrap_f2)),F3(x0)=>f(d,x0,v.map(Value::unwrap_f3)),F4(x0)=>f(d,x0,v.map(Value::unwrap_f4)),F5(x0)=>f(d,x0,v.map(Value::unwrap_f5)),F6(x0)=>f(d,x0,v.map(Value::unwrap_f6)),F7(x0)=>f(d,x0,v.map(Value::unwrap_f7)),F8(x0)=>f(d,x0,v.map(Value::unwrap_f8)),I1(x0)=>f(d,x0,v.map(Value::unwrap_i1)),I2(x0)=>f(d,x0,v.map(Value::unwrap_i2)),I3(x0)=>f(d,x0,v.map(Value::unwrap_i3)),I4(x0)=>f(d,x0,v.map(Value::unwrap_i4)),I5(x0)=>f(d,x0,v.map(Value::unwrap_i5)),I6(x0)=>f(d,x0,v.map(Value::unwrap_i6)),I7(x0)=>f(d,x0,v.map(Value::unwrap_i7)),I8(x0)=>f(d,x0,v.map(Value::unwrap_i8)),Value::Incompatible(_e)=>panic!("internal error not handled in correct location"),Value::Multi(_e)=>panic!("internal error not handled in correct location")}
655 }
656 type Output=Self;
657}
658impl<B:Backend> Decompose for LossOutput<B>{
659 fn compose((loss,output,target):Self::Decomposition)->Self{Self::new(loss,output,target)}
660 fn decompose(self)->Self::Decomposition{(self.loss(),self.output(),self.target())}
661 fn decompose_cloned(&self)->Self::Decomposition{(self.loss(),self.output(),self.target())}
662 type Decomposition=(Value<B>,Value<B>,Value<B>);
663}
664impl<B:Backend> Decompose for Value<B>{
665 fn compose(decomposition:Self::Decomposition)->Self{decomposition}
666 fn decompose(self)->Self::Decomposition{self}
667 fn decompose_cloned(&self)->Self::Decomposition{self.clone()}
668 type Decomposition=Self;
669}
670impl<B:Backend> Default for Value<B>{
671 fn default()->Self{Self::Multi(Vec::new())}
672}
673impl<B:Backend> Display for Value<B>{
674 fn fmt(&self,f:&mut std::fmt::Formatter<'_>)->FmtResult{
675 match self{
676 B1(x)=>x.fmt(f),
677 B2(x)=>x.fmt(f),
678 B3(x)=>x.fmt(f),
679 B4(x)=>x.fmt(f),
680 B5(x)=>x.fmt(f),
681 B6(x)=>x.fmt(f),
682 B7(x)=>x.fmt(f),
683 B8(x)=>x.fmt(f),
684 F1(x)=>x.fmt(f),
685 F2(x)=>x.fmt(f),
686 F3(x)=>x.fmt(f),
687 F4(x)=>x.fmt(f),
688 F5(x)=>x.fmt(f),
689 F6(x)=>x.fmt(f),
690 F7(x)=>x.fmt(f),
691 F8(x)=>x.fmt(f),
692 I1(x)=>x.fmt(f),
693 I2(x)=>x.fmt(f),
694 I3(x)=>x.fmt(f),
695 I4(x)=>x.fmt(f),
696 I5(x)=>x.fmt(f),
697 I6(x)=>x.fmt(f),
698 I7(x)=>x.fmt(f),
699 I8(x)=>x.fmt(f),
700 Value::Incompatible(e)=>e.fmt(f),
701 Value::Multi(v)=>{
702 write!(f,"[")?;
703 v.iter().take(v.len().saturating_sub(1)).try_for_each(|x|{
704 x.fmt(f)?;
705 write!(f,", ")
706 })?;
707 if let Some(x)=v.last(){
708 x.fmt(f)?;
709 }
710 write!(f,"]")
711 }
712 }
713 }
714}
715impl<B:Backend> From<Vec<bool>> for Value<B>{
716 fn from(value:Vec<bool>)->Self{
717 let l=value.len();
718 let t:Tensor<B,1,Bool>=Tensor::from_data(TensorData::new(value,[l]),&Default::default());
719
720 t.into()
721 }
722}
723impl<B:Backend> From<Vec<f32>> for Value<B>{
724 fn from(value:Vec<f32>)->Self{
725 let l=value.len();
726 let t:Tensor<B,1>=Tensor::from_data(TensorData::new(value,[l]),&Default::default());
727
728 t.into()
729 }
730}
731impl<B:Backend> From<Vec<i32>> for Value<B>{
732 fn from(value:Vec<i32>)->Self{
733 let l=value.len();
734 let t:Tensor<B,1,Int>=Tensor::from_data(TensorData::new(value,[l]),&Default::default());
735
736 t.into()
737 }
738}
739impl<B:Backend> From<Vec<u32>> for Value<B>{
740 fn from(value:Vec<u32>)->Self{
741 let l=value.len();
742 let t:Tensor<B,1,Int>=Tensor::from_data(TensorData::new(value,[l]),&Default::default());
743
744 t.into()
745 }
746}
747impl<B:Backend> ModuleDisplay for Value<B>{
748 fn custom_content(&self,_content:Content)->Option<Content>{None}
749 fn custom_settings(&self)->Option<DisplaySettings>{None}
750 fn format(&self,s:DisplaySettings)->String{
751 match self{
752 B1(x)=>x.format(s),
753 B2(x)=>x.format(s),
754 B3(x)=>x.format(s),
755 B4(x)=>x.format(s),
756 B5(x)=>x.format(s),
757 B6(x)=>x.format(s),
758 B7(x)=>x.format(s),
759 B8(x)=>x.format(s),
760 F1(x)=>x.format(s),
761 F2(x)=>x.format(s),
762 F3(x)=>x.format(s),
763 F4(x)=>x.format(s),
764 F5(x)=>x.format(s),
765 F6(x)=>x.format(s),
766 F7(x)=>x.format(s),
767 F8(x)=>x.format(s),
768 I1(x)=>x.format(s),
769 I2(x)=>x.format(s),
770 I3(x)=>x.format(s),
771 I4(x)=>x.format(s),
772 I5(x)=>x.format(s),
773 I6(x)=>x.format(s),
774 I7(x)=>x.format(s),
775 I8(x)=>x.format(s),
776 Value::Incompatible(e)=>e.to_string(),
777 Value::Multi(v)=>"[".chars().chain(v.iter().flat_map(|x|{
778 let x:Vec<char>=x.format(s.clone()).chars().chain(", ".chars()).collect();
779 x
780 })).chain("]".chars()).collect()
781 }
782 }
783}
784impl<B:Backend> ModuleDisplayDefault for Value<B>{
785 fn content(&self,content:Content)->Option<Content>{Some(content)}
786 fn num_params(&self)->usize{Module::num_params(self)}
787}
788impl<B:Backend> From<String> for Value<B>{
789 fn from(value:String)->Self{Self::Incompatible(value)}
790}
791impl<B:Backend> From<Value<B>> for ValueData{
792 fn from(value:Value<B>)->Self{
793 match value{B1(x)=>BX(x.into_data()),B2(x)=>BX(x.into_data()),B3(x)=>BX(x.into_data()),B4(x)=>BX(x.into_data()),B5(x)=>BX(x.into_data()),B6(x)=>BX(x.into_data()),B7(x)=>BX(x.into_data()),B8(x)=>BX(x.into_data()),F1(x)=>FX(x.into_data()),F2(x)=>FX(x.into_data()),F3(x)=>FX(x.into_data()),F4(x)=>FX(x.into_data()),F5(x)=>FX(x.into_data()),F6(x)=>FX(x.into_data()),F7(x)=>FX(x.into_data()),F8(x)=>FX(x.into_data()),I1(x)=>IX(x.into_data()),I2(x)=>IX(x.into_data()),I3(x)=>IX(x.into_data()),I4(x)=>IX(x.into_data()),I5(x)=>IX(x.into_data()),I6(x)=>IX(x.into_data()),I7(x)=>IX(x.into_data()),I8(x)=>IX(x.into_data()),Value::Incompatible(e)=>ValueData::Incompatible(e),Value::Multi(v)=>ValueData::Multi(v.into_iter().map(ValueData::from).collect())}
794 }
795}
796impl<B:Backend> From<ValueData> for Value<B>{
797 fn from(value:ValueData)->Self{
798 let device=Default::default();
799 match value{
800 BX(data)=>match data.shape.len(){1=>B1(Tensor::from_data(data,&device)),2=>B2(Tensor::from_data(data,&device)),3=>B3(Tensor::from_data(data,&device)),4=>B4(Tensor::from_data(data,&device)),5=>B5(Tensor::from_data(data,&device)),6=>B6(Tensor::from_data(data,&device)),7=>B7(Tensor::from_data(data,&device)),8=>B8(Tensor::from_data(data,&device)),_=>panic!("tensor ranks above 8 are currently not supported")},
801 FX(data)=>match data.shape.len(){1=>F1(Tensor::from_data(data,&device)),2=>F2(Tensor::from_data(data,&device)),3=>F3(Tensor::from_data(data,&device)),4=>F4(Tensor::from_data(data,&device)),5=>F5(Tensor::from_data(data,&device)),6=>F6(Tensor::from_data(data,&device)),7=>F7(Tensor::from_data(data,&device)),8=>F8(Tensor::from_data(data,&device)),_=>panic!("tensor ranks above 8 are currently not supported")},
802 IX(data)=>match data.shape.len(){1=>I1(Tensor::from_data(data,&device)),2=>I2(Tensor::from_data(data,&device)),3=>I3(Tensor::from_data(data,&device)),4=>I4(Tensor::from_data(data,&device)),5=>I5(Tensor::from_data(data,&device)),6=>I6(Tensor::from_data(data,&device)),7=>I7(Tensor::from_data(data,&device)),8=>I8(Tensor::from_data(data,&device)),_=>panic!("tensor ranks above 8 are currently not supported")},
803 ValueData::Incompatible(e)=>e.into(),
804 ValueData::Multi(v)=>v.into_iter().map(Value::from).collect(),
805 }
806 }
807}
808impl<B:Backend> From<Vec<Value<B>>> for Value<B>{
809 fn from(value:Vec<Value<B>>)->Self{Self::Multi(value)}
810}
811impl<B:Backend> IntoIterator for Value<B>{
812 fn into_iter(self)->Self::IntoIter{self.into_multi().into_iter()}
813 type IntoIter=VecIntoIter<Value<B>>;
814 type Item=Value<B>;
815}
816impl<B:Backend> LossOutput<B>{
817 pub fn loss(&self)->Value<B>{self.loss.clone()}
819 pub fn new(loss:Value<B>,output:Value<B>,target:Value<B>)->Self{
821 Self{loss,output,target}
822 }
823 pub fn output(&self)->Value<B>{self.output.clone()}
825 pub fn target(&self)->Value<B>{self.target.clone()}
827}
828impl<B:Backend> Merge for Value<B>{
829 fn merge(&mut self,other:Self){
830 match (mem::take(self),other){
831 (Value::Multi(mut u),Value::Multi(v))=>{
832 u.extend(v);
833 *self=u.into();
834 },
835 (Value::Multi(mut u),v)=>if u.len()==0{
836 *self=v;
837 }else{
838 u.push(v);
839 *self=u.into();
840 },
841 (u,Value::Multi(mut v))=>if v.len()==0{
842 *self=u;
843 }else{
844 v.insert(0,u);
845 *self=v.into();
846 },
847 (u,v)=>*self=vec![u,v].into()
848 }
849 }
850}
851impl<B:Backend> Module<B> for Value<B>{
852 fn collect_devices(&self,devices:Vec<<B as Backend>::Device>)->Vec<<B as Backend>::Device>{
853 match self{B1(x)=>x.collect_devices(devices),B2(x)=>x.collect_devices(devices),B3(x)=>x.collect_devices(devices),B4(x)=>x.collect_devices(devices),B5(x)=>x.collect_devices(devices),B6(x)=>x.collect_devices(devices),B7(x)=>x.collect_devices(devices),B8(x)=>x.collect_devices(devices),F1(x)=>x.collect_devices(devices),F2(x)=>x.collect_devices(devices),F3(x)=>x.collect_devices(devices),F4(x)=>x.collect_devices(devices),F5(x)=>x.collect_devices(devices),F6(x)=>x.collect_devices(devices),F7(x)=>x.collect_devices(devices),F8(x)=>x.collect_devices(devices),I1(x)=>x.collect_devices(devices),I2(x)=>x.collect_devices(devices),I3(x)=>x.collect_devices(devices),I4(x)=>x.collect_devices(devices),I5(x)=>x.collect_devices(devices),I6(x)=>x.collect_devices(devices),I7(x)=>x.collect_devices(devices),I8(x)=>x.collect_devices(devices),Value::Incompatible(_e)=>devices,Value::Multi(v)=>v.iter().fold(devices,|devices,x|x.collect_devices(devices))}
854 }
855 fn devices(&self)->Vec<<B as Backend>::Device>{self.collect_devices(Vec::new())}
856 fn fork(self,device:&<B as Backend>::Device)->Self{
857 match self{B1(x)=>B1(x.fork(device)),B2(x)=>B2(x.fork(device)),B3(x)=>B3(x.fork(device)),B4(x)=>B4(x.fork(device)),B5(x)=>B5(x.fork(device)),B6(x)=>B6(x.fork(device)),B7(x)=>B7(x.fork(device)),B8(x)=>B8(x.fork(device)),F1(x)=>F1(x.fork(device)),F2(x)=>F2(x.fork(device)),F3(x)=>F3(x.fork(device)),F4(x)=>F4(x.fork(device)),F5(x)=>F5(x.fork(device)),F6(x)=>F6(x.fork(device)),F7(x)=>F7(x.fork(device)),F8(x)=>F8(x.fork(device)),I1(x)=>I1(x.fork(device)),I2(x)=>I2(x.fork(device)),I3(x)=>I3(x.fork(device)),I4(x)=>I4(x.fork(device)),I5(x)=>I5(x.fork(device)),I6(x)=>I6(x.fork(device)),I7(x)=>I7(x.fork(device)),I8(x)=>I8(x.fork(device)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.fork(device)).collect()}
858 }
859 fn into_record(self)->Self::Record{ConstantRecord}
860 fn load_file<F:FileRecorder<B>,P:Into<PathBuf>>(self,_filepath:P,_recorder:&F,_device:&<B as Backend>::Device)->Result<Self,RecorderError>{Ok(self)}
861 fn load_record(self,_record:Self::Record)->Self{self}
862 fn map<Mapper:ModuleMapper<B>>(self,mapper:&mut Mapper)->Self{
863 match self{B1(x)=>B1(x.map(mapper)),B2(x)=>B2(x.map(mapper)),B3(x)=>B3(x.map(mapper)),B4(x)=>B4(x.map(mapper)),B5(x)=>B5(x.map(mapper)),B6(x)=>B6(x.map(mapper)),B7(x)=>B7(x.map(mapper)),B8(x)=>B8(x.map(mapper)),F1(x)=>F1(x.map(mapper)),F2(x)=>F2(x.map(mapper)),F3(x)=>F3(x.map(mapper)),F4(x)=>F4(x.map(mapper)),F5(x)=>F5(x.map(mapper)),F6(x)=>F6(x.map(mapper)),F7(x)=>F7(x.map(mapper)),F8(x)=>F8(x.map(mapper)),I1(x)=>I1(x.map(mapper)),I2(x)=>I2(x.map(mapper)),I3(x)=>I3(x.map(mapper)),I4(x)=>I4(x.map(mapper)),I5(x)=>I5(x.map(mapper)),I6(x)=>I6(x.map(mapper)),I7(x)=>I7(x.map(mapper)),I8(x)=>I8(x.map(mapper)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.map(mapper)).collect()}
864 }
865 fn num_params(&self)->usize{
866 match self{B1(x)=>Module::num_params(x),B2(x)=>Module::num_params(x),B3(x)=>Module::num_params(x),B4(x)=>Module::num_params(x),B5(x)=>Module::num_params(x),B6(x)=>Module::num_params(x),B7(x)=>Module::num_params(x),B8(x)=>Module::num_params(x),F1(x)=>Module::num_params(x),F2(x)=>Module::num_params(x),F3(x)=>Module::num_params(x),F4(x)=>Module::num_params(x),F5(x)=>Module::num_params(x),F6(x)=>Module::num_params(x),F7(x)=>Module::num_params(x),F8(x)=>Module::num_params(x),I1(x)=>Module::num_params(x),I2(x)=>Module::num_params(x),I3(x)=>Module::num_params(x),I4(x)=>Module::num_params(x),I5(x)=>Module::num_params(x),I6(x)=>Module::num_params(x),I7(x)=>Module::num_params(x),I8(x)=>Module::num_params(x),Value::Incompatible(_e)=>0,Value::Multi(v)=>v.into_iter().map(|x|Module::num_params(x)).sum()}
867 }
868 fn quantize_weights(self,quantizer:&mut Quantizer)->Self{
869 match self{B1(x)=>B1(x.quantize_weights(quantizer)),B2(x)=>B2(x.quantize_weights(quantizer)),B3(x)=>B3(x.quantize_weights(quantizer)),B4(x)=>B4(x.quantize_weights(quantizer)),B5(x)=>B5(x.quantize_weights(quantizer)),B6(x)=>B6(x.quantize_weights(quantizer)),B7(x)=>B7(x.quantize_weights(quantizer)),B8(x)=>B8(x.quantize_weights(quantizer)),F1(x)=>F1(x.quantize_weights(quantizer)),F2(x)=>F2(x.quantize_weights(quantizer)),F3(x)=>F3(x.quantize_weights(quantizer)),F4(x)=>F4(x.quantize_weights(quantizer)),F5(x)=>F5(x.quantize_weights(quantizer)),F6(x)=>F6(x.quantize_weights(quantizer)),F7(x)=>F7(x.quantize_weights(quantizer)),F8(x)=>F8(x.quantize_weights(quantizer)),I1(x)=>I1(x.quantize_weights(quantizer)),I2(x)=>I2(x.quantize_weights(quantizer)),I3(x)=>I3(x.quantize_weights(quantizer)),I4(x)=>I4(x.quantize_weights(quantizer)),I5(x)=>I5(x.quantize_weights(quantizer)),I6(x)=>I6(x.quantize_weights(quantizer)),I7(x)=>I7(x.quantize_weights(quantizer)),I8(x)=>I8(x.quantize_weights(quantizer)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.quantize_weights(quantizer)).collect()}
870 }
871 fn save_file<F:FileRecorder<B>,P:Into<PathBuf>>(self,_filepath:P,_recorder:&F)->Result<(),RecorderError>{
872 Ok(())
873 }
874 fn to_device(self,device:&<B as Backend>::Device)->Self{
875 match self{B1(x)=>B1(x.to_device(device)),B2(x)=>B2(x.to_device(device)),B3(x)=>B3(x.to_device(device)),B4(x)=>B4(x.to_device(device)),B5(x)=>B5(x.to_device(device)),B6(x)=>B6(x.to_device(device)),B7(x)=>B7(x.to_device(device)),B8(x)=>B8(x.to_device(device)),F1(x)=>F1(x.to_device(device)),F2(x)=>F2(x.to_device(device)),F3(x)=>F3(x.to_device(device)),F4(x)=>F4(x.to_device(device)),F5(x)=>F5(x.to_device(device)),F6(x)=>F6(x.to_device(device)),F7(x)=>F7(x.to_device(device)),F8(x)=>F8(x.to_device(device)),I1(x)=>I1(x.to_device(device)),I2(x)=>I2(x.to_device(device)),I3(x)=>I3(x.to_device(device)),I4(x)=>I4(x.to_device(device)),I5(x)=>I5(x.to_device(device)),I6(x)=>I6(x.to_device(device)),I7(x)=>I7(x.to_device(device)),I8(x)=>I8(x.to_device(device)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.to_device(device)).collect()}
876 }
877 fn visit<Visitor:ModuleVisitor<B>>(&self,visitor:&mut Visitor){
878 match self{B1(x)=>x.visit(visitor),B2(x)=>x.visit(visitor),B3(x)=>x.visit(visitor),B4(x)=>x.visit(visitor),B5(x)=>x.visit(visitor),B6(x)=>x.visit(visitor),B7(x)=>x.visit(visitor),B8(x)=>x.visit(visitor),F1(x)=>x.visit(visitor),F2(x)=>x.visit(visitor),F3(x)=>x.visit(visitor),F4(x)=>x.visit(visitor),F5(x)=>x.visit(visitor),F6(x)=>x.visit(visitor),F7(x)=>x.visit(visitor),F8(x)=>x.visit(visitor),I1(x)=>x.visit(visitor),I2(x)=>x.visit(visitor),I3(x)=>x.visit(visitor),I4(x)=>x.visit(visitor),I5(x)=>x.visit(visitor),I6(x)=>x.visit(visitor),I7(x)=>x.visit(visitor),I8(x)=>x.visit(visitor),Value::Incompatible(_e)=>(),Value::Multi(v)=>v.iter().for_each(|x|x.visit(visitor))}
879 }
880 type Record=ConstantRecord;
881}
882impl<B:Backend> Serialize for Value<B>{
883 fn serialize<S:Serializer>(&self,serializer:S)->Result<S::Ok,S::Error>{ValueData::from(self.clone()).serialize(serializer)}
884}
885impl<B:Backend> Squeeze for Value<B>{
886 fn squeeze(self,d:i32)->Self{self.squeeze_dim(d)}
887 type Output=Self;
888}
889impl<B:Backend> Stack for Value<B>{
890 fn stack(self,d:i32)->Self{self.unsqueeze_dim(d).cat(d)}
892 type Output=Self;
893}
894impl<B:Backend> Unsqueeze for Value<B>{
895 fn unsqueeze(self,d:i32)->Self{self.unsqueeze_dim(d)}
896 type Output=Self;
897}
898impl<B:Backend> Value<B>{pub fn all(self)->Value<B>{
901 fn f<B:Backend,const N:usize>(x:Tensor<B,N,Bool>)->Value<B>{x.all().into()}
902 match self.bool(){B1(x)=>f(x),B2(x)=>f(x),B3(x)=>f(x),B4(x)=>f(x),B5(x)=>f(x),B6(x)=>f(x),B7(x)=>f(x),B8(x)=>f(x),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(Value::all).collect(),_=>panic!("internal error")}
903 }
904 pub fn all_dim(self,d:i32)->Value<B>{
906 fn f<B:Backend,const N:usize>(d:i32,x:Tensor<B,N,Bool>)->Value<B>{
907 if d>=N as i32||d<(-(N as i32)){return format!("rank {N} is too low to all along dimension {d}").into()}
908 let d=if d<0{N-((-d) as usize)}else{d as usize};
909 x.all_dim(d).into()
910 }
911 match self.bool(){B1(x)=>f(d,x),B2(x)=>f(d,x),B3(x)=>f(d,x),B4(x)=>f(d,x),B5(x)=>f(d,x),B6(x)=>f(d,x),B7(x)=>f(d,x),B8(x)=>f(d,x),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|v|v.all_dim(d)).collect(),_=>panic!("internal error")}
912 }
913 pub fn any(self)->Value<B>{
915 fn f<B:Backend,const N:usize>(x:Tensor<B,N,Bool>)->Value<B>{x.any().into()}
916 match self.bool(){B1(x)=>f(x),B2(x)=>f(x),B3(x)=>f(x),B4(x)=>f(x),B5(x)=>f(x),B6(x)=>f(x),B7(x)=>f(x),B8(x)=>f(x),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(Value::any).collect(),_=>panic!("internal error")}
917 }
918 pub fn any_dim(self,d:i32)->Value<B>{
920 fn f<B:Backend,const N:usize>(d:i32,x:Tensor<B,N,Bool>)->Value<B>{
921 if d>=N as i32||d<(-(N as i32)){return format!("rank {N} is too low to any along dimension {d}").into()}
922 let d=if d<0{N-((-d) as usize)}else{d as usize};
923 x.any_dim(d).into()
924 }
925 match self.bool(){B1(x)=>f(d,x),B2(x)=>f(d,x),B3(x)=>f(d,x),B4(x)=>f(d,x),B5(x)=>f(d,x),B6(x)=>f(d,x),B7(x)=>f(d,x),B8(x)=>f(d,x),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|v|v.any_dim(d)).collect(),_=>panic!("internal error")}
926 }
927 pub fn bool(self)->Value<B>{
929 match self{B1(x)=>B1(x),B2(x)=>B2(x),B3(x)=>B3(x),B4(x)=>B4(x),B5(x)=>B5(x),B6(x)=>B6(x),B7(x)=>B7(x),B8(x)=>B8(x),F1(x)=>B1(x.bool()),F2(x)=>B2(x.bool()),F3(x)=>B3(x.bool()),F4(x)=>B4(x.bool()),F5(x)=>B5(x.bool()),F6(x)=>B6(x.bool()),F7(x)=>B7(x.bool()),F8(x)=>B8(x.bool()),I1(x)=>B1(x.bool()),I2(x)=>B2(x.bool()),I3(x)=>B3(x.bool()),I4(x)=>B4(x.bool()),I5(x)=>B5(x.bool()),I6(x)=>B6(x.bool()),I7(x)=>B7(x.bool()),I8(x)=>B8(x.bool()),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(Value::bool).collect())}
930 }
931 pub fn count(&self)->usize{
933 match self{
934 B1(x)=>x.dims().iter().product(),
935 B2(x)=>x.dims().iter().product(),
936 B3(x)=>x.dims().iter().product(),
937 B4(x)=>x.dims().iter().product(),
938 B5(x)=>x.dims().iter().product(),
939 B6(x)=>x.dims().iter().product(),
940 B7(x)=>x.dims().iter().product(),
941 B8(x)=>x.dims().iter().product(),
942 F1(x)=>x.dims().iter().product(),
943 F2(x)=>x.dims().iter().product(),
944 F3(x)=>x.dims().iter().product(),
945 F4(x)=>x.dims().iter().product(),
946 F5(x)=>x.dims().iter().product(),
947 F6(x)=>x.dims().iter().product(),
948 F7(x)=>x.dims().iter().product(),
949 F8(x)=>x.dims().iter().product(),
950 I1(x)=>x.dims().iter().product(),
951 I2(x)=>x.dims().iter().product(),
952 I3(x)=>x.dims().iter().product(),
953 I4(x)=>x.dims().iter().product(),
954 I5(x)=>x.dims().iter().product(),
955 I6(x)=>x.dims().iter().product(),
956 I7(x)=>x.dims().iter().product(),
957 I8(x)=>x.dims().iter().product(),
958 Value::Incompatible(_e)=>0,
959 Value::Multi(v)=>v.iter().map(Value::count).sum()
960 }
961 }
962 pub fn empty()->Self{Self::Multi(Vec::new())}
964 pub fn flatten_values(self)->Self{
966 fn f<B:Backend>(mut acc:Vec<Value<B>>,x:Value<B>)->Vec<Value<B>>{
967 if x.is_multi(){acc=x.into_iter().fold(acc,|acc,x|f(acc,x))}else{acc.push(x)}
968 acc
969 }
970 f(Vec::new(),self).into()
971 }
972 pub fn float(self)->Value<B>{
974 match self{B1(x)=>F1(x.float()),B2(x)=>F2(x.float()),B3(x)=>F3(x.float()),B4(x)=>F4(x.float()),B5(x)=>F5(x.float()),B6(x)=>F6(x.float()),B7(x)=>F7(x.float()),B8(x)=>F8(x.float()),F1(x)=>F1(x),F2(x)=>F2(x),F3(x)=>F3(x),F4(x)=>F4(x),F5(x)=>F5(x),F6(x)=>F6(x),F7(x)=>F7(x),F8(x)=>F8(x),I1(x)=>F1(x.float()),I2(x)=>F2(x.float()),I3(x)=>F3(x.float()),I4(x)=>F4(x.float()),I5(x)=>F5(x.float()),I6(x)=>F6(x.float()),I7(x)=>F7(x.float()),I8(x)=>F8(x.float()),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(Value::float).collect())}
975 }
976 pub fn from_values<I:IntoIterator<Item=Self>,S:Into<Shape>>(inner:I,shape:S)->Self{
978 fn f<B:Backend,I:Iterator<Item=Value<B>>,S:Into<Shape>>(inner:&mut I,shape:S)->Value<B>{
979 match shape.into(){
980 Shape::Incompatible(e)=>e.into(),
981 Shape::Multi(_l)=>inner.collect(),
982 Shape::Recursive(v)=>v.into_iter().map(|s|f(&mut *inner,s)).collect(),
983 X1(_s)=>inner.next().unwrap_or_default(),
984 X2(_s)=>inner.next().unwrap_or_default(),
985 X3(_s)=>inner.next().unwrap_or_default(),
986 X4(_s)=>inner.next().unwrap_or_default(),
987 X5(_s)=>inner.next().unwrap_or_default(),
988 X6(_s)=>inner.next().unwrap_or_default(),
989 X7(_s)=>inner.next().unwrap_or_default(),
990 X8(_s)=>inner.next().unwrap_or_default(),
991 }
992 }
993 f(&mut inner.into_iter(),shape)
994 }
995 pub fn gather(self,dim:i32,indices:Value<B>)->Self{
997 fn b<B:Backend,const N:usize>(d:i32,data:Tensor<B,N,Bool>,indices:Tensor<B,N,Int>)->Value<B>{f(d,data.int(),indices).bool()}
998 fn f<B:Backend,K:'static+BasicOps<B>+Numeric<B>+TensorKind<B>,const N:usize>(d:i32,data:Tensor<B,N,K>,indices:Tensor<B,N,Int>)->Value<B>{
999 let d=if d<0{N-((-d) as usize)}else{d as usize};
1000 if d>=N{format!("dim {d} must be less than rank {N}").into()}else{data.gather(d,indices).into()}
1001 }
1002
1003 match (self,indices){(B1(x),I1(i))=>b(dim,x,i),(B2(x),I2(i))=>b(dim,x,i),(B3(x),I3(i))=>b(dim,x,i),(B4(x),I4(i))=>b(dim,x,i),(B5(x),I5(i))=>b(dim,x,i),(B6(x),I6(i))=>b(dim,x,i),(B7(x),I7(i))=>b(dim,x,i),(B8(x),I8(i))=>b(dim,x,i),(F1(x),I1(i))=>f(dim,x,i),(F2(x),I2(i))=>f(dim,x,i),(F3(x),I3(i))=>f(dim,x,i),(F4(x),I4(i))=>f(dim,x,i),(F5(x),I5(i))=>f(dim,x,i),(F6(x),I6(i))=>f(dim,x,i),(F7(x),I7(i))=>f(dim,x,i),(F8(x),I8(i))=>f(dim,x,i),(I1(x),I1(i))=>f(dim,x,i),(I2(x),I2(i))=>f(dim,x,i),(I3(x),I3(i))=>f(dim,x,i),(I4(x),I4(i))=>f(dim,x,i),(I5(x),I5(i))=>f(dim,x,i),(I6(x),I6(i))=>f(dim,x,i),(I7(x),I7(i))=>f(dim,x,i),(I8(x),I8(i))=>f(dim,x,i),(Value::Incompatible(e),_)=>e.into(),(_,Value::Incompatible(e))=>e.into(),(Value::Multi(u),Value::Multi(v))=>u.into_iter().zip(v).map(|(u,v)|u.gather(dim,v)).collect(),_=>"gather is only available for tensors of matching dimensions with int indices".into()}
1004 }
1005 pub fn int(self)->Value<B>{
1007 match self{B1(x)=>I1(x.int()),B2(x)=>I2(x.int()),B3(x)=>I3(x.int()),B4(x)=>I4(x.int()),B5(x)=>I5(x.int()),B6(x)=>I6(x.int()),B7(x)=>I7(x.int()),B8(x)=>I8(x.int()),F1(x)=>I1(x.int()),F2(x)=>I2(x.int()),F3(x)=>I3(x.int()),F4(x)=>I4(x.int()),F5(x)=>I5(x.int()),F6(x)=>I6(x.int()),F7(x)=>I7(x.int()),F8(x)=>I8(x.int()),I1(x)=>I1(x),I2(x)=>I2(x),I3(x)=>I3(x),I4(x)=>I4(x),I5(x)=>I5(x),I6(x)=>I6(x),I7(x)=>I7(x),I8(x)=>I8(x),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(Value::int).collect())}
1008 }
1009 pub fn into_float_vec(self)->Vec<f32>{
1011 fn cat_vec<T>(mut a:Vec<T>,b:Vec<T>)->Vec<T>{
1012 a.extend(b);
1013 a
1014 }
1015 fn to_vec<B:Backend,const N:usize>(x:Tensor<B,N>)->Vec<f32>{x.into_data().to_vec().unwrap_or_default()}
1016
1017 match self.float(){F1(x)=>to_vec(x),F2(x)=>to_vec(x),F3(x)=>to_vec(x),F4(x)=>to_vec(x),F5(x)=>to_vec(x),F6(x)=>to_vec(x),F7(x)=>to_vec(x),F8(x)=>to_vec(x),Value::Incompatible(_e)=>Vec::new(),Value::Multi(v)=>v.into_iter().map(Value::into_float_vec).reduce(cat_vec).unwrap_or_default(),_=>panic!("internal error")}
1018 }
1019 pub fn is_empty(&self)->bool{self.len()==0}
1021 pub fn is_incompatible(&self)->bool{
1023 if let Value::Incompatible(_x)=self{true}else{false}
1024 }
1025 pub fn into_multi(self)->Vec<Value<B>>{
1027 if let Value::Multi(v)=self{v}else{vec![self]}
1028 }
1029 pub fn is_multi(&self)->bool{
1031 if let Value::Multi(_x)=self{true}else{false}
1032 }
1033 pub fn iter(&self)->SliceIter<'_,Self>{
1035 if let Value::Multi(v)=self{v.iter()}else{slice::from_ref(self).iter()}
1036 }
1037 pub fn kind(&self)->Kind{
1039 match self{B1(_x)=>Kind::Bool,B2(_x)=>Kind::Bool,B3(_x)=>Kind::Bool,B4(_x)=>Kind::Bool,B5(_x)=>Kind::Bool,B6(_x)=>Kind::Bool,B7(_x)=>Kind::Bool,B8(_x)=>Kind::Bool,F1(_x)=>Kind::Float,F2(_x)=>Kind::Float,F3(_x)=>Kind::Float,F4(_x)=>Kind::Float,F5(_x)=>Kind::Float,F6(_x)=>Kind::Float,F7(_x)=>Kind::Float,F8(_x)=>Kind::Float,I1(_x)=>Kind::Int,I2(_x)=>Kind::Int,I3(_x)=>Kind::Int,I4(_x)=>Kind::Int,I5(_x)=>Kind::Int,I6(_x)=>Kind::Int,I7(_x)=>Kind::Int,I8(_x)=>Kind::Int,Value::Incompatible(_v)=>Kind::Incompatible,Value::Multi(_v)=>Kind::Multi}
1040 }
1041 pub fn len(&self)->usize{
1043 if let Value::Multi(v)=self{v.len()}else{1}
1044 }
1045 pub fn len_recursive(&self)->usize{
1047 if let Value::Multi(v)=self{v.iter().map(Value::len_recursive).sum()}else{1}
1048 }
1049 pub fn mask_fill(self,mask:Value<B>,v:f32)->Self{
1051 let (x,mask)=self.promote_rank(mask.bool());
1052 match (x,mask){
1053 (B1(x),B1(m))=>B1(x.int().mask_fill(m,v).bool()),
1054 (B2(x),B2(m))=>B2(x.int().mask_fill(m,v).bool()),
1055 (B3(x),B3(m))=>B3(x.int().mask_fill(m,v).bool()),
1056 (B4(x),B4(m))=>B4(x.int().mask_fill(m,v).bool()),
1057 (B5(x),B5(m))=>B5(x.int().mask_fill(m,v).bool()),
1058 (B6(x),B6(m))=>B6(x.int().mask_fill(m,v).bool()),
1059 (B7(x),B7(m))=>B7(x.int().mask_fill(m,v).bool()),
1060 (B8(x),B8(m))=>B8(x.int().mask_fill(m,v).bool()),
1061 (F1(x),B1(m))=>F1(x.mask_fill(m,v)),
1062 (F2(x),B2(m))=>F2(x.mask_fill(m,v)),
1063 (F3(x),B3(m))=>F3(x.mask_fill(m,v)),
1064 (F4(x),B4(m))=>F4(x.mask_fill(m,v)),
1065 (F5(x),B5(m))=>F5(x.mask_fill(m,v)),
1066 (F6(x),B6(m))=>F6(x.mask_fill(m,v)),
1067 (F7(x),B7(m))=>F7(x.mask_fill(m,v)),
1068 (F8(x),B8(m))=>F8(x.mask_fill(m,v)),
1069 (I1(x),B1(m))=>I1(x.mask_fill(m,v)),
1070 (I2(x),B2(m))=>I2(x.mask_fill(m,v)),
1071 (I3(x),B3(m))=>I3(x.mask_fill(m,v)),
1072 (I4(x),B4(m))=>I4(x.mask_fill(m,v)),
1073 (I5(x),B5(m))=>I5(x.mask_fill(m,v)),
1074 (I6(x),B6(m))=>I6(x.mask_fill(m,v)),
1075 (I7(x),B7(m))=>I7(x.mask_fill(m,v)),
1076 (I8(x),B8(m))=>I8(x.mask_fill(m,v)),
1077 (Value::Incompatible(e),_)=>e.into(),
1078 (_,Value::Incompatible(e))=>e.into(),
1079 (Value::Multi(x),m)=>broadcast_multi(x,m.into_multi(),|x,m|x.mask_fill(m,v)),
1080 (x,Value::Multi(m))=>broadcast_multi(x.into_multi(),m,|x,m|x.mask_fill(m,v)),
1081 _=>panic!("internal error")
1082 }
1083 }
1084 pub fn multi(self)->Self{
1086 if let Value::Multi(v)=self{v.into()}else{vec![self].into()}
1087 }
1088 pub fn new<S:Into<Shape>>(data:&[f32],device:&B::Device,shape:S)->Self{
1090 match shape.into(){
1091 Shape::Incompatible(e)=>e.into(),
1092 Shape::Multi(l)=>data.chunks(l).map(|d|Value::new(d,device,X1([d.len()]))).collect(),
1093 Shape::Recursive(v)=>v.into_iter().scan(data,|data,s|{
1094 let v=Value::new(*data,device,s);
1095 *data=&data[..v.count()];
1096 Some(v)
1097 }).collect(),
1098 X1(s)=>F1(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1099 X2(s)=>F2(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1100 X3(s)=>F3(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1101 X4(s)=>F4(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1102 X5(s)=>F5(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1103 X6(s)=>F6(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1104 X7(s)=>F7(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1105 X8(s)=>F8(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1106 }
1107 }
1108 pub fn promote(self,rhs:Value<B>)->(Value<B>,Value<B>){
1110 let (l,r)=self.promote_kind(rhs);
1111 l.promote_rank(r)
1112 }
1113 pub fn promote_kind(self,rhs:Value<B>)->(Value<B>,Value<B>){
1115 let (lk,rk)=(self.kind(),rhs.kind());
1116
1117 let (mut l,mut r)=(self,rhs);
1118 if lk==rk{()}else if lk==Kind::Multi{r=r.multi()}else if rk==Kind::Multi{l=l.multi()}else if lk==Kind::Float{r=r.float()}else if rk==Kind::Float{l=l.float()}else if lk==Kind::Int{r=r.int()}else if rk==Kind::Int{l=l.int()}else if lk==Kind::Incompatible{return (l,r)}else if rk==Kind::Incompatible{return (l,r)}
1119 (l,r)
1120 }
1121 pub fn promote_rank(self,rhs:Value<B>)->(Value<B>,Value<B>){
1123 let (mut l,mut r)=(self,rhs);
1124 let (mut lr,mut rr)=if let (Some(l),Some(r))=(l.rank(),r.rank()){(l,r)}else{return (l,r)};
1125 while lr<rr{
1126 l=l.unsqueeze();
1127 lr+=1;
1128 }
1129 while lr>rr{
1130 r=r.unsqueeze();
1131 rr+=1;
1132 }
1133 (l,r)
1134 }
1135 pub fn rank(&self)->Option<usize>{
1137 match self{B1(_x)=>Some(1),B2(_x)=>Some(2),B3(_x)=>Some(3),B4(_x)=>Some(4),B5(_x)=>Some(5),B6(_x)=>Some(6),B7(_x)=>Some(7),B8(_x)=>Some(8),F1(_x)=>Some(1),F2(_x)=>Some(2),F3(_x)=>Some(3),F4(_x)=>Some(4),F5(_x)=>Some(5),F6(_x)=>Some(6),F7(_x)=>Some(7),F8(_x)=>Some(8),I1(_x)=>Some(1),I2(_x)=>Some(2),I3(_x)=>Some(3),I4(_x)=>Some(4),I5(_x)=>Some(5),I6(_x)=>Some(6),I7(_x)=>Some(7),I8(_x)=>Some(8),Value::Incompatible(_x)=>None,Value::Multi(_x)=>None}
1138 }
1139 pub fn reshape<const N:usize>(self,dims:[usize;N])->Self{
1141 fn f<B:Backend,K:'static+BasicOps<B>+TensorKind<B>,const D:usize,const N:usize>(dims:[usize;D],x:Tensor<B,N,K>)->Value<B>{
1142 if dims.into_iter().product::<usize>()==x.dims().into_iter().product::<usize>(){x.reshape(dims).into()}else{"incompatible reshape".into()}
1143 }
1144 match self{
1145 B1(x)=>f(dims,x),
1146 B2(x)=>f(dims,x),
1147 B3(x)=>f(dims,x),
1148 B4(x)=>f(dims,x),
1149 B5(x)=>f(dims,x),
1150 B6(x)=>f(dims,x),
1151 B7(x)=>f(dims,x),
1152 B8(x)=>f(dims,x),
1153 F1(x)=>f(dims,x),
1154 F2(x)=>f(dims,x),
1155 F3(x)=>f(dims,x),
1156 F4(x)=>f(dims,x),
1157 F5(x)=>f(dims,x),
1158 F6(x)=>f(dims,x),
1159 F7(x)=>f(dims,x),
1160 F8(x)=>f(dims,x),
1161 I1(x)=>f(dims,x),
1162 I2(x)=>f(dims,x),
1163 I3(x)=>f(dims,x),
1164 I4(x)=>f(dims,x),
1165 I5(x)=>f(dims,x),
1166 I6(x)=>f(dims,x),
1167 I7(x)=>f(dims,x),
1168 I8(x)=>f(dims,x),
1169 Value::Incompatible(e)=>e.into(),
1170 Value::Multi(v)=>v.into_iter().map(|x|x.reshape(dims)).collect()
1171 }
1172 }
1173 pub fn shape(&self)->Shape{
1175 match self{B1(x)=>Shape::X1(x.dims()),B2(x)=>Shape::X2(x.dims()),B3(x)=>Shape::X3(x.dims()),B4(x)=>Shape::X4(x.dims()),B5(x)=>Shape::X5(x.dims()),B6(x)=>Shape::X6(x.dims()),B7(x)=>Shape::X7(x.dims()),B8(x)=>Shape::X8(x.dims()),F1(x)=>Shape::X1(x.dims()),F2(x)=>Shape::X2(x.dims()),F3(x)=>Shape::X3(x.dims()),F4(x)=>Shape::X4(x.dims()),F5(x)=>Shape::X5(x.dims()),F6(x)=>Shape::X6(x.dims()),F7(x)=>Shape::X7(x.dims()),F8(x)=>Shape::X8(x.dims()),I1(x)=>Shape::X1(x.dims()),I2(x)=>Shape::X2(x.dims()),I3(x)=>Shape::X3(x.dims()),I4(x)=>Shape::X4(x.dims()),I5(x)=>Shape::X5(x.dims()),I6(x)=>Shape::X6(x.dims()),I7(x)=>Shape::X7(x.dims()),I8(x)=>Shape::X8(x.dims()),Value::Incompatible(x)=>Shape::Incompatible(x.clone()),Value::Multi(x)=>Shape::Multi(x.len())}
1176 }
1177 pub fn shape_recursive(&self)->Shape{
1179 if let Value::Multi(x)=self{Shape::Recursive(x.iter().map(Value::shape_recursive).collect())}else{self.shape()}
1180 }
1181 pub fn shift(self,d:i32,n:i32,v:f32)->Self{
1183 fn b<B:Backend,const N:usize>(d:i32,n:i32,v:f32,x:Tensor<B,N,Bool>)->Value<B>{f(d,n,if v==0.0{0.0}else{1.0},x.int()).bool()}
1184 fn f<B:Backend,K:'static+BasicOps<B>+Numeric<B>+TensorKind<B>,const N:usize>(d:i32,n:i32,v:f32,x:Tensor<B,N,K>)->Value<B>{
1185 let device=x.device();
1186 let d=if d<0{N-((-d) as usize)}else{d as usize};
1187 let mut paddims=x.dims();
1188 let mut slicedims=paddims.map(|n|0..n);
1189
1190 paddims[d]=n.abs() as usize;
1191 slicedims[d]=if n<0{(-n) as usize..slicedims[d].end}else{0..slicedims[d].end.saturating_sub(n as usize)};
1192 if slicedims[d].len()==0{return x.full_like(v).into()}
1193 let pad:Tensor<B,N,K>=Tensor::full(paddims,v,&device);
1194 let slice=x.slice(slicedims);
1195
1196 Tensor::cat(if n<0{vec![slice,pad]}else{vec![pad,slice]},d).into()
1197 }
1198 if n==0{return self}
1199
1200 match self{B1(x)=>b(d,n,v,x),B2(x)=>b(d,n,v,x),B3(x)=>b(d,n,v,x),B4(x)=>b(d,n,v,x),B5(x)=>b(d,n,v,x),B6(x)=>b(d,n,v,x),B7(x)=>b(d,n,v,x),B8(x)=>b(d,n,v,x),F1(x)=>f(d,n,v,x),F2(x)=>f(d,n,v,x),F3(x)=>f(d,n,v,x),F4(x)=>f(d,n,v,x),F5(x)=>f(d,n,v,x),F6(x)=>f(d,n,v,x),F7(x)=>f(d,n,v,x),F8(x)=>f(d,n,v,x),I1(x)=>f(d,n,v,x),I2(x)=>f(d,n,v,x),I3(x)=>f(d,n,v,x),I4(x)=>f(d,n,v,x),I5(x)=>f(d,n,v,x),I6(x)=>f(d,n,v,x),I7(x)=>f(d,n,v,x),I8(x)=>f(d,n,v,x),Value::Incompatible(e)=>e.into(),Value::Multi(x)=>x.into_iter().map(|x|x.shift(d,n,v)).collect()}
1201 }
1202 pub fn slice<A:AsRef<[R]>,R:RangeBounds<usize>>(self,ranges:A)->Self{
1204 let ranges=ranges.as_ref();
1205 let len=ranges.len();
1206 if let Value::Incompatible(x)=self{return x.into()}
1207 let rank=self.rank().unwrap_or(len);
1208 let shape=self.shape();
1209
1210 let mut normalizedranges=[0;8].map(|_|0..0);
1211 for ((d,n),r) in shape.clone().to_array(Alignment::Left).into_iter().zip(normalizedranges.iter_mut()).zip(ranges){
1212 n.start=match r.start_bound(){Excluded(&x)=>x+1,Included(&x)=>x,Unbounded=>0};
1213 n.end=match r.end_bound(){Excluded(&x)=>x,Included(&x)=>x+1,Unbounded=>d};
1214 }
1215 if len>rank{return format!("Length of ranges argument must be less than the the value's rank. len: {len} ranges: {normalizedranges:?} rank: {rank} shape: {shape:?}").into()}
1216 for (d,n) in shape.clone().to_array(Alignment::Left).into_iter().zip(normalizedranges.iter()).take(len){
1217 if n.start>=n.end{return format!("Empty or reverse ranges are currently not supported. ranges: {normalizedranges:?}").into()}
1218 if d<n.end{return format!("Cannot index beyond the end of a dimension. ranges: {normalizedranges:?} shape: {shape:?}").into()}
1219 }
1220 let ranges=&normalizedranges[..len];
1221
1222 match self{B1(x)=>B1(slice_slice(ranges,x)),B2(x)=>B2(slice_slice(ranges,x)),B3(x)=>B3(slice_slice(ranges,x)),B4(x)=>B4(slice_slice(ranges,x)),B5(x)=>B5(slice_slice(ranges,x)),B6(x)=>B6(slice_slice(ranges,x)),B7(x)=>B7(slice_slice(ranges,x)),B8(x)=>B8(slice_slice(ranges,x)),F1(x)=>F1(slice_slice(ranges,x)),F2(x)=>F2(slice_slice(ranges,x)),F3(x)=>F3(slice_slice(ranges,x)),F4(x)=>F4(slice_slice(ranges,x)),F5(x)=>F5(slice_slice(ranges,x)),F6(x)=>F6(slice_slice(ranges,x)),F7(x)=>F7(slice_slice(ranges,x)),F8(x)=>F8(slice_slice(ranges,x)),I1(x)=>I1(slice_slice(ranges,x)),I2(x)=>I2(slice_slice(ranges,x)),I3(x)=>I3(slice_slice(ranges,x)),I4(x)=>I4(slice_slice(ranges,x)),I5(x)=>I5(slice_slice(ranges,x)),I6(x)=>I6(slice_slice(ranges,x)),I7(x)=>I7(slice_slice(ranges,x)),I8(x)=>I8(slice_slice(ranges,x)),Value::Incompatible(x)=>x.into(),Value::Multi(x)=>Value::Multi(x.into_iter().map(|x|x.slice(ranges)).collect())}
1223 }
1224 pub fn split<I:Into<Option<i32>>>(self,chunksize:usize,dim:I)->Self{
1226 fn f<B:Backend,K:'static+BasicOps<B>+TensorKind<B>,const N:usize>(dim:i32,size:usize,tensor:Tensor<B,N,K>)->Value<B>{
1227 if dim>=N as i32||dim<(-(N as i32)){return format!("rank {N} is too low to split along dimension {dim}").into()}
1228 let dim=if dim<0{N-((-dim) as usize)}else{dim as usize};
1229
1230 tensor.split(dim,size).into_iter().map(Value::from).collect()
1231 }
1232 let c=if chunksize==0{return "cannot split into chunks of 0 size".into()}else{chunksize};
1233
1234 if let Some(d)=dim.into(){
1235 match self{B1(x)=>f(d,c,x),B2(x)=>f(d,c,x),B3(x)=>f(d,c,x),B4(x)=>f(d,c,x),B5(x)=>f(d,c,x),B6(x)=>f(d,c,x),B7(x)=>f(d,c,x),B8(x)=>f(d,c,x),F1(x)=>f(d,c,x),F2(x)=>f(d,c,x),F3(x)=>f(d,c,x),F4(x)=>f(d,c,x),F5(x)=>f(d,c,x),F6(x)=>f(d,c,x),F7(x)=>f(d,c,x),F8(x)=>f(d,c,x),I1(x)=>f(d,c,x),I2(x)=>f(d,c,x),I3(x)=>f(d,c,x),I4(x)=>f(d,c,x),I5(x)=>f(d,c,x),I6(x)=>f(d,c,x),I7(x)=>f(d,c,x),I8(x)=>f(d,c,x),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.split(c,Some(d))).collect()}
1236 }else{
1237 let v=self.into_multi();
1238 v.chunks(chunksize).map(|c|Value::from(c.to_vec())).collect()
1239 }
1240 }
1241 pub fn squeeze(self)->Self{self.squeeze_dim(0)}
1243 pub fn squeeze_dim(self,d:i32)->Self{
1245 fn f<B:Backend,K:BasicOps<B>+TensorKind<B>,const D:usize,const N:usize>(x:Tensor<B,D,K>,d:i32)->Result<Tensor<B,N,K>,String>{
1246 let d=if d<0{D-((-d) as usize)}else{d as usize};
1247
1248 if d>=D{return Err(format!("dim {d} must be less than {D}"))}
1249 let xdim=x.dims()[d];
1250
1251 if xdim==1{Ok(x.squeeze(d))}else{Err(format!("cannot squeeze a dim of size not equal to 1. dim {d} was {xdim}"))}
1252 }
1253 match match self{B1(_x)=>Err("currently cannot decrease the number of tensor dimensions below 1".into()),B2(x)=>f(x,d).map(B1),B3(x)=>f(x,d).map(B2),B4(x)=>f(x,d).map(B3),B5(x)=>f(x,d).map(B4),B6(x)=>f(x,d).map(B5),B7(x)=>f(x,d).map(B6),B8(x)=>f(x,d).map(B7),F1(_x)=>Err("currently cannot decrease the number of tensor dimensions below 1".into()),F2(x)=>f(x,d).map(F1),F3(x)=>f(x,d).map(F2),F4(x)=>f(x,d).map(F3),F5(x)=>f(x,d).map(F4),F6(x)=>f(x,d).map(F5),F7(x)=>f(x,d).map(F6),F8(x)=>f(x,d).map(F7),I1(_x)=>Err("currently cannot decrease the number of tensor dimensions below 1".into()),I2(x)=>f(x,d).map(I1),I3(x)=>f(x,d).map(I2),I4(x)=>f(x,d).map(I3),I5(x)=>f(x,d).map(I4),I6(x)=>f(x,d).map(I5),I7(x)=>f(x,d).map(I6),I8(x)=>f(x,d).map(I7),Value::Incompatible(e)=>Err(e),Value::Multi(v)=>Ok(v.into_iter().map(|x|x.squeeze_dim(d)).collect())}{Err(e)=>e.into(),Ok(x)=>x}
1254 }
1255 pub fn try_incompatible(self)->Result<String,Self>{
1257 if let Value::Incompatible(x)=self{Ok(x)}else{Err(self)}
1258 }
1259 pub fn try_multi(self)->Result<Vec<Value<B>>,Self>{
1261 if let Value::Multi(v)=self{Ok(v)}else{Err(self)}
1262 }
1263 pub fn unsqueeze(self)->Self{self.unsqueeze_dim(0)}
1265 pub fn unsqueeze_dim(self,d:i32)->Self{
1267 fn f<B:Backend,K:BasicOps<B>+TensorKind<B>,const D:usize,const N:usize>(x:Tensor<B,D,K>,d:i32)->Tensor<B,N,K>{
1268 x.unsqueeze_dim(if d<0{D-((-d) as usize)+1}else{d as usize})
1269 }
1270 if let Some(r)=self.rank(){
1271 let e=if d<0{r-((-d) as usize)+1}else{d as usize};
1272 if e>r{return format!("dim {e} must be less than or equal to rank {r}").into()}
1273 }
1274 match self{B1(x)=>B2(f(x,d)),B2(x)=>B3(f(x,d)),B3(x)=>B4(f(x,d)),B4(x)=>B5(f(x,d)),B5(x)=>B6(f(x,d)),B6(x)=>B7(f(x,d)),B7(x)=>B8(f(x,d)),B8(_x)=>"currently can't increase number of tensor dimensions above 8".into(),F1(x)=>F2(f(x,d)),F2(x)=>F3(f(x,d)),F3(x)=>F4(f(x,d)),F4(x)=>F5(f(x,d)),F5(x)=>F6(f(x,d)),F6(x)=>F7(f(x,d)),F7(x)=>F8(f(x,d)),F8(_x)=>"currently can't increase number of tensor dimensions above 8".into(),I1(x)=>I2(f(x,d)),I2(x)=>I3(f(x,d)),I3(x)=>I4(f(x,d)),I4(x)=>I5(f(x,d)),I5(x)=>I6(f(x,d)),I6(x)=>I7(f(x,d)),I7(x)=>I8(f(x,d)),I8(_x)=>"currently can't increase number of tensor dimensions above 8".into(),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.unsqueeze_dim(d)).collect()}
1275 }
1276 #[track_caller]
1277 pub fn unwrap_incompatible(self)->String{self.try_incompatible().unwrap()}
1279 #[track_caller]
1280 pub fn unwrap_multi(self)->Vec<Value<B>>{self.try_multi().unwrap()}
1282 pub fn zeros_like(&self)->Value<B>{match self{B1(x)=>B1(x.clone().int().zeros_like().bool()),B2(x)=>B2(x.clone().int().zeros_like().bool()),B3(x)=>B3(x.clone().int().zeros_like().bool()),B4(x)=>B4(x.clone().int().zeros_like().bool()),B5(x)=>B5(x.clone().int().zeros_like().bool()),B6(x)=>B6(x.clone().int().zeros_like().bool()),B7(x)=>B7(x.clone().int().zeros_like().bool()),B8(x)=>B8(x.clone().int().zeros_like().bool()),F1(x)=>F1(x.zeros_like()),F2(x)=>F2(x.zeros_like()),F3(x)=>F3(x.zeros_like()),F4(x)=>F4(x.zeros_like()),F5(x)=>F5(x.zeros_like()),F6(x)=>F6(x.zeros_like()),F7(x)=>F7(x.zeros_like()),F8(x)=>F8(x.zeros_like()),I1(x)=>I1(x.zeros_like()),I2(x)=>I2(x.zeros_like()),I3(x)=>I3(x.zeros_like()),I4(x)=>I4(x.zeros_like()),I5(x)=>I5(x.zeros_like()),I6(x)=>I6(x.zeros_like()),I7(x)=>I7(x.zeros_like()),I8(x)=>I8(x.zeros_like()),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.iter().map(Value::zeros_like).collect()}
1285 }
1286 pub fn zip(self)->Self{if self.len()<=1||self.iter().all(|v|v.len()<=1){return self}
1289
1290 let mut iters:Vec<_>=self.into_iter().map(Value::into_iter).collect();
1291
1292 let cols=iters.len();
1293 let rows=iters.iter().map(ExactSizeIterator::len).max().expect("should not be empty at this point");
1294 let transposed:Vec<Value<B>>=(0..rows).map(|_|{
1295 let v:Value<B>=(0..cols).map(|c|iters[c].next().unwrap_or_default()).collect();
1296 v.zip()
1297 }).collect();
1298
1299 Value::Multi(transposed)
1300 }
1301 try_unwrap!(Tensor<B,1,Bool>,try_b1,unwrap_b1);
1302 try_unwrap!(Tensor<B,2,Bool>,try_b2,unwrap_b2);
1303 try_unwrap!(Tensor<B,3,Bool>,try_b3,unwrap_b3);
1304 try_unwrap!(Tensor<B,4,Bool>,try_b4,unwrap_b4);
1305 try_unwrap!(Tensor<B,5,Bool>,try_b5,unwrap_b5);
1306 try_unwrap!(Tensor<B,6,Bool>,try_b6,unwrap_b6);
1307 try_unwrap!(Tensor<B,7,Bool>,try_b7,unwrap_b7);
1308 try_unwrap!(Tensor<B,8,Bool>,try_b8,unwrap_b8);
1309 try_unwrap!(Tensor<B,1,Float>,try_f1,unwrap_f1);
1310 try_unwrap!(Tensor<B,2,Float>,try_f2,unwrap_f2);
1311 try_unwrap!(Tensor<B,3,Float>,try_f3,unwrap_f3);
1312 try_unwrap!(Tensor<B,4,Float>,try_f4,unwrap_f4);
1313 try_unwrap!(Tensor<B,5,Float>,try_f5,unwrap_f5);
1314 try_unwrap!(Tensor<B,6,Float>,try_f6,unwrap_f6);
1315 try_unwrap!(Tensor<B,7,Float>,try_f7,unwrap_f7);
1316 try_unwrap!(Tensor<B,8,Float>,try_f8,unwrap_f8);
1317 try_unwrap!(Tensor<B,1,Int>,try_i1,unwrap_i1);
1318 try_unwrap!(Tensor<B,2,Int>,try_i2,unwrap_i2);
1319 try_unwrap!(Tensor<B,3,Int>,try_i3,unwrap_i3);
1320 try_unwrap!(Tensor<B,4,Int>,try_i4,unwrap_i4);
1321 try_unwrap!(Tensor<B,5,Int>,try_i5,unwrap_i5);
1322 try_unwrap!(Tensor<B,6,Int>,try_i6,unwrap_i6);
1323 try_unwrap!(Tensor<B,7,Int>,try_i7,unwrap_i7);
1324 try_unwrap!(Tensor<B,8,Int>,try_i8,unwrap_i8);
1325}
1326macro_rules! bicop_num{
1327 ($trait:ident,$traitfn:ident,$traitscalar:ident)=>(
1328 impl<B:Backend,E:Copy+ElementConversion> $trait<E> for &Value<B>{
1329 fn $traitfn(self,rhs:E)->Value<B>{self.clone().$traitfn(rhs)}
1330 type Output=Value<B>;
1331 }
1332 impl<B:Backend,E:Copy+ElementConversion> $trait<E> for Value<B>{
1333 fn $traitfn(self,rhs:E)->Value<B>{
1334 match self{B1(x)=>I1(x.int().$traitscalar(rhs)),B2(x)=>I2(x.int().$traitscalar(rhs)),B3(x)=>I3(x.int().$traitscalar(rhs)),B4(x)=>I4(x.int().$traitscalar(rhs)),B5(x)=>I5(x.int().$traitscalar(rhs)),B6(x)=>I6(x.int().$traitscalar(rhs)),B7(x)=>I7(x.int().$traitscalar(rhs)),B8(x)=>I8(x.int().$traitscalar(rhs)),F1(x)=>F1(x.$traitscalar(rhs)),F2(x)=>F2(x.$traitscalar(rhs)),F3(x)=>F3(x.$traitscalar(rhs)),F4(x)=>F4(x.$traitscalar(rhs)),F5(x)=>F5(x.$traitscalar(rhs)),F6(x)=>F6(x.$traitscalar(rhs)),F7(x)=>F7(x.$traitscalar(rhs)),F8(x)=>F8(x.$traitscalar(rhs)),I1(x)=>I1(x.$traitscalar(rhs)),I2(x)=>I2(x.$traitscalar(rhs)),I3(x)=>I3(x.$traitscalar(rhs)),I4(x)=>I4(x.$traitscalar(rhs)),I5(x)=>I5(x.$traitscalar(rhs)),I6(x)=>I6(x.$traitscalar(rhs)),I7(x)=>I7(x.$traitscalar(rhs)),I8(x)=>I8(x.$traitscalar(rhs)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.$traitfn(rhs)).collect()}
1335 }
1336 type Output=Value<B>;
1337 }
1338 impl<B:Backend> $trait<&Value<B>> for &Value<B>{
1339 fn $traitfn(self,rhs:&Value<B>)->Value<B>{self.clone().$traitfn(rhs.clone())}
1340 type Output=Value<B>;
1341 }
1342 impl<B:Backend> $trait<&Value<B>> for Value<B>{
1343 fn $traitfn(self,rhs:&Value<B>)->Value<B>{self.$traitfn(rhs.clone())}
1344 type Output=Value<B>;
1345 }
1346 impl<B:Backend> $trait<Value<B>> for &Value<B>{
1347 fn $traitfn(self,rhs:Value<B>)->Value<B>{self.clone().$traitfn(rhs)}
1348 type Output=Value<B>;
1349 }
1350 impl<B:Backend> $trait<Value<B>> for Value<B>{
1351 fn $traitfn(self,rhs:Value<B>)->Value<B>{match self.promote(rhs){(B1(l),B1(r))=>I1(l.int().$traitfn(r.int())),(B2(l),B2(r))=>I2(l.int().$traitfn(r.int())),(B3(l),B3(r))=>I3(l.int().$traitfn(r.int())),(B4(l),B4(r))=>I4(l.int().$traitfn(r.int())),(B5(l),B5(r))=>I5(l.int().$traitfn(r.int())),(B6(l),B6(r))=>I6(l.int().$traitfn(r.int())),(B7(l),B7(r))=>I7(l.int().$traitfn(r.int())),(B8(l),B8(r))=>I8(l.int().$traitfn(r.int())),(F1(l),F1(r))=>F1(l.$traitfn(r)),(F2(l),F2(r))=>F2(l.$traitfn(r)),(F3(l),F3(r))=>F3(l.$traitfn(r)),(F4(l),F4(r))=>F4(l.$traitfn(r)),(F5(l),F5(r))=>F5(l.$traitfn(r)),(F6(l),F6(r))=>F6(l.$traitfn(r)),(F7(l),F7(r))=>F7(l.$traitfn(r)),(F8(l),F8(r))=>F8(l.$traitfn(r)),(I1(l),I1(r))=>I1(l.$traitfn(r)),(I2(l),I2(r))=>I2(l.$traitfn(r)),(I3(l),I3(r))=>I3(l.$traitfn(r)),(I4(l),I4(r))=>I4(l.$traitfn(r)),(I5(l),I5(r))=>I5(l.$traitfn(r)),(I6(l),I6(r))=>I6(l.$traitfn(r)),(I7(l),I7(r))=>I7(l.$traitfn(r)),(I8(l),I8(r))=>I8(l.$traitfn(r)),(Value::Incompatible(e),_)=>e.into(),(_,Value::Incompatible(e))=>e.into(),(Value::Multi(l),r)=>broadcast_multi(l,r.into_multi(),$trait::$traitfn),(l,Value::Multi(r))=>broadcast_multi(l.into_multi(),r,$trait::$traitfn),_=>panic!("couldn't promote types for $traitfn")}
1353 }
1354 type Output=Value<B>;
1355 }
1356 );
1357}
1358macro_rules! try_unwrap{
1359 ($tensor:ty,$try_unwrap:ident,$unwrap:ident)=>{
1360 pub fn $try_unwrap(self)->Result<$tensor,Self>{self.try_into()}
1362 #[track_caller]
1363 pub fn $unwrap(self)->$tensor{self.try_into().unwrap()}
1365 }
1366}
1367
1368#[derive(Clone,Copy,Debug,Eq,PartialEq,Deserialize,Serialize)]
1369pub enum Kind{Bool,Float,Incompatible,Int,Multi}
1371#[derive(Clone,Debug,Deserialize,Serialize)]pub enum Shape{Incompatible(String),Multi(usize),Recursive(Vec<Shape>),X1([usize;1]),X2([usize;2]),X3([usize;3]),X4([usize;4]),X5([usize;5]),X6([usize;6]),X7([usize;7]),X8([usize;8])}
1374#[derive(Clone,Debug)]
1375pub enum Value<B:Backend>{B1(Tensor<B,1,Bool>),B2(Tensor<B,2,Bool>),B3(Tensor<B,3,Bool>),B4(Tensor<B,4,Bool>),B5(Tensor<B,5,Bool>),B6(Tensor<B,6,Bool>),B7(Tensor<B,7,Bool>),B8(Tensor<B,8,Bool>),F1(Tensor<B,1,Float>),F2(Tensor<B,2,Float>),F3(Tensor<B,3,Float>),F4(Tensor<B,4,Float>),F5(Tensor<B,5,Float>),F6(Tensor<B,6,Float>),F7(Tensor<B,7,Float>),F8(Tensor<B,8,Float>),I1(Tensor<B,1,Int>),I2(Tensor<B,2,Int>),I3(Tensor<B,3,Int>),I4(Tensor<B,4,Int>),I5(Tensor<B,5,Int>),I6(Tensor<B,6,Int>),I7(Tensor<B,7,Int>),I8(Tensor<B,8,Int>),Incompatible(String),Multi(Vec<Self>)}
1377#[derive(Clone,Debug,Deserialize,Serialize)]
1378pub enum ValueData{BX(TensorData),FX(TensorData),IX(TensorData),Incompatible(String),Multi(Vec<ValueData>)}
1380#[derive(Clone,Debug,Deserialize,Serialize)]
1381#[serde(bound="")]
1382pub struct LossOutput<B:Backend>{loss:Value<B>,output:Value<B>,target:Value<B>}
1384use {bicop_num,try_unwrap};
1385use Bound::{Excluded,Included,Unbounded};
1386use Shape::{X1,X2,X3,X4,X5,X6,X7,X8};
1387use Value::{B1,B2,B3,B4,B5,B6,B7,B8,F1,F2,F3,F4,F5,F6,F7,F8,I1,I2,I3,I4,I5,I6,I7,I8};
1388use ValueData::{BX,FX,IX};
1389use burn::{
1390 module::{AutodiffModule,ConstantRecord,Content,DisplaySettings,ModuleDisplay,ModuleDisplayDefault,ModuleMapper,ModuleVisitor,Quantizer},
1391 nn::{
1392 BatchNorm,Dropout,Embedding,LayerNorm,Linear,Relu,RotaryEncoding,Tanh,conv::Conv2d,loss::{CrossEntropyLoss,MseLoss},pool::MaxPool2d
1393 },
1394 prelude::{Backend,Bool,Float,Int,Module,Tensor,TensorData},
1395 record::{FileRecorder,RecorderError},
1396 tensor::{
1397 BasicOps,ElementConversion,Numeric,TensorKind,activation::{log_softmax,softmax},backend::AutodiffBackend,cast::ToElement
1398 }
1399};
1400use crate::{
1401 AI,Decompose,Merge,Op,
1402 builtin::{
1403 Alignment,ReductionMode,math::{MeanLayer,SquaredErrorLayer,SumLayer},reinforcement::AccQLayer,soft::{ChooseLayer,CrossEntropyLayer,SoftmaxLayer}
1404 },
1405 ops::{Abs,Cat,Stack,Squeeze,Unsqueeze}
1406};
1407use rand::random;
1408use serde::{Deserialize,Deserializer,Serialize,Serializer};
1409use std::{
1410 any::TypeId,fmt::{Display,Result as FmtResult},iter::{FromIterator,once},mem,ops::{Add,Bound,Div,Mul,RangeBounds,Range,Rem,Sub},path::PathBuf,slice::{Iter as SliceIter,self},vec::IntoIter as VecIntoIter
1411};