1bicop_num!(Add,add,add_scalar);
2bicop_num!(Div,div,div_scalar);
3bicop_num!(Mul,mul,mul_scalar);
4bicop_num!(Rem,rem,remainder_scalar);
5bicop_num!(Sub,sub,sub_scalar);
6fn broadcast_multi<B:Backend,F:FnMut(Value<B>,Value<B>)->Value<B>>(u:Vec<Value<B>>,v:Vec<Value<B>>,mut f:F)->Value<B>{
7 if u.len()==1{
8 u.into_iter().cycle().zip(v).map(|(x,y)|f(x,y)).collect()
9 }else if v.len()==1{
10 u.into_iter().zip(v.into_iter().cycle()).map(|(x,y)|f(x,y)).collect()
11 }else if u.len()==v.len(){
12 u.into_iter().zip(v).map(|(x,y)|f(x,y)).collect()
13 }else{
14 "mismatched lengths".into()
15 }
16}
17fn hard_choose_burn_1<B:Backend,const N:usize>(dim:i32,distribution:Tensor<B,N>)->u32{
18 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
19 let distribution=if dim==N-1{distribution}else{distribution.movedim(dim,N-1)}.into_data();
20 let sum=distribution.iter().fold(0.0,|acc:f32,weight:f32|acc+weight);
21
22 distribution.iter().scan(random::<f32>()*sum,|choice:&mut f32,weight:f32|Some(*choice-=weight).filter(|_|*choice>=0.0)).count() as u32
23}
24fn hard_choose_burn_multi<B:Backend,const N:usize>(dim:i32,distribution:Tensor<B,N>)->Vec<u32>{
25 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
26
27 let chunk=distribution.dims()[dim];
28 let distribution=if dim==N-1{distribution}else{distribution.movedim(dim,N-1)}.into_data().to_vec().unwrap();
29
30 distribution.chunks_exact(chunk).map(|d|{
31 let sum=d.iter().fold(0.0,|acc:f32,weight:&f32|acc+weight);
32 d.iter().scan(random::<f32>()*sum,|choice:&mut f32,weight:&f32|Some(*choice-=weight).filter(|_|*choice>=0.0)).count() as u32
33 }).collect()
34}
35fn hard_choose_burn_tensor<B:Backend,const N:usize>(dim:i32,distribution:Tensor<B,N>)->Tensor<B,N,Int>{let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
37 let device=distribution.device();
38 let mut dims=distribution.dims();
39
40 dims[N-1]=1;
41 let r:Tensor<B,N,Int>=Tensor::from_data(TensorData::new(hard_choose_burn_multi(dim as i32,distribution),dims),&device);
42
43 r.movedim(N-1,dim)
44}
45fn slice_slice<B:Backend,K:BasicOps<B>+TensorKind<B>,const N:usize>(ranges:&[Range<usize>],tensor:Tensor<B,N,K>)->Tensor<B,N,K>{
46 let mut n=0;
47 let mut acc=||{
48 let a=n;
49 n+=1;
50 a
51 };
52
53 match ranges.len(){0=>tensor,1=>tensor.slice([0;1].map(|_|ranges[acc()].clone())),2=>tensor.slice([0;2].map(|_|ranges[acc()].clone())),3=>tensor.slice([0;3].map(|_|ranges[acc()].clone())),4=>tensor.slice([0;4].map(|_|ranges[acc()].clone())),5=>tensor.slice([0;5].map(|_|ranges[acc()].clone())),6=>tensor.slice([0;6].map(|_|ranges[acc()].clone())),7=>tensor.slice([0;7].map(|_|ranges[acc()].clone())),8=>tensor.slice([0;8].map(|_|ranges[acc()].clone())),_=>panic!("too many ranges for current max 8 dims")}
54}
55fn soft_choose_burn_1<B:Backend,const N:usize>(dim:i32,logits:Tensor<B,N>,temperature:f32)->u32{
56 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
57 let logits=if dim==N-1{logits}else{logits.movedim(dim,N-1)};
58 let distribution=softmax(logits/temperature,N-1).into_data();
59 distribution.iter().scan(random(),|choice:&mut f32,weight:f32|Some(*choice-=weight).filter(|_|*choice>=0.0)).count() as u32
60}
61fn soft_choose_burn_multi<B:Backend,const N:usize>(dim:i32,logits:Tensor<B,N>,temperature:f32)->Vec<u32>{
62 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
63 let logits=if dim==N-1{logits}else{logits.movedim(dim,N-1)};
64 let chunk=logits.dims()[N-1];
65 let distribution=softmax(logits/temperature,N-1).into_data().to_vec().unwrap();
66 distribution.chunks_exact(chunk).map(|d|d.iter().scan(random(),|choice:&mut f32,weight:&f32|Some(*choice-=weight).filter(|_|*choice>=0.0)).count() as u32).collect()
67}
68fn soft_choose_burn_tensor<B:Backend,const N:usize>(dim:i32,logits:Tensor<B,N>,temperature:f32)->Tensor<B,N,Int>{let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
70 let device=logits.device();
71 let mut dims=logits.dims();
72
73 dims[N-1]=1;
74 let r:Tensor<B,N,Int>=Tensor::from_data(TensorData::new(soft_choose_burn_multi(dim as i32,logits,temperature),dims),&device);
75 r.movedim(N-1,dim)
76}
77impl AsRef<Self> for Shape{fn as_ref(&self)->&Self{self}
79}
80impl Shape{
81 pub fn count(&self)->Option<usize>{
83 match self{
84 Shape::Incompatible(_e)=>None,
85 Shape::Multi(n)=>if *n==0{Some(0)}else{None},
86 Shape::Recursive(v)=>{
87 let mut s=0;
88 for v in v{s+=v.count()?}
89 Some(s)
90 },
91 X1(x)=>Some(x.iter().product()),
92 X2(x)=>Some(x.iter().product()),
93 X3(x)=>Some(x.iter().product()),
94 X4(x)=>Some(x.iter().product()),
95 X5(x)=>Some(x.iter().product()),
96 X6(x)=>Some(x.iter().product()),
97 X7(x)=>Some(x.iter().product()),
98 X8(x)=>Some(x.iter().product())
99 }
100 }
101 pub fn to_array(self,alignment:Alignment)->[usize;8]{
103 let mut result=[1;8];
104 let slice=match &self{Shape::Incompatible(_e)=>return result,Shape::Multi(_v)=>return result,Shape::Recursive(_r)=>return result,X1(x)=>x.as_slice(),X2(x)=>x.as_slice(),X3(x)=>x.as_slice(),X4(x)=>x.as_slice(),X5(x)=>x.as_slice(),X6(x)=>x.as_slice(),X7(x)=>x.as_slice(),X8(x)=>x.as_slice()};
105 let l=slice.len();
106 match alignment{Alignment::Center=>result[4-l/2..][..l].copy_from_slice(slice),Alignment::Left=>result[..l].copy_from_slice(slice),Alignment::Right=>result[8-l..].copy_from_slice(slice)}
107 result
108 }
109}
110impl<'a,B:Backend> Deserialize<'a> for Value<B>{
111 fn deserialize<D:Deserializer<'a>>(deserializer:D)->Result<Self,D::Error>{ValueData::deserialize(deserializer).map(Into::into)}
112}
113impl<A:AutodiffBackend> AutodiffModule<A> for Value<A>{
114 fn valid(&self)->Self::InnerModule{
115 match self{B1(x)=>B1(x.valid()),B2(x)=>B2(x.valid()),B3(x)=>B3(x.valid()),B4(x)=>B4(x.valid()),B5(x)=>B5(x.valid()),B6(x)=>B6(x.valid()),B7(x)=>B7(x.valid()),B8(x)=>B8(x.valid()),F1(x)=>F1(x.valid()),F2(x)=>F2(x.valid()),F3(x)=>F3(x.valid()),F4(x)=>F4(x.valid()),F5(x)=>F5(x.valid()),F6(x)=>F6(x.valid()),F7(x)=>F7(x.valid()),F8(x)=>F8(x.valid()),I1(x)=>I1(x.valid()),I2(x)=>I2(x.valid()),I3(x)=>I3(x.valid()),I4(x)=>I4(x.valid()),I5(x)=>I5(x.valid()),I6(x)=>I6(x.valid()),I7(x)=>I7(x.valid()),I8(x)=>I8(x.valid()),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.iter().map(|x|x.valid()).collect()}
116 }
117 type InnerModule=Value<A::InnerBackend>;
118}
119impl<A:Into<Value<B>>,B:Backend> FromIterator<A> for Value<B>{
120 fn from_iter<I:IntoIterator<Item=A>>(iter:I)->Self{Value::Multi(iter.into_iter().map(Into::into).collect())}
121}
122impl<B:Backend,K:'static+TensorKind<B>,const N:usize> From<Result<Tensor<B,N,K>,String>> for Value<B>{
123 fn from(value:Result<Tensor<B,N,K>,String>)->Self{
124 match value{Err(e)=>e.into(),Ok(t)=>t.into()}
125 }
126}
127impl<B:Backend,K:'static+TensorKind<B>,const N:usize> From<Tensor<B,N,K>> for Value<B>{
128 fn from(value:Tensor<B,N,K>)->Self{
129 let kind=TypeId::of::<K>();
130 let kind=if kind==TypeId::of::<Bool>(){Kind::Bool}else if kind==TypeId::of::<Float>(){Kind::Float}else if kind==TypeId::of::<Int>(){Kind::Int}else{return "only bool, float, and int tensors with dimensions 1-8 are currently supported".into()};
131
132 let v=unsafe{
133 match (N,kind){(1,Kind::Bool)=>B1(mem::transmute_copy(&value)),(2,Kind::Bool)=>B2(mem::transmute_copy(&value)),(3,Kind::Bool)=>B3(mem::transmute_copy(&value)),(4,Kind::Bool)=>B4(mem::transmute_copy(&value)),(5,Kind::Bool)=>B5(mem::transmute_copy(&value)),(6,Kind::Bool)=>B6(mem::transmute_copy(&value)),(7,Kind::Bool)=>B7(mem::transmute_copy(&value)),(8,Kind::Bool)=>B8(mem::transmute_copy(&value)),(1,Kind::Float)=>F1(mem::transmute_copy(&value)),(2,Kind::Float)=>F2(mem::transmute_copy(&value)),(3,Kind::Float)=>F3(mem::transmute_copy(&value)),(4,Kind::Float)=>F4(mem::transmute_copy(&value)),(5,Kind::Float)=>F5(mem::transmute_copy(&value)),(6,Kind::Float)=>F6(mem::transmute_copy(&value)),(7,Kind::Float)=>F7(mem::transmute_copy(&value)),(8,Kind::Float)=>F8(mem::transmute_copy(&value)),(1,Kind::Int)=>I1(mem::transmute_copy(&value)),(2,Kind::Int)=>I2(mem::transmute_copy(&value)),(3,Kind::Int)=>I3(mem::transmute_copy(&value)),(4,Kind::Int)=>I4(mem::transmute_copy(&value)),(5,Kind::Int)=>I5(mem::transmute_copy(&value)),(6,Kind::Int)=>I6(mem::transmute_copy(&value)),(7,Kind::Int)=>I7(mem::transmute_copy(&value)),(8,Kind::Int)=>I8(mem::transmute_copy(&value)),_=>return "only bool, float, and int tensors with dimensions 1-8 are currently supported".into()}
134 };
135 mem::forget(value);
136 v
137 }
138}
139impl<B:Backend,K:'static+TensorKind<B>,const N:usize> TryFrom<Value<B>> for Tensor<B,N,K>{
140 fn try_from(value:Value<B>)->Result<Self,Self::Error>{
141 let kind=TypeId::of::<K>();
142 let kind=if kind==TypeId::of::<Bool>(){Kind::Bool}else if kind==TypeId::of::<Float>(){Kind::Float}else if kind==TypeId::of::<Int>(){Kind::Int}else{return Err(value)};
143
144 if Some(N)!=value.rank()||kind!=value.kind(){return Err(value)}
145 let r=unsafe{
146 match &value{B1(x)=>mem::transmute_copy(x),B2(x)=>mem::transmute_copy(x),B3(x)=>mem::transmute_copy(x),B4(x)=>mem::transmute_copy(x),B5(x)=>mem::transmute_copy(x),B6(x)=>mem::transmute_copy(x),B7(x)=>mem::transmute_copy(x),B8(x)=>mem::transmute_copy(x),F1(x)=>mem::transmute_copy(x),F2(x)=>mem::transmute_copy(x),F3(x)=>mem::transmute_copy(x),F4(x)=>mem::transmute_copy(x),F5(x)=>mem::transmute_copy(x),F6(x)=>mem::transmute_copy(x),F7(x)=>mem::transmute_copy(x),F8(x)=>mem::transmute_copy(x),I1(x)=>mem::transmute_copy(x),I2(x)=>mem::transmute_copy(x),I3(x)=>mem::transmute_copy(x),I4(x)=>mem::transmute_copy(x),I5(x)=>mem::transmute_copy(x),I6(x)=>mem::transmute_copy(x),I7(x)=>mem::transmute_copy(x),I8(x)=>mem::transmute_copy(x),_=>panic!("internal error")}
147 };
148 mem::forget(value);
149 Ok(r)
150 }
151 type Error=Value<B>;
152}
153impl<B:Backend,S:?Sized+AsRef<str>> From<&S> for Value<B>{
154 fn from(value:&S)->Self{Self::Incompatible(value.as_ref().to_string())}
155}
156impl<B:Backend,const D:usize> AI<Value<B>,Value<B>> for BatchNorm<B,D>{
157 fn forward(&self,input:Value<B>)->Value<B>{
158 fn f<B:Backend,const D:usize,const E:usize,const F:usize>(norm:&BatchNorm<B,D>,x:Tensor<B,E>)->Value<B>{
159 let norm:BatchNorm<B,F>=BatchNorm{beta:norm.beta.clone(),epsilon:norm.epsilon.clone(),gamma:norm.gamma.clone(),momentum:norm.momentum.clone(),running_mean:norm.running_mean.clone(),running_var:norm.running_var.clone()};
160 norm.forward(x).into()
161 }
162 match input.float(){
163 F1(x)=>AI::forward(self,F1(x).unsqueeze().unsqueeze()).squeeze().squeeze(),
164 F2(x)=>AI::forward(self,F2(x).unsqueeze()).squeeze(),
165 F3(x)=>f::<B,D,3,1>(self,x),
166 F4(x)=>f::<B,D,4,2>(self,x),
167 F5(x)=>f::<B,D,5,3>(self,x),
168 F6(x)=>f::<B,D,6,4>(self,x),
169 F7(x)=>f::<B,D,7,5>(self,x),
170 F8(x)=>f::<B,D,8,6>(self,x),
171 Value::Incompatible(e)=>e.into(),
172 Value::Multi(v)=>v.into_iter().map(|x|AI::forward(self,x)).collect(),
173 _=>panic!("internal error")
174 }
175 }
176}
177impl<B:Backend> AI<(Value<B>,Value<B>),Vec<f32>> for CrossEntropyLayer{
178 fn forward(&self,input:(Value<B>,Value<B>))->Vec<f32>{
179 let output:Value<B>=self.forward(input);
180 output.into_float_vec()
181 }
182}
183impl<B:Backend> AI<(Value<B>,Value<B>),LossOutput<B>> for CrossEntropyLayer{
184 fn forward(&self,(output,target):(Value<B>,Value<B>))->LossOutput<B>{
185 let loss=self.forward((output.clone(),target.clone()));
186 LossOutput::new(loss,output,target)
187 }
188}
189impl<B:Backend> AI<(Value<B>,Value<B>),Value<B>> for CrossEntropyLayer{fn forward(&self,(output,target):(Value<B>,Value<B>))->Value<B>{
191 fn ff<B:Backend,const N:usize>(dim:i32,y:Tensor<B,N>,t:Tensor<B,N>,temperature:f32)->Result<Tensor<B,N>,String>{
192 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
193 let (ydims,tdims)=(y.dims(),t.dims());
194 if ydims==tdims{
195 let logy=if temperature.is_nan(){y.log()}else{log_softmax(y/temperature,dim)};
196 Ok(logy*t.neg())
197 }else{
198 Err(format!("incompatible shapes to cross entropy. ydims: {ydims:?} tdims: {tdims:?}"))
199 }
200 }
201 fn fi<B:Backend,const N:usize,const K:usize>(dim:i32,y:Tensor<B,N>,t:Tensor<B,K,Int>,temperature:f32)->Result<Tensor<B,K>,String>{
202 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
203 let (ydims,tdims)=(y.dims(),t.dims());
204 if ydims.iter().enumerate().filter_map(|(n,y)|(n!=dim).then_some(y)).eq(tdims.iter()){
205 let logy=if temperature.is_nan(){y.log()}else{log_softmax(y/temperature,dim)};
206 Ok(logy.gather(dim,t.unsqueeze_dim(dim)).neg().squeeze(dim))
207 }else{
208 Err(format!("incompatible shapes to cross entropy along dimension {dim}. ydims: {ydims:?} tdims: {tdims:?}"))
209 }
210 }
211 let (dim,temp)=(self.get_dim(),self.get_temperature());
212
213 match match (output,target){
214 (F1(y),F1(t))=>ff(dim,y,t,temp).map(Into::into),
215 (F2(y),F2(t))=>ff(dim,y,t,temp).map(Into::into),
216 (F3(y),F3(t))=>ff(dim,y,t,temp).map(Into::into),
217 (F4(y),F4(t))=>ff(dim,y,t,temp).map(Into::into),
218 (F5(y),F5(t))=>ff(dim,y,t,temp).map(Into::into),
219 (F6(y),F6(t))=>ff(dim,y,t,temp).map(Into::into),
220 (F7(y),F7(t))=>ff(dim,y,t,temp).map(Into::into),
221 (F8(y),F8(t))=>ff(dim,y,t,temp).map(Into::into),
222 (F1(y),I1(t))=>fi(dim,y.unsqueeze::<2>(),t,temp).map(Into::into),
223 (F2(y),I1(t))=>fi(dim,y,t,temp).map(Into::into),
224 (F3(y),I2(t))=>fi(dim,y,t,temp).map(Into::into),
225 (F4(y),I3(t))=>fi(dim,y,t,temp).map(Into::into),
226 (F5(y),I4(t))=>fi(dim,y,t,temp).map(Into::into),
227 (F6(y),I5(t))=>fi(dim,y,t,temp).map(Into::into),
228 (F7(y),I6(t))=>fi(dim,y,t,temp).map(Into::into),
229 (F7(y),I7(t))=>fi(dim,y,t,temp).map(Into::into),
230 (Value::Incompatible(y),_)=>Err(y),
231 (_,Value::Incompatible(t))=>Err(t),(Value::Multi(y),Value::Multi(t))=>if y.len()==t.len(){Ok(Value::Multi(y.into_iter().zip(t).map(|x|self.forward_typed::<_,Value<B>>(x)).collect()))}else{Err("mismatched lengths".into())},
233 _=>Err("incompatible".into())
234 }{
235 Err(e)=>Value::Incompatible(e),Ok(x)=>x
236 }
237 }
238}
239impl<B:Backend> AI<(Value<B>,Value<B>),Value<B>> for CrossEntropyLoss<B>{
240 fn forward(&self,(output,target):(Value<B>,Value<B>))->Value<B>{
241 let mut op=().fix_type::<Value<B>>().cross_entropy(1.0);
242 if !self.logits{op.set_temperature(f32::NAN)}
243 op.forward((output,target))
244 }
245}
246impl<B:Backend> AI<(Value<B>,Value<B>),LossOutput<B>> for SquaredErrorLayer{
247 fn forward(&self,(output,target):(Value<B>,Value<B>))->LossOutput<B>{
248 let loss=self.forward((output.clone(),target.clone()));
249 LossOutput::new(loss,output,target)
250 }
251}
252impl<B:Backend> AI<(Value<B>,Value<B>),Value<B>> for SquaredErrorLayer{
253 fn forward(&self,(output,target):(Value<B>,Value<B>))->Value<B>{
254 fn f<B:Backend,const N:usize>(y:Tensor<B,N>,t:Tensor<B,N>)->Value<B>{
255 if y.dims()==t.dims(){MseLoss.forward_no_reduction(y,t).into()}else{"compatible inputs for squared error are float tensors of the same shape".into()}
256 }
257 match (output.float(),target.float()){(F1(y),F1(t))=>f(y,t),(F2(y),F2(t))=>f(y,t),(F3(y),F3(t))=>f(y,t),(F4(y),F4(t))=>f(y,t),(F5(y),F5(t))=>f(y,t),(F6(y),F6(t))=>f(y,t),(F7(y),F7(t))=>f(y,t),(F8(y),F8(t))=>f(y,t),(Value::Incompatible(y),_)=>y.into(),(_,Value::Incompatible(t))=>t.into(),(Value::Multi(y),t)=>broadcast_multi(y,t.into_multi(),|y,t|self.forward((y,t))),(y,Value::Multi(t))=>broadcast_multi(y.into_multi(),t,|y,t|self.forward((y,t))),_=>"compatible inputs for squared error are float tensors of the same shape".into()}
258 }
259}
260impl<B:Backend> AI<(Value<B>,Value<B>),Vec<f32>> for SquaredErrorLayer{
261 fn forward(&self,(output,target):(Value<B>,Value<B>))->Vec<f32>{
262 let error:Value<B>=self.forward((output,target));
263 error.into_float_vec()
264 }
265}
266impl<B:Backend> AI<(Value<B>,Value<B>),f32> for SquaredErrorLayer{
267 fn forward(&self,(output,target):(Value<B>,Value<B>))->f32{().fix_type::<Value<B>>().squared_error().mean().forward((output,target))}
268}
269impl<B:Backend> AI<Value<B>,Tensor<B,1>> for MeanLayer{
270 fn forward(&self,input:Value<B>)->Tensor<B,1>{
271 fn avg<B:Backend,const N:usize>(x:Tensor<B,N>)->Tensor<B,1>{x.mean()}
272 let l=input.len();
273
274 if l==0{return Tensor::from_data(TensorData::new(vec![f32::NAN],[1]),&Default::default())}
275 match input.float(){F1(x)=>avg(x),F2(x)=>avg(x),F3(x)=>avg(x),F4(x)=>avg(x),F5(x)=>avg(x),F6(x)=>avg(x),F7(x)=>avg(x),F8(x)=>avg(x),Value::Incompatible(e)=>panic!("Could not reduce to a scalar due to incompatibility: {e}"),Value::Multi(v)=>v.into_iter().map(|x|self.forward(x)).reduce(|x:Tensor<B,1>,y:Tensor<B,1>|x+y).unwrap()/l as f32,_=>panic!("internal error")}
276 }
277}
278impl<B:Backend> AI<Value<B>,Tensor<B,1>> for SumLayer{
279 fn forward(&self,input:Value<B>)->Tensor<B,1>{
280 fn sum<B:Backend,const N:usize>(x:Tensor<B,N>)->Tensor<B,1>{x.sum()}
281 let l=input.len();
282
283 if l==0{return Tensor::from_data(TensorData::new(vec![f32::NAN],[1]),&Default::default())}
284 match input.float(){F1(x)=>sum(x),F2(x)=>sum(x),F3(x)=>sum(x),F4(x)=>sum(x),F5(x)=>sum(x),F6(x)=>sum(x),F7(x)=>sum(x),F8(x)=>sum(x),Value::Incompatible(e)=>panic!("Could not reduce to a scalar due to incompatibility: {e}"),Value::Multi(v)=>v.into_iter().map(|x|self.forward(x)).reduce(|x:Tensor<B,1>,y:Tensor<B,1>|x+y).unwrap(),_=>panic!("internal error")}
285 }
286}
287impl<B:Backend> AI<Value<B>,Value<B>> for Conv2d<B>{
288 fn forward(&self,input:Value<B>)->Value<B>{
289 fn f<B:Backend,const N:usize>(input:Tensor<B,N>,layer:&Conv2d<B>)->Value<B>{let mut dims=input.dims();
291 let n:usize=dims.iter().product();
292
293 let c=if N<3{1}else{dims[N-3]};
294 let h=if N<2{1}else{dims[N-2]};
295 let w=dims[N-1];
296
297 let b=n/(c*h*w);
298 let output=layer.forward(input.reshape([b,c,h,w]));
299
300 let [_b,c,h,w]=output.dims();
301
302 dims[N-1]=w;
303 if N<3&&c!=1{return F3(output.reshape([c,h,w]))}else if N>=3{dims[N-3]=c}
304 if N<2&&h!=1{return F2(output.reshape([h,w]))}else if N>=2{dims[N-2]=h}
305 output.reshape(dims).into()
306 }
307 let l=self;
308
309 match input.float(){F1(x)=>f(x,l),F2(x)=>f(x,l),F3(x)=>f(x,l),F4(x)=>f(x,l),F5(x)=>f(x,l),F6(x)=>f(x,l),F7(x)=>f(x,l),F8(x)=>f(x,l),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|AI::forward(self,x)).collect(),_=>panic!("internal error")}
310 }
311}
312impl<B:Backend> AI<Value<B>,Value<B>> for CrossEntropyLayer{
313 fn forward(&self,input:Value<B>)->Value<B>{
314 match input{
315 Value::Incompatible(e)=>e.into(),
316 Value::Multi(v)=>if v.len()==2{
317 let [output,target]=v.try_into().unwrap();
318 self.forward((output,target))
319 }else{
320 v.into_iter().map(|x|self.forward(x)).collect()
321 },
322 _=>"cross entropy inputs must be in pairs".into()
323 }
324 }
325}
326impl<B:Backend> AI<Value<B>,Value<B>> for CrossEntropyLoss<B>{
327 fn forward(&self,input:Value<B>)->Value<B>{
328 let mut op=CrossEntropyLayer::new(1.0);
329 if !self.logits{op.set_temperature(f32::NAN)}
330 op.forward(input)
331 }
332}
333impl<B:Backend> AI<Value<B>,Value<B>> for MeanLayer{
334 fn forward(&self,input:Value<B>)->Value<B>{
335 fn avg<B:Backend,const N:usize,const K:usize>(d:i32,x:Tensor<B,N>)->Tensor<B,K>{
336 let d=if d<0{N-((-d) as usize)}else{d as usize};
337 x.mean_dim(d).squeeze(d)
338 }
339 let l=input.len();
340
341 if l==0{return input}
342 match self.get_reduction_mode(){
343 ReductionMode::Component=>F1(self.forward(input)),
344 ReductionMode::Dim(d)=>{
345 if let Some(r)=input.rank(){
346 if d>=r as i32||d<(-(r as i32)){return format!("rank {r} is too low to cat along dimension {d}").into()}
347 }
348 match input.float(){F1(x)=>F1(x.mean()),F2(x)=>F1(avg(d,x)),F3(x)=>F2(avg(d,x)),F4(x)=>F3(avg(d,x)),F5(x)=>F4(avg(d,x)),F6(x)=>F5(avg(d,x)),F7(x)=>F6(avg(d,x)),F8(x)=>F7(avg(d,x)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|self.forward(x)).reduce(|x:Value<B>,y:Value<B>|x+y).unwrap()/l as f32,_=>panic!("internal error")}
349 },
350 ReductionMode::Tensor=>match input.float(){Value::Multi(v)=>v.into_iter().reduce(|x,y|x+y).unwrap()/l as f32,x=>x}
351 }
352 }
353}
354impl<B:Backend> AI<Value<B>,Value<B>> for MseLoss{
355 fn forward(&self,input:Value<B>)->Value<B>{SquaredErrorLayer::new().forward(input)}
356}
357impl<B:Backend> AI<Value<B>,Value<B>> for SquaredErrorLayer{
358 fn forward(&self,input:Value<B>)->Value<B>{
359 match input{
360 Value::Incompatible(e)=>e.into(),
361 Value::Multi(v)=>if v.len()==2{
362 let [output,target]=v.try_into().unwrap();
363 self.forward((output,target))
364 }else{
365 v.into_iter().map(|x|self.forward(x)).collect()
366 },
367 _=>"squared error inputs must be in pairs".into()
368 }
369 }
370}
371impl<B:Backend> AI<Value<B>,Value<B>> for SumLayer{
372 fn forward(&self,input:Value<B>)->Value<B>{
373 fn sum<B:Backend,const N:usize,const K:usize>(d:i32,x:Tensor<B,N>)->Tensor<B,K>{
374 let d=if d<0{N-((-d) as usize)}else{d as usize};
375 x.mean_dim(d).squeeze(d)
376 }
377 let l=input.len();
378
379 if l==0{return input}
380 match self.get_reduction_mode(){
381 ReductionMode::Component=>F1(self.forward(input)),
382 ReductionMode::Dim(d)=>{
383 if let Some(r)=input.rank(){
384 if d>=r as i32||d<(-(r as i32)){return format!("rank {r} is too low to cat along dimension {d}").into()}
385 }
386 match input.float(){F1(x)=>F1(x.sum()),F2(x)=>F1(sum(d,x)),F3(x)=>F2(sum(d,x)),F4(x)=>F3(sum(d,x)),F5(x)=>F4(sum(d,x)),F6(x)=>F5(sum(d,x)),F7(x)=>F6(sum(d,x)),F8(x)=>F7(sum(d,x)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|self.forward(x)).reduce(|x:Value<B>,y:Value<B>|x+y).unwrap(),_=>panic!("internal error")}
387 },
388 ReductionMode::Tensor=>match input.float(){Value::Multi(v)=>v.into_iter().reduce(|x,y|x+y).unwrap(),x=>x}
389 }
390 }
391}
392impl<B:Backend> AI<Value<B>,f32> for MeanLayer{
393 fn forward(&self,input:Value<B>)->f32{
394 let y:Tensor<B,1>=self.forward(input);
395 y.into_scalar().to_f32()
396 }
397}
398impl<B:Backend> AI<Value<B>,f32> for SumLayer{
399 fn forward(&self,input:Value<B>)->f32{
400 let y:Tensor<B,1>=self.forward(input);
401 y.into_scalar().to_f32()
402 }
403}
404impl<B:Backend> AI<Value<B>,Value<B>> for AccQLayer{
405 fn forward(&self,input:Value<B>)->Value<B>{
406 fn acc_q<B:Backend,const N:usize>(dim:i32,gamma:f32,i:Tensor<B,N>)->Tensor<B,N>{
407 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
408 let mut q=i.split(1,dim);
409 q.iter_mut().rev().fold(None,|future,present|{
410 if let Some(f)=future{*present=f*gamma+present.clone()}
411 Some(present.clone())
412 });
413 Tensor::cat(q,dim)
414 }
415 let (dim,gamma)=(self.get_dim(),self.get_gamma());
416
417 match input.float(){F1(x)=>F1(acc_q(dim,gamma,x)),F2(x)=>F2(acc_q(dim,gamma,x)),F3(x)=>F3(acc_q(dim,gamma,x)),F4(x)=>F4(acc_q(dim,gamma,x)),F5(x)=>F5(acc_q(dim,gamma,x)),F6(x)=>F6(acc_q(dim,gamma,x)),F7(x)=>F7(acc_q(dim,gamma,x)),F8(x)=>F8(acc_q(dim,gamma,x)),Value::Incompatible(x)=>x.into(),Value::Multi(x)=>Value::Multi(x.into_iter().map(|x|self.forward(x)).collect()),_=>panic!("unexpected non float value")}
418 }
419}
420impl<B:Backend> AI<Value<B>,u32> for ChooseLayer{
421 fn forward(&self,input:Value<B>)->u32{
422 let (dim,temperature)=(self.get_dim(),self.get_temperature());
423
424 match input.float(){
425 F1(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
426 F2(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
427 F3(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
428 F4(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
429 F5(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
430 F6(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
431 F7(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
432 F8(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
433 Value::Incompatible(e)=>panic!("Could not create scalar due to incompatibility: {e}"),
434 Value::Multi(v)=>if v.len()==1{self.forward(v.into_iter().next().unwrap())}else{panic!("Cannot soft choose one scalar from multiple values")},
435 _=>panic!("internal error")
436 }
437 }
438}
439impl<B:Backend> AI<Value<B>,Vec<u32>> for ChooseLayer{
440 fn forward(&self,input:Value<B>)->Vec<u32>{
441 let (dim,temperature)=(self.get_dim(),self.get_temperature());
442
443 match input.float(){
444 F1(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
445 F2(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
446 F3(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
447 F4(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
448 F5(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
449 F6(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
450 F7(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
451 F8(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
452 Value::Incompatible(e)=>panic!("Could not create vector due to incompatibility: {e}"),
453 Value::Multi(v)=>v.into_iter().flat_map(|x|self.forward_typed::<_,Vec<u32>>(x)).collect(),
454 _=>panic!("internal error")
455 }
456 }
457}
458impl<B:Backend> AI<Value<B>,Value<B>> for Dropout{
459 fn forward(&self,input:Value<B>)->Value<B>{
460 match input.float(){F1(x)=>F1(self.forward(x)),F2(x)=>F2(self.forward(x)),F3(x)=>F3(self.forward(x)),F4(x)=>F4(self.forward(x)),F5(x)=>F5(self.forward(x)),F6(x)=>F6(self.forward(x)),F7(x)=>F7(self.forward(x)),F8(x)=>F8(self.forward(x)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(|x|AI::forward(self,x)).collect()),_=>panic!("internal error")}
461 }
462}
463impl<B:Backend> AI<Value<B>,Value<B>> for Embedding<B>{
464 fn forward(&self,input:Value<B>)->Value<B>{
465 fn apply_embed<B:Backend,const N:usize,const K:usize>(this:&Embedding<B>,x:Tensor<B,N,Int>)->Tensor<B,K>{
466 let dims=x.dims();
467 let [batch,seq]=[dims[0],dims.iter().skip(1).product()];
468 let x=x.reshape([batch,seq]);
469 let y=this.forward(x);
470 let embed=y.dims().last().copied().unwrap();
471 let mut ydims=[0;K];
472 ydims[..N].copy_from_slice(&dims);
473 ydims[N]=embed;
474 y.reshape(ydims)
475 }
476 fn apply_linear<B:Backend,const N:usize>(this:&Embedding<B>,x:Tensor<B,N>)->Tensor<B,N>{
477 Linear{bias:None,weight:this.weight.clone()}.forward(x)
478 }
479 match input{F1(x)=>apply_linear(self,x).into(),F2(x)=>apply_linear(self,x).into(),F3(x)=>apply_linear(self,x).into(),F4(x)=>apply_linear(self,x).into(),F5(x)=>apply_linear(self,x).into(),F6(x)=>apply_linear(self,x).into(),F7(x)=>apply_linear(self,x).into(),F8(x)=>apply_linear(self,x).into(),I1(x)=>apply_embed::<B,1,2>(self,x).into(),I2(x)=>apply_embed::<B,2,3>(self,x).into(),I3(x)=>apply_embed::<B,3,4>(self,x).into(),I4(x)=>apply_embed::<B,4,5>(self,x).into(),I5(x)=>apply_embed::<B,5,6>(self,x).into(),I6(x)=>apply_embed::<B,6,7>(self,x).into(),I7(x)=>apply_embed::<B,7,8>(self,x).into(),I8(_x)=>"embedding output would exceed maximum supported rank".into(),Value::Incompatible(x)=>x.into(),Value::Multi(x)=>x.into_iter().map(|x|AI::forward(self,x)).collect::<Vec<_>>().into(),_=>"embedding is only available for float or int inputs".into()}
480 }
481}
482impl<B:Backend> AI<Value<B>,Value<B>> for LayerNorm<B>{
483 fn forward(&self,input:Value<B>)->Value<B>{
484 fn f<B:Backend,const N:usize>(input:Tensor<B,N>,layer:&LayerNorm<B>)->Value<B>{
485 let b=layer.beta.dims();
486 let g=layer.gamma.dims();
487 let i=input.dims();
488
489 if b!=g{return format!("malformed layer norm. beta dims: {b:?}. gamma dims: {g:?}.").into()}
490 if b.last()!=i.last(){return format!("layer norm for dimension {b:?} is not compatible with input dimensions {i:?}. The last dimension must match the norm dimension.").into()}
491 layer.forward(input).into()
492 }
493 let l=self;
494
495 match input.float(){F1(x)=>f(x,l),F2(x)=>f(x,l),F3(x)=>f(x,l),F4(x)=>f(x,l),F5(x)=>f(x,l),F6(x)=>f(x,l),F7(x)=>f(x,l),F8(x)=>f(x,l),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(|x|AI::forward(self,x)).collect()),_=>panic!("internal error")}
496 }
497}
498impl<B:Backend> AI<Value<B>,Value<B>> for Linear<B>{
499 fn forward(&self,input:Value<B>)->Value<B>{
500 match input.float(){F1(x)=>F1(self.forward(x)),F2(x)=>F2(self.forward(x)),F3(x)=>F3(self.forward(x)),F4(x)=>F4(self.forward(x)),F5(x)=>F5(self.forward(x)),F6(x)=>F6(self.forward(x)),F7(x)=>F7(self.forward(x)),F8(x)=>F8(self.forward(x)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(|x|AI::forward(self,x)).collect()),_=>panic!("internal error")}
501 }
502}
503impl<B:Backend> AI<Value<B>,Value<B>> for Relu{
504 fn forward(&self,input:Value<B>)->Value<B>{
505 match input.float(){F1(x)=>F1(self.forward(x)),F2(x)=>F2(self.forward(x)),F3(x)=>F3(self.forward(x)),F4(x)=>F4(self.forward(x)),F5(x)=>F5(self.forward(x)),F6(x)=>F6(self.forward(x)),F7(x)=>F7(self.forward(x)),F8(x)=>F8(self.forward(x)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(|x|AI::forward(self,x)).collect()),_=>panic!("internal error")}
506 }
507}
508impl<B:Backend> AI<Value<B>,Value<B>> for SoftmaxLayer{
509 fn forward(&self,input:Value<B>)->Value<B>{
510 fn f<B:Backend,const N:usize>(dim:i32,temperature:f32,x:Tensor<B,N>)->Tensor<B,N>{
511 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
512 softmax(x/temperature,dim)
513 }
514 let (dim,temperature)=(self.get_dim(),self.get_temperature());
515
516 match input.float(){F1(x)=>F1(f(dim,temperature,x)),F2(x)=>F2(f(dim,temperature,x)),F3(x)=>F3(f(dim,temperature,x)),F4(x)=>F4(f(dim,temperature,x)),F5(x)=>F5(f(dim,temperature,x)),F6(x)=>F6(f(dim,temperature,x)),F7(x)=>F7(f(dim,temperature,x)),F8(x)=>F8(f(dim,temperature,x)),Value::Incompatible(x)=>x.into(),Value::Multi(x)=>x.into_iter().map(|x|self.forward(x)).collect(),_=>panic!("unexpected non float value")}
517 }
518}
519impl<B:Backend> AI<Value<B>,Value<B>> for ChooseLayer{
520 fn forward(&self,input:Value<B>)->Value<B>{let (dim,temperature)=(self.get_dim(),self.get_temperature());
522 let d=if dim<0{input.rank().unwrap_or(8)-((-dim) as usize)}else{dim as usize};
523
524 match input.float(){
525 F1(x)=>I1(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}),
526 F2(x)=>I1(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze(d)),
527 F3(x)=>I2(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze(d)),
528 F4(x)=>I3(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze(d)),
529 F5(x)=>I4(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze(d)),
530 F6(x)=>I5(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze(d)),
531 F7(x)=>I6(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze(d)),
532 F8(x)=>I7(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze(d)),
533 Value::Incompatible(e)=>e.into(),
534 Value::Multi(v)=>Value::Multi(v.into_iter().map(|v|self.forward_typed::<_,Value<B>>(v)).collect()),
535 _=>panic!("internal error")}
536 }
537}
538impl<B:Backend> AI<Value<B>,Value<B>> for MaxPool2d{
539 fn forward(&self,input:Value<B>)->Value<B>{
540 fn f<B:Backend,const N:usize>(pool:&MaxPool2d,x:Tensor<B,N>)->Value<B>{
541 match N{
542 0=>panic!("internal error"),
543 1=>f::<B,2>(pool,x.unsqueeze()).squeeze(),
544 2=>f::<B,3>(pool,x.unsqueeze()).squeeze(),
545 3=>f::<B,4>(pool,x.unsqueeze()).squeeze(),
546 4=>pool.forward(Value::from(x).unwrap_f4()).into(),
547 _=>{
548 let mut dims=x.dims();
549
550 let [channels,h,w]=[dims[N-3],dims[N-2],dims[N-1]];
551 let big:usize=dims.iter().take(N-3).product();
552 let y=x.reshape([big,channels,h,w]);
553
554 dims[N-3..].copy_from_slice(&y.dims()[1..]);
555
556 let y=pool.forward(y);
557 y.reshape(dims).into()
558 }
559 }
560 }
561 match input.float(){
562 F1(x)=>f(self,x),
563 F2(x)=>f(self,x),
564 F3(x)=>f(self,x),
565 F4(x)=>f(self,x),
566 F5(x)=>f(self,x),
567 F6(x)=>f(self,x),
568 F7(x)=>f(self,x),
569 F8(x)=>f(self,x),
570 Value::Incompatible(e)=>e.into(),
571 Value::Multi(v)=>v.into_iter().map(|x|AI::forward(self,x)).collect(),
572 _=>panic!("Internal error")
573 }
574 }
575}
576impl<B:Backend> AI<Value<B>,Value<B>> for RotaryEncoding<B>{
577 fn forward(&self,input:Value<B>)->Value<B>{AI::forward(self,(input,0)).0}
578}
579impl<B:Backend> AI<(Value<B>,usize),(Value<B>,usize)> for RotaryEncoding<B>{
580 fn forward(&self,(input,offset):(Value<B>,usize))->(Value<B>,usize){
581 fn apply<B:Backend,const D:usize>(a:&RotaryEncoding<B>,input:Tensor<B,D>,offset:usize)->Value<B>{
582 assert!(D>=2);
583 const MAX_KERNEL:usize=65535; let device=input.device();
585 let freq=&a.freq_complex;
586 let shape=input.shape();
587
588 let (context,key)=(shape.dims[D-2],shape.dims[D-1]);
589 let [distance,head,_2]=freq.dims();
590
591 if context>distance{return "context length must not exceed rotary distance".into()}
592 if key%head!=0{return "input dimension must be a multiple of head".into()}
593 let count=shape.num_elements();
594 let big=count/(context*key);
595 let heads=key/head;
596 let group=count/head; let input=input.reshape([big,context,heads,head]).swap_dims(1,2).reshape([big*heads,context,head/2,2]);
598 let sign=Tensor::<B,2>::from_floats([[1.0,0.0,0.0,1.0],[0.0,-1.0,1.0,0.0]],&device).unsqueeze();
599
600 let chunks=input.chunk(group.div_ceil(MAX_KERNEL),0).into_iter().map(|x|{
601 let smaller=x.dims()[0];
602 let x=x.matmul(sign.clone()).reshape([smaller,context,head,2])*freq.clone().slice([offset..context+offset]).unsqueeze();
603 x.sum_dim(3)
604 }).collect();
605 Tensor::cat(chunks,0).reshape([big,heads,context,head]).swap_dims(1,2).reshape::<D,_>(shape).into()
606 }
607
608 (match input.float(){F1(x)=>apply(self,x.unsqueeze::<2>(),offset).squeeze(),F2(x)=>apply(self,x,offset),F3(x)=>apply(self,x,offset),F4(x)=>apply(self,x,offset),F5(x)=>apply(self,x,offset),F6(x)=>apply(self,x,offset),F7(x)=>apply(self,x,offset),F8(x)=>apply(self,x,offset),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|AI::forward(self,(x,offset)).0).collect(),_=>panic!("internal error")},offset)
609 }
610}
611impl<B:Backend> AI<Value<B>,Value<B>> for Tanh{
612 fn forward(&self,input:Value<B>)->Value<B>{
613 match input.float(){F1(x)=>F1(self.forward(x)),F2(x)=>F2(self.forward(x)),F3(x)=>F3(self.forward(x)),F4(x)=>F4(self.forward(x)),F5(x)=>F5(self.forward(x)),F6(x)=>F6(self.forward(x)),F7(x)=>F7(self.forward(x)),F8(x)=>F8(self.forward(x)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(|x|AI::forward(self,x)).collect()),_=>panic!("internal error")}
614 }
615}
616impl<B:Backend> Abs for Value<B>{
617 fn abs(self)->Self::Output{
618 match self{B1(x)=>B1(x),B2(x)=>B2(x),B3(x)=>B3(x),B4(x)=>B4(x),B5(x)=>B5(x),B6(x)=>B6(x),B7(x)=>B7(x),B8(x)=>B8(x),F1(x)=>F1(x.abs()),F2(x)=>F2(x.abs()),F3(x)=>F3(x.abs()),F4(x)=>F4(x.abs()),F5(x)=>F5(x.abs()),F6(x)=>F6(x.abs()),F7(x)=>F7(x.abs()),F8(x)=>F8(x.abs()),I1(x)=>I1(x.abs()),I2(x)=>I2(x.abs()),I3(x)=>I3(x.abs()),I4(x)=>I4(x.abs()),I5(x)=>I5(x.abs()),I6(x)=>I6(x.abs()),I7(x)=>I7(x.abs()),I8(x)=>I8(x.abs()),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(Value::abs).collect()}
619 }
620 type Output=Value<B>;
621}
622impl<B:Backend> AsRef<Self> for Value<B>{
623 fn as_ref(&self)->&Self{self}
624}
625impl<B:Backend> Cat for Value<B>{
626 fn cat(self,d:i32)->Self{
628 fn f<B:Backend,I:Iterator<Item=Tensor<B,N,K>>,K:BasicOps<B>+TensorKind<B>,const N:usize>(d:i32,x0:Tensor<B,N,K>,tensors:I)->Value<B> where Tensor<B,N,K>:Into<Value<B>>{
629 if d>=N as i32||d<(-(N as i32)){return format!("rank {N} is too low to cat along dimension {d}").into()}
630 let d=if d<0{N-((-d) as usize)}else{d as usize};
631 let shape=x0.dims();
632 let tensors:Vec<Tensor<B,N,K>>=once(x0).chain(tensors).collect();
633
634 if let Err(e)=tensors.iter().try_for_each(|x|{
635 let mut xshape=x.dims();
636 xshape[d]=shape[d];
637 if shape==xshape{Ok(())}else{Err("mismatched shapes {shape:?}, {xshape:?}")}
638 }){
639 return e.into()
640 }
641
642 Tensor::cat(tensors,d).into()
643 }
644 let v=if let Value::Multi(v)=self{v}else{return self};
645
646 if let Some(n)=v.iter().position(Value::is_incompatible){return v.into_iter().nth(n).unwrap()}
647 if v.iter().all(Value::is_multi){return v.into_iter().map(|x|x.cat(d)).collect()}
648 if v.iter().any(Value::is_multi){return "cannot mix single and multi values in a cat operation".into()}
649 let variant=mem::discriminant(&v[0]);
650
651 if v.iter().any(|x|mem::discriminant(x)!=variant){return "cannot mix variants in a cat operation".into()}
652 let mut v=v.into_iter();
653
654 match v.next().unwrap(){B1(x0)=>f(d,x0,v.map(Value::unwrap_b1)),B2(x0)=>f(d,x0,v.map(Value::unwrap_b2)),B3(x0)=>f(d,x0,v.map(Value::unwrap_b3)),B4(x0)=>f(d,x0,v.map(Value::unwrap_b4)),B5(x0)=>f(d,x0,v.map(Value::unwrap_b5)),B6(x0)=>f(d,x0,v.map(Value::unwrap_b6)),B7(x0)=>f(d,x0,v.map(Value::unwrap_b7)),B8(x0)=>f(d,x0,v.map(Value::unwrap_b8)),F1(x0)=>f(d,x0,v.map(Value::unwrap_f1)),F2(x0)=>f(d,x0,v.map(Value::unwrap_f2)),F3(x0)=>f(d,x0,v.map(Value::unwrap_f3)),F4(x0)=>f(d,x0,v.map(Value::unwrap_f4)),F5(x0)=>f(d,x0,v.map(Value::unwrap_f5)),F6(x0)=>f(d,x0,v.map(Value::unwrap_f6)),F7(x0)=>f(d,x0,v.map(Value::unwrap_f7)),F8(x0)=>f(d,x0,v.map(Value::unwrap_f8)),I1(x0)=>f(d,x0,v.map(Value::unwrap_i1)),I2(x0)=>f(d,x0,v.map(Value::unwrap_i2)),I3(x0)=>f(d,x0,v.map(Value::unwrap_i3)),I4(x0)=>f(d,x0,v.map(Value::unwrap_i4)),I5(x0)=>f(d,x0,v.map(Value::unwrap_i5)),I6(x0)=>f(d,x0,v.map(Value::unwrap_i6)),I7(x0)=>f(d,x0,v.map(Value::unwrap_i7)),I8(x0)=>f(d,x0,v.map(Value::unwrap_i8)),Value::Incompatible(_e)=>panic!("internal error not handled in correct location"),Value::Multi(_e)=>panic!("internal error not handled in correct location")}
655 }
656 type Output=Self;
657}
658impl<B:Backend> Decompose for LossOutput<B>{
659 fn compose((loss,output,target):Self::Decomposition)->Self{Self::new(loss,output,target)}
660 fn decompose(self)->Self::Decomposition{(self.loss(),self.output(),self.target())}
661 fn decompose_cloned(&self)->Self::Decomposition{(self.loss(),self.output(),self.target())}
662 type Decomposition=(Value<B>,Value<B>,Value<B>);
663}
664impl<B:Backend> Decompose for Value<B>{
665 fn compose(decomposition:Self::Decomposition)->Self{decomposition}
666 fn decompose(self)->Self::Decomposition{self}
667 fn decompose_cloned(&self)->Self::Decomposition{self.clone()}
668 type Decomposition=Self;
669}
670impl<B:Backend> Default for Value<B>{
671 fn default()->Self{Self::Multi(Vec::new())}
672}
673impl<B:Backend> Display for Value<B>{
674 fn fmt(&self,f:&mut std::fmt::Formatter<'_>)->FmtResult{write!(f,"todo")}
675}
676impl<B:Backend> From<Vec<bool>> for Value<B>{
677 fn from(value:Vec<bool>)->Self{
678 let l=value.len();
679 let t:Tensor<B,1,Bool>=Tensor::from_data(TensorData::new(value,[l]),&Default::default());
680
681 t.into()
682 }
683}
684impl<B:Backend> From<Vec<f32>> for Value<B>{
685 fn from(value:Vec<f32>)->Self{
686 let l=value.len();
687 let t:Tensor<B,1>=Tensor::from_data(TensorData::new(value,[l]),&Default::default());
688
689 t.into()
690 }
691}
692impl<B:Backend> From<Vec<i32>> for Value<B>{
693 fn from(value:Vec<i32>)->Self{
694 let l=value.len();
695 let t:Tensor<B,1,Int>=Tensor::from_data(TensorData::new(value,[l]),&Default::default());
696
697 t.into()
698 }
699}
700impl<B:Backend> From<Vec<u32>> for Value<B>{
701 fn from(value:Vec<u32>)->Self{
702 let l=value.len();
703 let t:Tensor<B,1,Int>=Tensor::from_data(TensorData::new(value,[l]),&Default::default());
704
705 t.into()
706 }
707}
708impl<B:Backend> ModuleDisplay for Value<B>{
709 fn custom_content(&self,_content:Content)->Option<Content>{None}
710 fn custom_settings(&self)->Option<DisplaySettings>{None}
711 fn format(&self,s:DisplaySettings)->String{
712 match self{
713 B1(x)=>x.format(s),
714 B2(x)=>x.format(s),
715 B3(x)=>x.format(s),
716 B4(x)=>x.format(s),
717 B5(x)=>x.format(s),
718 B6(x)=>x.format(s),
719 B7(x)=>x.format(s),
720 B8(x)=>x.format(s),
721 F1(x)=>x.format(s),
722 F2(x)=>x.format(s),
723 F3(x)=>x.format(s),
724 F4(x)=>x.format(s),
725 F5(x)=>x.format(s),
726 F6(x)=>x.format(s),
727 F7(x)=>x.format(s),
728 F8(x)=>x.format(s),
729 I1(x)=>x.format(s),
730 I2(x)=>x.format(s),
731 I3(x)=>x.format(s),
732 I4(x)=>x.format(s),
733 I5(x)=>x.format(s),
734 I6(x)=>x.format(s),
735 I7(x)=>x.format(s),
736 I8(x)=>x.format(s),
737 Value::Incompatible(e)=>e.to_string(),
738 Value::Multi(v)=>"[".chars().chain(v.iter().flat_map(|x|{
739 let x:Vec<char>=x.format(s.clone()).chars().chain(", ".chars()).collect();
740 x
741 })).chain("]".chars()).collect()
742 }
743 }
744}
745impl<B:Backend> ModuleDisplayDefault for Value<B>{
746 fn content(&self,content:Content)->Option<Content>{Some(content)}
747 fn num_params(&self)->usize{Module::num_params(self)}
748}
749impl<B:Backend> From<String> for Value<B>{
750 fn from(value:String)->Self{Self::Incompatible(value)}
751}
752impl<B:Backend> From<Value<B>> for ValueData{
753 fn from(value:Value<B>)->Self{
754 match value{B1(x)=>BX(x.into_data()),B2(x)=>BX(x.into_data()),B3(x)=>BX(x.into_data()),B4(x)=>BX(x.into_data()),B5(x)=>BX(x.into_data()),B6(x)=>BX(x.into_data()),B7(x)=>BX(x.into_data()),B8(x)=>BX(x.into_data()),F1(x)=>FX(x.into_data()),F2(x)=>FX(x.into_data()),F3(x)=>FX(x.into_data()),F4(x)=>FX(x.into_data()),F5(x)=>FX(x.into_data()),F6(x)=>FX(x.into_data()),F7(x)=>FX(x.into_data()),F8(x)=>FX(x.into_data()),I1(x)=>IX(x.into_data()),I2(x)=>IX(x.into_data()),I3(x)=>IX(x.into_data()),I4(x)=>IX(x.into_data()),I5(x)=>IX(x.into_data()),I6(x)=>IX(x.into_data()),I7(x)=>IX(x.into_data()),I8(x)=>IX(x.into_data()),Value::Incompatible(e)=>ValueData::Incompatible(e),Value::Multi(v)=>ValueData::Multi(v.into_iter().map(ValueData::from).collect())}
755 }
756}
757impl<B:Backend> From<ValueData> for Value<B>{
758 fn from(value:ValueData)->Self{
759 let device=Default::default();
760 match value{
761 BX(data)=>match data.shape.len(){1=>B1(Tensor::from_data(data,&device)),2=>B2(Tensor::from_data(data,&device)),3=>B3(Tensor::from_data(data,&device)),4=>B4(Tensor::from_data(data,&device)),5=>B5(Tensor::from_data(data,&device)),6=>B6(Tensor::from_data(data,&device)),7=>B7(Tensor::from_data(data,&device)),8=>B8(Tensor::from_data(data,&device)),_=>panic!("tensor ranks above 8 are currently not supported")},
762 FX(data)=>match data.shape.len(){1=>F1(Tensor::from_data(data,&device)),2=>F2(Tensor::from_data(data,&device)),3=>F3(Tensor::from_data(data,&device)),4=>F4(Tensor::from_data(data,&device)),5=>F5(Tensor::from_data(data,&device)),6=>F6(Tensor::from_data(data,&device)),7=>F7(Tensor::from_data(data,&device)),8=>F8(Tensor::from_data(data,&device)),_=>panic!("tensor ranks above 8 are currently not supported")},
763 IX(data)=>match data.shape.len(){1=>I1(Tensor::from_data(data,&device)),2=>I2(Tensor::from_data(data,&device)),3=>I3(Tensor::from_data(data,&device)),4=>I4(Tensor::from_data(data,&device)),5=>I5(Tensor::from_data(data,&device)),6=>I6(Tensor::from_data(data,&device)),7=>I7(Tensor::from_data(data,&device)),8=>I8(Tensor::from_data(data,&device)),_=>panic!("tensor ranks above 8 are currently not supported")},
764 ValueData::Incompatible(e)=>e.into(),
765 ValueData::Multi(v)=>v.into_iter().map(Value::from).collect(),
766 }
767 }
768}
769impl<B:Backend> From<Vec<Value<B>>> for Value<B>{
770 fn from(value:Vec<Value<B>>)->Self{Self::Multi(value)}
771}
772impl<B:Backend> IntoIterator for Value<B>{
773 fn into_iter(self)->Self::IntoIter{self.into_multi().into_iter()}
774 type IntoIter=VecIntoIter<Value<B>>;
775 type Item=Value<B>;
776}
777impl<B:Backend> LossOutput<B>{
778 pub fn loss(&self)->Value<B>{self.loss.clone()}
780 pub fn new(loss:Value<B>,output:Value<B>,target:Value<B>)->Self{
782 Self{loss,output,target}
783 }
784 pub fn output(&self)->Value<B>{self.output.clone()}
786 pub fn target(&self)->Value<B>{self.target.clone()}
788}
789impl<B:Backend> Merge for Value<B>{
790 fn merge(&mut self,other:Self){
791 match (mem::take(self),other){
792 (Value::Multi(mut u),Value::Multi(v))=>{
793 u.extend(v);
794 *self=u.into();
795 },
796 (Value::Multi(mut u),v)=>if u.len()==0{
797 *self=v;
798 }else{
799 u.push(v);
800 *self=u.into();
801 },
802 (u,Value::Multi(mut v))=>if v.len()==0{
803 *self=u;
804 }else{
805 v.insert(0,u);
806 *self=v.into();
807 },
808 (u,v)=>*self=vec![u,v].into()
809 }
810 }
811}
812impl<B:Backend> Module<B> for Value<B>{
813 fn collect_devices(&self,devices:Vec<<B as Backend>::Device>)->Vec<<B as Backend>::Device>{
814 match self{B1(x)=>x.collect_devices(devices),B2(x)=>x.collect_devices(devices),B3(x)=>x.collect_devices(devices),B4(x)=>x.collect_devices(devices),B5(x)=>x.collect_devices(devices),B6(x)=>x.collect_devices(devices),B7(x)=>x.collect_devices(devices),B8(x)=>x.collect_devices(devices),F1(x)=>x.collect_devices(devices),F2(x)=>x.collect_devices(devices),F3(x)=>x.collect_devices(devices),F4(x)=>x.collect_devices(devices),F5(x)=>x.collect_devices(devices),F6(x)=>x.collect_devices(devices),F7(x)=>x.collect_devices(devices),F8(x)=>x.collect_devices(devices),I1(x)=>x.collect_devices(devices),I2(x)=>x.collect_devices(devices),I3(x)=>x.collect_devices(devices),I4(x)=>x.collect_devices(devices),I5(x)=>x.collect_devices(devices),I6(x)=>x.collect_devices(devices),I7(x)=>x.collect_devices(devices),I8(x)=>x.collect_devices(devices),Value::Incompatible(_e)=>devices,Value::Multi(v)=>v.iter().fold(devices,|devices,x|x.collect_devices(devices))}
815 }
816 fn devices(&self)->Vec<<B as Backend>::Device>{self.collect_devices(Vec::new())}
817 fn fork(self,device:&<B as Backend>::Device)->Self{
818 match self{B1(x)=>B1(x.fork(device)),B2(x)=>B2(x.fork(device)),B3(x)=>B3(x.fork(device)),B4(x)=>B4(x.fork(device)),B5(x)=>B5(x.fork(device)),B6(x)=>B6(x.fork(device)),B7(x)=>B7(x.fork(device)),B8(x)=>B8(x.fork(device)),F1(x)=>F1(x.fork(device)),F2(x)=>F2(x.fork(device)),F3(x)=>F3(x.fork(device)),F4(x)=>F4(x.fork(device)),F5(x)=>F5(x.fork(device)),F6(x)=>F6(x.fork(device)),F7(x)=>F7(x.fork(device)),F8(x)=>F8(x.fork(device)),I1(x)=>I1(x.fork(device)),I2(x)=>I2(x.fork(device)),I3(x)=>I3(x.fork(device)),I4(x)=>I4(x.fork(device)),I5(x)=>I5(x.fork(device)),I6(x)=>I6(x.fork(device)),I7(x)=>I7(x.fork(device)),I8(x)=>I8(x.fork(device)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.fork(device)).collect()}
819 }
820 fn into_record(self)->Self::Record{ConstantRecord}
821 fn load_file<F:FileRecorder<B>,P:Into<PathBuf>>(self,_filepath:P,_recorder:&F,_device:&<B as Backend>::Device)->Result<Self,RecorderError>{Ok(self)}
822 fn load_record(self,_record:Self::Record)->Self{self}
823 fn map<Mapper:ModuleMapper<B>>(self,mapper:&mut Mapper)->Self{
824 match self{B1(x)=>B1(x.map(mapper)),B2(x)=>B2(x.map(mapper)),B3(x)=>B3(x.map(mapper)),B4(x)=>B4(x.map(mapper)),B5(x)=>B5(x.map(mapper)),B6(x)=>B6(x.map(mapper)),B7(x)=>B7(x.map(mapper)),B8(x)=>B8(x.map(mapper)),F1(x)=>F1(x.map(mapper)),F2(x)=>F2(x.map(mapper)),F3(x)=>F3(x.map(mapper)),F4(x)=>F4(x.map(mapper)),F5(x)=>F5(x.map(mapper)),F6(x)=>F6(x.map(mapper)),F7(x)=>F7(x.map(mapper)),F8(x)=>F8(x.map(mapper)),I1(x)=>I1(x.map(mapper)),I2(x)=>I2(x.map(mapper)),I3(x)=>I3(x.map(mapper)),I4(x)=>I4(x.map(mapper)),I5(x)=>I5(x.map(mapper)),I6(x)=>I6(x.map(mapper)),I7(x)=>I7(x.map(mapper)),I8(x)=>I8(x.map(mapper)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.map(mapper)).collect()}
825 }
826 fn num_params(&self)->usize{
827 match self{B1(x)=>Module::num_params(x),B2(x)=>Module::num_params(x),B3(x)=>Module::num_params(x),B4(x)=>Module::num_params(x),B5(x)=>Module::num_params(x),B6(x)=>Module::num_params(x),B7(x)=>Module::num_params(x),B8(x)=>Module::num_params(x),F1(x)=>Module::num_params(x),F2(x)=>Module::num_params(x),F3(x)=>Module::num_params(x),F4(x)=>Module::num_params(x),F5(x)=>Module::num_params(x),F6(x)=>Module::num_params(x),F7(x)=>Module::num_params(x),F8(x)=>Module::num_params(x),I1(x)=>Module::num_params(x),I2(x)=>Module::num_params(x),I3(x)=>Module::num_params(x),I4(x)=>Module::num_params(x),I5(x)=>Module::num_params(x),I6(x)=>Module::num_params(x),I7(x)=>Module::num_params(x),I8(x)=>Module::num_params(x),Value::Incompatible(_e)=>0,Value::Multi(v)=>v.into_iter().map(|x|Module::num_params(x)).sum()}
828 }
829 fn quantize_weights(self,quantizer:&mut Quantizer)->Self{
830 match self{B1(x)=>B1(x.quantize_weights(quantizer)),B2(x)=>B2(x.quantize_weights(quantizer)),B3(x)=>B3(x.quantize_weights(quantizer)),B4(x)=>B4(x.quantize_weights(quantizer)),B5(x)=>B5(x.quantize_weights(quantizer)),B6(x)=>B6(x.quantize_weights(quantizer)),B7(x)=>B7(x.quantize_weights(quantizer)),B8(x)=>B8(x.quantize_weights(quantizer)),F1(x)=>F1(x.quantize_weights(quantizer)),F2(x)=>F2(x.quantize_weights(quantizer)),F3(x)=>F3(x.quantize_weights(quantizer)),F4(x)=>F4(x.quantize_weights(quantizer)),F5(x)=>F5(x.quantize_weights(quantizer)),F6(x)=>F6(x.quantize_weights(quantizer)),F7(x)=>F7(x.quantize_weights(quantizer)),F8(x)=>F8(x.quantize_weights(quantizer)),I1(x)=>I1(x.quantize_weights(quantizer)),I2(x)=>I2(x.quantize_weights(quantizer)),I3(x)=>I3(x.quantize_weights(quantizer)),I4(x)=>I4(x.quantize_weights(quantizer)),I5(x)=>I5(x.quantize_weights(quantizer)),I6(x)=>I6(x.quantize_weights(quantizer)),I7(x)=>I7(x.quantize_weights(quantizer)),I8(x)=>I8(x.quantize_weights(quantizer)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.quantize_weights(quantizer)).collect()}
831 }
832 fn save_file<F:FileRecorder<B>,P:Into<PathBuf>>(self,_filepath:P,_recorder:&F)->Result<(),RecorderError>{
833 Ok(())
834 }
835 fn to_device(self,device:&<B as Backend>::Device)->Self{
836 match self{B1(x)=>B1(x.to_device(device)),B2(x)=>B2(x.to_device(device)),B3(x)=>B3(x.to_device(device)),B4(x)=>B4(x.to_device(device)),B5(x)=>B5(x.to_device(device)),B6(x)=>B6(x.to_device(device)),B7(x)=>B7(x.to_device(device)),B8(x)=>B8(x.to_device(device)),F1(x)=>F1(x.to_device(device)),F2(x)=>F2(x.to_device(device)),F3(x)=>F3(x.to_device(device)),F4(x)=>F4(x.to_device(device)),F5(x)=>F5(x.to_device(device)),F6(x)=>F6(x.to_device(device)),F7(x)=>F7(x.to_device(device)),F8(x)=>F8(x.to_device(device)),I1(x)=>I1(x.to_device(device)),I2(x)=>I2(x.to_device(device)),I3(x)=>I3(x.to_device(device)),I4(x)=>I4(x.to_device(device)),I5(x)=>I5(x.to_device(device)),I6(x)=>I6(x.to_device(device)),I7(x)=>I7(x.to_device(device)),I8(x)=>I8(x.to_device(device)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.to_device(device)).collect()}
837 }
838 fn visit<Visitor:ModuleVisitor<B>>(&self,visitor:&mut Visitor){
839 match self{B1(x)=>x.visit(visitor),B2(x)=>x.visit(visitor),B3(x)=>x.visit(visitor),B4(x)=>x.visit(visitor),B5(x)=>x.visit(visitor),B6(x)=>x.visit(visitor),B7(x)=>x.visit(visitor),B8(x)=>x.visit(visitor),F1(x)=>x.visit(visitor),F2(x)=>x.visit(visitor),F3(x)=>x.visit(visitor),F4(x)=>x.visit(visitor),F5(x)=>x.visit(visitor),F6(x)=>x.visit(visitor),F7(x)=>x.visit(visitor),F8(x)=>x.visit(visitor),I1(x)=>x.visit(visitor),I2(x)=>x.visit(visitor),I3(x)=>x.visit(visitor),I4(x)=>x.visit(visitor),I5(x)=>x.visit(visitor),I6(x)=>x.visit(visitor),I7(x)=>x.visit(visitor),I8(x)=>x.visit(visitor),Value::Incompatible(_e)=>(),Value::Multi(v)=>v.iter().for_each(|x|x.visit(visitor))}
840 }
841 type Record=ConstantRecord;
842}
843impl<B:Backend> Serialize for Value<B>{
844 fn serialize<S:Serializer>(&self,serializer:S)->Result<S::Ok,S::Error>{ValueData::from(self.clone()).serialize(serializer)}
845}
846impl<B:Backend> Squeeze for Value<B>{
847 fn squeeze(self,d:i32)->Self{self.squeeze_dim(d)}
848 type Output=Self;
849}
850impl<B:Backend> Stack for Value<B>{
851 fn stack(self,d:i32)->Self{self.unsqueeze_dim(d).cat(d)}
853 type Output=Self;
854}
855impl<B:Backend> Unsqueeze for Value<B>{
856 fn unsqueeze(self,d:i32)->Self{self.unsqueeze_dim(d)}
857 type Output=Self;
858}
859impl<B:Backend> Value<B>{pub fn all(self)->Value<B>{
862 fn f<B:Backend,const N:usize>(x:Tensor<B,N,Bool>)->Value<B>{x.all().into()}
863 match self.bool(){B1(x)=>f(x),B2(x)=>f(x),B3(x)=>f(x),B4(x)=>f(x),B5(x)=>f(x),B6(x)=>f(x),B7(x)=>f(x),B8(x)=>f(x),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(Value::all).collect(),_=>panic!("internal error")}
864 }
865 pub fn all_dim(self,d:i32)->Value<B>{
867 fn f<B:Backend,const N:usize>(d:i32,x:Tensor<B,N,Bool>)->Value<B>{
868 if d>=N as i32||d<(-(N as i32)){return format!("rank {N} is too low to all along dimension {d}").into()}
869 let d=if d<0{N-((-d) as usize)}else{d as usize};
870 x.all_dim(d).into()
871 }
872 match self.bool(){B1(x)=>f(d,x),B2(x)=>f(d,x),B3(x)=>f(d,x),B4(x)=>f(d,x),B5(x)=>f(d,x),B6(x)=>f(d,x),B7(x)=>f(d,x),B8(x)=>f(d,x),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|v|v.all_dim(d)).collect(),_=>panic!("internal error")}
873 }
874 pub fn any(self)->Value<B>{
876 fn f<B:Backend,const N:usize>(x:Tensor<B,N,Bool>)->Value<B>{x.any().into()}
877 match self.bool(){B1(x)=>f(x),B2(x)=>f(x),B3(x)=>f(x),B4(x)=>f(x),B5(x)=>f(x),B6(x)=>f(x),B7(x)=>f(x),B8(x)=>f(x),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(Value::any).collect(),_=>panic!("internal error")}
878 }
879 pub fn any_dim(self,d:i32)->Value<B>{
881 fn f<B:Backend,const N:usize>(d:i32,x:Tensor<B,N,Bool>)->Value<B>{
882 if d>=N as i32||d<(-(N as i32)){return format!("rank {N} is too low to any along dimension {d}").into()}
883 let d=if d<0{N-((-d) as usize)}else{d as usize};
884 x.any_dim(d).into()
885 }
886 match self.bool(){B1(x)=>f(d,x),B2(x)=>f(d,x),B3(x)=>f(d,x),B4(x)=>f(d,x),B5(x)=>f(d,x),B6(x)=>f(d,x),B7(x)=>f(d,x),B8(x)=>f(d,x),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|v|v.any_dim(d)).collect(),_=>panic!("internal error")}
887 }
888 pub fn bool(self)->Value<B>{
890 match self{B1(x)=>B1(x),B2(x)=>B2(x),B3(x)=>B3(x),B4(x)=>B4(x),B5(x)=>B5(x),B6(x)=>B6(x),B7(x)=>B7(x),B8(x)=>B8(x),F1(x)=>B1(x.bool()),F2(x)=>B2(x.bool()),F3(x)=>B3(x.bool()),F4(x)=>B4(x.bool()),F5(x)=>B5(x.bool()),F6(x)=>B6(x.bool()),F7(x)=>B7(x.bool()),F8(x)=>B8(x.bool()),I1(x)=>B1(x.bool()),I2(x)=>B2(x.bool()),I3(x)=>B3(x.bool()),I4(x)=>B4(x.bool()),I5(x)=>B5(x.bool()),I6(x)=>B6(x.bool()),I7(x)=>B7(x.bool()),I8(x)=>B8(x.bool()),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(Value::bool).collect())}
891 }
892 pub fn count(&self)->usize{
894 match self{
895 B1(x)=>x.dims().iter().product(),
896 B2(x)=>x.dims().iter().product(),
897 B3(x)=>x.dims().iter().product(),
898 B4(x)=>x.dims().iter().product(),
899 B5(x)=>x.dims().iter().product(),
900 B6(x)=>x.dims().iter().product(),
901 B7(x)=>x.dims().iter().product(),
902 B8(x)=>x.dims().iter().product(),
903 F1(x)=>x.dims().iter().product(),
904 F2(x)=>x.dims().iter().product(),
905 F3(x)=>x.dims().iter().product(),
906 F4(x)=>x.dims().iter().product(),
907 F5(x)=>x.dims().iter().product(),
908 F6(x)=>x.dims().iter().product(),
909 F7(x)=>x.dims().iter().product(),
910 F8(x)=>x.dims().iter().product(),
911 I1(x)=>x.dims().iter().product(),
912 I2(x)=>x.dims().iter().product(),
913 I3(x)=>x.dims().iter().product(),
914 I4(x)=>x.dims().iter().product(),
915 I5(x)=>x.dims().iter().product(),
916 I6(x)=>x.dims().iter().product(),
917 I7(x)=>x.dims().iter().product(),
918 I8(x)=>x.dims().iter().product(),
919 Value::Incompatible(_e)=>0,
920 Value::Multi(v)=>v.iter().map(Value::count).sum()
921 }
922 }
923 pub fn empty()->Self{Self::Multi(Vec::new())}
925 pub fn flatten_values(self)->Self{
927 fn f<B:Backend>(mut acc:Vec<Value<B>>,x:Value<B>)->Vec<Value<B>>{
928 if x.is_multi(){acc=x.into_iter().fold(acc,|acc,x|f(acc,x))}else{acc.push(x)}
929 acc
930 }
931 f(Vec::new(),self).into()
932 }
933 pub fn float(self)->Value<B>{
935 match self{B1(x)=>F1(x.float()),B2(x)=>F2(x.float()),B3(x)=>F3(x.float()),B4(x)=>F4(x.float()),B5(x)=>F5(x.float()),B6(x)=>F6(x.float()),B7(x)=>F7(x.float()),B8(x)=>F8(x.float()),F1(x)=>F1(x),F2(x)=>F2(x),F3(x)=>F3(x),F4(x)=>F4(x),F5(x)=>F5(x),F6(x)=>F6(x),F7(x)=>F7(x),F8(x)=>F8(x),I1(x)=>F1(x.float()),I2(x)=>F2(x.float()),I3(x)=>F3(x.float()),I4(x)=>F4(x.float()),I5(x)=>F5(x.float()),I6(x)=>F6(x.float()),I7(x)=>F7(x.float()),I8(x)=>F8(x.float()),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(Value::float).collect())}
936 }
937 pub fn from_values<I:IntoIterator<Item=Self>,S:Into<Shape>>(inner:I,shape:S)->Self{
939 fn f<B:Backend,I:Iterator<Item=Value<B>>,S:Into<Shape>>(inner:&mut I,shape:S)->Value<B>{
940 match shape.into(){
941 Shape::Incompatible(e)=>e.into(),
942 Shape::Multi(_l)=>inner.collect(),
943 Shape::Recursive(v)=>v.into_iter().map(|s|f(&mut *inner,s)).collect(),
944 X1(_s)=>inner.next().unwrap_or_default(),
945 X2(_s)=>inner.next().unwrap_or_default(),
946 X3(_s)=>inner.next().unwrap_or_default(),
947 X4(_s)=>inner.next().unwrap_or_default(),
948 X5(_s)=>inner.next().unwrap_or_default(),
949 X6(_s)=>inner.next().unwrap_or_default(),
950 X7(_s)=>inner.next().unwrap_or_default(),
951 X8(_s)=>inner.next().unwrap_or_default(),
952 }
953 }
954 f(&mut inner.into_iter(),shape)
955 }
956 pub fn gather(self,dim:i32,indices:Value<B>)->Self{
958 fn b<B:Backend,const N:usize>(d:i32,data:Tensor<B,N,Bool>,indices:Tensor<B,N,Int>)->Value<B>{f(d,data.int(),indices).bool()}
959 fn f<B:Backend,K:'static+BasicOps<B>+Numeric<B>+TensorKind<B>,const N:usize>(d:i32,data:Tensor<B,N,K>,indices:Tensor<B,N,Int>)->Value<B>{
960 let d=if d<0{N-((-d) as usize)}else{d as usize};
961 if d>=N{format!("dim {d} must be less than rank {N}").into()}else{data.gather(d,indices).into()}
962 }
963
964 match (self,indices){(B1(x),I1(i))=>b(dim,x,i),(B2(x),I2(i))=>b(dim,x,i),(B3(x),I3(i))=>b(dim,x,i),(B4(x),I4(i))=>b(dim,x,i),(B5(x),I5(i))=>b(dim,x,i),(B6(x),I6(i))=>b(dim,x,i),(B7(x),I7(i))=>b(dim,x,i),(B8(x),I8(i))=>b(dim,x,i),(F1(x),I1(i))=>f(dim,x,i),(F2(x),I2(i))=>f(dim,x,i),(F3(x),I3(i))=>f(dim,x,i),(F4(x),I4(i))=>f(dim,x,i),(F5(x),I5(i))=>f(dim,x,i),(F6(x),I6(i))=>f(dim,x,i),(F7(x),I7(i))=>f(dim,x,i),(F8(x),I8(i))=>f(dim,x,i),(I1(x),I1(i))=>f(dim,x,i),(I2(x),I2(i))=>f(dim,x,i),(I3(x),I3(i))=>f(dim,x,i),(I4(x),I4(i))=>f(dim,x,i),(I5(x),I5(i))=>f(dim,x,i),(I6(x),I6(i))=>f(dim,x,i),(I7(x),I7(i))=>f(dim,x,i),(I8(x),I8(i))=>f(dim,x,i),(Value::Incompatible(e),_)=>e.into(),(_,Value::Incompatible(e))=>e.into(),(Value::Multi(u),Value::Multi(v))=>u.into_iter().zip(v).map(|(u,v)|u.gather(dim,v)).collect(),_=>"gather is only available for tensors of matching dimensions with int indices".into()}
965 }
966 pub fn int(self)->Value<B>{
968 match self{B1(x)=>I1(x.int()),B2(x)=>I2(x.int()),B3(x)=>I3(x.int()),B4(x)=>I4(x.int()),B5(x)=>I5(x.int()),B6(x)=>I6(x.int()),B7(x)=>I7(x.int()),B8(x)=>I8(x.int()),F1(x)=>I1(x.int()),F2(x)=>I2(x.int()),F3(x)=>I3(x.int()),F4(x)=>I4(x.int()),F5(x)=>I5(x.int()),F6(x)=>I6(x.int()),F7(x)=>I7(x.int()),F8(x)=>I8(x.int()),I1(x)=>I1(x),I2(x)=>I2(x),I3(x)=>I3(x),I4(x)=>I4(x),I5(x)=>I5(x),I6(x)=>I6(x),I7(x)=>I7(x),I8(x)=>I8(x),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(Value::int).collect())}
969 }
970 pub fn into_float_vec(self)->Vec<f32>{
972 fn cat_vec<T>(mut a:Vec<T>,b:Vec<T>)->Vec<T>{
973 a.extend(b);
974 a
975 }
976 fn to_vec<B:Backend,const N:usize>(x:Tensor<B,N>)->Vec<f32>{x.into_data().to_vec().unwrap_or_default()}
977
978 match self.float(){F1(x)=>to_vec(x),F2(x)=>to_vec(x),F3(x)=>to_vec(x),F4(x)=>to_vec(x),F5(x)=>to_vec(x),F6(x)=>to_vec(x),F7(x)=>to_vec(x),F8(x)=>to_vec(x),Value::Incompatible(_e)=>Vec::new(),Value::Multi(v)=>v.into_iter().map(Value::into_float_vec).reduce(cat_vec).unwrap_or_default(),_=>panic!("internal error")}
979 }
980 pub fn is_empty(&self)->bool{self.len()==0}
982 pub fn is_incompatible(&self)->bool{
984 if let Value::Incompatible(_x)=self{true}else{false}
985 }
986 pub fn into_multi(self)->Vec<Value<B>>{
988 if let Value::Multi(v)=self{v}else{vec![self]}
989 }
990 pub fn is_multi(&self)->bool{
992 if let Value::Multi(_x)=self{true}else{false}
993 }
994 pub fn iter(&self)->SliceIter<'_,Self>{
996 if let Value::Multi(v)=self{v.iter()}else{slice::from_ref(self).iter()}
997 }
998 pub fn kind(&self)->Kind{
1000 match self{B1(_x)=>Kind::Bool,B2(_x)=>Kind::Bool,B3(_x)=>Kind::Bool,B4(_x)=>Kind::Bool,B5(_x)=>Kind::Bool,B6(_x)=>Kind::Bool,B7(_x)=>Kind::Bool,B8(_x)=>Kind::Bool,F1(_x)=>Kind::Float,F2(_x)=>Kind::Float,F3(_x)=>Kind::Float,F4(_x)=>Kind::Float,F5(_x)=>Kind::Float,F6(_x)=>Kind::Float,F7(_x)=>Kind::Float,F8(_x)=>Kind::Float,I1(_x)=>Kind::Int,I2(_x)=>Kind::Int,I3(_x)=>Kind::Int,I4(_x)=>Kind::Int,I5(_x)=>Kind::Int,I6(_x)=>Kind::Int,I7(_x)=>Kind::Int,I8(_x)=>Kind::Int,Value::Incompatible(_v)=>Kind::Incompatible,Value::Multi(_v)=>Kind::Multi}
1001 }
1002 pub fn len(&self)->usize{
1004 if let Value::Multi(v)=self{v.len()}else{1}
1005 }
1006 pub fn len_recursive(&self)->usize{
1008 if let Value::Multi(v)=self{v.iter().map(Value::len_recursive).sum()}else{1}
1009 }
1010 pub fn mask_fill(self,mask:Value<B>,v:f32)->Self{
1012 let (x,mask)=self.promote_rank(mask.bool());
1013 match (x,mask){
1014 (B1(x),B1(m))=>B1(x.int().mask_fill(m,v).bool()),
1015 (B2(x),B2(m))=>B2(x.int().mask_fill(m,v).bool()),
1016 (B3(x),B3(m))=>B3(x.int().mask_fill(m,v).bool()),
1017 (B4(x),B4(m))=>B4(x.int().mask_fill(m,v).bool()),
1018 (B5(x),B5(m))=>B5(x.int().mask_fill(m,v).bool()),
1019 (B6(x),B6(m))=>B6(x.int().mask_fill(m,v).bool()),
1020 (B7(x),B7(m))=>B7(x.int().mask_fill(m,v).bool()),
1021 (B8(x),B8(m))=>B8(x.int().mask_fill(m,v).bool()),
1022 (F1(x),B1(m))=>F1(x.mask_fill(m,v)),
1023 (F2(x),B2(m))=>F2(x.mask_fill(m,v)),
1024 (F3(x),B3(m))=>F3(x.mask_fill(m,v)),
1025 (F4(x),B4(m))=>F4(x.mask_fill(m,v)),
1026 (F5(x),B5(m))=>F5(x.mask_fill(m,v)),
1027 (F6(x),B6(m))=>F6(x.mask_fill(m,v)),
1028 (F7(x),B7(m))=>F7(x.mask_fill(m,v)),
1029 (F8(x),B8(m))=>F8(x.mask_fill(m,v)),
1030 (I1(x),B1(m))=>I1(x.mask_fill(m,v)),
1031 (I2(x),B2(m))=>I2(x.mask_fill(m,v)),
1032 (I3(x),B3(m))=>I3(x.mask_fill(m,v)),
1033 (I4(x),B4(m))=>I4(x.mask_fill(m,v)),
1034 (I5(x),B5(m))=>I5(x.mask_fill(m,v)),
1035 (I6(x),B6(m))=>I6(x.mask_fill(m,v)),
1036 (I7(x),B7(m))=>I7(x.mask_fill(m,v)),
1037 (I8(x),B8(m))=>I8(x.mask_fill(m,v)),
1038 (Value::Incompatible(e),_)=>e.into(),
1039 (_,Value::Incompatible(e))=>e.into(),
1040 (Value::Multi(x),m)=>broadcast_multi(x,m.into_multi(),|x,m|x.mask_fill(m,v)),
1041 (x,Value::Multi(m))=>broadcast_multi(x.into_multi(),m,|x,m|x.mask_fill(m,v)),
1042 _=>panic!("internal error")
1043 }
1044 }
1045 pub fn multi(self)->Self{
1047 if let Value::Multi(v)=self{v.into()}else{vec![self].into()}
1048 }
1049 pub fn new<S:Into<Shape>>(data:&[f32],device:&B::Device,shape:S)->Self{
1051 match shape.into(){
1052 Shape::Incompatible(e)=>e.into(),
1053 Shape::Multi(l)=>data.chunks(l).map(|d|Value::new(d,device,X1([d.len()]))).collect(),
1054 Shape::Recursive(v)=>v.into_iter().scan(data,|data,s|{
1055 let v=Value::new(*data,device,s);
1056 *data=&data[..v.count()];
1057 Some(v)
1058 }).collect(),
1059 X1(s)=>F1(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1060 X2(s)=>F2(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1061 X3(s)=>F3(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1062 X4(s)=>F4(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1063 X5(s)=>F5(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1064 X6(s)=>F6(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1065 X7(s)=>F7(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1066 X8(s)=>F8(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1067 }
1068 }
1069 pub fn promote(self,rhs:Value<B>)->(Value<B>,Value<B>){
1071 let (l,r)=self.promote_kind(rhs);
1072 l.promote_rank(r)
1073 }
1074 pub fn promote_kind(self,rhs:Value<B>)->(Value<B>,Value<B>){
1076 let (lk,rk)=(self.kind(),rhs.kind());
1077
1078 let (mut l,mut r)=(self,rhs);
1079 if lk==rk{()}else if lk==Kind::Multi{r=r.multi()}else if rk==Kind::Multi{l=l.multi()}else if lk==Kind::Float{r=r.float()}else if rk==Kind::Float{l=l.float()}else if lk==Kind::Int{r=r.int()}else if rk==Kind::Int{l=l.int()}else if lk==Kind::Incompatible{return (l,r)}else if rk==Kind::Incompatible{return (l,r)}
1080 (l,r)
1081 }
1082 pub fn promote_rank(self,rhs:Value<B>)->(Value<B>,Value<B>){
1084 let (mut l,mut r)=(self,rhs);
1085 let (mut lr,mut rr)=if let (Some(l),Some(r))=(l.rank(),r.rank()){(l,r)}else{return (l,r)};
1086 while lr<rr{
1087 l=l.unsqueeze();
1088 lr+=1;
1089 }
1090 while lr>rr{
1091 r=r.unsqueeze();
1092 rr+=1;
1093 }
1094 (l,r)
1095 }
1096 pub fn rank(&self)->Option<usize>{
1098 match self{B1(_x)=>Some(1),B2(_x)=>Some(2),B3(_x)=>Some(3),B4(_x)=>Some(4),B5(_x)=>Some(5),B6(_x)=>Some(6),B7(_x)=>Some(7),B8(_x)=>Some(8),F1(_x)=>Some(1),F2(_x)=>Some(2),F3(_x)=>Some(3),F4(_x)=>Some(4),F5(_x)=>Some(5),F6(_x)=>Some(6),F7(_x)=>Some(7),F8(_x)=>Some(8),I1(_x)=>Some(1),I2(_x)=>Some(2),I3(_x)=>Some(3),I4(_x)=>Some(4),I5(_x)=>Some(5),I6(_x)=>Some(6),I7(_x)=>Some(7),I8(_x)=>Some(8),Value::Incompatible(_x)=>None,Value::Multi(_x)=>None}
1099 }
1100 pub fn reshape<const N:usize>(self,dims:[usize;N])->Self{
1102 fn f<B:Backend,K:'static+BasicOps<B>+TensorKind<B>,const D:usize,const N:usize>(dims:[usize;D],x:Tensor<B,N,K>)->Value<B>{
1103 if dims.into_iter().product::<usize>()==x.dims().into_iter().product::<usize>(){x.reshape(dims).into()}else{"incompatible reshape".into()}
1104 }
1105 match self{
1106 B1(x)=>f(dims,x),
1107 B2(x)=>f(dims,x),
1108 B3(x)=>f(dims,x),
1109 B4(x)=>f(dims,x),
1110 B5(x)=>f(dims,x),
1111 B6(x)=>f(dims,x),
1112 B7(x)=>f(dims,x),
1113 B8(x)=>f(dims,x),
1114 F1(x)=>f(dims,x),
1115 F2(x)=>f(dims,x),
1116 F3(x)=>f(dims,x),
1117 F4(x)=>f(dims,x),
1118 F5(x)=>f(dims,x),
1119 F6(x)=>f(dims,x),
1120 F7(x)=>f(dims,x),
1121 F8(x)=>f(dims,x),
1122 I1(x)=>f(dims,x),
1123 I2(x)=>f(dims,x),
1124 I3(x)=>f(dims,x),
1125 I4(x)=>f(dims,x),
1126 I5(x)=>f(dims,x),
1127 I6(x)=>f(dims,x),
1128 I7(x)=>f(dims,x),
1129 I8(x)=>f(dims,x),
1130 Value::Incompatible(e)=>e.into(),
1131 Value::Multi(v)=>v.into_iter().map(|x|x.reshape(dims)).collect()
1132 }
1133 }
1134 pub fn shape(&self)->Shape{
1136 match self{B1(x)=>Shape::X1(x.dims()),B2(x)=>Shape::X2(x.dims()),B3(x)=>Shape::X3(x.dims()),B4(x)=>Shape::X4(x.dims()),B5(x)=>Shape::X5(x.dims()),B6(x)=>Shape::X6(x.dims()),B7(x)=>Shape::X7(x.dims()),B8(x)=>Shape::X8(x.dims()),F1(x)=>Shape::X1(x.dims()),F2(x)=>Shape::X2(x.dims()),F3(x)=>Shape::X3(x.dims()),F4(x)=>Shape::X4(x.dims()),F5(x)=>Shape::X5(x.dims()),F6(x)=>Shape::X6(x.dims()),F7(x)=>Shape::X7(x.dims()),F8(x)=>Shape::X8(x.dims()),I1(x)=>Shape::X1(x.dims()),I2(x)=>Shape::X2(x.dims()),I3(x)=>Shape::X3(x.dims()),I4(x)=>Shape::X4(x.dims()),I5(x)=>Shape::X5(x.dims()),I6(x)=>Shape::X6(x.dims()),I7(x)=>Shape::X7(x.dims()),I8(x)=>Shape::X8(x.dims()),Value::Incompatible(x)=>Shape::Incompatible(x.clone()),Value::Multi(x)=>Shape::Multi(x.len())}
1137 }
1138 pub fn shape_recursive(&self)->Shape{
1140 if let Value::Multi(x)=self{Shape::Recursive(x.iter().map(Value::shape_recursive).collect())}else{self.shape()}
1141 }
1142 pub fn shift(self,d:i32,n:i32,v:f32)->Self{
1144 fn b<B:Backend,const N:usize>(d:i32,n:i32,v:f32,x:Tensor<B,N,Bool>)->Value<B>{f(d,n,if v==0.0{0.0}else{1.0},x.int()).bool()}
1145 fn f<B:Backend,K:'static+BasicOps<B>+Numeric<B>+TensorKind<B>,const N:usize>(d:i32,n:i32,v:f32,x:Tensor<B,N,K>)->Value<B>{
1146 let device=x.device();
1147 let d=if d<0{N-((-d) as usize)}else{d as usize};
1148 let mut paddims=x.dims();
1149 let mut slicedims=paddims.map(|n|0..n);
1150
1151 paddims[d]=n.abs() as usize;
1152 slicedims[d]=if n<0{(-n) as usize..slicedims[d].end}else{0..slicedims[d].end.saturating_sub(n as usize)};
1153 if slicedims[d].len()==0{return x.full_like(v).into()}
1154 let pad:Tensor<B,N,K>=Tensor::full(paddims,v,&device);
1155 let slice=x.slice(slicedims);
1156
1157 Tensor::cat(if n<0{vec![slice,pad]}else{vec![pad,slice]},d).into()
1158 }
1159 if n==0{return self}
1160
1161 match self{B1(x)=>b(d,n,v,x),B2(x)=>b(d,n,v,x),B3(x)=>b(d,n,v,x),B4(x)=>b(d,n,v,x),B5(x)=>b(d,n,v,x),B6(x)=>b(d,n,v,x),B7(x)=>b(d,n,v,x),B8(x)=>b(d,n,v,x),F1(x)=>f(d,n,v,x),F2(x)=>f(d,n,v,x),F3(x)=>f(d,n,v,x),F4(x)=>f(d,n,v,x),F5(x)=>f(d,n,v,x),F6(x)=>f(d,n,v,x),F7(x)=>f(d,n,v,x),F8(x)=>f(d,n,v,x),I1(x)=>f(d,n,v,x),I2(x)=>f(d,n,v,x),I3(x)=>f(d,n,v,x),I4(x)=>f(d,n,v,x),I5(x)=>f(d,n,v,x),I6(x)=>f(d,n,v,x),I7(x)=>f(d,n,v,x),I8(x)=>f(d,n,v,x),Value::Incompatible(e)=>e.into(),Value::Multi(x)=>x.into_iter().map(|x|x.shift(d,n,v)).collect()}
1162 }
1163 pub fn slice<A:AsRef<[R]>,R:RangeBounds<usize>>(self,ranges:A)->Self{
1165 let ranges=ranges.as_ref();
1166 let len=ranges.len();
1167 if let Value::Incompatible(x)=self{return x.into()}
1168 let rank=self.rank().unwrap_or(len);
1169 let shape=self.shape();
1170
1171 let mut normalizedranges=[0;8].map(|_|0..0);
1172 for ((d,n),r) in shape.clone().to_array(Alignment::Left).into_iter().zip(normalizedranges.iter_mut()).zip(ranges){
1173 n.start=match r.start_bound(){Excluded(&x)=>x+1,Included(&x)=>x,Unbounded=>0};
1174 n.end=match r.end_bound(){Excluded(&x)=>x,Included(&x)=>x+1,Unbounded=>d};
1175 }
1176 if len>rank{return format!("Length of ranges argument must be less than the the value's rank. len: {len} ranges: {normalizedranges:?} rank: {rank} shape: {shape:?}").into()}
1177 for (d,n) in shape.clone().to_array(Alignment::Left).into_iter().zip(normalizedranges.iter()).take(len){
1178 if n.start>=n.end{return format!("Empty or reverse ranges are currently not supported. ranges: {normalizedranges:?}").into()}
1179 if d<n.end{return format!("Cannot index beyond the end of a dimension. ranges: {normalizedranges:?} shape: {shape:?}").into()}
1180 }
1181 let ranges=&normalizedranges[..len];
1182
1183 match self{B1(x)=>B1(slice_slice(ranges,x)),B2(x)=>B2(slice_slice(ranges,x)),B3(x)=>B3(slice_slice(ranges,x)),B4(x)=>B4(slice_slice(ranges,x)),B5(x)=>B5(slice_slice(ranges,x)),B6(x)=>B6(slice_slice(ranges,x)),B7(x)=>B7(slice_slice(ranges,x)),B8(x)=>B8(slice_slice(ranges,x)),F1(x)=>F1(slice_slice(ranges,x)),F2(x)=>F2(slice_slice(ranges,x)),F3(x)=>F3(slice_slice(ranges,x)),F4(x)=>F4(slice_slice(ranges,x)),F5(x)=>F5(slice_slice(ranges,x)),F6(x)=>F6(slice_slice(ranges,x)),F7(x)=>F7(slice_slice(ranges,x)),F8(x)=>F8(slice_slice(ranges,x)),I1(x)=>I1(slice_slice(ranges,x)),I2(x)=>I2(slice_slice(ranges,x)),I3(x)=>I3(slice_slice(ranges,x)),I4(x)=>I4(slice_slice(ranges,x)),I5(x)=>I5(slice_slice(ranges,x)),I6(x)=>I6(slice_slice(ranges,x)),I7(x)=>I7(slice_slice(ranges,x)),I8(x)=>I8(slice_slice(ranges,x)),Value::Incompatible(x)=>x.into(),Value::Multi(x)=>Value::Multi(x.into_iter().map(|x|x.slice(ranges)).collect())}
1184 }
1185 pub fn split<I:Into<Option<i32>>>(self,chunksize:usize,dim:I)->Self{
1187 fn f<B:Backend,K:'static+BasicOps<B>+TensorKind<B>,const N:usize>(dim:i32,size:usize,tensor:Tensor<B,N,K>)->Value<B>{
1188 if dim>=N as i32||dim<(-(N as i32)){return format!("rank {N} is too low to split along dimension {dim}").into()}
1189 let dim=if dim<0{N-((-dim) as usize)}else{dim as usize};
1190
1191 tensor.split(dim,size).into_iter().map(Value::from).collect()
1192 }
1193 let c=if chunksize==0{return "cannot split into chunks of 0 size".into()}else{chunksize};
1194
1195 if let Some(d)=dim.into(){
1196 match self{B1(x)=>f(d,c,x),B2(x)=>f(d,c,x),B3(x)=>f(d,c,x),B4(x)=>f(d,c,x),B5(x)=>f(d,c,x),B6(x)=>f(d,c,x),B7(x)=>f(d,c,x),B8(x)=>f(d,c,x),F1(x)=>f(d,c,x),F2(x)=>f(d,c,x),F3(x)=>f(d,c,x),F4(x)=>f(d,c,x),F5(x)=>f(d,c,x),F6(x)=>f(d,c,x),F7(x)=>f(d,c,x),F8(x)=>f(d,c,x),I1(x)=>f(d,c,x),I2(x)=>f(d,c,x),I3(x)=>f(d,c,x),I4(x)=>f(d,c,x),I5(x)=>f(d,c,x),I6(x)=>f(d,c,x),I7(x)=>f(d,c,x),I8(x)=>f(d,c,x),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.split(c,Some(d))).collect()}
1197 }else{
1198 let v=self.into_multi();
1199 v.chunks(chunksize).map(|c|Value::from(c.to_vec())).collect()
1200 }
1201 }
1202 pub fn squeeze(self)->Self{self.squeeze_dim(0)}
1204 pub fn squeeze_dim(self,d:i32)->Self{
1206 fn f<B:Backend,K:BasicOps<B>+TensorKind<B>,const D:usize,const N:usize>(x:Tensor<B,D,K>,d:i32)->Result<Tensor<B,N,K>,String>{
1207 let d=if d<0{D-((-d) as usize)}else{d as usize};
1208
1209 if d>=D{return Err(format!("dim {d} must be less than {D}"))}
1210 let xdim=x.dims()[d];
1211
1212 if xdim==1{Ok(x.squeeze(d))}else{Err(format!("cannot squeeze a dim of size not equal to 1. dim {d} was {xdim}"))}
1213 }
1214 match match self{B1(_x)=>Err("currently cannot decrease the number of tensor dimensions below 1".into()),B2(x)=>f(x,d).map(B1),B3(x)=>f(x,d).map(B2),B4(x)=>f(x,d).map(B3),B5(x)=>f(x,d).map(B4),B6(x)=>f(x,d).map(B5),B7(x)=>f(x,d).map(B6),B8(x)=>f(x,d).map(B7),F1(_x)=>Err("currently cannot decrease the number of tensor dimensions below 1".into()),F2(x)=>f(x,d).map(F1),F3(x)=>f(x,d).map(F2),F4(x)=>f(x,d).map(F3),F5(x)=>f(x,d).map(F4),F6(x)=>f(x,d).map(F5),F7(x)=>f(x,d).map(F6),F8(x)=>f(x,d).map(F7),I1(_x)=>Err("currently cannot decrease the number of tensor dimensions below 1".into()),I2(x)=>f(x,d).map(I1),I3(x)=>f(x,d).map(I2),I4(x)=>f(x,d).map(I3),I5(x)=>f(x,d).map(I4),I6(x)=>f(x,d).map(I5),I7(x)=>f(x,d).map(I6),I8(x)=>f(x,d).map(I7),Value::Incompatible(e)=>Err(e),Value::Multi(v)=>Ok(v.into_iter().map(|x|x.squeeze_dim(d)).collect())}{Err(e)=>e.into(),Ok(x)=>x}
1215 }
1216 pub fn try_incompatible(self)->Result<String,Self>{
1218 if let Value::Incompatible(x)=self{Ok(x)}else{Err(self)}
1219 }
1220 pub fn try_multi(self)->Result<Vec<Value<B>>,Self>{
1222 if let Value::Multi(v)=self{Ok(v)}else{Err(self)}
1223 }
1224 pub fn unsqueeze(self)->Self{self.unsqueeze_dim(0)}
1226 pub fn unsqueeze_dim(self,d:i32)->Self{
1228 fn f<B:Backend,K:BasicOps<B>+TensorKind<B>,const D:usize,const N:usize>(x:Tensor<B,D,K>,d:i32)->Tensor<B,N,K>{
1229 x.unsqueeze_dim(if d<0{D-((-d) as usize)+1}else{d as usize})
1230 }
1231 if let Some(r)=self.rank(){
1232 let e=if d<0{r-((-d) as usize)+1}else{d as usize};
1233 if e>r{return format!("dim {e} must be less than or equal to rank {r}").into()}
1234 }
1235 match self{B1(x)=>B2(f(x,d)),B2(x)=>B3(f(x,d)),B3(x)=>B4(f(x,d)),B4(x)=>B5(f(x,d)),B5(x)=>B6(f(x,d)),B6(x)=>B7(f(x,d)),B7(x)=>B8(f(x,d)),B8(_x)=>"currently can't increase number of tensor dimensions above 8".into(),F1(x)=>F2(f(x,d)),F2(x)=>F3(f(x,d)),F3(x)=>F4(f(x,d)),F4(x)=>F5(f(x,d)),F5(x)=>F6(f(x,d)),F6(x)=>F7(f(x,d)),F7(x)=>F8(f(x,d)),F8(_x)=>"currently can't increase number of tensor dimensions above 8".into(),I1(x)=>I2(f(x,d)),I2(x)=>I3(f(x,d)),I3(x)=>I4(f(x,d)),I4(x)=>I5(f(x,d)),I5(x)=>I6(f(x,d)),I6(x)=>I7(f(x,d)),I7(x)=>I8(f(x,d)),I8(_x)=>"currently can't increase number of tensor dimensions above 8".into(),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.unsqueeze_dim(d)).collect()}
1236 }
1237 #[track_caller]
1238 pub fn unwrap_incompatible(self)->String{self.try_incompatible().unwrap()}
1240 #[track_caller]
1241 pub fn unwrap_multi(self)->Vec<Value<B>>{self.try_multi().unwrap()}
1243 pub fn zeros_like(&self)->Value<B>{match self{B1(x)=>B1(x.clone().int().zeros_like().bool()),B2(x)=>B2(x.clone().int().zeros_like().bool()),B3(x)=>B3(x.clone().int().zeros_like().bool()),B4(x)=>B4(x.clone().int().zeros_like().bool()),B5(x)=>B5(x.clone().int().zeros_like().bool()),B6(x)=>B6(x.clone().int().zeros_like().bool()),B7(x)=>B7(x.clone().int().zeros_like().bool()),B8(x)=>B8(x.clone().int().zeros_like().bool()),F1(x)=>F1(x.zeros_like()),F2(x)=>F2(x.zeros_like()),F3(x)=>F3(x.zeros_like()),F4(x)=>F4(x.zeros_like()),F5(x)=>F5(x.zeros_like()),F6(x)=>F6(x.zeros_like()),F7(x)=>F7(x.zeros_like()),F8(x)=>F8(x.zeros_like()),I1(x)=>I1(x.zeros_like()),I2(x)=>I2(x.zeros_like()),I3(x)=>I3(x.zeros_like()),I4(x)=>I4(x.zeros_like()),I5(x)=>I5(x.zeros_like()),I6(x)=>I6(x.zeros_like()),I7(x)=>I7(x.zeros_like()),I8(x)=>I8(x.zeros_like()),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.iter().map(Value::zeros_like).collect()}
1246 }
1247 pub fn zip(self)->Self{if self.len()<=1||self.iter().all(|v|v.len()<=1){return self}
1250
1251 let mut iters:Vec<_>=self.into_iter().map(Value::into_iter).collect();
1252
1253 let cols=iters.len();
1254 let rows=iters.iter().map(ExactSizeIterator::len).max().expect("should not be empty at this point");
1255 let transposed:Vec<Value<B>>=(0..rows).map(|_|{
1256 let v:Value<B>=(0..cols).map(|c|iters[c].next().unwrap_or_default()).collect();
1257 v.zip()
1258 }).collect();
1259
1260 Value::Multi(transposed)
1261 }
1262 try_unwrap!(Tensor<B,1,Bool>,try_b1,unwrap_b1);
1263 try_unwrap!(Tensor<B,2,Bool>,try_b2,unwrap_b2);
1264 try_unwrap!(Tensor<B,3,Bool>,try_b3,unwrap_b3);
1265 try_unwrap!(Tensor<B,4,Bool>,try_b4,unwrap_b4);
1266 try_unwrap!(Tensor<B,5,Bool>,try_b5,unwrap_b5);
1267 try_unwrap!(Tensor<B,6,Bool>,try_b6,unwrap_b6);
1268 try_unwrap!(Tensor<B,7,Bool>,try_b7,unwrap_b7);
1269 try_unwrap!(Tensor<B,8,Bool>,try_b8,unwrap_b8);
1270 try_unwrap!(Tensor<B,1,Float>,try_f1,unwrap_f1);
1271 try_unwrap!(Tensor<B,2,Float>,try_f2,unwrap_f2);
1272 try_unwrap!(Tensor<B,3,Float>,try_f3,unwrap_f3);
1273 try_unwrap!(Tensor<B,4,Float>,try_f4,unwrap_f4);
1274 try_unwrap!(Tensor<B,5,Float>,try_f5,unwrap_f5);
1275 try_unwrap!(Tensor<B,6,Float>,try_f6,unwrap_f6);
1276 try_unwrap!(Tensor<B,7,Float>,try_f7,unwrap_f7);
1277 try_unwrap!(Tensor<B,8,Float>,try_f8,unwrap_f8);
1278 try_unwrap!(Tensor<B,1,Int>,try_i1,unwrap_i1);
1279 try_unwrap!(Tensor<B,2,Int>,try_i2,unwrap_i2);
1280 try_unwrap!(Tensor<B,3,Int>,try_i3,unwrap_i3);
1281 try_unwrap!(Tensor<B,4,Int>,try_i4,unwrap_i4);
1282 try_unwrap!(Tensor<B,5,Int>,try_i5,unwrap_i5);
1283 try_unwrap!(Tensor<B,6,Int>,try_i6,unwrap_i6);
1284 try_unwrap!(Tensor<B,7,Int>,try_i7,unwrap_i7);
1285 try_unwrap!(Tensor<B,8,Int>,try_i8,unwrap_i8);
1286}
1287macro_rules! bicop_num{
1288 ($trait:ident,$traitfn:ident,$traitscalar:ident)=>(
1289 impl<B:Backend,E:Copy+ElementConversion> $trait<E> for &Value<B>{
1290 fn $traitfn(self,rhs:E)->Value<B>{self.clone().$traitfn(rhs)}
1291 type Output=Value<B>;
1292 }
1293 impl<B:Backend,E:Copy+ElementConversion> $trait<E> for Value<B>{
1294 fn $traitfn(self,rhs:E)->Value<B>{
1295 match self{B1(x)=>I1(x.int().$traitscalar(rhs)),B2(x)=>I2(x.int().$traitscalar(rhs)),B3(x)=>I3(x.int().$traitscalar(rhs)),B4(x)=>I4(x.int().$traitscalar(rhs)),B5(x)=>I5(x.int().$traitscalar(rhs)),B6(x)=>I6(x.int().$traitscalar(rhs)),B7(x)=>I7(x.int().$traitscalar(rhs)),B8(x)=>I8(x.int().$traitscalar(rhs)),F1(x)=>F1(x.$traitscalar(rhs)),F2(x)=>F2(x.$traitscalar(rhs)),F3(x)=>F3(x.$traitscalar(rhs)),F4(x)=>F4(x.$traitscalar(rhs)),F5(x)=>F5(x.$traitscalar(rhs)),F6(x)=>F6(x.$traitscalar(rhs)),F7(x)=>F7(x.$traitscalar(rhs)),F8(x)=>F8(x.$traitscalar(rhs)),I1(x)=>I1(x.$traitscalar(rhs)),I2(x)=>I2(x.$traitscalar(rhs)),I3(x)=>I3(x.$traitscalar(rhs)),I4(x)=>I4(x.$traitscalar(rhs)),I5(x)=>I5(x.$traitscalar(rhs)),I6(x)=>I6(x.$traitscalar(rhs)),I7(x)=>I7(x.$traitscalar(rhs)),I8(x)=>I8(x.$traitscalar(rhs)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.$traitfn(rhs)).collect()}
1296 }
1297 type Output=Value<B>;
1298 }
1299 impl<B:Backend> $trait<&Value<B>> for &Value<B>{
1300 fn $traitfn(self,rhs:&Value<B>)->Value<B>{self.clone().$traitfn(rhs.clone())}
1301 type Output=Value<B>;
1302 }
1303 impl<B:Backend> $trait<&Value<B>> for Value<B>{
1304 fn $traitfn(self,rhs:&Value<B>)->Value<B>{self.$traitfn(rhs.clone())}
1305 type Output=Value<B>;
1306 }
1307 impl<B:Backend> $trait<Value<B>> for &Value<B>{
1308 fn $traitfn(self,rhs:Value<B>)->Value<B>{self.clone().$traitfn(rhs)}
1309 type Output=Value<B>;
1310 }
1311 impl<B:Backend> $trait<Value<B>> for Value<B>{
1312 fn $traitfn(self,rhs:Value<B>)->Value<B>{match self.promote(rhs){(B1(l),B1(r))=>I1(l.int().$traitfn(r.int())),(B2(l),B2(r))=>I2(l.int().$traitfn(r.int())),(B3(l),B3(r))=>I3(l.int().$traitfn(r.int())),(B4(l),B4(r))=>I4(l.int().$traitfn(r.int())),(B5(l),B5(r))=>I5(l.int().$traitfn(r.int())),(B6(l),B6(r))=>I6(l.int().$traitfn(r.int())),(B7(l),B7(r))=>I7(l.int().$traitfn(r.int())),(B8(l),B8(r))=>I8(l.int().$traitfn(r.int())),(F1(l),F1(r))=>F1(l.$traitfn(r)),(F2(l),F2(r))=>F2(l.$traitfn(r)),(F3(l),F3(r))=>F3(l.$traitfn(r)),(F4(l),F4(r))=>F4(l.$traitfn(r)),(F5(l),F5(r))=>F5(l.$traitfn(r)),(F6(l),F6(r))=>F6(l.$traitfn(r)),(F7(l),F7(r))=>F7(l.$traitfn(r)),(F8(l),F8(r))=>F8(l.$traitfn(r)),(I1(l),I1(r))=>I1(l.$traitfn(r)),(I2(l),I2(r))=>I2(l.$traitfn(r)),(I3(l),I3(r))=>I3(l.$traitfn(r)),(I4(l),I4(r))=>I4(l.$traitfn(r)),(I5(l),I5(r))=>I5(l.$traitfn(r)),(I6(l),I6(r))=>I6(l.$traitfn(r)),(I7(l),I7(r))=>I7(l.$traitfn(r)),(I8(l),I8(r))=>I8(l.$traitfn(r)),(Value::Incompatible(e),_)=>e.into(),(_,Value::Incompatible(e))=>e.into(),(Value::Multi(l),r)=>broadcast_multi(l,r.into_multi(),$trait::$traitfn),(l,Value::Multi(r))=>broadcast_multi(l.into_multi(),r,$trait::$traitfn),_=>panic!("couldn't promote types for $traitfn")}
1314 }
1315 type Output=Value<B>;
1316 }
1317 );
1318}
1319macro_rules! try_unwrap{
1320 ($tensor:ty,$try_unwrap:ident,$unwrap:ident)=>{
1321 pub fn $try_unwrap(self)->Result<$tensor,Self>{self.try_into()}
1323 #[track_caller]
1324 pub fn $unwrap(self)->$tensor{self.try_into().unwrap()}
1326 }
1327}
1328
1329#[derive(Clone,Copy,Debug,Eq,PartialEq,Deserialize,Serialize)]
1330pub enum Kind{Bool,Float,Incompatible,Int,Multi}
1332#[derive(Clone,Debug,Deserialize,Serialize)]pub enum Shape{Incompatible(String),Multi(usize),Recursive(Vec<Shape>),X1([usize;1]),X2([usize;2]),X3([usize;3]),X4([usize;4]),X5([usize;5]),X6([usize;6]),X7([usize;7]),X8([usize;8])}
1335#[derive(Clone,Debug)]
1336pub enum Value<B:Backend>{B1(Tensor<B,1,Bool>),B2(Tensor<B,2,Bool>),B3(Tensor<B,3,Bool>),B4(Tensor<B,4,Bool>),B5(Tensor<B,5,Bool>),B6(Tensor<B,6,Bool>),B7(Tensor<B,7,Bool>),B8(Tensor<B,8,Bool>),F1(Tensor<B,1,Float>),F2(Tensor<B,2,Float>),F3(Tensor<B,3,Float>),F4(Tensor<B,4,Float>),F5(Tensor<B,5,Float>),F6(Tensor<B,6,Float>),F7(Tensor<B,7,Float>),F8(Tensor<B,8,Float>),I1(Tensor<B,1,Int>),I2(Tensor<B,2,Int>),I3(Tensor<B,3,Int>),I4(Tensor<B,4,Int>),I5(Tensor<B,5,Int>),I6(Tensor<B,6,Int>),I7(Tensor<B,7,Int>),I8(Tensor<B,8,Int>),Incompatible(String),Multi(Vec<Self>)}
1338#[derive(Clone,Debug,Deserialize,Serialize)]
1339pub enum ValueData{BX(TensorData),FX(TensorData),IX(TensorData),Incompatible(String),Multi(Vec<ValueData>)}
1341#[derive(Clone,Debug,Deserialize,Serialize)]
1342#[serde(bound="")]
1343pub struct LossOutput<B:Backend>{loss:Value<B>,output:Value<B>,target:Value<B>}
1345use {bicop_num,try_unwrap};
1346use Bound::{Excluded,Included,Unbounded};
1347use Shape::{X1,X2,X3,X4,X5,X6,X7,X8};
1348use Value::{B1,B2,B3,B4,B5,B6,B7,B8,F1,F2,F3,F4,F5,F6,F7,F8,I1,I2,I3,I4,I5,I6,I7,I8};
1349use ValueData::{BX,FX,IX};
1350use burn::{
1351 module::{AutodiffModule,ConstantRecord,Content,DisplaySettings,ModuleDisplay,ModuleDisplayDefault,ModuleMapper,ModuleVisitor,Quantizer},
1352 nn::{
1353 BatchNorm,Dropout,Embedding,LayerNorm,Linear,Relu,RotaryEncoding,Tanh,conv::Conv2d,loss::{CrossEntropyLoss,MseLoss},pool::MaxPool2d
1354 },
1355 prelude::{Backend,Bool,Float,Int,Module,Tensor,TensorData},
1356 record::{FileRecorder,RecorderError},
1357 tensor::{
1358 BasicOps,ElementConversion,Numeric,TensorKind,activation::{log_softmax,softmax},backend::AutodiffBackend,cast::ToElement
1359 }
1360};
1361use crate::{
1362 AI,Decompose,Merge,Op,
1363 builtin::{
1364 Alignment,ReductionMode,math::{MeanLayer,SquaredErrorLayer,SumLayer},reinforcement::AccQLayer,soft::{ChooseLayer,CrossEntropyLayer,SoftmaxLayer}
1365 },
1366 ops::{Abs,Cat,Stack,Squeeze,Unsqueeze}
1367};
1368use rand::random;
1369use serde::{Deserialize,Deserializer,Serialize,Serializer};
1370use std::{
1371 any::TypeId,fmt::{Display,Result as FmtResult},iter::{FromIterator,once},mem,ops::{Add,Bound,Div,Mul,RangeBounds,Range,Rem,Sub},path::PathBuf,slice::{Iter as SliceIter,self},vec::IntoIter as VecIntoIter
1372};