1bicop_num!(Add,add,add_scalar);
2bicop_num!(Div,div,div_scalar);
3bicop_num!(Mul,mul,mul_scalar);
4bicop_num!(Rem,rem,remainder_scalar);
5bicop_num!(Sub,sub,sub_scalar);
6fn broadcast_multi<B:Backend,F:FnMut(Value<B>,Value<B>)->Value<B>>(u:Vec<Value<B>>,v:Vec<Value<B>>,mut f:F)->Value<B>{
7 if u.len()==1{
8 u.into_iter().cycle().zip(v).map(|(x,y)|f(x,y)).collect()
9 }else if v.len()==1{
10 u.into_iter().zip(v.into_iter().cycle()).map(|(x,y)|f(x,y)).collect()
11 }else if u.len()==v.len(){
12 u.into_iter().zip(v).map(|(x,y)|f(x,y)).collect()
13 }else{
14 "mismatched lengths".into()
15 }
16}
17fn hard_choose_burn_1<B:Backend,const N:usize>(dim:i32,distribution:Tensor<B,N>)->u32{
18 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
19 let distribution=if dim==N-1{distribution}else{distribution.movedim(dim,N-1)}.into_data();
20 let sum=distribution.iter().fold(0.0,|acc:f32,weight:f32|acc+weight);
21
22 distribution.iter().scan(random::<f32>()*sum,|choice:&mut f32,weight:f32|Some(*choice-=weight).filter(|_|*choice>=0.0)).count() as u32
23}
24fn hard_choose_burn_multi<B:Backend,const N:usize>(dim:i32,distribution:Tensor<B,N>)->Vec<u32>{
25 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
26
27 let chunk=distribution.dims()[dim];
28 let distribution=if dim==N-1{distribution}else{distribution.movedim(dim,N-1)}.into_data().to_vec().unwrap();
29
30 distribution.chunks_exact(chunk).map(|d|{
31 let sum=d.iter().fold(0.0,|acc:f32,weight:&f32|acc+weight);
32 d.iter().scan(random::<f32>()*sum,|choice:&mut f32,weight:&f32|Some(*choice-=weight).filter(|_|*choice>=0.0)).count() as u32
33 }).collect()
34}
35fn hard_choose_burn_tensor<B:Backend,const N:usize>(dim:i32,distribution:Tensor<B,N>)->Tensor<B,N,Int>{let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
37 let device=distribution.device();
38 let mut dims=distribution.dims();
39
40 dims[N-1]=1;
41 let r:Tensor<B,N,Int>=Tensor::from_data(TensorData::new(hard_choose_burn_multi(dim as i32,distribution),dims),&device);
42
43 r.movedim(N-1,dim)
44}
45fn slice_slice<B:Backend,K:BasicOps<B>+TensorKind<B>,const N:usize>(ranges:&[Range<usize>],tensor:Tensor<B,N,K>)->Tensor<B,N,K>{
46 let mut n=0;
47 let mut acc=||{
48 let a=n;
49 n+=1;
50 a
51 };
52
53 match ranges.len(){0=>tensor,1=>tensor.slice([0;1].map(|_|ranges[acc()].clone())),2=>tensor.slice([0;2].map(|_|ranges[acc()].clone())),3=>tensor.slice([0;3].map(|_|ranges[acc()].clone())),4=>tensor.slice([0;4].map(|_|ranges[acc()].clone())),5=>tensor.slice([0;5].map(|_|ranges[acc()].clone())),6=>tensor.slice([0;6].map(|_|ranges[acc()].clone())),7=>tensor.slice([0;7].map(|_|ranges[acc()].clone())),8=>tensor.slice([0;8].map(|_|ranges[acc()].clone())),_=>panic!("too many ranges for current max 8 dims")}
54}
55fn soft_choose_burn_1<B:Backend,const N:usize>(dim:i32,logits:Tensor<B,N>,temperature:f32)->u32{
56 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
57 let logits=if dim==N-1{logits}else{logits.movedim(dim,N-1)};
58
59 let chunk=logits.dims()[N-1];
60 let distribution=softmax(logits/temperature,N-1).into_data();
61 distribution.iter().scan(random(),|choice:&mut f32,weight:f32|Some(*choice-=weight).filter(|_|*choice>=0.0)).count().min(chunk-1) as u32
62}
63fn soft_choose_burn_multi<B:Backend,const N:usize>(dim:i32,logits:Tensor<B,N>,temperature:f32)->Vec<u32>{
64 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
65 let logits=if dim==N-1{logits}else{logits.movedim(dim,N-1)};
66 let chunk=logits.dims()[N-1];
67 let distribution=softmax(logits/temperature,N-1).into_data().to_vec().unwrap();
68 distribution.chunks_exact(chunk).map(|d|d.iter().scan(random(),|choice:&mut f32,weight:&f32|Some(*choice-=weight).filter(|_|*choice>=0.0)).count().min(chunk-1) as u32).collect()
69}
70fn soft_choose_burn_tensor<B:Backend,const N:usize>(dim:i32,logits:Tensor<B,N>,temperature:f32)->Tensor<B,N,Int>{let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
72 let device=logits.device();
73 let mut dims=logits.dims();
74
75 dims[N-1]=1;
76 let r:Tensor<B,N,Int>=Tensor::from_data(TensorData::new(soft_choose_burn_multi(dim as i32,logits,temperature),dims),&device);
77 r.movedim(N-1,dim)
78}
79impl<'a,B:Backend> Deserialize<'a> for Value<B>{
80 fn deserialize<D:Deserializer<'a>>(deserializer:D)->Result<Self,D::Error>{ValueData::deserialize(deserializer).map(Into::into)}
81}
82impl<A:AutodiffBackend> AutodiffModule<A> for Value<A>{
83 fn valid(&self)->Self::InnerModule{
84 match self{B1(x)=>B1(x.valid()),B2(x)=>B2(x.valid()),B3(x)=>B3(x.valid()),B4(x)=>B4(x.valid()),B5(x)=>B5(x.valid()),B6(x)=>B6(x.valid()),B7(x)=>B7(x.valid()),B8(x)=>B8(x.valid()),F1(x)=>F1(x.valid()),F2(x)=>F2(x.valid()),F3(x)=>F3(x.valid()),F4(x)=>F4(x.valid()),F5(x)=>F5(x.valid()),F6(x)=>F6(x.valid()),F7(x)=>F7(x.valid()),F8(x)=>F8(x.valid()),I1(x)=>I1(x.valid()),I2(x)=>I2(x.valid()),I3(x)=>I3(x.valid()),I4(x)=>I4(x.valid()),I5(x)=>I5(x.valid()),I6(x)=>I6(x.valid()),I7(x)=>I7(x.valid()),I8(x)=>I8(x.valid()),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.iter().map(|x|x.valid()).collect()}
85 }
86 type InnerModule=Value<A::InnerBackend>;
87}
88impl<A:Into<Value<B>>,B:Backend> FromIterator<A> for Value<B>{
89 fn from_iter<I:IntoIterator<Item=A>>(iter:I)->Self{Value::Multi(iter.into_iter().map(Into::into).collect())}
90}
91impl<B:Backend,K:'static+TensorKind<B>,const N:usize> From<Result<Tensor<B,N,K>,String>> for Value<B>{
92 fn from(value:Result<Tensor<B,N,K>,String>)->Self{
93 match value{Err(e)=>e.into(),Ok(t)=>t.into()}
94 }
95}
96impl<B:Backend,K:'static+TensorKind<B>,const N:usize> From<Tensor<B,N,K>> for Value<B>{
97 fn from(value:Tensor<B,N,K>)->Self{
98 let kind=TypeId::of::<K>();
99 let kind=if kind==TypeId::of::<Bool>(){Kind::Bool}else if kind==TypeId::of::<Float>(){Kind::Float}else if kind==TypeId::of::<Int>(){Kind::Int}else{return "only bool, float, and int tensors with dimensions 1-8 are currently supported".into()};
100
101 let v=unsafe{
102 match (N,kind){(1,Kind::Bool)=>B1(mem::transmute_copy(&value)),(2,Kind::Bool)=>B2(mem::transmute_copy(&value)),(3,Kind::Bool)=>B3(mem::transmute_copy(&value)),(4,Kind::Bool)=>B4(mem::transmute_copy(&value)),(5,Kind::Bool)=>B5(mem::transmute_copy(&value)),(6,Kind::Bool)=>B6(mem::transmute_copy(&value)),(7,Kind::Bool)=>B7(mem::transmute_copy(&value)),(8,Kind::Bool)=>B8(mem::transmute_copy(&value)),(1,Kind::Float)=>F1(mem::transmute_copy(&value)),(2,Kind::Float)=>F2(mem::transmute_copy(&value)),(3,Kind::Float)=>F3(mem::transmute_copy(&value)),(4,Kind::Float)=>F4(mem::transmute_copy(&value)),(5,Kind::Float)=>F5(mem::transmute_copy(&value)),(6,Kind::Float)=>F6(mem::transmute_copy(&value)),(7,Kind::Float)=>F7(mem::transmute_copy(&value)),(8,Kind::Float)=>F8(mem::transmute_copy(&value)),(1,Kind::Int)=>I1(mem::transmute_copy(&value)),(2,Kind::Int)=>I2(mem::transmute_copy(&value)),(3,Kind::Int)=>I3(mem::transmute_copy(&value)),(4,Kind::Int)=>I4(mem::transmute_copy(&value)),(5,Kind::Int)=>I5(mem::transmute_copy(&value)),(6,Kind::Int)=>I6(mem::transmute_copy(&value)),(7,Kind::Int)=>I7(mem::transmute_copy(&value)),(8,Kind::Int)=>I8(mem::transmute_copy(&value)),_=>return "only bool, float, and int tensors with dimensions 1-8 are currently supported".into()}
103 };
104 mem::forget(value);
105 v
106 }
107}
108impl<B:Backend,K:'static+TensorKind<B>,const N:usize> TryFrom<Value<B>> for Tensor<B,N,K>{
109 fn try_from(value:Value<B>)->Result<Self,Self::Error>{
110 let kind=TypeId::of::<K>();
111 let kind=if kind==TypeId::of::<Bool>(){Kind::Bool}else if kind==TypeId::of::<Float>(){Kind::Float}else if kind==TypeId::of::<Int>(){Kind::Int}else{return Err(value)};
112
113 if Some(N)!=value.rank()||kind!=value.kind(){return Err(value)}
114 let r=unsafe{
115 match &value{B1(x)=>mem::transmute_copy(x),B2(x)=>mem::transmute_copy(x),B3(x)=>mem::transmute_copy(x),B4(x)=>mem::transmute_copy(x),B5(x)=>mem::transmute_copy(x),B6(x)=>mem::transmute_copy(x),B7(x)=>mem::transmute_copy(x),B8(x)=>mem::transmute_copy(x),F1(x)=>mem::transmute_copy(x),F2(x)=>mem::transmute_copy(x),F3(x)=>mem::transmute_copy(x),F4(x)=>mem::transmute_copy(x),F5(x)=>mem::transmute_copy(x),F6(x)=>mem::transmute_copy(x),F7(x)=>mem::transmute_copy(x),F8(x)=>mem::transmute_copy(x),I1(x)=>mem::transmute_copy(x),I2(x)=>mem::transmute_copy(x),I3(x)=>mem::transmute_copy(x),I4(x)=>mem::transmute_copy(x),I5(x)=>mem::transmute_copy(x),I6(x)=>mem::transmute_copy(x),I7(x)=>mem::transmute_copy(x),I8(x)=>mem::transmute_copy(x),_=>panic!("internal error")}
116 };
117 mem::forget(value);
118 Ok(r)
119 }
120 type Error=Value<B>;
121}
122impl<B:Backend,R:Clone+RangeBounds<isize>> Flatten<R> for Value<B>{
123 fn flatten(self,args:R)->Self::Output{
124 fn f<B:Backend,K:'static+BasicOps<B>+TensorKind<B>,R:RangeBounds<isize>,const N:usize>(args:R,x:Tensor<B,N,K>)->Value<B>{
125 let a=match args.start_bound(){
126 Excluded(&n)=>(if n<0{N-((-n) as usize)}else{n as usize})+1,
127 Included(&n)=>if n<0{N-((-n) as usize)}else{n as usize},
128 Unbounded=>0
129 };
130 let b=match args.end_bound(){
131 Excluded(&n)=>if n<0{N-((-n) as usize)}else if n==0{N}else{n as usize},
132 Included(&n)=>(if n<0{N-((-n) as usize)}else{n as usize})+1,
133 Unbounded=>N
134 };
135 let flattenedlen=b-a;
136
137 match N-flattenedlen+1{
138 1=>x.flatten::<1>(a,b-1).into(),
139 2=>x.flatten::<2>(a,b-1).into(),
140 3=>x.flatten::<3>(a,b-1).into(),
141 4=>x.flatten::<4>(a,b-1).into(),
142 5=>x.flatten::<5>(a,b-1).into(),
143 6=>x.flatten::<6>(a,b-1).into(),
144 7=>x.flatten::<7>(a,b-1).into(),
145 8=>x.into(),
146 _=>"invalid flatten".into()
147 }
148 }
149
150 match self{
151 B1(x)=>f(args,x),
152 B2(x)=>f(args,x),
153 B3(x)=>f(args,x),
154 B4(x)=>f(args,x),
155 B5(x)=>f(args,x),
156 B6(x)=>f(args,x),
157 B7(x)=>f(args,x),
158 B8(x)=>f(args,x),
159 F1(x)=>f(args,x),
160 F2(x)=>f(args,x),
161 F3(x)=>f(args,x),
162 F4(x)=>f(args,x),
163 F5(x)=>f(args,x),
164 F6(x)=>f(args,x),
165 F7(x)=>f(args,x),
166 F8(x)=>f(args,x),
167 I1(x)=>f(args,x),
168 I2(x)=>f(args,x),
169 I3(x)=>f(args,x),
170 I4(x)=>f(args,x),
171 I5(x)=>f(args,x),
172 I6(x)=>f(args,x),
173 I7(x)=>f(args,x),
174 I8(x)=>f(args,x),
175 Value::Incompatible(e)=>e.into(),
176 Value::Multi(v)=>v.into_iter().map(|x|x.flatten(args.clone())).collect(),
177 }
178 }
179 type Output=Self;
180}
181impl<B:Backend,R:Into<Reshape>> OpsReshape<R> for Value<B>{fn reshape(self,args:R)->Self::Output{
183 fn f<B:Backend,K:'static+BasicOps<B>+TensorKind<B>,const N:usize>(args:Reshape,x:Tensor<B,N,K>)->Value<B>{
184 match args{
185 R1(r)=>x.reshape(r.map(|x|x as i32)).into(),
186 R2(r)=>x.reshape(r.map(|x|x as i32)).into(),
187 R3(r)=>x.reshape(r.map(|x|x as i32)).into(),
188 R4(r)=>x.reshape(r.map(|x|x as i32)).into(),
189 R5(r)=>x.reshape(r.map(|x|x as i32)).into(),
190 R6(r)=>x.reshape(r.map(|x|x as i32)).into(),
191 R7(r)=>x.reshape(r.map(|x|x as i32)).into(),
192 R8(r)=>x.reshape(r.map(|x|x as i32)).into(),
193 Reshape::Recursive(_v)=>"reshaping a single tensor into multiple tensors is currently not supported".into()
194 }
195 }
196
197 let args=args.into();
198 let depth=args.depth();
199
200 apply_depthwise(depth,|value|{
201 let args=args.clone();
202 match value{
203 B1(x)=>f(args,x),
204 B2(x)=>f(args,x),
205 B3(x)=>f(args,x),
206 B4(x)=>f(args,x),
207 B5(x)=>f(args,x),
208 B6(x)=>f(args,x),
209 B7(x)=>f(args,x),
210 B8(x)=>f(args,x),
211 F1(x)=>f(args,x),
212 F2(x)=>f(args,x),
213 F3(x)=>f(args,x),
214 F4(x)=>f(args,x),
215 F5(x)=>f(args,x),
216 F6(x)=>f(args,x),
217 F7(x)=>f(args,x),
218 F8(x)=>f(args,x),
219 I1(x)=>f(args,x),
220 I2(x)=>f(args,x),
221 I3(x)=>f(args,x),
222 I4(x)=>f(args,x),
223 I5(x)=>f(args,x),
224 I6(x)=>f(args,x),
225 I7(x)=>f(args,x),
226 I8(x)=>f(args,x),
227 Value::Incompatible(e)=>e.into(),
228 Value::Multi(v)=>if let Reshape::Recursive(r)=args{r.into_iter().zip(v).map(|(r,v)|v.reshape(r)).collect()}else{"reshaping multiple tensors into a single tensor is currently not supported".into()}
229 }
230 },self)
231 }
232 type Output=Self;
233}
234impl<B:Backend,S:?Sized+AsRef<str>> From<&S> for Value<B>{
235 fn from(value:&S)->Self{Self::Incompatible(value.as_ref().to_string())}
236}
237impl<B:Backend> AI<Value<B>,Value<B>> for BatchNorm<B>{
238 fn forward(&self,input:Value<B>)->Value<B>{
239 fn f<B:Backend,const E:usize>(norm:&BatchNorm<B>,x:Tensor<B,E>)->Value<B>{norm.forward(x).into()}
240 match input.float(){
241 F1(x)=>AI::forward(self,F1(x).unsqueeze().unsqueeze()).squeeze().squeeze(),
242 F2(x)=>AI::forward(self,F2(x).unsqueeze()).squeeze(),
243 F3(x)=>f::<B,3>(self,x),
244 F4(x)=>f::<B,4>(self,x),
245 F5(x)=>f::<B,5>(self,x),
246 F6(x)=>f::<B,6>(self,x),
247 F7(x)=>f::<B,7>(self,x),
248 F8(x)=>f::<B,8>(self,x),
249 Value::Incompatible(e)=>e.into(),
250 Value::Multi(v)=>v.into_iter().map(|x|AI::forward(self,x)).collect(),
251 _=>panic!("internal error")
252 }
253 }
254}
255impl<B:Backend> AI<(Value<B>,Value<B>),Vec<f32>> for CrossEntropyLayer{
256 fn forward(&self,input:(Value<B>,Value<B>))->Vec<f32>{
257 let output:Value<B>=self.forward(input);
258 output.into_float_vec()
259 }
260}
261impl<B:Backend> AI<(Value<B>,Value<B>),LossOutput<B>> for CrossEntropyLayer{
262 fn forward(&self,(output,target):(Value<B>,Value<B>))->LossOutput<B>{
263 let loss=self.forward((output.clone(),target.clone()));
264 LossOutput::new(loss,output,target)
265 }
266}
267impl<B:Backend> AI<(Value<B>,Value<B>),Value<B>> for CrossEntropyLayer{fn forward(&self,(output,target):(Value<B>,Value<B>))->Value<B>{
269 fn ff<B:Backend,const N:usize>(dim:i32,y:Tensor<B,N>,t:Tensor<B,N>,temperature:f32)->Result<Tensor<B,N>,String>{
270 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
271 let (ydims,tdims)=(y.dims(),t.dims());
272 if ydims==tdims{
273 let logy=if temperature.is_nan(){y.log()}else{log_softmax(y/temperature,dim)};
274 Ok(logy*t.neg())
275 }else{
276 Err(format!("incompatible shapes to cross entropy. ydims: {ydims:?} tdims: {tdims:?}"))
277 }
278 }
279 fn fi<B:Backend,const N:usize,const K:usize>(dim:i32,y:Tensor<B,N>,t:Tensor<B,K,Int>,temperature:f32)->Result<Tensor<B,K>,String>{
280 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
281 let (ydims,tdims)=(y.dims(),t.dims());
282 if ydims.iter().enumerate().filter_map(|(n,y)|(n!=dim).then_some(y)).eq(tdims.iter()){
283 let logy=if temperature.is_nan(){y.log()}else{log_softmax(y/temperature,dim)};
284 Ok(logy.gather(dim,t.unsqueeze_dim(dim)).neg().squeeze_dim(dim))
285 }else{
286 Err(format!("incompatible shapes to cross entropy along dimension {dim}. ydims: {ydims:?} tdims: {tdims:?}"))
287 }
288 }
289 let (dim,temp)=(self.get_dim(),self.get_temperature());
290
291 match match (output,target){
292 (F1(y),F1(t))=>ff(dim,y,t,temp).map(Into::into),
293 (F2(y),F2(t))=>ff(dim,y,t,temp).map(Into::into),
294 (F3(y),F3(t))=>ff(dim,y,t,temp).map(Into::into),
295 (F4(y),F4(t))=>ff(dim,y,t,temp).map(Into::into),
296 (F5(y),F5(t))=>ff(dim,y,t,temp).map(Into::into),
297 (F6(y),F6(t))=>ff(dim,y,t,temp).map(Into::into),
298 (F7(y),F7(t))=>ff(dim,y,t,temp).map(Into::into),
299 (F8(y),F8(t))=>ff(dim,y,t,temp).map(Into::into),
300 (F1(y),I1(t))=>fi(dim,y.unsqueeze::<2>(),t,temp).map(Into::into),(F2(y),I1(t))=>fi(dim,y,t,temp).map(Into::into),
302 (F3(y),I2(t))=>fi(dim,y,t,temp).map(Into::into),
303 (F4(y),I3(t))=>fi(dim,y,t,temp).map(Into::into),
304 (F5(y),I4(t))=>fi(dim,y,t,temp).map(Into::into),
305 (F6(y),I5(t))=>fi(dim,y,t,temp).map(Into::into),
306 (F7(y),I6(t))=>fi(dim,y,t,temp).map(Into::into),
307 (F8(y),I7(t))=>fi(dim,y,t,temp).map(Into::into),
308 (Value::Incompatible(y),_)=>Err(y),
309 (_,Value::Incompatible(t))=>Err(t),(Value::Multi(y),Value::Multi(t))=>if y.len()==t.len(){Ok(Value::Multi(y.into_iter().zip(t).map(|x|self.forward_typed::<_,Value<B>>(x)).collect()))}else{Err("mismatched lengths".into())},
311 _=>Err("incompatible".into())
312 }{
313 Err(e)=>Value::Incompatible(e),Ok(x)=>x
314 }
315 }
316}
317impl<B:Backend> AI<(Value<B>,Value<B>),Value<B>> for CrossEntropyLoss<B>{
318 fn forward(&self,(output,target):(Value<B>,Value<B>))->Value<B>{
319 let mut op=().fix_type::<Value<B>>().cross_entropy(1.0);
320 if !self.logits{op.set_temperature(f32::NAN)}
321 op.forward((output,target))
322 }
323}
324impl<B:Backend> AI<(Value<B>,Value<B>),LossOutput<B>> for SquaredErrorLayer{
325 fn forward(&self,(output,target):(Value<B>,Value<B>))->LossOutput<B>{
326 let loss=self.forward((output.clone(),target.clone()));
327 LossOutput::new(loss,output,target)
328 }
329}
330impl<B:Backend> AI<(Value<B>,Value<B>),Value<B>> for SquaredErrorLayer{
331 fn forward(&self,(output,target):(Value<B>,Value<B>))->Value<B>{
332 fn f<B:Backend,const N:usize>(y:Tensor<B,N>,t:Tensor<B,N>)->Value<B>{
333 if y.dims()==t.dims(){MseLoss.forward_no_reduction(y,t).into()}else{"compatible inputs for squared error are float tensors of the same shape".into()}
334 }
335 match (output.float(),target.float()){(F1(y),F1(t))=>f(y,t),(F2(y),F2(t))=>f(y,t),(F3(y),F3(t))=>f(y,t),(F4(y),F4(t))=>f(y,t),(F5(y),F5(t))=>f(y,t),(F6(y),F6(t))=>f(y,t),(F7(y),F7(t))=>f(y,t),(F8(y),F8(t))=>f(y,t),(Value::Incompatible(y),_)=>y.into(),(_,Value::Incompatible(t))=>t.into(),(Value::Multi(y),t)=>broadcast_multi(y,t.into_multi(),|y,t|self.forward((y,t))),(y,Value::Multi(t))=>broadcast_multi(y.into_multi(),t,|y,t|self.forward((y,t))),_=>"compatible inputs for squared error are float tensors of the same shape".into()}
336 }
337}
338impl<B:Backend> AI<(Value<B>,Value<B>),Vec<f32>> for SquaredErrorLayer{
339 fn forward(&self,(output,target):(Value<B>,Value<B>))->Vec<f32>{
340 let error:Value<B>=self.forward((output,target));
341 error.into_float_vec()
342 }
343}
344impl<B:Backend> AI<(Value<B>,Value<B>),f32> for SquaredErrorLayer{
345 fn forward(&self,(output,target):(Value<B>,Value<B>))->f32{().fix_type::<Value<B>>().squared_error().mean().forward((output,target))}
346}
347impl<B:Backend> AI<Value<B>,Tensor<B,1>> for MeanLayer{
348 fn forward(&self,input:Value<B>)->Tensor<B,1>{
349 fn avg<B:Backend,const N:usize>(x:Tensor<B,N>)->Tensor<B,1>{x.mean()}
350 let l=input.len();
351
352 if l==0{return Tensor::from_data(TensorData::new(vec![f32::NAN],[1]),&Default::default())}
353 match input.float(){F1(x)=>avg(x),F2(x)=>avg(x),F3(x)=>avg(x),F4(x)=>avg(x),F5(x)=>avg(x),F6(x)=>avg(x),F7(x)=>avg(x),F8(x)=>avg(x),Value::Incompatible(e)=>panic!("Could not reduce to a scalar due to incompatibility: {e}"),Value::Multi(v)=>v.into_iter().map(|x|self.forward(x)).reduce(|x:Tensor<B,1>,y:Tensor<B,1>|x+y).unwrap()/l as f32,_=>panic!("internal error")}
354 }
355}
356impl<B:Backend> AI<Value<B>,Tensor<B,1>> for SumLayer{
357 fn forward(&self,input:Value<B>)->Tensor<B,1>{
358 fn sum<B:Backend,const N:usize>(x:Tensor<B,N>)->Tensor<B,1>{x.sum()}
359 let l=input.len();
360
361 if l==0{return Tensor::from_data(TensorData::new(vec![f32::NAN],[1]),&Default::default())}
362 match input.float(){F1(x)=>sum(x),F2(x)=>sum(x),F3(x)=>sum(x),F4(x)=>sum(x),F5(x)=>sum(x),F6(x)=>sum(x),F7(x)=>sum(x),F8(x)=>sum(x),Value::Incompatible(e)=>panic!("Could not reduce to a scalar due to incompatibility: {e}"),Value::Multi(v)=>v.into_iter().map(|x|self.forward(x)).reduce(|x:Tensor<B,1>,y:Tensor<B,1>|x+y).unwrap(),_=>panic!("internal error")}
363 }
364}
365impl<B:Backend> AI<Value<B>,Value<B>> for Conv2d<B>{
366 fn forward(&self,input:Value<B>)->Value<B>{
367 fn f<B:Backend,const N:usize>(input:Tensor<B,N>,layer:&Conv2d<B>)->Value<B>{let mut dims=input.dims();
369 let n:usize=dims.iter().product();
370
371 let c=if N<3{1}else{dims[N-3]};
372 let h=if N<2{1}else{dims[N-2]};
373 let w=dims[N-1];
374
375 let b=n/(c*h*w);
376 let output=layer.forward(input.reshape([b,c,h,w]));
377
378 let [_b,c,h,w]=output.dims();
379
380 dims[N-1]=w;
381 if N<3&&c!=1{return F3(output.reshape([c,h,w]))}else if N>=3{dims[N-3]=c}
382 if N<2&&h!=1{return F2(output.reshape([h,w]))}else if N>=2{dims[N-2]=h}
383 output.reshape(dims).into()
384 }
385 let l=self;
386
387 match input.float(){F1(x)=>f(x,l),F2(x)=>f(x,l),F3(x)=>f(x,l),F4(x)=>f(x,l),F5(x)=>f(x,l),F6(x)=>f(x,l),F7(x)=>f(x,l),F8(x)=>f(x,l),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|AI::forward(self,x)).collect(),_=>panic!("internal error")}
388 }
389}
390impl<B:Backend> AI<Value<B>,Value<B>> for CrossEntropyLayer{
391 fn forward(&self,input:Value<B>)->Value<B>{
392 match input{
393 Value::Incompatible(e)=>e.into(),
394 Value::Multi(v)=>if v.len()==2{
395 let [output,target]=v.try_into().unwrap();
396 self.forward((output,target))
397 }else{
398 v.into_iter().map(|x|self.forward(x)).collect()
399 },
400 _=>"cross entropy inputs must be in pairs".into()
401 }
402 }
403}
404impl<B:Backend> AI<Value<B>,Value<B>> for CrossEntropyLoss<B>{
405 fn forward(&self,input:Value<B>)->Value<B>{
406 let mut op=CrossEntropyLayer::new(1.0);
407 if !self.logits{op.set_temperature(f32::NAN)}
408 op.forward(input)
409 }
410}
411impl<B:Backend> AI<Value<B>,Value<B>> for MeanLayer{
412 fn forward(&self,input:Value<B>)->Value<B>{
413 fn avg<B:Backend,const N:usize,const K:usize>(d:i32,x:Tensor<B,N>)->Tensor<B,K>{
414 let d=if d<0{N-((-d) as usize)}else{d as usize};
415 x.mean_dim(d).squeeze_dim(d)
416 }
417 let l=input.len();
418
419 if l==0{return input}
420 match self.get_reduction_mode(){
421 ReductionMode::Component=>F1(self.forward(input)),
422 ReductionMode::Dim(d)=>{
423 if let Some(r)=input.rank(){
424 if d>=r as i32||d<(-(r as i32)){return format!("rank {r} is too low to cat along dimension {d}").into()}
425 }
426 match input.float(){F1(x)=>F1(x.mean()),F2(x)=>F1(avg(d,x)),F3(x)=>F2(avg(d,x)),F4(x)=>F3(avg(d,x)),F5(x)=>F4(avg(d,x)),F6(x)=>F5(avg(d,x)),F7(x)=>F6(avg(d,x)),F8(x)=>F7(avg(d,x)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|self.forward(x)).reduce(|x:Value<B>,y:Value<B>|x+y).unwrap()/l as f32,_=>panic!("internal error")}
427 },
428 ReductionMode::Tensor=>match input.float(){Value::Multi(v)=>v.into_iter().reduce(|x,y|x+y).unwrap()/l as f32,x=>x}
429 }
430 }
431}
432impl<B:Backend> AI<Value<B>,Value<B>> for MseLoss{
433 fn forward(&self,input:Value<B>)->Value<B>{SquaredErrorLayer::new().forward(input)}
434}
435impl<B:Backend> AI<Value<B>,Value<B>> for SquaredErrorLayer{
436 fn forward(&self,input:Value<B>)->Value<B>{
437 match input{
438 Value::Incompatible(e)=>e.into(),
439 Value::Multi(v)=>if v.len()==2{
440 let [output,target]=v.try_into().unwrap();
441 self.forward((output,target))
442 }else{
443 v.into_iter().map(|x|self.forward(x)).collect()
444 },
445 _=>"squared error inputs must be in pairs".into()
446 }
447 }
448}
449impl<B:Backend> AI<Value<B>,Value<B>> for SumLayer{
450 fn forward(&self,input:Value<B>)->Value<B>{
451 fn sum<B:Backend,const N:usize,const K:usize>(d:i32,x:Tensor<B,N>)->Tensor<B,K>{
452 let d=if d<0{N-((-d) as usize)}else{d as usize};
453 x.mean_dim(d).squeeze_dim(d)
454 }
455 let l=input.len();
456
457 if l==0{return input}
458 match self.get_reduction_mode(){
459 ReductionMode::Component=>F1(self.forward(input)),
460 ReductionMode::Dim(d)=>{
461 if let Some(r)=input.rank(){
462 if d>=r as i32||d<(-(r as i32)){return format!("rank {r} is too low to cat along dimension {d}").into()}
463 }
464 match input.float(){F1(x)=>F1(x.sum()),F2(x)=>F1(sum(d,x)),F3(x)=>F2(sum(d,x)),F4(x)=>F3(sum(d,x)),F5(x)=>F4(sum(d,x)),F6(x)=>F5(sum(d,x)),F7(x)=>F6(sum(d,x)),F8(x)=>F7(sum(d,x)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|self.forward(x)).reduce(|x:Value<B>,y:Value<B>|x+y).unwrap(),_=>panic!("internal error")}
465 },
466 ReductionMode::Tensor=>match input.float(){Value::Multi(v)=>v.into_iter().reduce(|x,y|x+y).unwrap(),x=>x}
467 }
468 }
469}
470impl<B:Backend> AI<Value<B>,f32> for MeanLayer{
471 fn forward(&self,input:Value<B>)->f32{
472 let y:Tensor<B,1>=self.forward(input);
473 y.into_scalar().to_f32()
474 }
475}
476impl<B:Backend> AI<Value<B>,f32> for SumLayer{
477 fn forward(&self,input:Value<B>)->f32{
478 let y:Tensor<B,1>=self.forward(input);
479 y.into_scalar().to_f32()
480 }
481}
482impl<B:Backend> AI<Value<B>,Value<B>> for AccQLayer{
483 fn forward(&self,input:Value<B>)->Value<B>{
484 fn acc_q<B:Backend,const N:usize>(dim:i32,gamma:f32,i:Tensor<B,N>)->Tensor<B,N>{
485 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
486 let mut q=i.split(1,dim);
487 q.iter_mut().rev().fold(None,|future,present|{
488 if let Some(f)=future{*present=f*gamma+present.clone()}
489 Some(present.clone())
490 });
491 Tensor::cat(q,dim)
492 }
493 let (dim,gamma)=(self.get_dim(),self.get_gamma());
494
495 match input.float(){F1(x)=>F1(acc_q(dim,gamma,x)),F2(x)=>F2(acc_q(dim,gamma,x)),F3(x)=>F3(acc_q(dim,gamma,x)),F4(x)=>F4(acc_q(dim,gamma,x)),F5(x)=>F5(acc_q(dim,gamma,x)),F6(x)=>F6(acc_q(dim,gamma,x)),F7(x)=>F7(acc_q(dim,gamma,x)),F8(x)=>F8(acc_q(dim,gamma,x)),Value::Incompatible(x)=>x.into(),Value::Multi(x)=>Value::Multi(x.into_iter().map(|x|self.forward(x)).collect()),_=>panic!("unexpected non float value")}
496 }
497}
498impl<B:Backend> AI<Value<B>,u32> for ChooseLayer{
499 fn forward(&self,input:Value<B>)->u32{
500 let (dim,temperature)=(self.get_dim(),self.get_temperature());
501
502 match input.float(){
503 F1(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
504 F2(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
505 F3(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
506 F4(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
507 F5(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
508 F6(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
509 F7(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
510 F8(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
511 Value::Incompatible(e)=>panic!("Could not create scalar due to incompatibility: {e}"),
512 Value::Multi(v)=>if v.len()==1{self.forward(v.into_iter().next().unwrap())}else{panic!("Cannot soft choose one scalar from multiple values")},
513 _=>panic!("internal error")
514 }
515 }
516}
517impl<B:Backend> AI<Value<B>,Vec<u32>> for ChooseLayer{
518 fn forward(&self,input:Value<B>)->Vec<u32>{
519 let (dim,temperature)=(self.get_dim(),self.get_temperature());
520
521 match input.float(){
522 F1(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
523 F2(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
524 F3(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
525 F4(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
526 F5(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
527 F6(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
528 F7(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
529 F8(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
530 Value::Incompatible(e)=>panic!("Could not create vector due to incompatibility: {e}"),
531 Value::Multi(v)=>v.into_iter().flat_map(|x|self.forward_typed::<_,Vec<u32>>(x)).collect(),
532 _=>panic!("internal error")
533 }
534 }
535}
536impl<B:Backend> AI<Value<B>,Value<B>> for Dropout{
537 fn forward(&self,input:Value<B>)->Value<B>{
538 match input.float(){F1(x)=>F1(self.forward(x)),F2(x)=>F2(self.forward(x)),F3(x)=>F3(self.forward(x)),F4(x)=>F4(self.forward(x)),F5(x)=>F5(self.forward(x)),F6(x)=>F6(self.forward(x)),F7(x)=>F7(self.forward(x)),F8(x)=>F8(self.forward(x)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(|x|AI::forward(self,x)).collect()),_=>panic!("internal error")}
539 }
540}
541impl<B:Backend> AI<Value<B>,Value<B>> for Embedding<B>{
542 fn forward(&self,input:Value<B>)->Value<B>{
543 fn apply_embed<B:Backend,const N:usize,const K:usize>(this:&Embedding<B>,x:Tensor<B,N,Int>)->Tensor<B,K>{
544 let dims=x.dims();
545 let [batch,seq]=[dims[0],dims.iter().skip(1).product()];
546 let x=x.reshape([batch,seq]);
547 let y=this.forward(x);
548 let embed=y.dims().last().copied().unwrap();
549 let mut ydims=[0;K];
550 ydims[..N].copy_from_slice(&dims);
551 ydims[N]=embed;
552 y.reshape(ydims)
553 }
554 fn apply_linear<B:Backend,const N:usize>(this:&Embedding<B>,x:Tensor<B,N>)->Tensor<B,N>{
555 Linear{bias:None,weight:this.weight.clone()}.forward(x)
556 }
557 match input{F1(x)=>apply_linear(self,x).into(),F2(x)=>apply_linear(self,x).into(),F3(x)=>apply_linear(self,x).into(),F4(x)=>apply_linear(self,x).into(),F5(x)=>apply_linear(self,x).into(),F6(x)=>apply_linear(self,x).into(),F7(x)=>apply_linear(self,x).into(),F8(x)=>apply_linear(self,x).into(),I1(x)=>apply_embed::<B,1,2>(self,x).into(),I2(x)=>apply_embed::<B,2,3>(self,x).into(),I3(x)=>apply_embed::<B,3,4>(self,x).into(),I4(x)=>apply_embed::<B,4,5>(self,x).into(),I5(x)=>apply_embed::<B,5,6>(self,x).into(),I6(x)=>apply_embed::<B,6,7>(self,x).into(),I7(x)=>apply_embed::<B,7,8>(self,x).into(),I8(_x)=>"embedding output would exceed maximum supported rank".into(),Value::Incompatible(x)=>x.into(),Value::Multi(x)=>x.into_iter().map(|x|AI::forward(self,x)).collect::<Vec<_>>().into(),_=>"embedding is only available for float or int inputs".into()}
558 }
559}
560impl<B:Backend> AI<Value<B>,Value<B>> for LayerNorm<B>{
561 fn forward(&self,input:Value<B>)->Value<B>{
562 fn f<B:Backend,const N:usize>(input:Tensor<B,N>,layer:&LayerNorm<B>)->Value<B>{
563 let b=layer.beta.dims();
564 let g=layer.gamma.dims();
565 let i=input.dims();
566
567 if b!=g{return format!("malformed layer norm. beta dims: {b:?}. gamma dims: {g:?}.").into()}
568 if b.last()!=i.last(){return format!("layer norm for dimension {b:?} is not compatible with input dimensions {i:?}. The last dimension must match the norm dimension.").into()}
569 layer.forward(input).into()
570 }
571 let l=self;
572
573 match input.float(){F1(x)=>f(x,l),F2(x)=>f(x,l),F3(x)=>f(x,l),F4(x)=>f(x,l),F5(x)=>f(x,l),F6(x)=>f(x,l),F7(x)=>f(x,l),F8(x)=>f(x,l),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(|x|AI::forward(self,x)).collect()),_=>panic!("internal error")}
574 }
575}
576impl<B:Backend> AI<Value<B>,Value<B>> for Linear<B>{
577 fn forward(&self,input:Value<B>)->Value<B>{
578 match input.float(){F1(x)=>F1(self.forward(x)),F2(x)=>F2(self.forward(x)),F3(x)=>F3(self.forward(x)),F4(x)=>F4(self.forward(x)),F5(x)=>F5(self.forward(x)),F6(x)=>F6(self.forward(x)),F7(x)=>F7(self.forward(x)),F8(x)=>F8(self.forward(x)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(|x|AI::forward(self,x)).collect()),_=>panic!("internal error")}
579 }
580}
581impl<B:Backend> AI<Value<B>,Value<B>> for Relu{
582 fn forward(&self,input:Value<B>)->Value<B>{
583 match input.float(){F1(x)=>F1(self.forward(x)),F2(x)=>F2(self.forward(x)),F3(x)=>F3(self.forward(x)),F4(x)=>F4(self.forward(x)),F5(x)=>F5(self.forward(x)),F6(x)=>F6(self.forward(x)),F7(x)=>F7(self.forward(x)),F8(x)=>F8(self.forward(x)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(|x|AI::forward(self,x)).collect()),_=>panic!("internal error")}
584 }
585}
586impl<B:Backend> AI<Value<B>,Value<B>> for SoftmaxLayer{
587 fn forward(&self,input:Value<B>)->Value<B>{
588 fn f<B:Backend,const N:usize>(dim:i32,temperature:f32,x:Tensor<B,N>)->Tensor<B,N>{
589 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
590 softmax(x/temperature,dim)
591 }
592 let (dim,temperature)=(self.get_dim(),self.get_temperature());
593
594 match input.float(){F1(x)=>F1(f(dim,temperature,x)),F2(x)=>F2(f(dim,temperature,x)),F3(x)=>F3(f(dim,temperature,x)),F4(x)=>F4(f(dim,temperature,x)),F5(x)=>F5(f(dim,temperature,x)),F6(x)=>F6(f(dim,temperature,x)),F7(x)=>F7(f(dim,temperature,x)),F8(x)=>F8(f(dim,temperature,x)),Value::Incompatible(x)=>x.into(),Value::Multi(x)=>x.into_iter().map(|x|self.forward(x)).collect(),_=>panic!("unexpected non float value")}
595 }
596}
597impl<B:Backend> AI<Value<B>,Value<B>> for ChooseLayer{
598 fn forward(&self,input:Value<B>)->Value<B>{let (dim,temperature)=(self.get_dim(),self.get_temperature());
600 let d=if dim<0{input.rank().unwrap_or(8)-((-dim) as usize)}else{dim as usize};
601
602 match input.float(){
603 F1(x)=>I1(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}),
604 F2(x)=>I1(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze_dim(d)),
605 F3(x)=>I2(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze_dim(d)),
606 F4(x)=>I3(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze_dim(d)),
607 F5(x)=>I4(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze_dim(d)),
608 F6(x)=>I5(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze_dim(d)),
609 F7(x)=>I6(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze_dim(d)),
610 F8(x)=>I7(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze_dim(d)),
611 Value::Incompatible(e)=>e.into(),
612 Value::Multi(v)=>Value::Multi(v.into_iter().map(|v|self.forward_typed::<_,Value<B>>(v)).collect()),
613 _=>panic!("internal error")}
614 }
615}
616impl<B:Backend> AI<Value<B>,Value<B>> for MaxPool2d{
617 fn forward(&self,input:Value<B>)->Value<B>{
618 fn f<B:Backend,const N:usize>(pool:&MaxPool2d,x:Tensor<B,N>)->Value<B>{
619 match N{
620 0=>panic!("internal error"),
621 1=>f::<B,2>(pool,x.unsqueeze()).squeeze(),
622 2=>f::<B,3>(pool,x.unsqueeze()).squeeze(),
623 3=>f::<B,4>(pool,x.unsqueeze()).squeeze(),
624 4=>pool.forward(Value::from(x).unwrap_f4()).into(),
625 _=>{
626 let mut dims=x.dims();
627
628 let [channels,h,w]=[dims[N-3],dims[N-2],dims[N-1]];
629 let big:usize=dims.iter().take(N-3).product();
630 let y=x.reshape([big,channels,h,w]);
631
632 dims[N-3..].copy_from_slice(&y.dims()[1..]);
633
634 let y=pool.forward(y);
635 y.reshape(dims).into()
636 }
637 }
638 }
639 match input.float(){
640 F1(x)=>f(self,x),
641 F2(x)=>f(self,x),
642 F3(x)=>f(self,x),
643 F4(x)=>f(self,x),
644 F5(x)=>f(self,x),
645 F6(x)=>f(self,x),
646 F7(x)=>f(self,x),
647 F8(x)=>f(self,x),
648 Value::Incompatible(e)=>e.into(),
649 Value::Multi(v)=>v.into_iter().map(|x|AI::forward(self,x)).collect(),
650 _=>panic!("Internal error")
651 }
652 }
653}
654impl<B:Backend> AI<Value<B>,Value<B>> for RotaryEncoding<B>{
655 fn forward(&self,input:Value<B>)->Value<B>{AI::forward(self,(input,0)).0}
656}
657impl<B:Backend> AI<(Value<B>,usize),(Value<B>,usize)> for RotaryEncoding<B>{
658 fn forward(&self,(input,offset):(Value<B>,usize))->(Value<B>,usize){
659 fn apply<B:Backend,const D:usize>(a:&RotaryEncoding<B>,input:Tensor<B,D>,offset:usize)->Value<B>{
660 assert!(D>=2);
661 const MAX_KERNEL:usize=65535; let device=input.device();
663 let freq=&a.freq_complex;
664 let shape=input.shape();
665
666 let (context,key)=(shape.dims[D-2],shape.dims[D-1]);
667 let [distance,head,_2]=freq.dims();
668
669 if context>distance{return format!("context length must not exceed rotary distance. context: {context}, distance: {distance}").into()}
670 if key%head!=0{return "input dimension must be a multiple of head".into()}
671 let count=shape.num_elements();
672 let big=count/(context*key);
673 let heads=key/head;
674 let group=count/head; let input=input.reshape([big,context,heads,head]).swap_dims(1,2).reshape([big*heads,context,head/2,2]);
676 let sign=Tensor::<B,2>::from_floats([[1.0,0.0,0.0,1.0],[0.0,-1.0,1.0,0.0]],&device).unsqueeze();
677
678 let chunks=input.chunk(group.div_ceil(MAX_KERNEL),0).into_iter().map(|x|{
679 let small=x.dims()[0];
680 let x=x.matmul(sign.clone()).reshape([small,context,head,2])*freq.clone().slice([offset..context+offset]).unsqueeze();
681 x.sum_dim(3)
682 }).collect();
683 Tensor::cat(chunks,0).reshape([big,heads,context,head]).swap_dims(1,2).reshape::<D,_>(shape).into()
684 }
685
686 (match input.float(){F1(x)=>apply(self,x.unsqueeze::<2>(),offset).squeeze(),F2(x)=>apply(self,x,offset),F3(x)=>apply(self,x,offset),F4(x)=>apply(self,x,offset),F5(x)=>apply(self,x,offset),F6(x)=>apply(self,x,offset),F7(x)=>apply(self,x,offset),F8(x)=>apply(self,x,offset),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|AI::forward(self,(x,offset)).0).collect(),_=>panic!("internal error")},offset)
687 }
688}
689impl<B:Backend> AI<Value<B>,Value<B>> for Tanh{
690 fn forward(&self,input:Value<B>)->Value<B>{
691 match input.float(){F1(x)=>F1(self.forward(x)),F2(x)=>F2(self.forward(x)),F3(x)=>F3(self.forward(x)),F4(x)=>F4(self.forward(x)),F5(x)=>F5(self.forward(x)),F6(x)=>F6(self.forward(x)),F7(x)=>F7(self.forward(x)),F8(x)=>F8(self.forward(x)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(|x|AI::forward(self,x)).collect()),_=>panic!("internal error")}
692 }
693}
694impl<B:Backend> Abs for Value<B>{
695 fn abs(self)->Self::Output{
696 match self{B1(x)=>B1(x),B2(x)=>B2(x),B3(x)=>B3(x),B4(x)=>B4(x),B5(x)=>B5(x),B6(x)=>B6(x),B7(x)=>B7(x),B8(x)=>B8(x),F1(x)=>F1(x.abs()),F2(x)=>F2(x.abs()),F3(x)=>F3(x.abs()),F4(x)=>F4(x.abs()),F5(x)=>F5(x.abs()),F6(x)=>F6(x.abs()),F7(x)=>F7(x.abs()),F8(x)=>F8(x.abs()),I1(x)=>I1(x.abs()),I2(x)=>I2(x.abs()),I3(x)=>I3(x.abs()),I4(x)=>I4(x.abs()),I5(x)=>I5(x.abs()),I6(x)=>I6(x.abs()),I7(x)=>I7(x.abs()),I8(x)=>I8(x.abs()),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(Value::abs).collect()}
697 }
698 type Output=Value<B>;
699}
700impl<B:Backend> AsRef<Self> for Value<B>{
701 fn as_ref(&self)->&Self{self}
702}
703impl<B:Backend> Cat for Value<B>{
704 fn cat(self,d:i32)->Self{
706 fn f<B:Backend,I:Iterator<Item=Tensor<B,N,K>>,K:BasicOps<B>+TensorKind<B>,const N:usize>(d:i32,x0:Tensor<B,N,K>,tensors:I)->Value<B> where Tensor<B,N,K>:Into<Value<B>>{
707 if d>=N as i32||d<(-(N as i32)){return format!("rank {N} is too low to cat along dimension {d}").into()}
708 let d=if d<0{N-((-d) as usize)}else{d as usize};
709 let shape=x0.dims();
710 let tensors:Vec<Tensor<B,N,K>>=once(x0).chain(tensors).collect();
711
712 if let Err(e)=tensors.iter().try_for_each(|x|{
713 let mut xshape=x.dims();
714 xshape[d]=shape[d];
715 if shape==xshape{Ok(())}else{Err("mismatched shapes {shape:?}, {xshape:?}")}
716 }){
717 return e.into()
718 }
719
720 Tensor::cat(tensors,d).into()
721 }
722 let v=if let Value::Multi(v)=self{v}else{return self};
723
724 if let Some(n)=v.iter().position(Value::is_incompatible){return v.into_iter().nth(n).unwrap()}
725 if v.iter().all(Value::is_multi){return v.into_iter().map(|x|x.cat(d)).collect()}
726 if v.iter().any(Value::is_multi){return "cannot mix single and multi values in a cat operation".into()}
727 let variant=mem::discriminant(&v[0]);
728
729 if v.iter().any(|x|mem::discriminant(x)!=variant){return "cannot mix variants in a cat operation".into()}
730 let mut v=v.into_iter();
731
732 match v.next().unwrap(){B1(x0)=>f(d,x0,v.map(Value::unwrap_b1)),B2(x0)=>f(d,x0,v.map(Value::unwrap_b2)),B3(x0)=>f(d,x0,v.map(Value::unwrap_b3)),B4(x0)=>f(d,x0,v.map(Value::unwrap_b4)),B5(x0)=>f(d,x0,v.map(Value::unwrap_b5)),B6(x0)=>f(d,x0,v.map(Value::unwrap_b6)),B7(x0)=>f(d,x0,v.map(Value::unwrap_b7)),B8(x0)=>f(d,x0,v.map(Value::unwrap_b8)),F1(x0)=>f(d,x0,v.map(Value::unwrap_f1)),F2(x0)=>f(d,x0,v.map(Value::unwrap_f2)),F3(x0)=>f(d,x0,v.map(Value::unwrap_f3)),F4(x0)=>f(d,x0,v.map(Value::unwrap_f4)),F5(x0)=>f(d,x0,v.map(Value::unwrap_f5)),F6(x0)=>f(d,x0,v.map(Value::unwrap_f6)),F7(x0)=>f(d,x0,v.map(Value::unwrap_f7)),F8(x0)=>f(d,x0,v.map(Value::unwrap_f8)),I1(x0)=>f(d,x0,v.map(Value::unwrap_i1)),I2(x0)=>f(d,x0,v.map(Value::unwrap_i2)),I3(x0)=>f(d,x0,v.map(Value::unwrap_i3)),I4(x0)=>f(d,x0,v.map(Value::unwrap_i4)),I5(x0)=>f(d,x0,v.map(Value::unwrap_i5)),I6(x0)=>f(d,x0,v.map(Value::unwrap_i6)),I7(x0)=>f(d,x0,v.map(Value::unwrap_i7)),I8(x0)=>f(d,x0,v.map(Value::unwrap_i8)),Value::Incompatible(_e)=>panic!("internal error not handled in correct location"),Value::Multi(_e)=>panic!("internal error not handled in correct location")}
733 }
734 type Output=Self;
735}
736impl<B:Backend> Decompose for LossOutput<B>{
737 fn compose((loss,output,target):Self::Decomposition)->Self{Self::new(loss,output,target)}
738 fn decompose(self)->Self::Decomposition{(self.loss(),self.output(),self.target())}
739 fn decompose_cloned(&self)->Self::Decomposition{(self.loss(),self.output(),self.target())}
740 type Decomposition=(Value<B>,Value<B>,Value<B>);
741}
742impl<B:Backend> Decompose for Value<B>{
743 fn compose(decomposition:Self::Decomposition)->Self{decomposition}
744 fn decompose(self)->Self::Decomposition{self}
745 fn decompose_cloned(&self)->Self::Decomposition{self.clone()}
746 type Decomposition=Self;
747}
748impl<B:Backend> Default for Value<B>{
749 fn default()->Self{Self::Multi(Vec::new())}
750}
751impl<B:Backend> Display for Value<B>{
752 fn fmt(&self,f:&mut std::fmt::Formatter<'_>)->FmtResult{
753 match self{
754 B1(x)=>x.fmt(f),
755 B2(x)=>x.fmt(f),
756 B3(x)=>x.fmt(f),
757 B4(x)=>x.fmt(f),
758 B5(x)=>x.fmt(f),
759 B6(x)=>x.fmt(f),
760 B7(x)=>x.fmt(f),
761 B8(x)=>x.fmt(f),
762 F1(x)=>x.fmt(f),
763 F2(x)=>x.fmt(f),
764 F3(x)=>x.fmt(f),
765 F4(x)=>x.fmt(f),
766 F5(x)=>x.fmt(f),
767 F6(x)=>x.fmt(f),
768 F7(x)=>x.fmt(f),
769 F8(x)=>x.fmt(f),
770 I1(x)=>x.fmt(f),
771 I2(x)=>x.fmt(f),
772 I3(x)=>x.fmt(f),
773 I4(x)=>x.fmt(f),
774 I5(x)=>x.fmt(f),
775 I6(x)=>x.fmt(f),
776 I7(x)=>x.fmt(f),
777 I8(x)=>x.fmt(f),
778 Value::Incompatible(e)=>e.fmt(f),
779 Value::Multi(v)=>{
780 write!(f,"[")?;
781 v.iter().take(v.len().saturating_sub(1)).try_for_each(|x|{
782 x.fmt(f)?;
783 write!(f,", ")
784 })?;
785 if let Some(x)=v.last(){
786 x.fmt(f)?;
787 }
788 write!(f,"]")
789 }
790 }
791 }
792}
793impl<B:Backend> From<Vec<bool>> for Value<B>{
794 fn from(value:Vec<bool>)->Self{
795 let l=value.len();
796 let t:Tensor<B,1,Bool>=Tensor::from_data(TensorData::new(value,[l]),&Default::default());
797
798 t.into()
799 }
800}
801impl<B:Backend> From<Vec<f32>> for Value<B>{
802 fn from(value:Vec<f32>)->Self{
803 let l=value.len();
804 let t:Tensor<B,1>=Tensor::from_data(TensorData::new(value,[l]),&Default::default());
805
806 t.into()
807 }
808}
809impl<B:Backend> From<Vec<i32>> for Value<B>{
810 fn from(value:Vec<i32>)->Self{
811 let l=value.len();
812 let t:Tensor<B,1,Int>=Tensor::from_data(TensorData::new(value,[l]),&Default::default());
813
814 t.into()
815 }
816}
817impl<B:Backend> From<Vec<u32>> for Value<B>{
818 fn from(value:Vec<u32>)->Self{
819 let l=value.len();
820 let t:Tensor<B,1,Int>=Tensor::from_data(TensorData::new(value,[l]),&Default::default());
821
822 t.into()
823 }
824}
825impl<B:Backend> ModuleDisplay for Value<B>{
826 fn custom_content(&self,_content:Content)->Option<Content>{None}
827 fn custom_settings(&self)->Option<DisplaySettings>{None}
828 fn format(&self,s:DisplaySettings)->String{
829 match self{
830 B1(x)=>x.format(s),
831 B2(x)=>x.format(s),
832 B3(x)=>x.format(s),
833 B4(x)=>x.format(s),
834 B5(x)=>x.format(s),
835 B6(x)=>x.format(s),
836 B7(x)=>x.format(s),
837 B8(x)=>x.format(s),
838 F1(x)=>x.format(s),
839 F2(x)=>x.format(s),
840 F3(x)=>x.format(s),
841 F4(x)=>x.format(s),
842 F5(x)=>x.format(s),
843 F6(x)=>x.format(s),
844 F7(x)=>x.format(s),
845 F8(x)=>x.format(s),
846 I1(x)=>x.format(s),
847 I2(x)=>x.format(s),
848 I3(x)=>x.format(s),
849 I4(x)=>x.format(s),
850 I5(x)=>x.format(s),
851 I6(x)=>x.format(s),
852 I7(x)=>x.format(s),
853 I8(x)=>x.format(s),
854 Value::Incompatible(e)=>e.to_string(),
855 Value::Multi(v)=>"[".chars().chain(v.iter().flat_map(|x|{
856 let x:Vec<char>=x.format(s.clone()).chars().chain(", ".chars()).collect();
857 x
858 })).chain("]".chars()).collect()
859 }
860 }
861}
862impl<B:Backend> ModuleDisplayDefault for Value<B>{
863 fn content(&self,content:Content)->Option<Content>{Some(content)}
864 fn num_params(&self)->usize{Module::num_params(self)}
865}
866impl<B:Backend> From<String> for Value<B>{
867 fn from(value:String)->Self{Self::Incompatible(value)}
868}
869impl<B:Backend> From<Value<B>> for ValueData{
870 fn from(value:Value<B>)->Self{
871 match value{B1(x)=>BX(x.into_data()),B2(x)=>BX(x.into_data()),B3(x)=>BX(x.into_data()),B4(x)=>BX(x.into_data()),B5(x)=>BX(x.into_data()),B6(x)=>BX(x.into_data()),B7(x)=>BX(x.into_data()),B8(x)=>BX(x.into_data()),F1(x)=>FX(x.into_data()),F2(x)=>FX(x.into_data()),F3(x)=>FX(x.into_data()),F4(x)=>FX(x.into_data()),F5(x)=>FX(x.into_data()),F6(x)=>FX(x.into_data()),F7(x)=>FX(x.into_data()),F8(x)=>FX(x.into_data()),I1(x)=>IX(x.into_data()),I2(x)=>IX(x.into_data()),I3(x)=>IX(x.into_data()),I4(x)=>IX(x.into_data()),I5(x)=>IX(x.into_data()),I6(x)=>IX(x.into_data()),I7(x)=>IX(x.into_data()),I8(x)=>IX(x.into_data()),Value::Incompatible(e)=>ValueData::Incompatible(e),Value::Multi(v)=>ValueData::Multi(v.into_iter().map(ValueData::from).collect())}
872 }
873}
874impl<B:Backend> From<ValueData> for Value<B>{
875 fn from(value:ValueData)->Self{
876 let device=Default::default();
877 match value{
878 BX(data)=>match data.shape.len(){1=>B1(Tensor::from_data(data,&device)),2=>B2(Tensor::from_data(data,&device)),3=>B3(Tensor::from_data(data,&device)),4=>B4(Tensor::from_data(data,&device)),5=>B5(Tensor::from_data(data,&device)),6=>B6(Tensor::from_data(data,&device)),7=>B7(Tensor::from_data(data,&device)),8=>B8(Tensor::from_data(data,&device)),_=>panic!("tensor ranks above 8 are currently not supported")},
879 FX(data)=>match data.shape.len(){1=>F1(Tensor::from_data(data,&device)),2=>F2(Tensor::from_data(data,&device)),3=>F3(Tensor::from_data(data,&device)),4=>F4(Tensor::from_data(data,&device)),5=>F5(Tensor::from_data(data,&device)),6=>F6(Tensor::from_data(data,&device)),7=>F7(Tensor::from_data(data,&device)),8=>F8(Tensor::from_data(data,&device)),_=>panic!("tensor ranks above 8 are currently not supported")},
880 IX(data)=>match data.shape.len(){1=>I1(Tensor::from_data(data,&device)),2=>I2(Tensor::from_data(data,&device)),3=>I3(Tensor::from_data(data,&device)),4=>I4(Tensor::from_data(data,&device)),5=>I5(Tensor::from_data(data,&device)),6=>I6(Tensor::from_data(data,&device)),7=>I7(Tensor::from_data(data,&device)),8=>I8(Tensor::from_data(data,&device)),_=>panic!("tensor ranks above 8 are currently not supported")},
881 ValueData::Incompatible(e)=>e.into(),
882 ValueData::Multi(v)=>v.into_iter().map(Value::from).collect(),
883 }
884 }
885}
886impl<B:Backend> From<Vec<Value<B>>> for Value<B>{
887 fn from(value:Vec<Value<B>>)->Self{Self::Multi(value)}
888}
889impl<B:Backend> IntoIterator for Value<B>{
890 fn into_iter(self)->Self::IntoIter{self.into_multi().into_iter()}
891 type IntoIter=VecIntoIter<Value<B>>;
892 type Item=Value<B>;
893}
894impl<B:Backend> LossOutput<B>{
895 pub fn loss(&self)->Value<B>{self.loss.clone()}
897 pub fn new(loss:Value<B>,output:Value<B>,target:Value<B>)->Self{
899 Self{loss,output,target}
900 }
901 pub fn output(&self)->Value<B>{self.output.clone()}
903 pub fn target(&self)->Value<B>{self.target.clone()}
905}
906impl<B:Backend> Merge for Value<B>{
907 fn merge(&mut self,other:Self){
908 match (mem::take(self),other){
909 (Value::Multi(mut u),Value::Multi(v))=>{
910 u.extend(v);
911 *self=u.into();
912 },
913 (Value::Multi(mut u),v)=>if u.len()==0{
914 *self=v;
915 }else{
916 u.push(v);
917 *self=u.into();
918 },
919 (u,Value::Multi(mut v))=>if v.len()==0{
920 *self=u;
921 }else{
922 v.insert(0,u);
923 *self=v.into();
924 },
925 (u,v)=>*self=vec![u,v].into()
926 }
927 }
928}
929impl<B:Backend> Module<B> for Value<B>{
930 fn collect_devices(&self,devices:Vec<<B as Backend>::Device>)->Vec<<B as Backend>::Device>{
931 match self{B1(x)=>x.collect_devices(devices),B2(x)=>x.collect_devices(devices),B3(x)=>x.collect_devices(devices),B4(x)=>x.collect_devices(devices),B5(x)=>x.collect_devices(devices),B6(x)=>x.collect_devices(devices),B7(x)=>x.collect_devices(devices),B8(x)=>x.collect_devices(devices),F1(x)=>x.collect_devices(devices),F2(x)=>x.collect_devices(devices),F3(x)=>x.collect_devices(devices),F4(x)=>x.collect_devices(devices),F5(x)=>x.collect_devices(devices),F6(x)=>x.collect_devices(devices),F7(x)=>x.collect_devices(devices),F8(x)=>x.collect_devices(devices),I1(x)=>x.collect_devices(devices),I2(x)=>x.collect_devices(devices),I3(x)=>x.collect_devices(devices),I4(x)=>x.collect_devices(devices),I5(x)=>x.collect_devices(devices),I6(x)=>x.collect_devices(devices),I7(x)=>x.collect_devices(devices),I8(x)=>x.collect_devices(devices),Value::Incompatible(_e)=>devices,Value::Multi(v)=>v.iter().fold(devices,|devices,x|x.collect_devices(devices))}
932 }
933 fn devices(&self)->Vec<<B as Backend>::Device>{self.collect_devices(Vec::new())}
934 fn fork(self,device:&<B as Backend>::Device)->Self{
935 match self{B1(x)=>B1(x.fork(device)),B2(x)=>B2(x.fork(device)),B3(x)=>B3(x.fork(device)),B4(x)=>B4(x.fork(device)),B5(x)=>B5(x.fork(device)),B6(x)=>B6(x.fork(device)),B7(x)=>B7(x.fork(device)),B8(x)=>B8(x.fork(device)),F1(x)=>F1(x.fork(device)),F2(x)=>F2(x.fork(device)),F3(x)=>F3(x.fork(device)),F4(x)=>F4(x.fork(device)),F5(x)=>F5(x.fork(device)),F6(x)=>F6(x.fork(device)),F7(x)=>F7(x.fork(device)),F8(x)=>F8(x.fork(device)),I1(x)=>I1(x.fork(device)),I2(x)=>I2(x.fork(device)),I3(x)=>I3(x.fork(device)),I4(x)=>I4(x.fork(device)),I5(x)=>I5(x.fork(device)),I6(x)=>I6(x.fork(device)),I7(x)=>I7(x.fork(device)),I8(x)=>I8(x.fork(device)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.fork(device)).collect()}
936 }
937 fn into_record(self)->Self::Record{ConstantRecord}
938 fn load_file<F:FileRecorder<B>,P:Into<PathBuf>>(self,_filepath:P,_recorder:&F,_device:&<B as Backend>::Device)->Result<Self,RecorderError>{Ok(self)}
939 fn load_record(self,_record:Self::Record)->Self{self}
940 fn map<Mapper:ModuleMapper<B>>(self,mapper:&mut Mapper)->Self{
941 match self{B1(x)=>B1(x.map(mapper)),B2(x)=>B2(x.map(mapper)),B3(x)=>B3(x.map(mapper)),B4(x)=>B4(x.map(mapper)),B5(x)=>B5(x.map(mapper)),B6(x)=>B6(x.map(mapper)),B7(x)=>B7(x.map(mapper)),B8(x)=>B8(x.map(mapper)),F1(x)=>F1(x.map(mapper)),F2(x)=>F2(x.map(mapper)),F3(x)=>F3(x.map(mapper)),F4(x)=>F4(x.map(mapper)),F5(x)=>F5(x.map(mapper)),F6(x)=>F6(x.map(mapper)),F7(x)=>F7(x.map(mapper)),F8(x)=>F8(x.map(mapper)),I1(x)=>I1(x.map(mapper)),I2(x)=>I2(x.map(mapper)),I3(x)=>I3(x.map(mapper)),I4(x)=>I4(x.map(mapper)),I5(x)=>I5(x.map(mapper)),I6(x)=>I6(x.map(mapper)),I7(x)=>I7(x.map(mapper)),I8(x)=>I8(x.map(mapper)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.map(mapper)).collect()}
942 }
943 fn num_params(&self)->usize{
944 match self{B1(x)=>Module::num_params(x),B2(x)=>Module::num_params(x),B3(x)=>Module::num_params(x),B4(x)=>Module::num_params(x),B5(x)=>Module::num_params(x),B6(x)=>Module::num_params(x),B7(x)=>Module::num_params(x),B8(x)=>Module::num_params(x),F1(x)=>Module::num_params(x),F2(x)=>Module::num_params(x),F3(x)=>Module::num_params(x),F4(x)=>Module::num_params(x),F5(x)=>Module::num_params(x),F6(x)=>Module::num_params(x),F7(x)=>Module::num_params(x),F8(x)=>Module::num_params(x),I1(x)=>Module::num_params(x),I2(x)=>Module::num_params(x),I3(x)=>Module::num_params(x),I4(x)=>Module::num_params(x),I5(x)=>Module::num_params(x),I6(x)=>Module::num_params(x),I7(x)=>Module::num_params(x),I8(x)=>Module::num_params(x),Value::Incompatible(_e)=>0,Value::Multi(v)=>v.into_iter().map(|x|Module::num_params(x)).sum()}
945 }
946 fn quantize_weights(self,quantizer:&mut Quantizer)->Self{
947 match self{B1(x)=>B1(x.quantize_weights(quantizer)),B2(x)=>B2(x.quantize_weights(quantizer)),B3(x)=>B3(x.quantize_weights(quantizer)),B4(x)=>B4(x.quantize_weights(quantizer)),B5(x)=>B5(x.quantize_weights(quantizer)),B6(x)=>B6(x.quantize_weights(quantizer)),B7(x)=>B7(x.quantize_weights(quantizer)),B8(x)=>B8(x.quantize_weights(quantizer)),F1(x)=>F1(x.quantize_weights(quantizer)),F2(x)=>F2(x.quantize_weights(quantizer)),F3(x)=>F3(x.quantize_weights(quantizer)),F4(x)=>F4(x.quantize_weights(quantizer)),F5(x)=>F5(x.quantize_weights(quantizer)),F6(x)=>F6(x.quantize_weights(quantizer)),F7(x)=>F7(x.quantize_weights(quantizer)),F8(x)=>F8(x.quantize_weights(quantizer)),I1(x)=>I1(x.quantize_weights(quantizer)),I2(x)=>I2(x.quantize_weights(quantizer)),I3(x)=>I3(x.quantize_weights(quantizer)),I4(x)=>I4(x.quantize_weights(quantizer)),I5(x)=>I5(x.quantize_weights(quantizer)),I6(x)=>I6(x.quantize_weights(quantizer)),I7(x)=>I7(x.quantize_weights(quantizer)),I8(x)=>I8(x.quantize_weights(quantizer)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.quantize_weights(quantizer)).collect()}
948 }
949 fn save_file<F:FileRecorder<B>,P:Into<PathBuf>>(self,_filepath:P,_recorder:&F)->Result<(),RecorderError>{
950 Ok(())
951 }
952 fn to_device(self,device:&<B as Backend>::Device)->Self{
953 match self{B1(x)=>B1(x.to_device(device)),B2(x)=>B2(x.to_device(device)),B3(x)=>B3(x.to_device(device)),B4(x)=>B4(x.to_device(device)),B5(x)=>B5(x.to_device(device)),B6(x)=>B6(x.to_device(device)),B7(x)=>B7(x.to_device(device)),B8(x)=>B8(x.to_device(device)),F1(x)=>F1(x.to_device(device)),F2(x)=>F2(x.to_device(device)),F3(x)=>F3(x.to_device(device)),F4(x)=>F4(x.to_device(device)),F5(x)=>F5(x.to_device(device)),F6(x)=>F6(x.to_device(device)),F7(x)=>F7(x.to_device(device)),F8(x)=>F8(x.to_device(device)),I1(x)=>I1(x.to_device(device)),I2(x)=>I2(x.to_device(device)),I3(x)=>I3(x.to_device(device)),I4(x)=>I4(x.to_device(device)),I5(x)=>I5(x.to_device(device)),I6(x)=>I6(x.to_device(device)),I7(x)=>I7(x.to_device(device)),I8(x)=>I8(x.to_device(device)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.to_device(device)).collect()}
954 }
955 fn visit<Visitor:ModuleVisitor<B>>(&self,visitor:&mut Visitor){
956 match self{B1(x)=>x.visit(visitor),B2(x)=>x.visit(visitor),B3(x)=>x.visit(visitor),B4(x)=>x.visit(visitor),B5(x)=>x.visit(visitor),B6(x)=>x.visit(visitor),B7(x)=>x.visit(visitor),B8(x)=>x.visit(visitor),F1(x)=>x.visit(visitor),F2(x)=>x.visit(visitor),F3(x)=>x.visit(visitor),F4(x)=>x.visit(visitor),F5(x)=>x.visit(visitor),F6(x)=>x.visit(visitor),F7(x)=>x.visit(visitor),F8(x)=>x.visit(visitor),I1(x)=>x.visit(visitor),I2(x)=>x.visit(visitor),I3(x)=>x.visit(visitor),I4(x)=>x.visit(visitor),I5(x)=>x.visit(visitor),I6(x)=>x.visit(visitor),I7(x)=>x.visit(visitor),I8(x)=>x.visit(visitor),Value::Incompatible(_e)=>(),Value::Multi(v)=>v.iter().for_each(|x|x.visit(visitor))}
957 }
958 type Record=ConstantRecord;
959}
960impl<B:Backend> Serialize for Value<B>{
961 fn serialize<S:Serializer>(&self,serializer:S)->Result<S::Ok,S::Error>{ValueData::from(self.clone()).serialize(serializer)}
962}
963impl<B:Backend> Squeeze for Value<B>{
964 fn squeeze(self,d:i32)->Self{self.squeeze_dim(d)}
965 type Output=Self;
966}
967impl<B:Backend> Stack for Value<B>{
968 fn stack(self,d:i32)->Self{self.unsqueeze_dim(d).cat(d)}
970 type Output=Self;
971}
972impl<B:Backend> Unsqueeze for Value<B>{
973 fn unsqueeze(self,d:i32)->Self{self.unsqueeze_dim(d)}
974 type Output=Self;
975}
976impl<B:Backend> Value<B>{pub fn all(self)->Value<B>{
979 fn f<B:Backend,const N:usize>(x:Tensor<B,N,Bool>)->Value<B>{x.all().into()}
980 match self.bool(){B1(x)=>f(x),B2(x)=>f(x),B3(x)=>f(x),B4(x)=>f(x),B5(x)=>f(x),B6(x)=>f(x),B7(x)=>f(x),B8(x)=>f(x),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(Value::all).collect(),_=>panic!("internal error")}
981 }
982 pub fn all_dim(self,d:i32)->Value<B>{
984 fn f<B:Backend,const N:usize>(d:i32,x:Tensor<B,N,Bool>)->Value<B>{
985 if d>=N as i32||d<(-(N as i32)){return format!("rank {N} is too low to all along dimension {d}").into()}
986 let d=if d<0{N-((-d) as usize)}else{d as usize};
987 x.all_dim(d).into()
988 }
989 match self.bool(){B1(x)=>f(d,x),B2(x)=>f(d,x),B3(x)=>f(d,x),B4(x)=>f(d,x),B5(x)=>f(d,x),B6(x)=>f(d,x),B7(x)=>f(d,x),B8(x)=>f(d,x),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|v|v.all_dim(d)).collect(),_=>panic!("internal error")}
990 }
991 pub fn any(self)->Value<B>{
993 fn f<B:Backend,const N:usize>(x:Tensor<B,N,Bool>)->Value<B>{x.any().into()}
994 match self.bool(){B1(x)=>f(x),B2(x)=>f(x),B3(x)=>f(x),B4(x)=>f(x),B5(x)=>f(x),B6(x)=>f(x),B7(x)=>f(x),B8(x)=>f(x),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(Value::any).collect(),_=>panic!("internal error")}
995 }
996 pub fn any_dim(self,d:i32)->Value<B>{
998 fn f<B:Backend,const N:usize>(d:i32,x:Tensor<B,N,Bool>)->Value<B>{
999 if d>=N as i32||d<(-(N as i32)){return format!("rank {N} is too low to any along dimension {d}").into()}
1000 let d=if d<0{N-((-d) as usize)}else{d as usize};
1001 x.any_dim(d).into()
1002 }
1003 match self.bool(){B1(x)=>f(d,x),B2(x)=>f(d,x),B3(x)=>f(d,x),B4(x)=>f(d,x),B5(x)=>f(d,x),B6(x)=>f(d,x),B7(x)=>f(d,x),B8(x)=>f(d,x),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|v|v.any_dim(d)).collect(),_=>panic!("internal error")}
1004 }
1005 pub fn bool(self)->Value<B>{
1007 match self{B1(x)=>B1(x),B2(x)=>B2(x),B3(x)=>B3(x),B4(x)=>B4(x),B5(x)=>B5(x),B6(x)=>B6(x),B7(x)=>B7(x),B8(x)=>B8(x),F1(x)=>B1(x.bool()),F2(x)=>B2(x.bool()),F3(x)=>B3(x.bool()),F4(x)=>B4(x.bool()),F5(x)=>B5(x.bool()),F6(x)=>B6(x.bool()),F7(x)=>B7(x.bool()),F8(x)=>B8(x.bool()),I1(x)=>B1(x.bool()),I2(x)=>B2(x.bool()),I3(x)=>B3(x.bool()),I4(x)=>B4(x.bool()),I5(x)=>B5(x.bool()),I6(x)=>B6(x.bool()),I7(x)=>B7(x.bool()),I8(x)=>B8(x.bool()),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(Value::bool).collect())}
1008 }
1009 pub fn count(&self)->usize{
1011 match self{
1012 B1(x)=>x.dims().iter().product(),
1013 B2(x)=>x.dims().iter().product(),
1014 B3(x)=>x.dims().iter().product(),
1015 B4(x)=>x.dims().iter().product(),
1016 B5(x)=>x.dims().iter().product(),
1017 B6(x)=>x.dims().iter().product(),
1018 B7(x)=>x.dims().iter().product(),
1019 B8(x)=>x.dims().iter().product(),
1020 F1(x)=>x.dims().iter().product(),
1021 F2(x)=>x.dims().iter().product(),
1022 F3(x)=>x.dims().iter().product(),
1023 F4(x)=>x.dims().iter().product(),
1024 F5(x)=>x.dims().iter().product(),
1025 F6(x)=>x.dims().iter().product(),
1026 F7(x)=>x.dims().iter().product(),
1027 F8(x)=>x.dims().iter().product(),
1028 I1(x)=>x.dims().iter().product(),
1029 I2(x)=>x.dims().iter().product(),
1030 I3(x)=>x.dims().iter().product(),
1031 I4(x)=>x.dims().iter().product(),
1032 I5(x)=>x.dims().iter().product(),
1033 I6(x)=>x.dims().iter().product(),
1034 I7(x)=>x.dims().iter().product(),
1035 I8(x)=>x.dims().iter().product(),
1036 Value::Incompatible(_e)=>0,
1037 Value::Multi(v)=>v.iter().map(Value::count).sum()
1038 }
1039 }
1040 pub fn empty()->Self{Self::Multi(Vec::new())}
1042 pub fn flatten_values(self)->Self{
1046 fn f<B:Backend>(mut acc:Vec<Value<B>>,x:Value<B>)->Vec<Value<B>>{
1047 if x.is_multi(){acc=x.into_iter().fold(acc,|acc,x|f(acc,x))}else{acc.push(x)}
1048 acc
1049 }
1050 f(Vec::new(),self).into()
1051 }
1052 pub fn float(self)->Value<B>{
1054 match self{B1(x)=>F1(x.float()),B2(x)=>F2(x.float()),B3(x)=>F3(x.float()),B4(x)=>F4(x.float()),B5(x)=>F5(x.float()),B6(x)=>F6(x.float()),B7(x)=>F7(x.float()),B8(x)=>F8(x.float()),F1(x)=>F1(x),F2(x)=>F2(x),F3(x)=>F3(x),F4(x)=>F4(x),F5(x)=>F5(x),F6(x)=>F6(x),F7(x)=>F7(x),F8(x)=>F8(x),I1(x)=>F1(x.float()),I2(x)=>F2(x.float()),I3(x)=>F3(x.float()),I4(x)=>F4(x.float()),I5(x)=>F5(x.float()),I6(x)=>F6(x.float()),I7(x)=>F7(x.float()),I8(x)=>F8(x.float()),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(Value::float).collect())}
1055 }
1056 pub fn from_values<I:IntoIterator<Item=Self>,S:Into<Shape>>(inner:I,shape:S)->Self{
1058 fn f<B:Backend,I:Iterator<Item=Value<B>>,S:Into<Shape>>(inner:&mut I,shape:S)->Value<B>{
1059 match shape.into(){
1060 Shape::Incompatible(e)=>e.into(),
1061 Shape::Multi(_l)=>inner.collect(),
1062 Shape::Recursive(v)=>v.into_iter().map(|s|f(&mut *inner,s)).collect(),
1063 X1(_s)=>inner.next().unwrap_or_default(),
1064 X2(_s)=>inner.next().unwrap_or_default(),
1065 X3(_s)=>inner.next().unwrap_or_default(),
1066 X4(_s)=>inner.next().unwrap_or_default(),
1067 X5(_s)=>inner.next().unwrap_or_default(),
1068 X6(_s)=>inner.next().unwrap_or_default(),
1069 X7(_s)=>inner.next().unwrap_or_default(),
1070 X8(_s)=>inner.next().unwrap_or_default(),
1071 }
1072 }
1073 f(&mut inner.into_iter(),shape)
1074 }
1075 pub fn gather(self,dim:i32,indices:Value<B>)->Self{
1077 fn b<B:Backend,const N:usize>(d:i32,data:Tensor<B,N,Bool>,indices:Tensor<B,N,Int>)->Value<B>{f(d,data.int(),indices).bool()}
1078 fn f<B:Backend,K:'static+BasicOps<B>+Numeric<B>+TensorKind<B>,const N:usize>(d:i32,data:Tensor<B,N,K>,indices:Tensor<B,N,Int>)->Value<B>{
1079 let d=if d<0{N-((-d) as usize)}else{d as usize};
1080 if d>=N{format!("dim {d} must be less than rank {N}").into()}else{data.gather(d,indices).into()}
1081 }
1082
1083 match (self,indices){(B1(x),I1(i))=>b(dim,x,i),(B2(x),I2(i))=>b(dim,x,i),(B3(x),I3(i))=>b(dim,x,i),(B4(x),I4(i))=>b(dim,x,i),(B5(x),I5(i))=>b(dim,x,i),(B6(x),I6(i))=>b(dim,x,i),(B7(x),I7(i))=>b(dim,x,i),(B8(x),I8(i))=>b(dim,x,i),(F1(x),I1(i))=>f(dim,x,i),(F2(x),I2(i))=>f(dim,x,i),(F3(x),I3(i))=>f(dim,x,i),(F4(x),I4(i))=>f(dim,x,i),(F5(x),I5(i))=>f(dim,x,i),(F6(x),I6(i))=>f(dim,x,i),(F7(x),I7(i))=>f(dim,x,i),(F8(x),I8(i))=>f(dim,x,i),(I1(x),I1(i))=>f(dim,x,i),(I2(x),I2(i))=>f(dim,x,i),(I3(x),I3(i))=>f(dim,x,i),(I4(x),I4(i))=>f(dim,x,i),(I5(x),I5(i))=>f(dim,x,i),(I6(x),I6(i))=>f(dim,x,i),(I7(x),I7(i))=>f(dim,x,i),(I8(x),I8(i))=>f(dim,x,i),(Value::Incompatible(e),_)=>e.into(),(_,Value::Incompatible(e))=>e.into(),(Value::Multi(u),Value::Multi(v))=>u.into_iter().zip(v).map(|(u,v)|u.gather(dim,v)).collect(),_=>"gather is only available for tensors of matching dimensions with int indices".into()}
1084 }
1085 pub fn int(self)->Value<B>{
1087 match self{B1(x)=>I1(x.int()),B2(x)=>I2(x.int()),B3(x)=>I3(x.int()),B4(x)=>I4(x.int()),B5(x)=>I5(x.int()),B6(x)=>I6(x.int()),B7(x)=>I7(x.int()),B8(x)=>I8(x.int()),F1(x)=>I1(x.int()),F2(x)=>I2(x.int()),F3(x)=>I3(x.int()),F4(x)=>I4(x.int()),F5(x)=>I5(x.int()),F6(x)=>I6(x.int()),F7(x)=>I7(x.int()),F8(x)=>I8(x.int()),I1(x)=>I1(x),I2(x)=>I2(x),I3(x)=>I3(x),I4(x)=>I4(x),I5(x)=>I5(x),I6(x)=>I6(x),I7(x)=>I7(x),I8(x)=>I8(x),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(Value::int).collect())}
1088 }
1089 pub fn into_float_vec(self)->Vec<f32>{
1091 fn cat_vec<T>(mut a:Vec<T>,b:Vec<T>)->Vec<T>{
1092 a.extend(b);
1093 a
1094 }
1095 fn to_vec<B:Backend,const N:usize>(x:Tensor<B,N>)->Vec<f32>{x.into_data().to_vec().unwrap_or_default()}
1096
1097 match self.float(){F1(x)=>to_vec(x),F2(x)=>to_vec(x),F3(x)=>to_vec(x),F4(x)=>to_vec(x),F5(x)=>to_vec(x),F6(x)=>to_vec(x),F7(x)=>to_vec(x),F8(x)=>to_vec(x),Value::Incompatible(_e)=>Vec::new(),Value::Multi(v)=>v.into_iter().map(Value::into_float_vec).reduce(cat_vec).unwrap_or_default(),_=>panic!("internal error")}
1098 }
1099 pub fn into_int_vec(self)->Vec<i32>{
1101 fn cat_vec<T>(mut a:Vec<T>,b:Vec<T>)->Vec<T>{
1102 a.extend(b);
1103 a
1104 }
1105 fn to_vec<B:Backend,const N:usize>(x:Tensor<B,N,Int>)->Vec<i32>{x.into_data().to_vec().unwrap_or_default()}
1106
1107 match self.int(){I1(x)=>to_vec(x),I2(x)=>to_vec(x),I3(x)=>to_vec(x),I4(x)=>to_vec(x),I5(x)=>to_vec(x),I6(x)=>to_vec(x),I7(x)=>to_vec(x),I8(x)=>to_vec(x),Value::Incompatible(_e)=>Vec::new(),Value::Multi(v)=>v.into_iter().map(Value::into_int_vec).reduce(cat_vec).unwrap_or_default(),_=>panic!("internal error")}
1108 }
1109 pub fn height(&self)->usize{
1111 if let Value::Multi(v)=self{v.iter().map(Value::height).max()}else{None}.unwrap_or(0)
1112 }
1113 pub fn is_empty(&self)->bool{self.len()==0}
1115 pub fn is_incompatible(&self)->bool{
1117 if let Value::Incompatible(_x)=self{true}else{false}
1118 }
1119 pub fn into_multi(self)->Vec<Value<B>>{
1121 if let Value::Multi(v)=self{v}else{vec![self]}
1122 }
1123 pub fn is_multi(&self)->bool{
1125 if let Value::Multi(_v)=self{true}else{false}
1126 }
1127 pub fn iter(&self)->SliceIter<'_,Self>{
1129 if let Value::Multi(v)=self{v.iter()}else{slice::from_ref(self).iter()}
1130 }
1131 pub fn kind(&self)->Kind{
1133 match self{B1(_x)=>Kind::Bool,B2(_x)=>Kind::Bool,B3(_x)=>Kind::Bool,B4(_x)=>Kind::Bool,B5(_x)=>Kind::Bool,B6(_x)=>Kind::Bool,B7(_x)=>Kind::Bool,B8(_x)=>Kind::Bool,F1(_x)=>Kind::Float,F2(_x)=>Kind::Float,F3(_x)=>Kind::Float,F4(_x)=>Kind::Float,F5(_x)=>Kind::Float,F6(_x)=>Kind::Float,F7(_x)=>Kind::Float,F8(_x)=>Kind::Float,I1(_x)=>Kind::Int,I2(_x)=>Kind::Int,I3(_x)=>Kind::Int,I4(_x)=>Kind::Int,I5(_x)=>Kind::Int,I6(_x)=>Kind::Int,I7(_x)=>Kind::Int,I8(_x)=>Kind::Int,Value::Incompatible(_v)=>Kind::Incompatible,Value::Multi(_v)=>Kind::Multi}
1134 }
1135 pub fn len(&self)->usize{
1137 if let Value::Multi(v)=self{v.len()}else{1}
1138 }
1139 pub fn len_recursive(&self)->usize{
1141 if let Value::Multi(v)=self{v.iter().map(Value::len_recursive).sum()}else{1}
1142 }
1143 pub fn map_multi<F:FnMut(Value<B>)->Value<B>>(self,depth:usize,mut f:F)->Self{
1145 fn recur<B:Backend,F:FnMut(Value<B>)->Value<B>>(depth:usize,f:&mut F,x:Value<B>)->(Value<B>,usize){
1146 if x.is_empty(){
1147 (if depth==0{f(x)}else{x},0)
1148 }else if x.is_multi(){
1149 let mut height=0;
1150 let y=x.into_iter().map(|x|{
1151 let (y,h)=recur(depth,f,x);
1152 height=height.max(h);
1153 y
1154 }).collect();
1155 height+=1;
1156 (if depth==height{f(y)}else{y},height)
1157 }else{
1158 (if depth==0{f(x)}else{x},0)
1159 }
1160 }
1161 recur(depth,&mut f,self).0
1162 }
1163 pub fn map_values<F:FnMut(Value<B>)->Value<B>>(self,mut f:F)->Self{
1165 fn recur<B:Backend,F:FnMut(Value<B>)->Value<B>>(f:&mut F,x:Value<B>)->Value<B>{
1166 if let Value::Multi(v)=x{v.into_iter().map(|x|recur(f,x)).collect()}else{f(x)}
1167 }
1168 recur(&mut f,self)
1169 }
1170 pub fn mask_fill(self,mask:Value<B>,v:f32)->Self{
1172 let (x,mask)=self.promote_rank(mask.bool());
1173 match (x,mask){
1174 (B1(x),B1(m))=>B1(x.int().mask_fill(m,v).bool()),
1175 (B2(x),B2(m))=>B2(x.int().mask_fill(m,v).bool()),
1176 (B3(x),B3(m))=>B3(x.int().mask_fill(m,v).bool()),
1177 (B4(x),B4(m))=>B4(x.int().mask_fill(m,v).bool()),
1178 (B5(x),B5(m))=>B5(x.int().mask_fill(m,v).bool()),
1179 (B6(x),B6(m))=>B6(x.int().mask_fill(m,v).bool()),
1180 (B7(x),B7(m))=>B7(x.int().mask_fill(m,v).bool()),
1181 (B8(x),B8(m))=>B8(x.int().mask_fill(m,v).bool()),
1182 (F1(x),B1(m))=>F1(x.mask_fill(m,v)),
1183 (F2(x),B2(m))=>F2(x.mask_fill(m,v)),
1184 (F3(x),B3(m))=>F3(x.mask_fill(m,v)),
1185 (F4(x),B4(m))=>F4(x.mask_fill(m,v)),
1186 (F5(x),B5(m))=>F5(x.mask_fill(m,v)),
1187 (F6(x),B6(m))=>F6(x.mask_fill(m,v)),
1188 (F7(x),B7(m))=>F7(x.mask_fill(m,v)),
1189 (F8(x),B8(m))=>F8(x.mask_fill(m,v)),
1190 (I1(x),B1(m))=>I1(x.mask_fill(m,v)),
1191 (I2(x),B2(m))=>I2(x.mask_fill(m,v)),
1192 (I3(x),B3(m))=>I3(x.mask_fill(m,v)),
1193 (I4(x),B4(m))=>I4(x.mask_fill(m,v)),
1194 (I5(x),B5(m))=>I5(x.mask_fill(m,v)),
1195 (I6(x),B6(m))=>I6(x.mask_fill(m,v)),
1196 (I7(x),B7(m))=>I7(x.mask_fill(m,v)),
1197 (I8(x),B8(m))=>I8(x.mask_fill(m,v)),
1198 (Value::Incompatible(e),_)=>e.into(),
1199 (_,Value::Incompatible(e))=>e.into(),
1200 (Value::Multi(x),m)=>broadcast_multi(x,m.into_multi(),|x,m|x.mask_fill(m,v)),
1201 (x,Value::Multi(m))=>broadcast_multi(x.into_multi(),m,|x,m|x.mask_fill(m,v)),
1202 _=>panic!("internal error")
1203 }
1204 }
1205 pub fn multi(self)->Self{
1207 if let Value::Multi(v)=self{v.into()}else{vec![self].into()}
1208 }
1209 pub fn new<S:Into<Shape>>(data:&[f32],device:&B::Device,shape:S)->Self{
1211 match shape.into(){
1212 Shape::Incompatible(e)=>e.into(),
1213 Shape::Multi(l)=>data.chunks(l).map(|d|Value::new(d,device,X1([d.len()]))).collect(),
1214 Shape::Recursive(v)=>v.into_iter().scan(data,|data,s|{
1215 let v=Value::new(*data,device,s);
1216 *data=&data[..v.count()];
1217 Some(v)
1218 }).collect(),
1219 X1(s)=>F1(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1220 X2(s)=>F2(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1221 X3(s)=>F3(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1222 X4(s)=>F4(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1223 X5(s)=>F5(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1224 X6(s)=>F6(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1225 X7(s)=>F7(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1226 X8(s)=>F8(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1227 }
1228 }
1229 pub fn promote(self,rhs:Value<B>)->(Value<B>,Value<B>){
1231 let (l,r)=self.promote_kind(rhs);
1232 l.promote_rank(r)
1233 }
1234 pub fn promote_kind(self,rhs:Value<B>)->(Value<B>,Value<B>){
1236 let (lk,rk)=(self.kind(),rhs.kind());
1237
1238 let (mut l,mut r)=(self,rhs);
1239 if lk==rk{()}else if lk==Kind::Multi{r=r.multi()}else if rk==Kind::Multi{l=l.multi()}else if lk==Kind::Float{r=r.float()}else if rk==Kind::Float{l=l.float()}else if lk==Kind::Int{r=r.int()}else if rk==Kind::Int{l=l.int()}else if lk==Kind::Incompatible{return (l,r)}else if rk==Kind::Incompatible{return (l,r)}
1240 (l,r)
1241 }
1242 pub fn promote_rank(self,rhs:Value<B>)->(Value<B>,Value<B>){
1244 let (mut l,mut r)=(self,rhs);
1245 let (mut lr,mut rr)=if let (Some(l),Some(r))=(l.rank(),r.rank()){(l,r)}else{return (l,r)};
1246 while lr<rr{
1247 l=l.unsqueeze();
1248 lr+=1;
1249 }
1250 while lr>rr{
1251 r=r.unsqueeze();
1252 rr+=1;
1253 }
1254 (l,r)
1255 }
1256 pub fn rank(&self)->Option<usize>{
1258 match self{B1(_x)=>Some(1),B2(_x)=>Some(2),B3(_x)=>Some(3),B4(_x)=>Some(4),B5(_x)=>Some(5),B6(_x)=>Some(6),B7(_x)=>Some(7),B8(_x)=>Some(8),F1(_x)=>Some(1),F2(_x)=>Some(2),F3(_x)=>Some(3),F4(_x)=>Some(4),F5(_x)=>Some(5),F6(_x)=>Some(6),F7(_x)=>Some(7),F8(_x)=>Some(8),I1(_x)=>Some(1),I2(_x)=>Some(2),I3(_x)=>Some(3),I4(_x)=>Some(4),I5(_x)=>Some(5),I6(_x)=>Some(6),I7(_x)=>Some(7),I8(_x)=>Some(8),Value::Incompatible(_x)=>None,Value::Multi(_x)=>None}
1259 }
1260 pub fn scatter(self,dim:i32,indices:Value<B>,values:Value<B>)->Self{
1262 fn f<B:Backend,K:'static+BasicOps<B>+Numeric<B>+TensorKind<B>,const N:usize>(d:i32,data:Tensor<B,N,K>,indices:Tensor<B,N,Int>,values:Tensor<B,N,K>)->Value<B>{
1264 let d=if d<0{N-((-d) as usize)}else{d as usize};
1265 if d>=N{format!("dim {d} must be less than rank {N}").into()}else{data.scatter(d,indices,values).into()}
1266 }
1267
1268 match (self,indices,values){
1269 (B1(x),I1(i),B1(y))=>f(dim,x.int(),i,y.int()),
1270 (B2(x),I2(i),B2(y))=>f(dim,x.int(),i,y.int()),
1271 (B3(x),I3(i),B3(y))=>f(dim,x.int(),i,y.int()),
1272 (B4(x),I4(i),B4(y))=>f(dim,x.int(),i,y.int()),
1273 (B5(x),I5(i),B5(y))=>f(dim,x.int(),i,y.int()),
1274 (B6(x),I6(i),B6(y))=>f(dim,x.int(),i,y.int()),
1275 (B7(x),I7(i),B7(y))=>f(dim,x.int(),i,y.int()),
1276 (B8(x),I8(i),B8(y))=>f(dim,x.int(),i,y.int()),
1277 (F1(x),I1(i),F1(y))=>f(dim,x,i,y),
1278 (F2(x),I2(i),F2(y))=>f(dim,x,i,y),
1279 (F3(x),I3(i),F3(y))=>f(dim,x,i,y),
1280 (F4(x),I4(i),F4(y))=>f(dim,x,i,y),
1281 (F5(x),I5(i),F5(y))=>f(dim,x,i,y),
1282 (F6(x),I6(i),F6(y))=>f(dim,x,i,y),
1283 (F7(x),I7(i),F7(y))=>f(dim,x,i,y),
1284 (F8(x),I8(i),F8(y))=>f(dim,x,i,y),
1285 (I1(x),I1(i),I1(y))=>f(dim,x,i,y),
1286 (I2(x),I2(i),I2(y))=>f(dim,x,i,y),
1287 (I3(x),I3(i),I3(y))=>f(dim,x,i,y),
1288 (I4(x),I4(i),I4(y))=>f(dim,x,i,y),
1289 (I5(x),I5(i),I5(y))=>f(dim,x,i,y),
1290 (I6(x),I6(i),I6(y))=>f(dim,x,i,y),
1291 (I7(x),I7(i),I7(y))=>f(dim,x,i,y),
1292 (I8(x),I8(i),I8(y))=>f(dim,x,i,y),
1293 (Value::Incompatible(e),_,_)=>e.into(),
1294 (_,Value::Incompatible(e),_)=>e.into(),
1295 (_,_,Value::Incompatible(e))=>e.into(),
1296 (Value::Multi(u),Value::Multi(v),Value::Multi(y))=>u.into_iter().zip(v).zip(y).map(|((u,v),y)|u.scatter(dim,v,y)).collect(),
1297 _=>"scatter is only available for tensors of matching dimensions with int indices".into()
1298 }
1299 }
1300 pub fn shape(&self)->Shape{
1302 match self{B1(x)=>Shape::X1(x.dims()),B2(x)=>Shape::X2(x.dims()),B3(x)=>Shape::X3(x.dims()),B4(x)=>Shape::X4(x.dims()),B5(x)=>Shape::X5(x.dims()),B6(x)=>Shape::X6(x.dims()),B7(x)=>Shape::X7(x.dims()),B8(x)=>Shape::X8(x.dims()),F1(x)=>Shape::X1(x.dims()),F2(x)=>Shape::X2(x.dims()),F3(x)=>Shape::X3(x.dims()),F4(x)=>Shape::X4(x.dims()),F5(x)=>Shape::X5(x.dims()),F6(x)=>Shape::X6(x.dims()),F7(x)=>Shape::X7(x.dims()),F8(x)=>Shape::X8(x.dims()),I1(x)=>Shape::X1(x.dims()),I2(x)=>Shape::X2(x.dims()),I3(x)=>Shape::X3(x.dims()),I4(x)=>Shape::X4(x.dims()),I5(x)=>Shape::X5(x.dims()),I6(x)=>Shape::X6(x.dims()),I7(x)=>Shape::X7(x.dims()),I8(x)=>Shape::X8(x.dims()),Value::Incompatible(x)=>Shape::Incompatible(x.clone()),Value::Multi(x)=>Shape::Multi(x.len())}
1303 }
1304 pub fn shape_recursive(&self)->Shape{
1306 if let Value::Multi(x)=self{Shape::Recursive(x.iter().map(Value::shape_recursive).collect())}else{self.shape()}
1307 }
1308 pub fn shift(self,d:i32,n:i32,v:f32)->Self{
1310 fn b<B:Backend,const N:usize>(d:i32,n:i32,v:f32,x:Tensor<B,N,Bool>)->Value<B>{f(d,n,if v==0.0{0.0}else{1.0},x.int()).bool()}
1311 fn f<B:Backend,K:'static+BasicOps<B>+Numeric<B>+TensorKind<B>,const N:usize>(d:i32,n:i32,v:f32,x:Tensor<B,N,K>)->Value<B>{
1312 let device=x.device();
1313 let d=if d<0{N-((-d) as usize)}else{d as usize};
1314 let mut paddims=x.dims();
1315 let mut slicedims=paddims.map(|n|0..n);
1316
1317 paddims[d]=n.abs() as usize;
1318 slicedims[d]=if n<0{(-n) as usize..slicedims[d].end}else{0..slicedims[d].end.saturating_sub(n as usize)};
1319 if slicedims[d].len()==0{return x.full_like(v).into()}
1320 let pad:Tensor<B,N,K>=Tensor::full(paddims,v,&device);
1321 let slice=x.slice(slicedims);
1322
1323 Tensor::cat(if n<0{vec![slice,pad]}else{vec![pad,slice]},d).into()
1324 }
1325 if n==0{return self}
1326
1327 match self{B1(x)=>b(d,n,v,x),B2(x)=>b(d,n,v,x),B3(x)=>b(d,n,v,x),B4(x)=>b(d,n,v,x),B5(x)=>b(d,n,v,x),B6(x)=>b(d,n,v,x),B7(x)=>b(d,n,v,x),B8(x)=>b(d,n,v,x),F1(x)=>f(d,n,v,x),F2(x)=>f(d,n,v,x),F3(x)=>f(d,n,v,x),F4(x)=>f(d,n,v,x),F5(x)=>f(d,n,v,x),F6(x)=>f(d,n,v,x),F7(x)=>f(d,n,v,x),F8(x)=>f(d,n,v,x),I1(x)=>f(d,n,v,x),I2(x)=>f(d,n,v,x),I3(x)=>f(d,n,v,x),I4(x)=>f(d,n,v,x),I5(x)=>f(d,n,v,x),I6(x)=>f(d,n,v,x),I7(x)=>f(d,n,v,x),I8(x)=>f(d,n,v,x),Value::Incompatible(e)=>e.into(),Value::Multi(x)=>x.into_iter().map(|x|x.shift(d,n,v)).collect()}
1328 }
1329 pub fn slice<A:AsRef<[R]>,R:RangeBounds<usize>>(self,ranges:A)->Self{let ranges=ranges.as_ref();
1332 let len=ranges.len();
1333 if let Value::Incompatible(x)=self{return x.into()}
1334 let rank=self.rank().unwrap_or(len);
1335 let shape=self.shape();
1336
1337 let mut normalizedranges=[0;8].map(|_|0..0);
1338 for ((d,n),r) in shape.clone().to_array(Alignment::Left).into_iter().zip(normalizedranges.iter_mut()).zip(ranges){
1339 n.start=match r.start_bound(){Excluded(&x)=>x+1,Included(&x)=>x,Unbounded=>0};
1340 n.end=match r.end_bound(){Excluded(&x)=>x,Included(&x)=>x+1,Unbounded=>d};
1341 }
1342 if len>rank{return format!("Length of ranges argument must be less than the the value's rank. len: {len} ranges: {normalizedranges:?} rank: {rank} shape: {shape:?}").into()}
1343 for (d,n) in shape.clone().to_array(Alignment::Left).into_iter().zip(normalizedranges.iter()).take(len){
1344 if n.start>=n.end{return format!("Empty or reverse ranges are currently not supported. ranges: {normalizedranges:?}").into()}
1345 if d<n.end{return format!("Cannot index beyond the end of a dimension. ranges: {normalizedranges:?} shape: {shape:?}").into()}
1346 }
1347 let ranges=&normalizedranges[..len];
1348
1349 match self{B1(x)=>B1(slice_slice(ranges,x)),B2(x)=>B2(slice_slice(ranges,x)),B3(x)=>B3(slice_slice(ranges,x)),B4(x)=>B4(slice_slice(ranges,x)),B5(x)=>B5(slice_slice(ranges,x)),B6(x)=>B6(slice_slice(ranges,x)),B7(x)=>B7(slice_slice(ranges,x)),B8(x)=>B8(slice_slice(ranges,x)),F1(x)=>F1(slice_slice(ranges,x)),F2(x)=>F2(slice_slice(ranges,x)),F3(x)=>F3(slice_slice(ranges,x)),F4(x)=>F4(slice_slice(ranges,x)),F5(x)=>F5(slice_slice(ranges,x)),F6(x)=>F6(slice_slice(ranges,x)),F7(x)=>F7(slice_slice(ranges,x)),F8(x)=>F8(slice_slice(ranges,x)),I1(x)=>I1(slice_slice(ranges,x)),I2(x)=>I2(slice_slice(ranges,x)),I3(x)=>I3(slice_slice(ranges,x)),I4(x)=>I4(slice_slice(ranges,x)),I5(x)=>I5(slice_slice(ranges,x)),I6(x)=>I6(slice_slice(ranges,x)),I7(x)=>I7(slice_slice(ranges,x)),I8(x)=>I8(slice_slice(ranges,x)),Value::Incompatible(x)=>x.into(),Value::Multi(x)=>Value::Multi(x.into_iter().map(|x|x.slice(ranges)).collect())}
1350 }
1351 pub fn split<I:Into<Option<i32>>>(self,chunksize:usize,dim:I)->Self{
1353 fn f<B:Backend,K:'static+BasicOps<B>+TensorKind<B>,const N:usize>(dim:i32,size:usize,tensor:Tensor<B,N,K>)->Value<B>{
1354 if dim>=N as i32||dim<(-(N as i32)){return format!("rank {N} is too low to split along dimension {dim}").into()}
1355 let dim=if dim<0{N-((-dim) as usize)}else{dim as usize};
1356
1357 tensor.split(size,dim).into_iter().map(Value::from).collect()
1358 }
1359 let c=if chunksize==0{return "cannot split into chunks of 0 size".into()}else{chunksize};
1360
1361 if let Some(d)=dim.into(){
1362 match self{B1(x)=>f(d,c,x),B2(x)=>f(d,c,x),B3(x)=>f(d,c,x),B4(x)=>f(d,c,x),B5(x)=>f(d,c,x),B6(x)=>f(d,c,x),B7(x)=>f(d,c,x),B8(x)=>f(d,c,x),F1(x)=>f(d,c,x),F2(x)=>f(d,c,x),F3(x)=>f(d,c,x),F4(x)=>f(d,c,x),F5(x)=>f(d,c,x),F6(x)=>f(d,c,x),F7(x)=>f(d,c,x),F8(x)=>f(d,c,x),I1(x)=>f(d,c,x),I2(x)=>f(d,c,x),I3(x)=>f(d,c,x),I4(x)=>f(d,c,x),I5(x)=>f(d,c,x),I6(x)=>f(d,c,x),I7(x)=>f(d,c,x),I8(x)=>f(d,c,x),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.split(c,Some(d))).collect()}
1363 }else{
1364 let v=self.into_multi();
1365 v.chunks(chunksize).map(|c|Value::from(c.to_vec())).collect()
1366 }
1367 }
1368 pub fn squeeze(self)->Self{self.squeeze_dim(0)}
1370 pub fn squeeze_dim(self,d:i32)->Self{
1372 fn f<B:Backend,K:BasicOps<B>+TensorKind<B>,const D:usize,const N:usize>(x:Tensor<B,D,K>,d:i32)->Result<Tensor<B,N,K>,String>{
1373 let d=if d<0{D-((-d) as usize)}else{d as usize};
1374
1375 if d>=D{return Err(format!("dim {d} must be less than {D}"))}
1376 let xdim=x.dims()[d];
1377
1378 if xdim==1{Ok(x.squeeze_dim(d))}else{Err(format!("cannot squeeze a dim of size not equal to 1. dim {d} was {xdim}"))}
1379 }
1380 match match self{B1(_x)=>Err("currently cannot decrease the number of tensor dimensions below 1".into()),B2(x)=>f(x,d).map(B1),B3(x)=>f(x,d).map(B2),B4(x)=>f(x,d).map(B3),B5(x)=>f(x,d).map(B4),B6(x)=>f(x,d).map(B5),B7(x)=>f(x,d).map(B6),B8(x)=>f(x,d).map(B7),F1(_x)=>Err("currently cannot decrease the number of tensor dimensions below 1".into()),F2(x)=>f(x,d).map(F1),F3(x)=>f(x,d).map(F2),F4(x)=>f(x,d).map(F3),F5(x)=>f(x,d).map(F4),F6(x)=>f(x,d).map(F5),F7(x)=>f(x,d).map(F6),F8(x)=>f(x,d).map(F7),I1(_x)=>Err("currently cannot decrease the number of tensor dimensions below 1".into()),I2(x)=>f(x,d).map(I1),I3(x)=>f(x,d).map(I2),I4(x)=>f(x,d).map(I3),I5(x)=>f(x,d).map(I4),I6(x)=>f(x,d).map(I5),I7(x)=>f(x,d).map(I6),I8(x)=>f(x,d).map(I7),Value::Incompatible(e)=>Err(e),Value::Multi(v)=>Ok(v.into_iter().map(|x|x.squeeze_dim(d)).collect())}{Err(e)=>e.into(),Ok(x)=>x}
1381 }
1382 pub fn try_incompatible(self)->Result<String,Self>{
1384 if let Value::Incompatible(x)=self{Ok(x)}else{Err(self)}
1385 }
1386 pub fn try_multi(self)->Result<Vec<Value<B>>,Self>{
1388 if let Value::Multi(v)=self{Ok(v)}else{Err(self)}
1389 }
1390 pub fn unsqueeze(self)->Self{self.unsqueeze_dim(0)}
1392 pub fn unsqueeze_dim(self,d:i32)->Self{
1394 fn f<B:Backend,K:BasicOps<B>+TensorKind<B>,const D:usize,const N:usize>(x:Tensor<B,D,K>,d:i32)->Tensor<B,N,K>{
1395 x.unsqueeze_dim(if d<0{D-((-d) as usize)+1}else{d as usize})
1396 }
1397 if let Some(r)=self.rank(){
1398 let e=if d<0{r-((-d) as usize)+1}else{d as usize};
1399 if e>r{return format!("dim {e} must be less than or equal to rank {r}").into()}
1400 }
1401 match self{B1(x)=>B2(f(x,d)),B2(x)=>B3(f(x,d)),B3(x)=>B4(f(x,d)),B4(x)=>B5(f(x,d)),B5(x)=>B6(f(x,d)),B6(x)=>B7(f(x,d)),B7(x)=>B8(f(x,d)),B8(_x)=>"currently can't increase number of tensor dimensions above 8".into(),F1(x)=>F2(f(x,d)),F2(x)=>F3(f(x,d)),F3(x)=>F4(f(x,d)),F4(x)=>F5(f(x,d)),F5(x)=>F6(f(x,d)),F6(x)=>F7(f(x,d)),F7(x)=>F8(f(x,d)),F8(_x)=>"currently can't increase number of tensor dimensions above 8".into(),I1(x)=>I2(f(x,d)),I2(x)=>I3(f(x,d)),I3(x)=>I4(f(x,d)),I4(x)=>I5(f(x,d)),I5(x)=>I6(f(x,d)),I6(x)=>I7(f(x,d)),I7(x)=>I8(f(x,d)),I8(_x)=>"currently can't increase number of tensor dimensions above 8".into(),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.unsqueeze_dim(d)).collect()}
1402 }
1403 #[track_caller]
1404 pub fn unwrap_incompatible(self)->String{self.try_incompatible().unwrap()}
1406 #[track_caller]
1407 pub fn unwrap_multi(self)->Vec<Value<B>>{self.try_multi().unwrap()}
1409 pub fn zeros_like(&self)->Value<B>{match self{B1(x)=>B1(x.clone().int().zeros_like().bool()),B2(x)=>B2(x.clone().int().zeros_like().bool()),B3(x)=>B3(x.clone().int().zeros_like().bool()),B4(x)=>B4(x.clone().int().zeros_like().bool()),B5(x)=>B5(x.clone().int().zeros_like().bool()),B6(x)=>B6(x.clone().int().zeros_like().bool()),B7(x)=>B7(x.clone().int().zeros_like().bool()),B8(x)=>B8(x.clone().int().zeros_like().bool()),F1(x)=>F1(x.zeros_like()),F2(x)=>F2(x.zeros_like()),F3(x)=>F3(x.zeros_like()),F4(x)=>F4(x.zeros_like()),F5(x)=>F5(x.zeros_like()),F6(x)=>F6(x.zeros_like()),F7(x)=>F7(x.zeros_like()),F8(x)=>F8(x.zeros_like()),I1(x)=>I1(x.zeros_like()),I2(x)=>I2(x.zeros_like()),I3(x)=>I3(x.zeros_like()),I4(x)=>I4(x.zeros_like()),I5(x)=>I5(x.zeros_like()),I6(x)=>I6(x.zeros_like()),I7(x)=>I7(x.zeros_like()),I8(x)=>I8(x.zeros_like()),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.iter().map(Value::zeros_like).collect()}
1412 }
1413 pub fn zip(self)->Self{fn build<B:Backend>(input:Vec<Value<B>>)->Value<B>{
1416 if input.len()==0{return input.into()}
1417 if !input.iter().all(Value::is_multi){return input.into()}
1418
1419 let len=input.iter().map(Value::len).max().unwrap();
1420 let mut iters:Vec<_>=input.into_iter().map(Value::into_iter).collect();
1421 (0..len).map(|_|build(iters.iter_mut().map(|v|v.next().unwrap_or_default()).collect())).collect()
1422 }
1423 match self{Value::Multi(v)=>build(v),x=>x}
1424 }
1425 try_unwrap!(Tensor<B,1,Bool>,try_b1,unwrap_b1);
1426 try_unwrap!(Tensor<B,2,Bool>,try_b2,unwrap_b2);
1427 try_unwrap!(Tensor<B,3,Bool>,try_b3,unwrap_b3);
1428 try_unwrap!(Tensor<B,4,Bool>,try_b4,unwrap_b4);
1429 try_unwrap!(Tensor<B,5,Bool>,try_b5,unwrap_b5);
1430 try_unwrap!(Tensor<B,6,Bool>,try_b6,unwrap_b6);
1431 try_unwrap!(Tensor<B,7,Bool>,try_b7,unwrap_b7);
1432 try_unwrap!(Tensor<B,8,Bool>,try_b8,unwrap_b8);
1433 try_unwrap!(Tensor<B,1,Float>,try_f1,unwrap_f1);
1434 try_unwrap!(Tensor<B,2,Float>,try_f2,unwrap_f2);
1435 try_unwrap!(Tensor<B,3,Float>,try_f3,unwrap_f3);
1436 try_unwrap!(Tensor<B,4,Float>,try_f4,unwrap_f4);
1437 try_unwrap!(Tensor<B,5,Float>,try_f5,unwrap_f5);
1438 try_unwrap!(Tensor<B,6,Float>,try_f6,unwrap_f6);
1439 try_unwrap!(Tensor<B,7,Float>,try_f7,unwrap_f7);
1440 try_unwrap!(Tensor<B,8,Float>,try_f8,unwrap_f8);
1441 try_unwrap!(Tensor<B,1,Int>,try_i1,unwrap_i1);
1442 try_unwrap!(Tensor<B,2,Int>,try_i2,unwrap_i2);
1443 try_unwrap!(Tensor<B,3,Int>,try_i3,unwrap_i3);
1444 try_unwrap!(Tensor<B,4,Int>,try_i4,unwrap_i4);
1445 try_unwrap!(Tensor<B,5,Int>,try_i5,unwrap_i5);
1446 try_unwrap!(Tensor<B,6,Int>,try_i6,unwrap_i6);
1447 try_unwrap!(Tensor<B,7,Int>,try_i7,unwrap_i7);
1448 try_unwrap!(Tensor<B,8,Int>,try_i8,unwrap_i8);
1449}
1450macro_rules! bicop_num{
1451 ($trait:ident,$traitfn:ident,$traitscalar:ident)=>(
1452 impl<B:Backend,E:Copy+ElementConversion> $trait<E> for &Value<B>{
1453 fn $traitfn(self,rhs:E)->Value<B>{self.clone().$traitfn(rhs)}
1454 type Output=Value<B>;
1455 }
1456 impl<B:Backend,E:Copy+ElementConversion> $trait<E> for Value<B>{
1457 fn $traitfn(self,rhs:E)->Value<B>{
1458 match self{B1(x)=>I1(x.int().$traitscalar(rhs)),B2(x)=>I2(x.int().$traitscalar(rhs)),B3(x)=>I3(x.int().$traitscalar(rhs)),B4(x)=>I4(x.int().$traitscalar(rhs)),B5(x)=>I5(x.int().$traitscalar(rhs)),B6(x)=>I6(x.int().$traitscalar(rhs)),B7(x)=>I7(x.int().$traitscalar(rhs)),B8(x)=>I8(x.int().$traitscalar(rhs)),F1(x)=>F1(x.$traitscalar(rhs)),F2(x)=>F2(x.$traitscalar(rhs)),F3(x)=>F3(x.$traitscalar(rhs)),F4(x)=>F4(x.$traitscalar(rhs)),F5(x)=>F5(x.$traitscalar(rhs)),F6(x)=>F6(x.$traitscalar(rhs)),F7(x)=>F7(x.$traitscalar(rhs)),F8(x)=>F8(x.$traitscalar(rhs)),I1(x)=>I1(x.$traitscalar(rhs)),I2(x)=>I2(x.$traitscalar(rhs)),I3(x)=>I3(x.$traitscalar(rhs)),I4(x)=>I4(x.$traitscalar(rhs)),I5(x)=>I5(x.$traitscalar(rhs)),I6(x)=>I6(x.$traitscalar(rhs)),I7(x)=>I7(x.$traitscalar(rhs)),I8(x)=>I8(x.$traitscalar(rhs)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.$traitfn(rhs)).collect()}
1459 }
1460 type Output=Value<B>;
1461 }
1462 impl<B:Backend> $trait<&Value<B>> for &Value<B>{
1463 fn $traitfn(self,rhs:&Value<B>)->Value<B>{self.clone().$traitfn(rhs.clone())}
1464 type Output=Value<B>;
1465 }
1466 impl<B:Backend> $trait<&Value<B>> for Value<B>{
1467 fn $traitfn(self,rhs:&Value<B>)->Value<B>{self.$traitfn(rhs.clone())}
1468 type Output=Value<B>;
1469 }
1470 impl<B:Backend> $trait<Value<B>> for &Value<B>{
1471 fn $traitfn(self,rhs:Value<B>)->Value<B>{self.clone().$traitfn(rhs)}
1472 type Output=Value<B>;
1473 }
1474 impl<B:Backend> $trait<Value<B>> for Value<B>{
1475 fn $traitfn(self,rhs:Value<B>)->Value<B>{match self.promote(rhs){(B1(l),B1(r))=>I1(l.int().$traitfn(r.int())),(B2(l),B2(r))=>I2(l.int().$traitfn(r.int())),(B3(l),B3(r))=>I3(l.int().$traitfn(r.int())),(B4(l),B4(r))=>I4(l.int().$traitfn(r.int())),(B5(l),B5(r))=>I5(l.int().$traitfn(r.int())),(B6(l),B6(r))=>I6(l.int().$traitfn(r.int())),(B7(l),B7(r))=>I7(l.int().$traitfn(r.int())),(B8(l),B8(r))=>I8(l.int().$traitfn(r.int())),(F1(l),F1(r))=>F1(l.$traitfn(r)),(F2(l),F2(r))=>F2(l.$traitfn(r)),(F3(l),F3(r))=>F3(l.$traitfn(r)),(F4(l),F4(r))=>F4(l.$traitfn(r)),(F5(l),F5(r))=>F5(l.$traitfn(r)),(F6(l),F6(r))=>F6(l.$traitfn(r)),(F7(l),F7(r))=>F7(l.$traitfn(r)),(F8(l),F8(r))=>F8(l.$traitfn(r)),(I1(l),I1(r))=>I1(l.$traitfn(r)),(I2(l),I2(r))=>I2(l.$traitfn(r)),(I3(l),I3(r))=>I3(l.$traitfn(r)),(I4(l),I4(r))=>I4(l.$traitfn(r)),(I5(l),I5(r))=>I5(l.$traitfn(r)),(I6(l),I6(r))=>I6(l.$traitfn(r)),(I7(l),I7(r))=>I7(l.$traitfn(r)),(I8(l),I8(r))=>I8(l.$traitfn(r)),(Value::Incompatible(e),_)=>e.into(),(_,Value::Incompatible(e))=>e.into(),(Value::Multi(l),r)=>broadcast_multi(l,r.into_multi(),$trait::$traitfn),(l,Value::Multi(r))=>broadcast_multi(l.into_multi(),r,$trait::$traitfn),_=>panic!("couldn't promote types for $traitfn")}
1477 }
1478 type Output=Value<B>;
1479 }
1480 );
1481}
1482macro_rules! try_unwrap{
1483 ($tensor:ty,$try_unwrap:ident,$unwrap:ident)=>{
1484 pub fn $try_unwrap(self)->Result<$tensor,Self>{self.try_into()}
1486 #[track_caller]
1487 pub fn $unwrap(self)->$tensor{self.try_into().unwrap()}
1489 }
1490}
1491#[derive(Clone,Debug)]
1492pub enum Value<B:Backend>{B1(Tensor<B,1,Bool>),B2(Tensor<B,2,Bool>),B3(Tensor<B,3,Bool>),B4(Tensor<B,4,Bool>),B5(Tensor<B,5,Bool>),B6(Tensor<B,6,Bool>),B7(Tensor<B,7,Bool>),B8(Tensor<B,8,Bool>),F1(Tensor<B,1,Float>),F2(Tensor<B,2,Float>),F3(Tensor<B,3,Float>),F4(Tensor<B,4,Float>),F5(Tensor<B,5,Float>),F6(Tensor<B,6,Float>),F7(Tensor<B,7,Float>),F8(Tensor<B,8,Float>),I1(Tensor<B,1,Int>),I2(Tensor<B,2,Int>),I3(Tensor<B,3,Int>),I4(Tensor<B,4,Int>),I5(Tensor<B,5,Int>),I6(Tensor<B,6,Int>),I7(Tensor<B,7,Int>),I8(Tensor<B,8,Int>),Incompatible(String),Multi(Vec<Self>)}
1494#[derive(Clone,Debug,Deserialize,Serialize)]
1495pub enum ValueData{BX(TensorData),FX(TensorData),IX(TensorData),Incompatible(String),Multi(Vec<ValueData>)}
1497#[derive(Clone,Debug,Deserialize,Serialize)]
1498#[serde(bound="")]
1499pub struct LossOutput<B:Backend>{loss:Value<B>,output:Value<B>,target:Value<B>}
1501use {bicop_num,try_unwrap};
1502use Bound::{Excluded,Included,Unbounded};
1503use Reshape::{R1,R2,R3,R4,R5,R6,R7,R8};
1504use Shape::{X1,X2,X3,X4,X5,X6,X7,X8};
1505use Value::{B1,B2,B3,B4,B5,B6,B7,B8,F1,F2,F3,F4,F5,F6,F7,F8,I1,I2,I3,I4,I5,I6,I7,I8};
1506use ValueData::{BX,FX,IX};
1507use burn::{
1508 module::{AutodiffModule,ConstantRecord,Content,DisplaySettings,ModuleDisplay,ModuleDisplayDefault,ModuleMapper,ModuleVisitor,Quantizer},
1509 nn::{
1510 BatchNorm,Dropout,Embedding,LayerNorm,Linear,Relu,RotaryEncoding,Tanh,conv::Conv2d,loss::{CrossEntropyLoss,MseLoss},pool::MaxPool2d
1511 },
1512 prelude::{Backend,Bool,Float,Int,Module,Tensor,TensorData},
1513 record::{FileRecorder,RecorderError},
1514 tensor::{
1515 BasicOps,ElementConversion,Numeric,TensorKind,activation::{log_softmax,softmax},backend::AutodiffBackend,cast::ToElement
1516 }
1517};
1518use crate::{
1519 AI,Decompose,Merge,Op,
1520 builtin::{
1521 Alignment,ReductionMode,math::{MeanLayer,SquaredErrorLayer,SumLayer},reinforcement::AccQLayer,soft::{ChooseLayer,CrossEntropyLayer,SoftmaxLayer}
1522 },
1523 ops::{Abs,Cat,Flatten,Reshape as OpsReshape,Stack,Squeeze,Unsqueeze}
1524};
1525use rand::random;
1526use serde::{Deserialize,Deserializer,Serialize,Serializer};
1527use std::{
1528 any::TypeId,fmt::{Display,Result as FmtResult},iter::{FromIterator,once},mem,ops::{Add,Bound,Div,Mul,RangeBounds,Range,Rem,Sub},path::PathBuf,slice::{Iter as SliceIter,self},vec::IntoIter as VecIntoIter
1529};
1530use super::{Kind,Reshape,Shape,apply_depthwise};