1bicop_num!(Add,add,add_scalar);
2bicop_num!(Div,div,div_scalar);
3bicop_num!(Mul,mul,mul_scalar);
4bicop_num!(Rem,rem,remainder_scalar);
5bicop_num!(Sub,sub,sub_scalar);
6fn broadcast_multi<B:Backend,F:FnMut(Value<B>,Value<B>)->Value<B>>(u:Vec<Value<B>>,v:Vec<Value<B>>,mut f:F)->Value<B>{
7 if u.len()==1{
8 u.into_iter().cycle().zip(v).map(|(x,y)|f(x,y)).collect()
9 }else if v.len()==1{
10 u.into_iter().zip(v.into_iter().cycle()).map(|(x,y)|f(x,y)).collect()
11 }else if u.len()==v.len(){
12 u.into_iter().zip(v).map(|(x,y)|f(x,y)).collect()
13 }else{
14 "mismatched lengths".into()
15 }
16}
17fn hard_choose_burn_1<B:Backend,const N:usize>(dim:i32,distribution:Tensor<B,N>)->u32{
18 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
19 let distribution=if dim==N-1{distribution}else{distribution.movedim(dim,N-1)}.into_data();
20 let sum=distribution.iter().fold(0.0,|acc:f32,weight:f32|acc+weight);
21
22 distribution.iter().scan(random::<f32>()*sum,|choice:&mut f32,weight:f32|Some(*choice-=weight).filter(|_|*choice>=0.0)).count() as u32
23}
24fn hard_choose_burn_multi<B:Backend,const N:usize>(dim:i32,distribution:Tensor<B,N>)->Vec<u32>{
25 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
26
27 let chunk=distribution.dims()[dim];
28 let distribution=if dim==N-1{distribution}else{distribution.movedim(dim,N-1)}.into_data().to_vec().unwrap();
29
30 distribution.chunks_exact(chunk).map(|d|{
31 let sum=d.iter().fold(0.0,|acc:f32,weight:&f32|acc+weight);
32 d.iter().scan(random::<f32>()*sum,|choice:&mut f32,weight:&f32|Some(*choice-=weight).filter(|_|*choice>=0.0)).count() as u32
33 }).collect()
34}
35fn hard_choose_burn_tensor<B:Backend,const N:usize>(dim:i32,distribution:Tensor<B,N>)->Tensor<B,N,Int>{let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
37 let device=distribution.device();
38 let mut dims=distribution.dims();
39
40 dims[N-1]=1;
41 let r:Tensor<B,N,Int>=Tensor::from_data(TensorData::new(hard_choose_burn_multi(dim as i32,distribution),dims),&device);
42
43 r.movedim(N-1,dim)
44}
45fn slice_slice<B:Backend,K:BasicOps<B>+TensorKind<B>,const N:usize>(ranges:&[Range<usize>],tensor:Tensor<B,N,K>)->Tensor<B,N,K>{
46 let mut n=0;
47 let mut acc=||{
48 let a=n;
49 n+=1;
50 a
51 };
52
53 match ranges.len(){0=>tensor,1=>tensor.slice([0;1].map(|_|ranges[acc()].clone())),2=>tensor.slice([0;2].map(|_|ranges[acc()].clone())),3=>tensor.slice([0;3].map(|_|ranges[acc()].clone())),4=>tensor.slice([0;4].map(|_|ranges[acc()].clone())),5=>tensor.slice([0;5].map(|_|ranges[acc()].clone())),6=>tensor.slice([0;6].map(|_|ranges[acc()].clone())),7=>tensor.slice([0;7].map(|_|ranges[acc()].clone())),8=>tensor.slice([0;8].map(|_|ranges[acc()].clone())),_=>panic!("too many ranges for current max 8 dims")}
54}
55fn soft_choose_burn_1<B:Backend,const N:usize>(dim:i32,logits:Tensor<B,N>,temperature:f32)->u32{
56 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
57 let logits=if dim==N-1{logits}else{logits.movedim(dim,N-1)};
58 let distribution=softmax(logits/temperature,N-1).into_data();
59 distribution.iter().scan(random(),|choice:&mut f32,weight:f32|Some(*choice-=weight).filter(|_|*choice>=0.0)).count() as u32
60}
61fn soft_choose_burn_multi<B:Backend,const N:usize>(dim:i32,logits:Tensor<B,N>,temperature:f32)->Vec<u32>{
62 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
63 let logits=if dim==N-1{logits}else{logits.movedim(dim,N-1)};
64 let chunk=logits.dims()[N-1];
65 let distribution=softmax(logits/temperature,N-1).into_data().to_vec().unwrap();
66 distribution.chunks_exact(chunk).map(|d|d.iter().scan(random(),|choice:&mut f32,weight:&f32|Some(*choice-=weight).filter(|_|*choice>=0.0)).count() as u32).collect()
67}
68fn soft_choose_burn_tensor<B:Backend,const N:usize>(dim:i32,logits:Tensor<B,N>,temperature:f32)->Tensor<B,N,Int>{let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
70 let device=logits.device();
71 let mut dims=logits.dims();
72
73 dims[N-1]=1;
74 let r:Tensor<B,N,Int>=Tensor::from_data(TensorData::new(soft_choose_burn_multi(dim as i32,logits,temperature),dims),&device);
75 r.movedim(N-1,dim)
76}
77impl<'a,B:Backend> Deserialize<'a> for Value<B>{
78 fn deserialize<D:Deserializer<'a>>(deserializer:D)->Result<Self,D::Error>{ValueData::deserialize(deserializer).map(Into::into)}
79}
80impl<A:AutodiffBackend> AutodiffModule<A> for Value<A>{
81 fn valid(&self)->Self::InnerModule{
82 match self{B1(x)=>B1(x.valid()),B2(x)=>B2(x.valid()),B3(x)=>B3(x.valid()),B4(x)=>B4(x.valid()),B5(x)=>B5(x.valid()),B6(x)=>B6(x.valid()),B7(x)=>B7(x.valid()),B8(x)=>B8(x.valid()),F1(x)=>F1(x.valid()),F2(x)=>F2(x.valid()),F3(x)=>F3(x.valid()),F4(x)=>F4(x.valid()),F5(x)=>F5(x.valid()),F6(x)=>F6(x.valid()),F7(x)=>F7(x.valid()),F8(x)=>F8(x.valid()),I1(x)=>I1(x.valid()),I2(x)=>I2(x.valid()),I3(x)=>I3(x.valid()),I4(x)=>I4(x.valid()),I5(x)=>I5(x.valid()),I6(x)=>I6(x.valid()),I7(x)=>I7(x.valid()),I8(x)=>I8(x.valid()),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.iter().map(|x|x.valid()).collect()}
83 }
84 type InnerModule=Value<A::InnerBackend>;
85}
86impl<A:Into<Value<B>>,B:Backend> FromIterator<A> for Value<B>{
87 fn from_iter<I:IntoIterator<Item=A>>(iter:I)->Self{Value::Multi(iter.into_iter().map(Into::into).collect())}
88}
89impl<B:Backend,K:'static+TensorKind<B>,const N:usize> From<Result<Tensor<B,N,K>,String>> for Value<B>{
90 fn from(value:Result<Tensor<B,N,K>,String>)->Self{
91 match value{Err(e)=>e.into(),Ok(t)=>t.into()}
92 }
93}
94impl<B:Backend,K:'static+TensorKind<B>,const N:usize> From<Tensor<B,N,K>> for Value<B>{
95 fn from(value:Tensor<B,N,K>)->Self{
96 let kind=TypeId::of::<K>();
97 let kind=if kind==TypeId::of::<Bool>(){Kind::Bool}else if kind==TypeId::of::<Float>(){Kind::Float}else if kind==TypeId::of::<Int>(){Kind::Int}else{return "only bool, float, and int tensors with dimensions 1-8 are currently supported".into()};
98
99 let v=unsafe{
100 match (N,kind){(1,Kind::Bool)=>B1(mem::transmute_copy(&value)),(2,Kind::Bool)=>B2(mem::transmute_copy(&value)),(3,Kind::Bool)=>B3(mem::transmute_copy(&value)),(4,Kind::Bool)=>B4(mem::transmute_copy(&value)),(5,Kind::Bool)=>B5(mem::transmute_copy(&value)),(6,Kind::Bool)=>B6(mem::transmute_copy(&value)),(7,Kind::Bool)=>B7(mem::transmute_copy(&value)),(8,Kind::Bool)=>B8(mem::transmute_copy(&value)),(1,Kind::Float)=>F1(mem::transmute_copy(&value)),(2,Kind::Float)=>F2(mem::transmute_copy(&value)),(3,Kind::Float)=>F3(mem::transmute_copy(&value)),(4,Kind::Float)=>F4(mem::transmute_copy(&value)),(5,Kind::Float)=>F5(mem::transmute_copy(&value)),(6,Kind::Float)=>F6(mem::transmute_copy(&value)),(7,Kind::Float)=>F7(mem::transmute_copy(&value)),(8,Kind::Float)=>F8(mem::transmute_copy(&value)),(1,Kind::Int)=>I1(mem::transmute_copy(&value)),(2,Kind::Int)=>I2(mem::transmute_copy(&value)),(3,Kind::Int)=>I3(mem::transmute_copy(&value)),(4,Kind::Int)=>I4(mem::transmute_copy(&value)),(5,Kind::Int)=>I5(mem::transmute_copy(&value)),(6,Kind::Int)=>I6(mem::transmute_copy(&value)),(7,Kind::Int)=>I7(mem::transmute_copy(&value)),(8,Kind::Int)=>I8(mem::transmute_copy(&value)),_=>return "only bool, float, and int tensors with dimensions 1-8 are currently supported".into()}
101 };
102 mem::forget(value);
103 v
104 }
105}
106impl<B:Backend,K:'static+TensorKind<B>,const N:usize> TryFrom<Value<B>> for Tensor<B,N,K>{
107 fn try_from(value:Value<B>)->Result<Self,Self::Error>{
108 let kind=TypeId::of::<K>();
109 let kind=if kind==TypeId::of::<Bool>(){Kind::Bool}else if kind==TypeId::of::<Float>(){Kind::Float}else if kind==TypeId::of::<Int>(){Kind::Int}else{return Err(value)};
110
111 if Some(N)!=value.rank()||kind!=value.kind(){return Err(value)}
112 let r=unsafe{
113 match &value{B1(x)=>mem::transmute_copy(x),B2(x)=>mem::transmute_copy(x),B3(x)=>mem::transmute_copy(x),B4(x)=>mem::transmute_copy(x),B5(x)=>mem::transmute_copy(x),B6(x)=>mem::transmute_copy(x),B7(x)=>mem::transmute_copy(x),B8(x)=>mem::transmute_copy(x),F1(x)=>mem::transmute_copy(x),F2(x)=>mem::transmute_copy(x),F3(x)=>mem::transmute_copy(x),F4(x)=>mem::transmute_copy(x),F5(x)=>mem::transmute_copy(x),F6(x)=>mem::transmute_copy(x),F7(x)=>mem::transmute_copy(x),F8(x)=>mem::transmute_copy(x),I1(x)=>mem::transmute_copy(x),I2(x)=>mem::transmute_copy(x),I3(x)=>mem::transmute_copy(x),I4(x)=>mem::transmute_copy(x),I5(x)=>mem::transmute_copy(x),I6(x)=>mem::transmute_copy(x),I7(x)=>mem::transmute_copy(x),I8(x)=>mem::transmute_copy(x),_=>panic!("internal error")}
114 };
115 mem::forget(value);
116 Ok(r)
117 }
118 type Error=Value<B>;
119}
120impl<B:Backend,R:Clone+RangeBounds<isize>> Flatten<R> for Value<B>{
121 fn flatten(self,args:R)->Self::Output{
122 fn f<B:Backend,K:'static+BasicOps<B>+TensorKind<B>,R:RangeBounds<isize>,const N:usize>(args:R,x:Tensor<B,N,K>)->Value<B>{
123 let a=match args.start_bound(){
124 Excluded(&n)=>(if n<0{N-((-n) as usize)}else{n as usize})+1,
125 Included(&n)=>if n<0{N-((-n) as usize)}else{n as usize},
126 Unbounded=>0
127 };
128 let b=match args.end_bound(){
129 Excluded(&n)=>if n<0{N-((-n) as usize)}else if n==0{N}else{n as usize},
130 Included(&n)=>(if n<0{N-((-n) as usize)}else{n as usize})+1,
131 Unbounded=>N
132 };
133 let flattenedlen=b-a;
134
135 match N-flattenedlen+1{
136 1=>x.flatten::<1>(a,b-1).into(),
137 2=>x.flatten::<2>(a,b-1).into(),
138 3=>x.flatten::<3>(a,b-1).into(),
139 4=>x.flatten::<4>(a,b-1).into(),
140 5=>x.flatten::<5>(a,b-1).into(),
141 6=>x.flatten::<6>(a,b-1).into(),
142 7=>x.flatten::<7>(a,b-1).into(),
143 8=>x.into(),
144 _=>"invalid flatten".into()
145 }
146 }
147
148 match self{
149 B1(x)=>f(args,x),
150 B2(x)=>f(args,x),
151 B3(x)=>f(args,x),
152 B4(x)=>f(args,x),
153 B5(x)=>f(args,x),
154 B6(x)=>f(args,x),
155 B7(x)=>f(args,x),
156 B8(x)=>f(args,x),
157 F1(x)=>f(args,x),
158 F2(x)=>f(args,x),
159 F3(x)=>f(args,x),
160 F4(x)=>f(args,x),
161 F5(x)=>f(args,x),
162 F6(x)=>f(args,x),
163 F7(x)=>f(args,x),
164 F8(x)=>f(args,x),
165 I1(x)=>f(args,x),
166 I2(x)=>f(args,x),
167 I3(x)=>f(args,x),
168 I4(x)=>f(args,x),
169 I5(x)=>f(args,x),
170 I6(x)=>f(args,x),
171 I7(x)=>f(args,x),
172 I8(x)=>f(args,x),
173 Value::Incompatible(e)=>e.into(),
174 Value::Multi(v)=>v.into_iter().map(|x|x.flatten(args.clone())).collect(),
175 }
176 }
177 type Output=Self;
178}
179impl<B:Backend,R:Into<Reshape>> OpsReshape<R> for Value<B>{fn reshape(self,args:R)->Self::Output{
181 fn f<B:Backend,K:'static+BasicOps<B>+TensorKind<B>,const N:usize>(args:Reshape,x:Tensor<B,N,K>)->Value<B>{
182 match args{
183 R1(r)=>x.reshape(r.map(|x|x as i32)).into(),
184 R2(r)=>x.reshape(r.map(|x|x as i32)).into(),
185 R3(r)=>x.reshape(r.map(|x|x as i32)).into(),
186 R4(r)=>x.reshape(r.map(|x|x as i32)).into(),
187 R5(r)=>x.reshape(r.map(|x|x as i32)).into(),
188 R6(r)=>x.reshape(r.map(|x|x as i32)).into(),
189 R7(r)=>x.reshape(r.map(|x|x as i32)).into(),
190 R8(r)=>x.reshape(r.map(|x|x as i32)).into(),
191 Reshape::Recursive(_v)=>"reshaping a single tensor into multiple tensors is currently not supported".into()
192 }
193 }
194
195 let args=args.into();
196 let depth=args.depth();
197
198 apply_depthwise(depth,|value|{
199 let args=args.clone();
200 match value{
201 B1(x)=>f(args,x),
202 B2(x)=>f(args,x),
203 B3(x)=>f(args,x),
204 B4(x)=>f(args,x),
205 B5(x)=>f(args,x),
206 B6(x)=>f(args,x),
207 B7(x)=>f(args,x),
208 B8(x)=>f(args,x),
209 F1(x)=>f(args,x),
210 F2(x)=>f(args,x),
211 F3(x)=>f(args,x),
212 F4(x)=>f(args,x),
213 F5(x)=>f(args,x),
214 F6(x)=>f(args,x),
215 F7(x)=>f(args,x),
216 F8(x)=>f(args,x),
217 I1(x)=>f(args,x),
218 I2(x)=>f(args,x),
219 I3(x)=>f(args,x),
220 I4(x)=>f(args,x),
221 I5(x)=>f(args,x),
222 I6(x)=>f(args,x),
223 I7(x)=>f(args,x),
224 I8(x)=>f(args,x),
225 Value::Incompatible(e)=>e.into(),
226 Value::Multi(v)=>if let Reshape::Recursive(r)=args{r.into_iter().zip(v).map(|(r,v)|v.reshape(r)).collect()}else{"reshaping multiple tensors into a single tensor is currently not supported".into()}
227 }
228 },self)
229 }
230 type Output=Self;
231}
232impl<B:Backend,S:?Sized+AsRef<str>> From<&S> for Value<B>{
233 fn from(value:&S)->Self{Self::Incompatible(value.as_ref().to_string())}
234}
235impl<B:Backend,const D:usize> AI<Value<B>,Value<B>> for BatchNorm<B,D>{
236 fn forward(&self,input:Value<B>)->Value<B>{
237 fn f<B:Backend,const D:usize,const E:usize,const F:usize>(norm:&BatchNorm<B,D>,x:Tensor<B,E>)->Value<B>{
238 let norm:BatchNorm<B,F>=BatchNorm{beta:norm.beta.clone(),epsilon:norm.epsilon.clone(),gamma:norm.gamma.clone(),momentum:norm.momentum.clone(),running_mean:norm.running_mean.clone(),running_var:norm.running_var.clone()};
239 norm.forward(x).into()
240 }
241 match input.float(){
242 F1(x)=>AI::forward(self,F1(x).unsqueeze().unsqueeze()).squeeze().squeeze(),
243 F2(x)=>AI::forward(self,F2(x).unsqueeze()).squeeze(),
244 F3(x)=>f::<B,D,3,1>(self,x),
245 F4(x)=>f::<B,D,4,2>(self,x),
246 F5(x)=>f::<B,D,5,3>(self,x),
247 F6(x)=>f::<B,D,6,4>(self,x),
248 F7(x)=>f::<B,D,7,5>(self,x),
249 F8(x)=>f::<B,D,8,6>(self,x),
250 Value::Incompatible(e)=>e.into(),
251 Value::Multi(v)=>v.into_iter().map(|x|AI::forward(self,x)).collect(),
252 _=>panic!("internal error")
253 }
254 }
255}
256impl<B:Backend> AI<(Value<B>,Value<B>),Vec<f32>> for CrossEntropyLayer{
257 fn forward(&self,input:(Value<B>,Value<B>))->Vec<f32>{
258 let output:Value<B>=self.forward(input);
259 output.into_float_vec()
260 }
261}
262impl<B:Backend> AI<(Value<B>,Value<B>),LossOutput<B>> for CrossEntropyLayer{
263 fn forward(&self,(output,target):(Value<B>,Value<B>))->LossOutput<B>{
264 let loss=self.forward((output.clone(),target.clone()));
265 LossOutput::new(loss,output,target)
266 }
267}
268impl<B:Backend> AI<(Value<B>,Value<B>),Value<B>> for CrossEntropyLayer{fn forward(&self,(output,target):(Value<B>,Value<B>))->Value<B>{
270 fn ff<B:Backend,const N:usize>(dim:i32,y:Tensor<B,N>,t:Tensor<B,N>,temperature:f32)->Result<Tensor<B,N>,String>{
271 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
272 let (ydims,tdims)=(y.dims(),t.dims());
273 if ydims==tdims{
274 let logy=if temperature.is_nan(){y.log()}else{log_softmax(y/temperature,dim)};
275 Ok(logy*t.neg())
276 }else{
277 Err(format!("incompatible shapes to cross entropy. ydims: {ydims:?} tdims: {tdims:?}"))
278 }
279 }
280 fn fi<B:Backend,const N:usize,const K:usize>(dim:i32,y:Tensor<B,N>,t:Tensor<B,K,Int>,temperature:f32)->Result<Tensor<B,K>,String>{
281 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
282 let (ydims,tdims)=(y.dims(),t.dims());
283 if ydims.iter().enumerate().filter_map(|(n,y)|(n!=dim).then_some(y)).eq(tdims.iter()){
284 let logy=if temperature.is_nan(){y.log()}else{log_softmax(y/temperature,dim)};
285 Ok(logy.gather(dim,t.unsqueeze_dim(dim)).neg().squeeze(dim))
286 }else{
287 Err(format!("incompatible shapes to cross entropy along dimension {dim}. ydims: {ydims:?} tdims: {tdims:?}"))
288 }
289 }
290 let (dim,temp)=(self.get_dim(),self.get_temperature());
291
292 match match (output,target){
293 (F1(y),F1(t))=>ff(dim,y,t,temp).map(Into::into),
294 (F2(y),F2(t))=>ff(dim,y,t,temp).map(Into::into),
295 (F3(y),F3(t))=>ff(dim,y,t,temp).map(Into::into),
296 (F4(y),F4(t))=>ff(dim,y,t,temp).map(Into::into),
297 (F5(y),F5(t))=>ff(dim,y,t,temp).map(Into::into),
298 (F6(y),F6(t))=>ff(dim,y,t,temp).map(Into::into),
299 (F7(y),F7(t))=>ff(dim,y,t,temp).map(Into::into),
300 (F8(y),F8(t))=>ff(dim,y,t,temp).map(Into::into),
301 (F1(y),I1(t))=>fi(dim,y.unsqueeze::<2>(),t,temp).map(Into::into),(F2(y),I1(t))=>fi(dim,y,t,temp).map(Into::into),
303 (F3(y),I2(t))=>fi(dim,y,t,temp).map(Into::into),
304 (F4(y),I3(t))=>fi(dim,y,t,temp).map(Into::into),
305 (F5(y),I4(t))=>fi(dim,y,t,temp).map(Into::into),
306 (F6(y),I5(t))=>fi(dim,y,t,temp).map(Into::into),
307 (F7(y),I6(t))=>fi(dim,y,t,temp).map(Into::into),
308 (F8(y),I7(t))=>fi(dim,y,t,temp).map(Into::into),
309 (Value::Incompatible(y),_)=>Err(y),
310 (_,Value::Incompatible(t))=>Err(t),(Value::Multi(y),Value::Multi(t))=>if y.len()==t.len(){Ok(Value::Multi(y.into_iter().zip(t).map(|x|self.forward_typed::<_,Value<B>>(x)).collect()))}else{Err("mismatched lengths".into())},
312 _=>Err("incompatible".into())
313 }{
314 Err(e)=>Value::Incompatible(e),Ok(x)=>x
315 }
316 }
317}
318impl<B:Backend> AI<(Value<B>,Value<B>),Value<B>> for CrossEntropyLoss<B>{
319 fn forward(&self,(output,target):(Value<B>,Value<B>))->Value<B>{
320 let mut op=().fix_type::<Value<B>>().cross_entropy(1.0);
321 if !self.logits{op.set_temperature(f32::NAN)}
322 op.forward((output,target))
323 }
324}
325impl<B:Backend> AI<(Value<B>,Value<B>),LossOutput<B>> for SquaredErrorLayer{
326 fn forward(&self,(output,target):(Value<B>,Value<B>))->LossOutput<B>{
327 let loss=self.forward((output.clone(),target.clone()));
328 LossOutput::new(loss,output,target)
329 }
330}
331impl<B:Backend> AI<(Value<B>,Value<B>),Value<B>> for SquaredErrorLayer{
332 fn forward(&self,(output,target):(Value<B>,Value<B>))->Value<B>{
333 fn f<B:Backend,const N:usize>(y:Tensor<B,N>,t:Tensor<B,N>)->Value<B>{
334 if y.dims()==t.dims(){MseLoss.forward_no_reduction(y,t).into()}else{"compatible inputs for squared error are float tensors of the same shape".into()}
335 }
336 match (output.float(),target.float()){(F1(y),F1(t))=>f(y,t),(F2(y),F2(t))=>f(y,t),(F3(y),F3(t))=>f(y,t),(F4(y),F4(t))=>f(y,t),(F5(y),F5(t))=>f(y,t),(F6(y),F6(t))=>f(y,t),(F7(y),F7(t))=>f(y,t),(F8(y),F8(t))=>f(y,t),(Value::Incompatible(y),_)=>y.into(),(_,Value::Incompatible(t))=>t.into(),(Value::Multi(y),t)=>broadcast_multi(y,t.into_multi(),|y,t|self.forward((y,t))),(y,Value::Multi(t))=>broadcast_multi(y.into_multi(),t,|y,t|self.forward((y,t))),_=>"compatible inputs for squared error are float tensors of the same shape".into()}
337 }
338}
339impl<B:Backend> AI<(Value<B>,Value<B>),Vec<f32>> for SquaredErrorLayer{
340 fn forward(&self,(output,target):(Value<B>,Value<B>))->Vec<f32>{
341 let error:Value<B>=self.forward((output,target));
342 error.into_float_vec()
343 }
344}
345impl<B:Backend> AI<(Value<B>,Value<B>),f32> for SquaredErrorLayer{
346 fn forward(&self,(output,target):(Value<B>,Value<B>))->f32{().fix_type::<Value<B>>().squared_error().mean().forward((output,target))}
347}
348impl<B:Backend> AI<Value<B>,Tensor<B,1>> for MeanLayer{
349 fn forward(&self,input:Value<B>)->Tensor<B,1>{
350 fn avg<B:Backend,const N:usize>(x:Tensor<B,N>)->Tensor<B,1>{x.mean()}
351 let l=input.len();
352
353 if l==0{return Tensor::from_data(TensorData::new(vec![f32::NAN],[1]),&Default::default())}
354 match input.float(){F1(x)=>avg(x),F2(x)=>avg(x),F3(x)=>avg(x),F4(x)=>avg(x),F5(x)=>avg(x),F6(x)=>avg(x),F7(x)=>avg(x),F8(x)=>avg(x),Value::Incompatible(e)=>panic!("Could not reduce to a scalar due to incompatibility: {e}"),Value::Multi(v)=>v.into_iter().map(|x|self.forward(x)).reduce(|x:Tensor<B,1>,y:Tensor<B,1>|x+y).unwrap()/l as f32,_=>panic!("internal error")}
355 }
356}
357impl<B:Backend> AI<Value<B>,Tensor<B,1>> for SumLayer{
358 fn forward(&self,input:Value<B>)->Tensor<B,1>{
359 fn sum<B:Backend,const N:usize>(x:Tensor<B,N>)->Tensor<B,1>{x.sum()}
360 let l=input.len();
361
362 if l==0{return Tensor::from_data(TensorData::new(vec![f32::NAN],[1]),&Default::default())}
363 match input.float(){F1(x)=>sum(x),F2(x)=>sum(x),F3(x)=>sum(x),F4(x)=>sum(x),F5(x)=>sum(x),F6(x)=>sum(x),F7(x)=>sum(x),F8(x)=>sum(x),Value::Incompatible(e)=>panic!("Could not reduce to a scalar due to incompatibility: {e}"),Value::Multi(v)=>v.into_iter().map(|x|self.forward(x)).reduce(|x:Tensor<B,1>,y:Tensor<B,1>|x+y).unwrap(),_=>panic!("internal error")}
364 }
365}
366impl<B:Backend> AI<Value<B>,Value<B>> for Conv2d<B>{
367 fn forward(&self,input:Value<B>)->Value<B>{
368 fn f<B:Backend,const N:usize>(input:Tensor<B,N>,layer:&Conv2d<B>)->Value<B>{let mut dims=input.dims();
370 let n:usize=dims.iter().product();
371
372 let c=if N<3{1}else{dims[N-3]};
373 let h=if N<2{1}else{dims[N-2]};
374 let w=dims[N-1];
375
376 let b=n/(c*h*w);
377 let output=layer.forward(input.reshape([b,c,h,w]));
378
379 let [_b,c,h,w]=output.dims();
380
381 dims[N-1]=w;
382 if N<3&&c!=1{return F3(output.reshape([c,h,w]))}else if N>=3{dims[N-3]=c}
383 if N<2&&h!=1{return F2(output.reshape([h,w]))}else if N>=2{dims[N-2]=h}
384 output.reshape(dims).into()
385 }
386 let l=self;
387
388 match input.float(){F1(x)=>f(x,l),F2(x)=>f(x,l),F3(x)=>f(x,l),F4(x)=>f(x,l),F5(x)=>f(x,l),F6(x)=>f(x,l),F7(x)=>f(x,l),F8(x)=>f(x,l),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|AI::forward(self,x)).collect(),_=>panic!("internal error")}
389 }
390}
391impl<B:Backend> AI<Value<B>,Value<B>> for CrossEntropyLayer{
392 fn forward(&self,input:Value<B>)->Value<B>{
393 match input{
394 Value::Incompatible(e)=>e.into(),
395 Value::Multi(v)=>if v.len()==2{
396 let [output,target]=v.try_into().unwrap();
397 self.forward((output,target))
398 }else{
399 v.into_iter().map(|x|self.forward(x)).collect()
400 },
401 _=>"cross entropy inputs must be in pairs".into()
402 }
403 }
404}
405impl<B:Backend> AI<Value<B>,Value<B>> for CrossEntropyLoss<B>{
406 fn forward(&self,input:Value<B>)->Value<B>{
407 let mut op=CrossEntropyLayer::new(1.0);
408 if !self.logits{op.set_temperature(f32::NAN)}
409 op.forward(input)
410 }
411}
412impl<B:Backend> AI<Value<B>,Value<B>> for MeanLayer{
413 fn forward(&self,input:Value<B>)->Value<B>{
414 fn avg<B:Backend,const N:usize,const K:usize>(d:i32,x:Tensor<B,N>)->Tensor<B,K>{
415 let d=if d<0{N-((-d) as usize)}else{d as usize};
416 x.mean_dim(d).squeeze(d)
417 }
418 let l=input.len();
419
420 if l==0{return input}
421 match self.get_reduction_mode(){
422 ReductionMode::Component=>F1(self.forward(input)),
423 ReductionMode::Dim(d)=>{
424 if let Some(r)=input.rank(){
425 if d>=r as i32||d<(-(r as i32)){return format!("rank {r} is too low to cat along dimension {d}").into()}
426 }
427 match input.float(){F1(x)=>F1(x.mean()),F2(x)=>F1(avg(d,x)),F3(x)=>F2(avg(d,x)),F4(x)=>F3(avg(d,x)),F5(x)=>F4(avg(d,x)),F6(x)=>F5(avg(d,x)),F7(x)=>F6(avg(d,x)),F8(x)=>F7(avg(d,x)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|self.forward(x)).reduce(|x:Value<B>,y:Value<B>|x+y).unwrap()/l as f32,_=>panic!("internal error")}
428 },
429 ReductionMode::Tensor=>match input.float(){Value::Multi(v)=>v.into_iter().reduce(|x,y|x+y).unwrap()/l as f32,x=>x}
430 }
431 }
432}
433impl<B:Backend> AI<Value<B>,Value<B>> for MseLoss{
434 fn forward(&self,input:Value<B>)->Value<B>{SquaredErrorLayer::new().forward(input)}
435}
436impl<B:Backend> AI<Value<B>,Value<B>> for SquaredErrorLayer{
437 fn forward(&self,input:Value<B>)->Value<B>{
438 match input{
439 Value::Incompatible(e)=>e.into(),
440 Value::Multi(v)=>if v.len()==2{
441 let [output,target]=v.try_into().unwrap();
442 self.forward((output,target))
443 }else{
444 v.into_iter().map(|x|self.forward(x)).collect()
445 },
446 _=>"squared error inputs must be in pairs".into()
447 }
448 }
449}
450impl<B:Backend> AI<Value<B>,Value<B>> for SumLayer{
451 fn forward(&self,input:Value<B>)->Value<B>{
452 fn sum<B:Backend,const N:usize,const K:usize>(d:i32,x:Tensor<B,N>)->Tensor<B,K>{
453 let d=if d<0{N-((-d) as usize)}else{d as usize};
454 x.mean_dim(d).squeeze(d)
455 }
456 let l=input.len();
457
458 if l==0{return input}
459 match self.get_reduction_mode(){
460 ReductionMode::Component=>F1(self.forward(input)),
461 ReductionMode::Dim(d)=>{
462 if let Some(r)=input.rank(){
463 if d>=r as i32||d<(-(r as i32)){return format!("rank {r} is too low to cat along dimension {d}").into()}
464 }
465 match input.float(){F1(x)=>F1(x.sum()),F2(x)=>F1(sum(d,x)),F3(x)=>F2(sum(d,x)),F4(x)=>F3(sum(d,x)),F5(x)=>F4(sum(d,x)),F6(x)=>F5(sum(d,x)),F7(x)=>F6(sum(d,x)),F8(x)=>F7(sum(d,x)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|self.forward(x)).reduce(|x:Value<B>,y:Value<B>|x+y).unwrap(),_=>panic!("internal error")}
466 },
467 ReductionMode::Tensor=>match input.float(){Value::Multi(v)=>v.into_iter().reduce(|x,y|x+y).unwrap(),x=>x}
468 }
469 }
470}
471impl<B:Backend> AI<Value<B>,f32> for MeanLayer{
472 fn forward(&self,input:Value<B>)->f32{
473 let y:Tensor<B,1>=self.forward(input);
474 y.into_scalar().to_f32()
475 }
476}
477impl<B:Backend> AI<Value<B>,f32> for SumLayer{
478 fn forward(&self,input:Value<B>)->f32{
479 let y:Tensor<B,1>=self.forward(input);
480 y.into_scalar().to_f32()
481 }
482}
483impl<B:Backend> AI<Value<B>,Value<B>> for AccQLayer{
484 fn forward(&self,input:Value<B>)->Value<B>{
485 fn acc_q<B:Backend,const N:usize>(dim:i32,gamma:f32,i:Tensor<B,N>)->Tensor<B,N>{
486 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
487 let mut q=i.split(1,dim);
488 q.iter_mut().rev().fold(None,|future,present|{
489 if let Some(f)=future{*present=f*gamma+present.clone()}
490 Some(present.clone())
491 });
492 Tensor::cat(q,dim)
493 }
494 let (dim,gamma)=(self.get_dim(),self.get_gamma());
495
496 match input.float(){F1(x)=>F1(acc_q(dim,gamma,x)),F2(x)=>F2(acc_q(dim,gamma,x)),F3(x)=>F3(acc_q(dim,gamma,x)),F4(x)=>F4(acc_q(dim,gamma,x)),F5(x)=>F5(acc_q(dim,gamma,x)),F6(x)=>F6(acc_q(dim,gamma,x)),F7(x)=>F7(acc_q(dim,gamma,x)),F8(x)=>F8(acc_q(dim,gamma,x)),Value::Incompatible(x)=>x.into(),Value::Multi(x)=>Value::Multi(x.into_iter().map(|x|self.forward(x)).collect()),_=>panic!("unexpected non float value")}
497 }
498}
499impl<B:Backend> AI<Value<B>,u32> for ChooseLayer{
500 fn forward(&self,input:Value<B>)->u32{
501 let (dim,temperature)=(self.get_dim(),self.get_temperature());
502
503 match input.float(){
504 F1(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
505 F2(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
506 F3(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
507 F4(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
508 F5(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
509 F6(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
510 F7(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
511 F8(x)=>if temperature.is_nan(){hard_choose_burn_1(dim,x)}else{soft_choose_burn_1(dim,x,temperature)},
512 Value::Incompatible(e)=>panic!("Could not create scalar due to incompatibility: {e}"),
513 Value::Multi(v)=>if v.len()==1{self.forward(v.into_iter().next().unwrap())}else{panic!("Cannot soft choose one scalar from multiple values")},
514 _=>panic!("internal error")
515 }
516 }
517}
518impl<B:Backend> AI<Value<B>,Vec<u32>> for ChooseLayer{
519 fn forward(&self,input:Value<B>)->Vec<u32>{
520 let (dim,temperature)=(self.get_dim(),self.get_temperature());
521
522 match input.float(){
523 F1(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
524 F2(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
525 F3(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
526 F4(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
527 F5(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
528 F6(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
529 F7(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
530 F8(x)=>if temperature.is_nan(){hard_choose_burn_multi(dim,x)}else{soft_choose_burn_multi(dim,x,temperature)},
531 Value::Incompatible(e)=>panic!("Could not create vector due to incompatibility: {e}"),
532 Value::Multi(v)=>v.into_iter().flat_map(|x|self.forward_typed::<_,Vec<u32>>(x)).collect(),
533 _=>panic!("internal error")
534 }
535 }
536}
537impl<B:Backend> AI<Value<B>,Value<B>> for Dropout{
538 fn forward(&self,input:Value<B>)->Value<B>{
539 match input.float(){F1(x)=>F1(self.forward(x)),F2(x)=>F2(self.forward(x)),F3(x)=>F3(self.forward(x)),F4(x)=>F4(self.forward(x)),F5(x)=>F5(self.forward(x)),F6(x)=>F6(self.forward(x)),F7(x)=>F7(self.forward(x)),F8(x)=>F8(self.forward(x)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(|x|AI::forward(self,x)).collect()),_=>panic!("internal error")}
540 }
541}
542impl<B:Backend> AI<Value<B>,Value<B>> for Embedding<B>{
543 fn forward(&self,input:Value<B>)->Value<B>{
544 fn apply_embed<B:Backend,const N:usize,const K:usize>(this:&Embedding<B>,x:Tensor<B,N,Int>)->Tensor<B,K>{
545 let dims=x.dims();
546 let [batch,seq]=[dims[0],dims.iter().skip(1).product()];
547 let x=x.reshape([batch,seq]);
548 let y=this.forward(x);
549 let embed=y.dims().last().copied().unwrap();
550 let mut ydims=[0;K];
551 ydims[..N].copy_from_slice(&dims);
552 ydims[N]=embed;
553 y.reshape(ydims)
554 }
555 fn apply_linear<B:Backend,const N:usize>(this:&Embedding<B>,x:Tensor<B,N>)->Tensor<B,N>{
556 Linear{bias:None,weight:this.weight.clone()}.forward(x)
557 }
558 match input{F1(x)=>apply_linear(self,x).into(),F2(x)=>apply_linear(self,x).into(),F3(x)=>apply_linear(self,x).into(),F4(x)=>apply_linear(self,x).into(),F5(x)=>apply_linear(self,x).into(),F6(x)=>apply_linear(self,x).into(),F7(x)=>apply_linear(self,x).into(),F8(x)=>apply_linear(self,x).into(),I1(x)=>apply_embed::<B,1,2>(self,x).into(),I2(x)=>apply_embed::<B,2,3>(self,x).into(),I3(x)=>apply_embed::<B,3,4>(self,x).into(),I4(x)=>apply_embed::<B,4,5>(self,x).into(),I5(x)=>apply_embed::<B,5,6>(self,x).into(),I6(x)=>apply_embed::<B,6,7>(self,x).into(),I7(x)=>apply_embed::<B,7,8>(self,x).into(),I8(_x)=>"embedding output would exceed maximum supported rank".into(),Value::Incompatible(x)=>x.into(),Value::Multi(x)=>x.into_iter().map(|x|AI::forward(self,x)).collect::<Vec<_>>().into(),_=>"embedding is only available for float or int inputs".into()}
559 }
560}
561impl<B:Backend> AI<Value<B>,Value<B>> for LayerNorm<B>{
562 fn forward(&self,input:Value<B>)->Value<B>{
563 fn f<B:Backend,const N:usize>(input:Tensor<B,N>,layer:&LayerNorm<B>)->Value<B>{
564 let b=layer.beta.dims();
565 let g=layer.gamma.dims();
566 let i=input.dims();
567
568 if b!=g{return format!("malformed layer norm. beta dims: {b:?}. gamma dims: {g:?}.").into()}
569 if b.last()!=i.last(){return format!("layer norm for dimension {b:?} is not compatible with input dimensions {i:?}. The last dimension must match the norm dimension.").into()}
570 layer.forward(input).into()
571 }
572 let l=self;
573
574 match input.float(){F1(x)=>f(x,l),F2(x)=>f(x,l),F3(x)=>f(x,l),F4(x)=>f(x,l),F5(x)=>f(x,l),F6(x)=>f(x,l),F7(x)=>f(x,l),F8(x)=>f(x,l),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(|x|AI::forward(self,x)).collect()),_=>panic!("internal error")}
575 }
576}
577impl<B:Backend> AI<Value<B>,Value<B>> for Linear<B>{
578 fn forward(&self,input:Value<B>)->Value<B>{
579 match input.float(){F1(x)=>F1(self.forward(x)),F2(x)=>F2(self.forward(x)),F3(x)=>F3(self.forward(x)),F4(x)=>F4(self.forward(x)),F5(x)=>F5(self.forward(x)),F6(x)=>F6(self.forward(x)),F7(x)=>F7(self.forward(x)),F8(x)=>F8(self.forward(x)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(|x|AI::forward(self,x)).collect()),_=>panic!("internal error")}
580 }
581}
582impl<B:Backend> AI<Value<B>,Value<B>> for Relu{
583 fn forward(&self,input:Value<B>)->Value<B>{
584 match input.float(){F1(x)=>F1(self.forward(x)),F2(x)=>F2(self.forward(x)),F3(x)=>F3(self.forward(x)),F4(x)=>F4(self.forward(x)),F5(x)=>F5(self.forward(x)),F6(x)=>F6(self.forward(x)),F7(x)=>F7(self.forward(x)),F8(x)=>F8(self.forward(x)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(|x|AI::forward(self,x)).collect()),_=>panic!("internal error")}
585 }
586}
587impl<B:Backend> AI<Value<B>,Value<B>> for SoftmaxLayer{
588 fn forward(&self,input:Value<B>)->Value<B>{
589 fn f<B:Backend,const N:usize>(dim:i32,temperature:f32,x:Tensor<B,N>)->Tensor<B,N>{
590 let dim=if dim<0{N-(-dim) as usize}else{dim as usize};
591 softmax(x/temperature,dim)
592 }
593 let (dim,temperature)=(self.get_dim(),self.get_temperature());
594
595 match input.float(){F1(x)=>F1(f(dim,temperature,x)),F2(x)=>F2(f(dim,temperature,x)),F3(x)=>F3(f(dim,temperature,x)),F4(x)=>F4(f(dim,temperature,x)),F5(x)=>F5(f(dim,temperature,x)),F6(x)=>F6(f(dim,temperature,x)),F7(x)=>F7(f(dim,temperature,x)),F8(x)=>F8(f(dim,temperature,x)),Value::Incompatible(x)=>x.into(),Value::Multi(x)=>x.into_iter().map(|x|self.forward(x)).collect(),_=>panic!("unexpected non float value")}
596 }
597}
598impl<B:Backend> AI<Value<B>,Value<B>> for ChooseLayer{
599 fn forward(&self,input:Value<B>)->Value<B>{let (dim,temperature)=(self.get_dim(),self.get_temperature());
601 let d=if dim<0{input.rank().unwrap_or(8)-((-dim) as usize)}else{dim as usize};
602
603 match input.float(){
604 F1(x)=>I1(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}),
605 F2(x)=>I1(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze(d)),
606 F3(x)=>I2(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze(d)),
607 F4(x)=>I3(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze(d)),
608 F5(x)=>I4(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze(d)),
609 F6(x)=>I5(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze(d)),
610 F7(x)=>I6(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze(d)),
611 F8(x)=>I7(if temperature.is_nan(){hard_choose_burn_tensor(dim,x)}else{soft_choose_burn_tensor(dim,x,temperature)}.squeeze(d)),
612 Value::Incompatible(e)=>e.into(),
613 Value::Multi(v)=>Value::Multi(v.into_iter().map(|v|self.forward_typed::<_,Value<B>>(v)).collect()),
614 _=>panic!("internal error")}
615 }
616}
617impl<B:Backend> AI<Value<B>,Value<B>> for MaxPool2d{
618 fn forward(&self,input:Value<B>)->Value<B>{
619 fn f<B:Backend,const N:usize>(pool:&MaxPool2d,x:Tensor<B,N>)->Value<B>{
620 match N{
621 0=>panic!("internal error"),
622 1=>f::<B,2>(pool,x.unsqueeze()).squeeze(),
623 2=>f::<B,3>(pool,x.unsqueeze()).squeeze(),
624 3=>f::<B,4>(pool,x.unsqueeze()).squeeze(),
625 4=>pool.forward(Value::from(x).unwrap_f4()).into(),
626 _=>{
627 let mut dims=x.dims();
628
629 let [channels,h,w]=[dims[N-3],dims[N-2],dims[N-1]];
630 let big:usize=dims.iter().take(N-3).product();
631 let y=x.reshape([big,channels,h,w]);
632
633 dims[N-3..].copy_from_slice(&y.dims()[1..]);
634
635 let y=pool.forward(y);
636 y.reshape(dims).into()
637 }
638 }
639 }
640 match input.float(){
641 F1(x)=>f(self,x),
642 F2(x)=>f(self,x),
643 F3(x)=>f(self,x),
644 F4(x)=>f(self,x),
645 F5(x)=>f(self,x),
646 F6(x)=>f(self,x),
647 F7(x)=>f(self,x),
648 F8(x)=>f(self,x),
649 Value::Incompatible(e)=>e.into(),
650 Value::Multi(v)=>v.into_iter().map(|x|AI::forward(self,x)).collect(),
651 _=>panic!("Internal error")
652 }
653 }
654}
655impl<B:Backend> AI<Value<B>,Value<B>> for RotaryEncoding<B>{
656 fn forward(&self,input:Value<B>)->Value<B>{AI::forward(self,(input,0)).0}
657}
658impl<B:Backend> AI<(Value<B>,usize),(Value<B>,usize)> for RotaryEncoding<B>{
659 fn forward(&self,(input,offset):(Value<B>,usize))->(Value<B>,usize){
660 fn apply<B:Backend,const D:usize>(a:&RotaryEncoding<B>,input:Tensor<B,D>,offset:usize)->Value<B>{
661 assert!(D>=2);
662 const MAX_KERNEL:usize=65535; let device=input.device();
664 let freq=&a.freq_complex;
665 let shape=input.shape();
666
667 let (context,key)=(shape.dims[D-2],shape.dims[D-1]);
668 let [distance,head,_2]=freq.dims();
669
670 if context>distance{return format!("context length must not exceed rotary distance. context: {context}, distance: {distance}").into()}
671 if key%head!=0{return "input dimension must be a multiple of head".into()}
672 let count=shape.num_elements();
673 let big=count/(context*key);
674 let heads=key/head;
675 let group=count/head; let input=input.reshape([big,context,heads,head]).swap_dims(1,2).reshape([big*heads,context,head/2,2]);
677 let sign=Tensor::<B,2>::from_floats([[1.0,0.0,0.0,1.0],[0.0,-1.0,1.0,0.0]],&device).unsqueeze();
678
679 let chunks=input.chunk(group.div_ceil(MAX_KERNEL),0).into_iter().map(|x|{
680 let small=x.dims()[0];
681 let x=x.matmul(sign.clone()).reshape([small,context,head,2])*freq.clone().slice([offset..context+offset]).unsqueeze();
682 x.sum_dim(3)
683 }).collect();
684 Tensor::cat(chunks,0).reshape([big,heads,context,head]).swap_dims(1,2).reshape::<D,_>(shape).into()
685 }
686
687 (match input.float(){F1(x)=>apply(self,x.unsqueeze::<2>(),offset).squeeze(),F2(x)=>apply(self,x,offset),F3(x)=>apply(self,x,offset),F4(x)=>apply(self,x,offset),F5(x)=>apply(self,x,offset),F6(x)=>apply(self,x,offset),F7(x)=>apply(self,x,offset),F8(x)=>apply(self,x,offset),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|AI::forward(self,(x,offset)).0).collect(),_=>panic!("internal error")},offset)
688 }
689}
690impl<B:Backend> AI<Value<B>,Value<B>> for Tanh{
691 fn forward(&self,input:Value<B>)->Value<B>{
692 match input.float(){F1(x)=>F1(self.forward(x)),F2(x)=>F2(self.forward(x)),F3(x)=>F3(self.forward(x)),F4(x)=>F4(self.forward(x)),F5(x)=>F5(self.forward(x)),F6(x)=>F6(self.forward(x)),F7(x)=>F7(self.forward(x)),F8(x)=>F8(self.forward(x)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(|x|AI::forward(self,x)).collect()),_=>panic!("internal error")}
693 }
694}
695impl<B:Backend> Abs for Value<B>{
696 fn abs(self)->Self::Output{
697 match self{B1(x)=>B1(x),B2(x)=>B2(x),B3(x)=>B3(x),B4(x)=>B4(x),B5(x)=>B5(x),B6(x)=>B6(x),B7(x)=>B7(x),B8(x)=>B8(x),F1(x)=>F1(x.abs()),F2(x)=>F2(x.abs()),F3(x)=>F3(x.abs()),F4(x)=>F4(x.abs()),F5(x)=>F5(x.abs()),F6(x)=>F6(x.abs()),F7(x)=>F7(x.abs()),F8(x)=>F8(x.abs()),I1(x)=>I1(x.abs()),I2(x)=>I2(x.abs()),I3(x)=>I3(x.abs()),I4(x)=>I4(x.abs()),I5(x)=>I5(x.abs()),I6(x)=>I6(x.abs()),I7(x)=>I7(x.abs()),I8(x)=>I8(x.abs()),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(Value::abs).collect()}
698 }
699 type Output=Value<B>;
700}
701impl<B:Backend> AsRef<Self> for Value<B>{
702 fn as_ref(&self)->&Self{self}
703}
704impl<B:Backend> Cat for Value<B>{
705 fn cat(self,d:i32)->Self{
707 fn f<B:Backend,I:Iterator<Item=Tensor<B,N,K>>,K:BasicOps<B>+TensorKind<B>,const N:usize>(d:i32,x0:Tensor<B,N,K>,tensors:I)->Value<B> where Tensor<B,N,K>:Into<Value<B>>{
708 if d>=N as i32||d<(-(N as i32)){return format!("rank {N} is too low to cat along dimension {d}").into()}
709 let d=if d<0{N-((-d) as usize)}else{d as usize};
710 let shape=x0.dims();
711 let tensors:Vec<Tensor<B,N,K>>=once(x0).chain(tensors).collect();
712
713 if let Err(e)=tensors.iter().try_for_each(|x|{
714 let mut xshape=x.dims();
715 xshape[d]=shape[d];
716 if shape==xshape{Ok(())}else{Err("mismatched shapes {shape:?}, {xshape:?}")}
717 }){
718 return e.into()
719 }
720
721 Tensor::cat(tensors,d).into()
722 }
723 let v=if let Value::Multi(v)=self{v}else{return self};
724
725 if let Some(n)=v.iter().position(Value::is_incompatible){return v.into_iter().nth(n).unwrap()}
726 if v.iter().all(Value::is_multi){return v.into_iter().map(|x|x.cat(d)).collect()}
727 if v.iter().any(Value::is_multi){return "cannot mix single and multi values in a cat operation".into()}
728 let variant=mem::discriminant(&v[0]);
729
730 if v.iter().any(|x|mem::discriminant(x)!=variant){return "cannot mix variants in a cat operation".into()}
731 let mut v=v.into_iter();
732
733 match v.next().unwrap(){B1(x0)=>f(d,x0,v.map(Value::unwrap_b1)),B2(x0)=>f(d,x0,v.map(Value::unwrap_b2)),B3(x0)=>f(d,x0,v.map(Value::unwrap_b3)),B4(x0)=>f(d,x0,v.map(Value::unwrap_b4)),B5(x0)=>f(d,x0,v.map(Value::unwrap_b5)),B6(x0)=>f(d,x0,v.map(Value::unwrap_b6)),B7(x0)=>f(d,x0,v.map(Value::unwrap_b7)),B8(x0)=>f(d,x0,v.map(Value::unwrap_b8)),F1(x0)=>f(d,x0,v.map(Value::unwrap_f1)),F2(x0)=>f(d,x0,v.map(Value::unwrap_f2)),F3(x0)=>f(d,x0,v.map(Value::unwrap_f3)),F4(x0)=>f(d,x0,v.map(Value::unwrap_f4)),F5(x0)=>f(d,x0,v.map(Value::unwrap_f5)),F6(x0)=>f(d,x0,v.map(Value::unwrap_f6)),F7(x0)=>f(d,x0,v.map(Value::unwrap_f7)),F8(x0)=>f(d,x0,v.map(Value::unwrap_f8)),I1(x0)=>f(d,x0,v.map(Value::unwrap_i1)),I2(x0)=>f(d,x0,v.map(Value::unwrap_i2)),I3(x0)=>f(d,x0,v.map(Value::unwrap_i3)),I4(x0)=>f(d,x0,v.map(Value::unwrap_i4)),I5(x0)=>f(d,x0,v.map(Value::unwrap_i5)),I6(x0)=>f(d,x0,v.map(Value::unwrap_i6)),I7(x0)=>f(d,x0,v.map(Value::unwrap_i7)),I8(x0)=>f(d,x0,v.map(Value::unwrap_i8)),Value::Incompatible(_e)=>panic!("internal error not handled in correct location"),Value::Multi(_e)=>panic!("internal error not handled in correct location")}
734 }
735 type Output=Self;
736}
737impl<B:Backend> Decompose for LossOutput<B>{
738 fn compose((loss,output,target):Self::Decomposition)->Self{Self::new(loss,output,target)}
739 fn decompose(self)->Self::Decomposition{(self.loss(),self.output(),self.target())}
740 fn decompose_cloned(&self)->Self::Decomposition{(self.loss(),self.output(),self.target())}
741 type Decomposition=(Value<B>,Value<B>,Value<B>);
742}
743impl<B:Backend> Decompose for Value<B>{
744 fn compose(decomposition:Self::Decomposition)->Self{decomposition}
745 fn decompose(self)->Self::Decomposition{self}
746 fn decompose_cloned(&self)->Self::Decomposition{self.clone()}
747 type Decomposition=Self;
748}
749impl<B:Backend> Default for Value<B>{
750 fn default()->Self{Self::Multi(Vec::new())}
751}
752impl<B:Backend> Display for Value<B>{
753 fn fmt(&self,f:&mut std::fmt::Formatter<'_>)->FmtResult{
754 match self{
755 B1(x)=>x.fmt(f),
756 B2(x)=>x.fmt(f),
757 B3(x)=>x.fmt(f),
758 B4(x)=>x.fmt(f),
759 B5(x)=>x.fmt(f),
760 B6(x)=>x.fmt(f),
761 B7(x)=>x.fmt(f),
762 B8(x)=>x.fmt(f),
763 F1(x)=>x.fmt(f),
764 F2(x)=>x.fmt(f),
765 F3(x)=>x.fmt(f),
766 F4(x)=>x.fmt(f),
767 F5(x)=>x.fmt(f),
768 F6(x)=>x.fmt(f),
769 F7(x)=>x.fmt(f),
770 F8(x)=>x.fmt(f),
771 I1(x)=>x.fmt(f),
772 I2(x)=>x.fmt(f),
773 I3(x)=>x.fmt(f),
774 I4(x)=>x.fmt(f),
775 I5(x)=>x.fmt(f),
776 I6(x)=>x.fmt(f),
777 I7(x)=>x.fmt(f),
778 I8(x)=>x.fmt(f),
779 Value::Incompatible(e)=>e.fmt(f),
780 Value::Multi(v)=>{
781 write!(f,"[")?;
782 v.iter().take(v.len().saturating_sub(1)).try_for_each(|x|{
783 x.fmt(f)?;
784 write!(f,", ")
785 })?;
786 if let Some(x)=v.last(){
787 x.fmt(f)?;
788 }
789 write!(f,"]")
790 }
791 }
792 }
793}
794impl<B:Backend> From<Vec<bool>> for Value<B>{
795 fn from(value:Vec<bool>)->Self{
796 let l=value.len();
797 let t:Tensor<B,1,Bool>=Tensor::from_data(TensorData::new(value,[l]),&Default::default());
798
799 t.into()
800 }
801}
802impl<B:Backend> From<Vec<f32>> for Value<B>{
803 fn from(value:Vec<f32>)->Self{
804 let l=value.len();
805 let t:Tensor<B,1>=Tensor::from_data(TensorData::new(value,[l]),&Default::default());
806
807 t.into()
808 }
809}
810impl<B:Backend> From<Vec<i32>> for Value<B>{
811 fn from(value:Vec<i32>)->Self{
812 let l=value.len();
813 let t:Tensor<B,1,Int>=Tensor::from_data(TensorData::new(value,[l]),&Default::default());
814
815 t.into()
816 }
817}
818impl<B:Backend> From<Vec<u32>> for Value<B>{
819 fn from(value:Vec<u32>)->Self{
820 let l=value.len();
821 let t:Tensor<B,1,Int>=Tensor::from_data(TensorData::new(value,[l]),&Default::default());
822
823 t.into()
824 }
825}
826impl<B:Backend> ModuleDisplay for Value<B>{
827 fn custom_content(&self,_content:Content)->Option<Content>{None}
828 fn custom_settings(&self)->Option<DisplaySettings>{None}
829 fn format(&self,s:DisplaySettings)->String{
830 match self{
831 B1(x)=>x.format(s),
832 B2(x)=>x.format(s),
833 B3(x)=>x.format(s),
834 B4(x)=>x.format(s),
835 B5(x)=>x.format(s),
836 B6(x)=>x.format(s),
837 B7(x)=>x.format(s),
838 B8(x)=>x.format(s),
839 F1(x)=>x.format(s),
840 F2(x)=>x.format(s),
841 F3(x)=>x.format(s),
842 F4(x)=>x.format(s),
843 F5(x)=>x.format(s),
844 F6(x)=>x.format(s),
845 F7(x)=>x.format(s),
846 F8(x)=>x.format(s),
847 I1(x)=>x.format(s),
848 I2(x)=>x.format(s),
849 I3(x)=>x.format(s),
850 I4(x)=>x.format(s),
851 I5(x)=>x.format(s),
852 I6(x)=>x.format(s),
853 I7(x)=>x.format(s),
854 I8(x)=>x.format(s),
855 Value::Incompatible(e)=>e.to_string(),
856 Value::Multi(v)=>"[".chars().chain(v.iter().flat_map(|x|{
857 let x:Vec<char>=x.format(s.clone()).chars().chain(", ".chars()).collect();
858 x
859 })).chain("]".chars()).collect()
860 }
861 }
862}
863impl<B:Backend> ModuleDisplayDefault for Value<B>{
864 fn content(&self,content:Content)->Option<Content>{Some(content)}
865 fn num_params(&self)->usize{Module::num_params(self)}
866}
867impl<B:Backend> From<String> for Value<B>{
868 fn from(value:String)->Self{Self::Incompatible(value)}
869}
870impl<B:Backend> From<Value<B>> for ValueData{
871 fn from(value:Value<B>)->Self{
872 match value{B1(x)=>BX(x.into_data()),B2(x)=>BX(x.into_data()),B3(x)=>BX(x.into_data()),B4(x)=>BX(x.into_data()),B5(x)=>BX(x.into_data()),B6(x)=>BX(x.into_data()),B7(x)=>BX(x.into_data()),B8(x)=>BX(x.into_data()),F1(x)=>FX(x.into_data()),F2(x)=>FX(x.into_data()),F3(x)=>FX(x.into_data()),F4(x)=>FX(x.into_data()),F5(x)=>FX(x.into_data()),F6(x)=>FX(x.into_data()),F7(x)=>FX(x.into_data()),F8(x)=>FX(x.into_data()),I1(x)=>IX(x.into_data()),I2(x)=>IX(x.into_data()),I3(x)=>IX(x.into_data()),I4(x)=>IX(x.into_data()),I5(x)=>IX(x.into_data()),I6(x)=>IX(x.into_data()),I7(x)=>IX(x.into_data()),I8(x)=>IX(x.into_data()),Value::Incompatible(e)=>ValueData::Incompatible(e),Value::Multi(v)=>ValueData::Multi(v.into_iter().map(ValueData::from).collect())}
873 }
874}
875impl<B:Backend> From<ValueData> for Value<B>{
876 fn from(value:ValueData)->Self{
877 let device=Default::default();
878 match value{
879 BX(data)=>match data.shape.len(){1=>B1(Tensor::from_data(data,&device)),2=>B2(Tensor::from_data(data,&device)),3=>B3(Tensor::from_data(data,&device)),4=>B4(Tensor::from_data(data,&device)),5=>B5(Tensor::from_data(data,&device)),6=>B6(Tensor::from_data(data,&device)),7=>B7(Tensor::from_data(data,&device)),8=>B8(Tensor::from_data(data,&device)),_=>panic!("tensor ranks above 8 are currently not supported")},
880 FX(data)=>match data.shape.len(){1=>F1(Tensor::from_data(data,&device)),2=>F2(Tensor::from_data(data,&device)),3=>F3(Tensor::from_data(data,&device)),4=>F4(Tensor::from_data(data,&device)),5=>F5(Tensor::from_data(data,&device)),6=>F6(Tensor::from_data(data,&device)),7=>F7(Tensor::from_data(data,&device)),8=>F8(Tensor::from_data(data,&device)),_=>panic!("tensor ranks above 8 are currently not supported")},
881 IX(data)=>match data.shape.len(){1=>I1(Tensor::from_data(data,&device)),2=>I2(Tensor::from_data(data,&device)),3=>I3(Tensor::from_data(data,&device)),4=>I4(Tensor::from_data(data,&device)),5=>I5(Tensor::from_data(data,&device)),6=>I6(Tensor::from_data(data,&device)),7=>I7(Tensor::from_data(data,&device)),8=>I8(Tensor::from_data(data,&device)),_=>panic!("tensor ranks above 8 are currently not supported")},
882 ValueData::Incompatible(e)=>e.into(),
883 ValueData::Multi(v)=>v.into_iter().map(Value::from).collect(),
884 }
885 }
886}
887impl<B:Backend> From<Vec<Value<B>>> for Value<B>{
888 fn from(value:Vec<Value<B>>)->Self{Self::Multi(value)}
889}
890impl<B:Backend> IntoIterator for Value<B>{
891 fn into_iter(self)->Self::IntoIter{self.into_multi().into_iter()}
892 type IntoIter=VecIntoIter<Value<B>>;
893 type Item=Value<B>;
894}
895impl<B:Backend> LossOutput<B>{
896 pub fn loss(&self)->Value<B>{self.loss.clone()}
898 pub fn new(loss:Value<B>,output:Value<B>,target:Value<B>)->Self{
900 Self{loss,output,target}
901 }
902 pub fn output(&self)->Value<B>{self.output.clone()}
904 pub fn target(&self)->Value<B>{self.target.clone()}
906}
907impl<B:Backend> Merge for Value<B>{
908 fn merge(&mut self,other:Self){
909 match (mem::take(self),other){
910 (Value::Multi(mut u),Value::Multi(v))=>{
911 u.extend(v);
912 *self=u.into();
913 },
914 (Value::Multi(mut u),v)=>if u.len()==0{
915 *self=v;
916 }else{
917 u.push(v);
918 *self=u.into();
919 },
920 (u,Value::Multi(mut v))=>if v.len()==0{
921 *self=u;
922 }else{
923 v.insert(0,u);
924 *self=v.into();
925 },
926 (u,v)=>*self=vec![u,v].into()
927 }
928 }
929}
930impl<B:Backend> Module<B> for Value<B>{
931 fn collect_devices(&self,devices:Vec<<B as Backend>::Device>)->Vec<<B as Backend>::Device>{
932 match self{B1(x)=>x.collect_devices(devices),B2(x)=>x.collect_devices(devices),B3(x)=>x.collect_devices(devices),B4(x)=>x.collect_devices(devices),B5(x)=>x.collect_devices(devices),B6(x)=>x.collect_devices(devices),B7(x)=>x.collect_devices(devices),B8(x)=>x.collect_devices(devices),F1(x)=>x.collect_devices(devices),F2(x)=>x.collect_devices(devices),F3(x)=>x.collect_devices(devices),F4(x)=>x.collect_devices(devices),F5(x)=>x.collect_devices(devices),F6(x)=>x.collect_devices(devices),F7(x)=>x.collect_devices(devices),F8(x)=>x.collect_devices(devices),I1(x)=>x.collect_devices(devices),I2(x)=>x.collect_devices(devices),I3(x)=>x.collect_devices(devices),I4(x)=>x.collect_devices(devices),I5(x)=>x.collect_devices(devices),I6(x)=>x.collect_devices(devices),I7(x)=>x.collect_devices(devices),I8(x)=>x.collect_devices(devices),Value::Incompatible(_e)=>devices,Value::Multi(v)=>v.iter().fold(devices,|devices,x|x.collect_devices(devices))}
933 }
934 fn devices(&self)->Vec<<B as Backend>::Device>{self.collect_devices(Vec::new())}
935 fn fork(self,device:&<B as Backend>::Device)->Self{
936 match self{B1(x)=>B1(x.fork(device)),B2(x)=>B2(x.fork(device)),B3(x)=>B3(x.fork(device)),B4(x)=>B4(x.fork(device)),B5(x)=>B5(x.fork(device)),B6(x)=>B6(x.fork(device)),B7(x)=>B7(x.fork(device)),B8(x)=>B8(x.fork(device)),F1(x)=>F1(x.fork(device)),F2(x)=>F2(x.fork(device)),F3(x)=>F3(x.fork(device)),F4(x)=>F4(x.fork(device)),F5(x)=>F5(x.fork(device)),F6(x)=>F6(x.fork(device)),F7(x)=>F7(x.fork(device)),F8(x)=>F8(x.fork(device)),I1(x)=>I1(x.fork(device)),I2(x)=>I2(x.fork(device)),I3(x)=>I3(x.fork(device)),I4(x)=>I4(x.fork(device)),I5(x)=>I5(x.fork(device)),I6(x)=>I6(x.fork(device)),I7(x)=>I7(x.fork(device)),I8(x)=>I8(x.fork(device)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.fork(device)).collect()}
937 }
938 fn into_record(self)->Self::Record{ConstantRecord}
939 fn load_file<F:FileRecorder<B>,P:Into<PathBuf>>(self,_filepath:P,_recorder:&F,_device:&<B as Backend>::Device)->Result<Self,RecorderError>{Ok(self)}
940 fn load_record(self,_record:Self::Record)->Self{self}
941 fn map<Mapper:ModuleMapper<B>>(self,mapper:&mut Mapper)->Self{
942 match self{B1(x)=>B1(x.map(mapper)),B2(x)=>B2(x.map(mapper)),B3(x)=>B3(x.map(mapper)),B4(x)=>B4(x.map(mapper)),B5(x)=>B5(x.map(mapper)),B6(x)=>B6(x.map(mapper)),B7(x)=>B7(x.map(mapper)),B8(x)=>B8(x.map(mapper)),F1(x)=>F1(x.map(mapper)),F2(x)=>F2(x.map(mapper)),F3(x)=>F3(x.map(mapper)),F4(x)=>F4(x.map(mapper)),F5(x)=>F5(x.map(mapper)),F6(x)=>F6(x.map(mapper)),F7(x)=>F7(x.map(mapper)),F8(x)=>F8(x.map(mapper)),I1(x)=>I1(x.map(mapper)),I2(x)=>I2(x.map(mapper)),I3(x)=>I3(x.map(mapper)),I4(x)=>I4(x.map(mapper)),I5(x)=>I5(x.map(mapper)),I6(x)=>I6(x.map(mapper)),I7(x)=>I7(x.map(mapper)),I8(x)=>I8(x.map(mapper)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.map(mapper)).collect()}
943 }
944 fn num_params(&self)->usize{
945 match self{B1(x)=>Module::num_params(x),B2(x)=>Module::num_params(x),B3(x)=>Module::num_params(x),B4(x)=>Module::num_params(x),B5(x)=>Module::num_params(x),B6(x)=>Module::num_params(x),B7(x)=>Module::num_params(x),B8(x)=>Module::num_params(x),F1(x)=>Module::num_params(x),F2(x)=>Module::num_params(x),F3(x)=>Module::num_params(x),F4(x)=>Module::num_params(x),F5(x)=>Module::num_params(x),F6(x)=>Module::num_params(x),F7(x)=>Module::num_params(x),F8(x)=>Module::num_params(x),I1(x)=>Module::num_params(x),I2(x)=>Module::num_params(x),I3(x)=>Module::num_params(x),I4(x)=>Module::num_params(x),I5(x)=>Module::num_params(x),I6(x)=>Module::num_params(x),I7(x)=>Module::num_params(x),I8(x)=>Module::num_params(x),Value::Incompatible(_e)=>0,Value::Multi(v)=>v.into_iter().map(|x|Module::num_params(x)).sum()}
946 }
947 fn quantize_weights(self,quantizer:&mut Quantizer)->Self{
948 match self{B1(x)=>B1(x.quantize_weights(quantizer)),B2(x)=>B2(x.quantize_weights(quantizer)),B3(x)=>B3(x.quantize_weights(quantizer)),B4(x)=>B4(x.quantize_weights(quantizer)),B5(x)=>B5(x.quantize_weights(quantizer)),B6(x)=>B6(x.quantize_weights(quantizer)),B7(x)=>B7(x.quantize_weights(quantizer)),B8(x)=>B8(x.quantize_weights(quantizer)),F1(x)=>F1(x.quantize_weights(quantizer)),F2(x)=>F2(x.quantize_weights(quantizer)),F3(x)=>F3(x.quantize_weights(quantizer)),F4(x)=>F4(x.quantize_weights(quantizer)),F5(x)=>F5(x.quantize_weights(quantizer)),F6(x)=>F6(x.quantize_weights(quantizer)),F7(x)=>F7(x.quantize_weights(quantizer)),F8(x)=>F8(x.quantize_weights(quantizer)),I1(x)=>I1(x.quantize_weights(quantizer)),I2(x)=>I2(x.quantize_weights(quantizer)),I3(x)=>I3(x.quantize_weights(quantizer)),I4(x)=>I4(x.quantize_weights(quantizer)),I5(x)=>I5(x.quantize_weights(quantizer)),I6(x)=>I6(x.quantize_weights(quantizer)),I7(x)=>I7(x.quantize_weights(quantizer)),I8(x)=>I8(x.quantize_weights(quantizer)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.quantize_weights(quantizer)).collect()}
949 }
950 fn save_file<F:FileRecorder<B>,P:Into<PathBuf>>(self,_filepath:P,_recorder:&F)->Result<(),RecorderError>{
951 Ok(())
952 }
953 fn to_device(self,device:&<B as Backend>::Device)->Self{
954 match self{B1(x)=>B1(x.to_device(device)),B2(x)=>B2(x.to_device(device)),B3(x)=>B3(x.to_device(device)),B4(x)=>B4(x.to_device(device)),B5(x)=>B5(x.to_device(device)),B6(x)=>B6(x.to_device(device)),B7(x)=>B7(x.to_device(device)),B8(x)=>B8(x.to_device(device)),F1(x)=>F1(x.to_device(device)),F2(x)=>F2(x.to_device(device)),F3(x)=>F3(x.to_device(device)),F4(x)=>F4(x.to_device(device)),F5(x)=>F5(x.to_device(device)),F6(x)=>F6(x.to_device(device)),F7(x)=>F7(x.to_device(device)),F8(x)=>F8(x.to_device(device)),I1(x)=>I1(x.to_device(device)),I2(x)=>I2(x.to_device(device)),I3(x)=>I3(x.to_device(device)),I4(x)=>I4(x.to_device(device)),I5(x)=>I5(x.to_device(device)),I6(x)=>I6(x.to_device(device)),I7(x)=>I7(x.to_device(device)),I8(x)=>I8(x.to_device(device)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.to_device(device)).collect()}
955 }
956 fn visit<Visitor:ModuleVisitor<B>>(&self,visitor:&mut Visitor){
957 match self{B1(x)=>x.visit(visitor),B2(x)=>x.visit(visitor),B3(x)=>x.visit(visitor),B4(x)=>x.visit(visitor),B5(x)=>x.visit(visitor),B6(x)=>x.visit(visitor),B7(x)=>x.visit(visitor),B8(x)=>x.visit(visitor),F1(x)=>x.visit(visitor),F2(x)=>x.visit(visitor),F3(x)=>x.visit(visitor),F4(x)=>x.visit(visitor),F5(x)=>x.visit(visitor),F6(x)=>x.visit(visitor),F7(x)=>x.visit(visitor),F8(x)=>x.visit(visitor),I1(x)=>x.visit(visitor),I2(x)=>x.visit(visitor),I3(x)=>x.visit(visitor),I4(x)=>x.visit(visitor),I5(x)=>x.visit(visitor),I6(x)=>x.visit(visitor),I7(x)=>x.visit(visitor),I8(x)=>x.visit(visitor),Value::Incompatible(_e)=>(),Value::Multi(v)=>v.iter().for_each(|x|x.visit(visitor))}
958 }
959 type Record=ConstantRecord;
960}
961impl<B:Backend> Serialize for Value<B>{
962 fn serialize<S:Serializer>(&self,serializer:S)->Result<S::Ok,S::Error>{ValueData::from(self.clone()).serialize(serializer)}
963}
964impl<B:Backend> Squeeze for Value<B>{
965 fn squeeze(self,d:i32)->Self{self.squeeze_dim(d)}
966 type Output=Self;
967}
968impl<B:Backend> Stack for Value<B>{
969 fn stack(self,d:i32)->Self{self.unsqueeze_dim(d).cat(d)}
971 type Output=Self;
972}
973impl<B:Backend> Unsqueeze for Value<B>{
974 fn unsqueeze(self,d:i32)->Self{self.unsqueeze_dim(d)}
975 type Output=Self;
976}
977impl<B:Backend> Value<B>{pub fn all(self)->Value<B>{
980 fn f<B:Backend,const N:usize>(x:Tensor<B,N,Bool>)->Value<B>{x.all().into()}
981 match self.bool(){B1(x)=>f(x),B2(x)=>f(x),B3(x)=>f(x),B4(x)=>f(x),B5(x)=>f(x),B6(x)=>f(x),B7(x)=>f(x),B8(x)=>f(x),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(Value::all).collect(),_=>panic!("internal error")}
982 }
983 pub fn all_dim(self,d:i32)->Value<B>{
985 fn f<B:Backend,const N:usize>(d:i32,x:Tensor<B,N,Bool>)->Value<B>{
986 if d>=N as i32||d<(-(N as i32)){return format!("rank {N} is too low to all along dimension {d}").into()}
987 let d=if d<0{N-((-d) as usize)}else{d as usize};
988 x.all_dim(d).into()
989 }
990 match self.bool(){B1(x)=>f(d,x),B2(x)=>f(d,x),B3(x)=>f(d,x),B4(x)=>f(d,x),B5(x)=>f(d,x),B6(x)=>f(d,x),B7(x)=>f(d,x),B8(x)=>f(d,x),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|v|v.all_dim(d)).collect(),_=>panic!("internal error")}
991 }
992 pub fn any(self)->Value<B>{
994 fn f<B:Backend,const N:usize>(x:Tensor<B,N,Bool>)->Value<B>{x.any().into()}
995 match self.bool(){B1(x)=>f(x),B2(x)=>f(x),B3(x)=>f(x),B4(x)=>f(x),B5(x)=>f(x),B6(x)=>f(x),B7(x)=>f(x),B8(x)=>f(x),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(Value::any).collect(),_=>panic!("internal error")}
996 }
997 pub fn any_dim(self,d:i32)->Value<B>{
999 fn f<B:Backend,const N:usize>(d:i32,x:Tensor<B,N,Bool>)->Value<B>{
1000 if d>=N as i32||d<(-(N as i32)){return format!("rank {N} is too low to any along dimension {d}").into()}
1001 let d=if d<0{N-((-d) as usize)}else{d as usize};
1002 x.any_dim(d).into()
1003 }
1004 match self.bool(){B1(x)=>f(d,x),B2(x)=>f(d,x),B3(x)=>f(d,x),B4(x)=>f(d,x),B5(x)=>f(d,x),B6(x)=>f(d,x),B7(x)=>f(d,x),B8(x)=>f(d,x),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|v|v.any_dim(d)).collect(),_=>panic!("internal error")}
1005 }
1006 pub fn bool(self)->Value<B>{
1008 match self{B1(x)=>B1(x),B2(x)=>B2(x),B3(x)=>B3(x),B4(x)=>B4(x),B5(x)=>B5(x),B6(x)=>B6(x),B7(x)=>B7(x),B8(x)=>B8(x),F1(x)=>B1(x.bool()),F2(x)=>B2(x.bool()),F3(x)=>B3(x.bool()),F4(x)=>B4(x.bool()),F5(x)=>B5(x.bool()),F6(x)=>B6(x.bool()),F7(x)=>B7(x.bool()),F8(x)=>B8(x.bool()),I1(x)=>B1(x.bool()),I2(x)=>B2(x.bool()),I3(x)=>B3(x.bool()),I4(x)=>B4(x.bool()),I5(x)=>B5(x.bool()),I6(x)=>B6(x.bool()),I7(x)=>B7(x.bool()),I8(x)=>B8(x.bool()),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(Value::bool).collect())}
1009 }
1010 pub fn count(&self)->usize{
1012 match self{
1013 B1(x)=>x.dims().iter().product(),
1014 B2(x)=>x.dims().iter().product(),
1015 B3(x)=>x.dims().iter().product(),
1016 B4(x)=>x.dims().iter().product(),
1017 B5(x)=>x.dims().iter().product(),
1018 B6(x)=>x.dims().iter().product(),
1019 B7(x)=>x.dims().iter().product(),
1020 B8(x)=>x.dims().iter().product(),
1021 F1(x)=>x.dims().iter().product(),
1022 F2(x)=>x.dims().iter().product(),
1023 F3(x)=>x.dims().iter().product(),
1024 F4(x)=>x.dims().iter().product(),
1025 F5(x)=>x.dims().iter().product(),
1026 F6(x)=>x.dims().iter().product(),
1027 F7(x)=>x.dims().iter().product(),
1028 F8(x)=>x.dims().iter().product(),
1029 I1(x)=>x.dims().iter().product(),
1030 I2(x)=>x.dims().iter().product(),
1031 I3(x)=>x.dims().iter().product(),
1032 I4(x)=>x.dims().iter().product(),
1033 I5(x)=>x.dims().iter().product(),
1034 I6(x)=>x.dims().iter().product(),
1035 I7(x)=>x.dims().iter().product(),
1036 I8(x)=>x.dims().iter().product(),
1037 Value::Incompatible(_e)=>0,
1038 Value::Multi(v)=>v.iter().map(Value::count).sum()
1039 }
1040 }
1041 pub fn empty()->Self{Self::Multi(Vec::new())}
1043 pub fn flatten_values(self)->Self{
1047 fn f<B:Backend>(mut acc:Vec<Value<B>>,x:Value<B>)->Vec<Value<B>>{
1048 if x.is_multi(){acc=x.into_iter().fold(acc,|acc,x|f(acc,x))}else{acc.push(x)}
1049 acc
1050 }
1051 f(Vec::new(),self).into()
1052 }
1053 pub fn float(self)->Value<B>{
1055 match self{B1(x)=>F1(x.float()),B2(x)=>F2(x.float()),B3(x)=>F3(x.float()),B4(x)=>F4(x.float()),B5(x)=>F5(x.float()),B6(x)=>F6(x.float()),B7(x)=>F7(x.float()),B8(x)=>F8(x.float()),F1(x)=>F1(x),F2(x)=>F2(x),F3(x)=>F3(x),F4(x)=>F4(x),F5(x)=>F5(x),F6(x)=>F6(x),F7(x)=>F7(x),F8(x)=>F8(x),I1(x)=>F1(x.float()),I2(x)=>F2(x.float()),I3(x)=>F3(x.float()),I4(x)=>F4(x.float()),I5(x)=>F5(x.float()),I6(x)=>F6(x.float()),I7(x)=>F7(x.float()),I8(x)=>F8(x.float()),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(Value::float).collect())}
1056 }
1057 pub fn from_values<I:IntoIterator<Item=Self>,S:Into<Shape>>(inner:I,shape:S)->Self{
1059 fn f<B:Backend,I:Iterator<Item=Value<B>>,S:Into<Shape>>(inner:&mut I,shape:S)->Value<B>{
1060 match shape.into(){
1061 Shape::Incompatible(e)=>e.into(),
1062 Shape::Multi(_l)=>inner.collect(),
1063 Shape::Recursive(v)=>v.into_iter().map(|s|f(&mut *inner,s)).collect(),
1064 X1(_s)=>inner.next().unwrap_or_default(),
1065 X2(_s)=>inner.next().unwrap_or_default(),
1066 X3(_s)=>inner.next().unwrap_or_default(),
1067 X4(_s)=>inner.next().unwrap_or_default(),
1068 X5(_s)=>inner.next().unwrap_or_default(),
1069 X6(_s)=>inner.next().unwrap_or_default(),
1070 X7(_s)=>inner.next().unwrap_or_default(),
1071 X8(_s)=>inner.next().unwrap_or_default(),
1072 }
1073 }
1074 f(&mut inner.into_iter(),shape)
1075 }
1076 pub fn gather(self,dim:i32,indices:Value<B>)->Self{
1078 fn b<B:Backend,const N:usize>(d:i32,data:Tensor<B,N,Bool>,indices:Tensor<B,N,Int>)->Value<B>{f(d,data.int(),indices).bool()}
1079 fn f<B:Backend,K:'static+BasicOps<B>+Numeric<B>+TensorKind<B>,const N:usize>(d:i32,data:Tensor<B,N,K>,indices:Tensor<B,N,Int>)->Value<B>{
1080 let d=if d<0{N-((-d) as usize)}else{d as usize};
1081 if d>=N{format!("dim {d} must be less than rank {N}").into()}else{data.gather(d,indices).into()}
1082 }
1083
1084 match (self,indices){(B1(x),I1(i))=>b(dim,x,i),(B2(x),I2(i))=>b(dim,x,i),(B3(x),I3(i))=>b(dim,x,i),(B4(x),I4(i))=>b(dim,x,i),(B5(x),I5(i))=>b(dim,x,i),(B6(x),I6(i))=>b(dim,x,i),(B7(x),I7(i))=>b(dim,x,i),(B8(x),I8(i))=>b(dim,x,i),(F1(x),I1(i))=>f(dim,x,i),(F2(x),I2(i))=>f(dim,x,i),(F3(x),I3(i))=>f(dim,x,i),(F4(x),I4(i))=>f(dim,x,i),(F5(x),I5(i))=>f(dim,x,i),(F6(x),I6(i))=>f(dim,x,i),(F7(x),I7(i))=>f(dim,x,i),(F8(x),I8(i))=>f(dim,x,i),(I1(x),I1(i))=>f(dim,x,i),(I2(x),I2(i))=>f(dim,x,i),(I3(x),I3(i))=>f(dim,x,i),(I4(x),I4(i))=>f(dim,x,i),(I5(x),I5(i))=>f(dim,x,i),(I6(x),I6(i))=>f(dim,x,i),(I7(x),I7(i))=>f(dim,x,i),(I8(x),I8(i))=>f(dim,x,i),(Value::Incompatible(e),_)=>e.into(),(_,Value::Incompatible(e))=>e.into(),(Value::Multi(u),Value::Multi(v))=>u.into_iter().zip(v).map(|(u,v)|u.gather(dim,v)).collect(),_=>"gather is only available for tensors of matching dimensions with int indices".into()}
1085 }
1086 pub fn int(self)->Value<B>{
1088 match self{B1(x)=>I1(x.int()),B2(x)=>I2(x.int()),B3(x)=>I3(x.int()),B4(x)=>I4(x.int()),B5(x)=>I5(x.int()),B6(x)=>I6(x.int()),B7(x)=>I7(x.int()),B8(x)=>I8(x.int()),F1(x)=>I1(x.int()),F2(x)=>I2(x.int()),F3(x)=>I3(x.int()),F4(x)=>I4(x.int()),F5(x)=>I5(x.int()),F6(x)=>I6(x.int()),F7(x)=>I7(x.int()),F8(x)=>I8(x.int()),I1(x)=>I1(x),I2(x)=>I2(x),I3(x)=>I3(x),I4(x)=>I4(x),I5(x)=>I5(x),I6(x)=>I6(x),I7(x)=>I7(x),I8(x)=>I8(x),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>Value::Multi(v.into_iter().map(Value::int).collect())}
1089 }
1090 pub fn into_float_vec(self)->Vec<f32>{
1092 fn cat_vec<T>(mut a:Vec<T>,b:Vec<T>)->Vec<T>{
1093 a.extend(b);
1094 a
1095 }
1096 fn to_vec<B:Backend,const N:usize>(x:Tensor<B,N>)->Vec<f32>{x.into_data().to_vec().unwrap_or_default()}
1097
1098 match self.float(){F1(x)=>to_vec(x),F2(x)=>to_vec(x),F3(x)=>to_vec(x),F4(x)=>to_vec(x),F5(x)=>to_vec(x),F6(x)=>to_vec(x),F7(x)=>to_vec(x),F8(x)=>to_vec(x),Value::Incompatible(_e)=>Vec::new(),Value::Multi(v)=>v.into_iter().map(Value::into_float_vec).reduce(cat_vec).unwrap_or_default(),_=>panic!("internal error")}
1099 }
1100 pub fn into_int_vec(self)->Vec<i32>{
1102 fn cat_vec<T>(mut a:Vec<T>,b:Vec<T>)->Vec<T>{
1103 a.extend(b);
1104 a
1105 }
1106 fn to_vec<B:Backend,const N:usize>(x:Tensor<B,N,Int>)->Vec<i32>{x.into_data().to_vec().unwrap_or_default()}
1107
1108 match self.int(){I1(x)=>to_vec(x),I2(x)=>to_vec(x),I3(x)=>to_vec(x),I4(x)=>to_vec(x),I5(x)=>to_vec(x),I6(x)=>to_vec(x),I7(x)=>to_vec(x),I8(x)=>to_vec(x),Value::Incompatible(_e)=>Vec::new(),Value::Multi(v)=>v.into_iter().map(Value::into_int_vec).reduce(cat_vec).unwrap_or_default(),_=>panic!("internal error")}
1109 }
1110 pub fn height(&self)->usize{
1112 if let Value::Multi(v)=self{v.iter().map(Value::height).max()}else{None}.unwrap_or(0)
1113 }
1114 pub fn is_empty(&self)->bool{self.len()==0}
1116 pub fn is_incompatible(&self)->bool{
1118 if let Value::Incompatible(_x)=self{true}else{false}
1119 }
1120 pub fn into_multi(self)->Vec<Value<B>>{
1122 if let Value::Multi(v)=self{v}else{vec![self]}
1123 }
1124 pub fn is_multi(&self)->bool{
1126 if let Value::Multi(_v)=self{true}else{false}
1127 }
1128 pub fn iter(&self)->SliceIter<'_,Self>{
1130 if let Value::Multi(v)=self{v.iter()}else{slice::from_ref(self).iter()}
1131 }
1132 pub fn kind(&self)->Kind{
1134 match self{B1(_x)=>Kind::Bool,B2(_x)=>Kind::Bool,B3(_x)=>Kind::Bool,B4(_x)=>Kind::Bool,B5(_x)=>Kind::Bool,B6(_x)=>Kind::Bool,B7(_x)=>Kind::Bool,B8(_x)=>Kind::Bool,F1(_x)=>Kind::Float,F2(_x)=>Kind::Float,F3(_x)=>Kind::Float,F4(_x)=>Kind::Float,F5(_x)=>Kind::Float,F6(_x)=>Kind::Float,F7(_x)=>Kind::Float,F8(_x)=>Kind::Float,I1(_x)=>Kind::Int,I2(_x)=>Kind::Int,I3(_x)=>Kind::Int,I4(_x)=>Kind::Int,I5(_x)=>Kind::Int,I6(_x)=>Kind::Int,I7(_x)=>Kind::Int,I8(_x)=>Kind::Int,Value::Incompatible(_v)=>Kind::Incompatible,Value::Multi(_v)=>Kind::Multi}
1135 }
1136 pub fn len(&self)->usize{
1138 if let Value::Multi(v)=self{v.len()}else{1}
1139 }
1140 pub fn len_recursive(&self)->usize{
1142 if let Value::Multi(v)=self{v.iter().map(Value::len_recursive).sum()}else{1}
1143 }
1144 pub fn map_multi<F:FnMut(Value<B>)->Value<B>>(self,depth:usize,mut f:F)->Self{
1146 fn recur<B:Backend,F:FnMut(Value<B>)->Value<B>>(depth:usize,f:&mut F,x:Value<B>)->(Value<B>,usize){
1147 if x.is_empty(){
1148 (if depth==0{f(x)}else{x},0)
1149 }else if x.is_multi(){
1150 let mut height=0;
1151 let y=x.into_iter().map(|x|{
1152 let (y,h)=recur(depth,f,x);
1153 height=height.max(h);
1154 y
1155 }).collect();
1156 height+=1;
1157 (if depth==height{f(y)}else{y},height)
1158 }else{
1159 (if depth==0{f(x)}else{x},0)
1160 }
1161 }
1162 recur(depth,&mut f,self).0
1163 }
1164 pub fn map_values<F:FnMut(Value<B>)->Value<B>>(self,mut f:F)->Self{
1166 fn recur<B:Backend,F:FnMut(Value<B>)->Value<B>>(f:&mut F,x:Value<B>)->Value<B>{
1167 if let Value::Multi(v)=x{v.into_iter().map(|x|recur(f,x)).collect()}else{f(x)}
1168 }
1169 recur(&mut f,self)
1170 }
1171 pub fn mask_fill(self,mask:Value<B>,v:f32)->Self{
1173 let (x,mask)=self.promote_rank(mask.bool());
1174 match (x,mask){
1175 (B1(x),B1(m))=>B1(x.int().mask_fill(m,v).bool()),
1176 (B2(x),B2(m))=>B2(x.int().mask_fill(m,v).bool()),
1177 (B3(x),B3(m))=>B3(x.int().mask_fill(m,v).bool()),
1178 (B4(x),B4(m))=>B4(x.int().mask_fill(m,v).bool()),
1179 (B5(x),B5(m))=>B5(x.int().mask_fill(m,v).bool()),
1180 (B6(x),B6(m))=>B6(x.int().mask_fill(m,v).bool()),
1181 (B7(x),B7(m))=>B7(x.int().mask_fill(m,v).bool()),
1182 (B8(x),B8(m))=>B8(x.int().mask_fill(m,v).bool()),
1183 (F1(x),B1(m))=>F1(x.mask_fill(m,v)),
1184 (F2(x),B2(m))=>F2(x.mask_fill(m,v)),
1185 (F3(x),B3(m))=>F3(x.mask_fill(m,v)),
1186 (F4(x),B4(m))=>F4(x.mask_fill(m,v)),
1187 (F5(x),B5(m))=>F5(x.mask_fill(m,v)),
1188 (F6(x),B6(m))=>F6(x.mask_fill(m,v)),
1189 (F7(x),B7(m))=>F7(x.mask_fill(m,v)),
1190 (F8(x),B8(m))=>F8(x.mask_fill(m,v)),
1191 (I1(x),B1(m))=>I1(x.mask_fill(m,v)),
1192 (I2(x),B2(m))=>I2(x.mask_fill(m,v)),
1193 (I3(x),B3(m))=>I3(x.mask_fill(m,v)),
1194 (I4(x),B4(m))=>I4(x.mask_fill(m,v)),
1195 (I5(x),B5(m))=>I5(x.mask_fill(m,v)),
1196 (I6(x),B6(m))=>I6(x.mask_fill(m,v)),
1197 (I7(x),B7(m))=>I7(x.mask_fill(m,v)),
1198 (I8(x),B8(m))=>I8(x.mask_fill(m,v)),
1199 (Value::Incompatible(e),_)=>e.into(),
1200 (_,Value::Incompatible(e))=>e.into(),
1201 (Value::Multi(x),m)=>broadcast_multi(x,m.into_multi(),|x,m|x.mask_fill(m,v)),
1202 (x,Value::Multi(m))=>broadcast_multi(x.into_multi(),m,|x,m|x.mask_fill(m,v)),
1203 _=>panic!("internal error")
1204 }
1205 }
1206 pub fn multi(self)->Self{
1208 if let Value::Multi(v)=self{v.into()}else{vec![self].into()}
1209 }
1210 pub fn new<S:Into<Shape>>(data:&[f32],device:&B::Device,shape:S)->Self{
1212 match shape.into(){
1213 Shape::Incompatible(e)=>e.into(),
1214 Shape::Multi(l)=>data.chunks(l).map(|d|Value::new(d,device,X1([d.len()]))).collect(),
1215 Shape::Recursive(v)=>v.into_iter().scan(data,|data,s|{
1216 let v=Value::new(*data,device,s);
1217 *data=&data[..v.count()];
1218 Some(v)
1219 }).collect(),
1220 X1(s)=>F1(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1221 X2(s)=>F2(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1222 X3(s)=>F3(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1223 X4(s)=>F4(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1224 X5(s)=>F5(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1225 X6(s)=>F6(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1226 X7(s)=>F7(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1227 X8(s)=>F8(Tensor::from_data(TensorData::new(data[..s.iter().product::<usize>()].to_vec(),s),device)),
1228 }
1229 }
1230 pub fn promote(self,rhs:Value<B>)->(Value<B>,Value<B>){
1232 let (l,r)=self.promote_kind(rhs);
1233 l.promote_rank(r)
1234 }
1235 pub fn promote_kind(self,rhs:Value<B>)->(Value<B>,Value<B>){
1237 let (lk,rk)=(self.kind(),rhs.kind());
1238
1239 let (mut l,mut r)=(self,rhs);
1240 if lk==rk{()}else if lk==Kind::Multi{r=r.multi()}else if rk==Kind::Multi{l=l.multi()}else if lk==Kind::Float{r=r.float()}else if rk==Kind::Float{l=l.float()}else if lk==Kind::Int{r=r.int()}else if rk==Kind::Int{l=l.int()}else if lk==Kind::Incompatible{return (l,r)}else if rk==Kind::Incompatible{return (l,r)}
1241 (l,r)
1242 }
1243 pub fn promote_rank(self,rhs:Value<B>)->(Value<B>,Value<B>){
1245 let (mut l,mut r)=(self,rhs);
1246 let (mut lr,mut rr)=if let (Some(l),Some(r))=(l.rank(),r.rank()){(l,r)}else{return (l,r)};
1247 while lr<rr{
1248 l=l.unsqueeze();
1249 lr+=1;
1250 }
1251 while lr>rr{
1252 r=r.unsqueeze();
1253 rr+=1;
1254 }
1255 (l,r)
1256 }
1257 pub fn rank(&self)->Option<usize>{
1259 match self{B1(_x)=>Some(1),B2(_x)=>Some(2),B3(_x)=>Some(3),B4(_x)=>Some(4),B5(_x)=>Some(5),B6(_x)=>Some(6),B7(_x)=>Some(7),B8(_x)=>Some(8),F1(_x)=>Some(1),F2(_x)=>Some(2),F3(_x)=>Some(3),F4(_x)=>Some(4),F5(_x)=>Some(5),F6(_x)=>Some(6),F7(_x)=>Some(7),F8(_x)=>Some(8),I1(_x)=>Some(1),I2(_x)=>Some(2),I3(_x)=>Some(3),I4(_x)=>Some(4),I5(_x)=>Some(5),I6(_x)=>Some(6),I7(_x)=>Some(7),I8(_x)=>Some(8),Value::Incompatible(_x)=>None,Value::Multi(_x)=>None}
1260 }
1261 pub fn scatter(self,dim:i32,indices:Value<B>,values:Value<B>)->Self{
1263 fn f<B:Backend,K:'static+BasicOps<B>+Numeric<B>+TensorKind<B>,const N:usize>(d:i32,data:Tensor<B,N,K>,indices:Tensor<B,N,Int>,values:Tensor<B,N,K>)->Value<B>{
1265 let d=if d<0{N-((-d) as usize)}else{d as usize};
1266 if d>=N{format!("dim {d} must be less than rank {N}").into()}else{data.scatter(d,indices,values).into()}
1267 }
1268
1269 match (self,indices,values){
1270 (B1(x),I1(i),B1(y))=>f(dim,x.int(),i,y.int()),
1271 (B2(x),I2(i),B2(y))=>f(dim,x.int(),i,y.int()),
1272 (B3(x),I3(i),B3(y))=>f(dim,x.int(),i,y.int()),
1273 (B4(x),I4(i),B4(y))=>f(dim,x.int(),i,y.int()),
1274 (B5(x),I5(i),B5(y))=>f(dim,x.int(),i,y.int()),
1275 (B6(x),I6(i),B6(y))=>f(dim,x.int(),i,y.int()),
1276 (B7(x),I7(i),B7(y))=>f(dim,x.int(),i,y.int()),
1277 (B8(x),I8(i),B8(y))=>f(dim,x.int(),i,y.int()),
1278 (F1(x),I1(i),F1(y))=>f(dim,x,i,y),
1279 (F2(x),I2(i),F2(y))=>f(dim,x,i,y),
1280 (F3(x),I3(i),F3(y))=>f(dim,x,i,y),
1281 (F4(x),I4(i),F4(y))=>f(dim,x,i,y),
1282 (F5(x),I5(i),F5(y))=>f(dim,x,i,y),
1283 (F6(x),I6(i),F6(y))=>f(dim,x,i,y),
1284 (F7(x),I7(i),F7(y))=>f(dim,x,i,y),
1285 (F8(x),I8(i),F8(y))=>f(dim,x,i,y),
1286 (I1(x),I1(i),I1(y))=>f(dim,x,i,y),
1287 (I2(x),I2(i),I2(y))=>f(dim,x,i,y),
1288 (I3(x),I3(i),I3(y))=>f(dim,x,i,y),
1289 (I4(x),I4(i),I4(y))=>f(dim,x,i,y),
1290 (I5(x),I5(i),I5(y))=>f(dim,x,i,y),
1291 (I6(x),I6(i),I6(y))=>f(dim,x,i,y),
1292 (I7(x),I7(i),I7(y))=>f(dim,x,i,y),
1293 (I8(x),I8(i),I8(y))=>f(dim,x,i,y),
1294 (Value::Incompatible(e),_,_)=>e.into(),
1295 (_,Value::Incompatible(e),_)=>e.into(),
1296 (_,_,Value::Incompatible(e))=>e.into(),
1297 (Value::Multi(u),Value::Multi(v),Value::Multi(y))=>u.into_iter().zip(v).zip(y).map(|((u,v),y)|u.scatter(dim,v,y)).collect(),
1298 _=>"scatter is only available for tensors of matching dimensions with int indices".into()
1299 }
1300 }
1301 pub fn shape(&self)->Shape{
1303 match self{B1(x)=>Shape::X1(x.dims()),B2(x)=>Shape::X2(x.dims()),B3(x)=>Shape::X3(x.dims()),B4(x)=>Shape::X4(x.dims()),B5(x)=>Shape::X5(x.dims()),B6(x)=>Shape::X6(x.dims()),B7(x)=>Shape::X7(x.dims()),B8(x)=>Shape::X8(x.dims()),F1(x)=>Shape::X1(x.dims()),F2(x)=>Shape::X2(x.dims()),F3(x)=>Shape::X3(x.dims()),F4(x)=>Shape::X4(x.dims()),F5(x)=>Shape::X5(x.dims()),F6(x)=>Shape::X6(x.dims()),F7(x)=>Shape::X7(x.dims()),F8(x)=>Shape::X8(x.dims()),I1(x)=>Shape::X1(x.dims()),I2(x)=>Shape::X2(x.dims()),I3(x)=>Shape::X3(x.dims()),I4(x)=>Shape::X4(x.dims()),I5(x)=>Shape::X5(x.dims()),I6(x)=>Shape::X6(x.dims()),I7(x)=>Shape::X7(x.dims()),I8(x)=>Shape::X8(x.dims()),Value::Incompatible(x)=>Shape::Incompatible(x.clone()),Value::Multi(x)=>Shape::Multi(x.len())}
1304 }
1305 pub fn shape_recursive(&self)->Shape{
1307 if let Value::Multi(x)=self{Shape::Recursive(x.iter().map(Value::shape_recursive).collect())}else{self.shape()}
1308 }
1309 pub fn shift(self,d:i32,n:i32,v:f32)->Self{
1311 fn b<B:Backend,const N:usize>(d:i32,n:i32,v:f32,x:Tensor<B,N,Bool>)->Value<B>{f(d,n,if v==0.0{0.0}else{1.0},x.int()).bool()}
1312 fn f<B:Backend,K:'static+BasicOps<B>+Numeric<B>+TensorKind<B>,const N:usize>(d:i32,n:i32,v:f32,x:Tensor<B,N,K>)->Value<B>{
1313 let device=x.device();
1314 let d=if d<0{N-((-d) as usize)}else{d as usize};
1315 let mut paddims=x.dims();
1316 let mut slicedims=paddims.map(|n|0..n);
1317
1318 paddims[d]=n.abs() as usize;
1319 slicedims[d]=if n<0{(-n) as usize..slicedims[d].end}else{0..slicedims[d].end.saturating_sub(n as usize)};
1320 if slicedims[d].len()==0{return x.full_like(v).into()}
1321 let pad:Tensor<B,N,K>=Tensor::full(paddims,v,&device);
1322 let slice=x.slice(slicedims);
1323
1324 Tensor::cat(if n<0{vec![slice,pad]}else{vec![pad,slice]},d).into()
1325 }
1326 if n==0{return self}
1327
1328 match self{B1(x)=>b(d,n,v,x),B2(x)=>b(d,n,v,x),B3(x)=>b(d,n,v,x),B4(x)=>b(d,n,v,x),B5(x)=>b(d,n,v,x),B6(x)=>b(d,n,v,x),B7(x)=>b(d,n,v,x),B8(x)=>b(d,n,v,x),F1(x)=>f(d,n,v,x),F2(x)=>f(d,n,v,x),F3(x)=>f(d,n,v,x),F4(x)=>f(d,n,v,x),F5(x)=>f(d,n,v,x),F6(x)=>f(d,n,v,x),F7(x)=>f(d,n,v,x),F8(x)=>f(d,n,v,x),I1(x)=>f(d,n,v,x),I2(x)=>f(d,n,v,x),I3(x)=>f(d,n,v,x),I4(x)=>f(d,n,v,x),I5(x)=>f(d,n,v,x),I6(x)=>f(d,n,v,x),I7(x)=>f(d,n,v,x),I8(x)=>f(d,n,v,x),Value::Incompatible(e)=>e.into(),Value::Multi(x)=>x.into_iter().map(|x|x.shift(d,n,v)).collect()}
1329 }
1330 pub fn slice<A:AsRef<[R]>,R:RangeBounds<usize>>(self,ranges:A)->Self{let ranges=ranges.as_ref();
1333 let len=ranges.len();
1334 if let Value::Incompatible(x)=self{return x.into()}
1335 let rank=self.rank().unwrap_or(len);
1336 let shape=self.shape();
1337
1338 let mut normalizedranges=[0;8].map(|_|0..0);
1339 for ((d,n),r) in shape.clone().to_array(Alignment::Left).into_iter().zip(normalizedranges.iter_mut()).zip(ranges){
1340 n.start=match r.start_bound(){Excluded(&x)=>x+1,Included(&x)=>x,Unbounded=>0};
1341 n.end=match r.end_bound(){Excluded(&x)=>x,Included(&x)=>x+1,Unbounded=>d};
1342 }
1343 if len>rank{return format!("Length of ranges argument must be less than the the value's rank. len: {len} ranges: {normalizedranges:?} rank: {rank} shape: {shape:?}").into()}
1344 for (d,n) in shape.clone().to_array(Alignment::Left).into_iter().zip(normalizedranges.iter()).take(len){
1345 if n.start>=n.end{return format!("Empty or reverse ranges are currently not supported. ranges: {normalizedranges:?}").into()}
1346 if d<n.end{return format!("Cannot index beyond the end of a dimension. ranges: {normalizedranges:?} shape: {shape:?}").into()}
1347 }
1348 let ranges=&normalizedranges[..len];
1349
1350 match self{B1(x)=>B1(slice_slice(ranges,x)),B2(x)=>B2(slice_slice(ranges,x)),B3(x)=>B3(slice_slice(ranges,x)),B4(x)=>B4(slice_slice(ranges,x)),B5(x)=>B5(slice_slice(ranges,x)),B6(x)=>B6(slice_slice(ranges,x)),B7(x)=>B7(slice_slice(ranges,x)),B8(x)=>B8(slice_slice(ranges,x)),F1(x)=>F1(slice_slice(ranges,x)),F2(x)=>F2(slice_slice(ranges,x)),F3(x)=>F3(slice_slice(ranges,x)),F4(x)=>F4(slice_slice(ranges,x)),F5(x)=>F5(slice_slice(ranges,x)),F6(x)=>F6(slice_slice(ranges,x)),F7(x)=>F7(slice_slice(ranges,x)),F8(x)=>F8(slice_slice(ranges,x)),I1(x)=>I1(slice_slice(ranges,x)),I2(x)=>I2(slice_slice(ranges,x)),I3(x)=>I3(slice_slice(ranges,x)),I4(x)=>I4(slice_slice(ranges,x)),I5(x)=>I5(slice_slice(ranges,x)),I6(x)=>I6(slice_slice(ranges,x)),I7(x)=>I7(slice_slice(ranges,x)),I8(x)=>I8(slice_slice(ranges,x)),Value::Incompatible(x)=>x.into(),Value::Multi(x)=>Value::Multi(x.into_iter().map(|x|x.slice(ranges)).collect())}
1351 }
1352 pub fn split<I:Into<Option<i32>>>(self,chunksize:usize,dim:I)->Self{
1354 fn f<B:Backend,K:'static+BasicOps<B>+TensorKind<B>,const N:usize>(dim:i32,size:usize,tensor:Tensor<B,N,K>)->Value<B>{
1355 if dim>=N as i32||dim<(-(N as i32)){return format!("rank {N} is too low to split along dimension {dim}").into()}
1356 let dim=if dim<0{N-((-dim) as usize)}else{dim as usize};
1357
1358 tensor.split(size,dim).into_iter().map(Value::from).collect()
1359 }
1360 let c=if chunksize==0{return "cannot split into chunks of 0 size".into()}else{chunksize};
1361
1362 if let Some(d)=dim.into(){
1363 match self{B1(x)=>f(d,c,x),B2(x)=>f(d,c,x),B3(x)=>f(d,c,x),B4(x)=>f(d,c,x),B5(x)=>f(d,c,x),B6(x)=>f(d,c,x),B7(x)=>f(d,c,x),B8(x)=>f(d,c,x),F1(x)=>f(d,c,x),F2(x)=>f(d,c,x),F3(x)=>f(d,c,x),F4(x)=>f(d,c,x),F5(x)=>f(d,c,x),F6(x)=>f(d,c,x),F7(x)=>f(d,c,x),F8(x)=>f(d,c,x),I1(x)=>f(d,c,x),I2(x)=>f(d,c,x),I3(x)=>f(d,c,x),I4(x)=>f(d,c,x),I5(x)=>f(d,c,x),I6(x)=>f(d,c,x),I7(x)=>f(d,c,x),I8(x)=>f(d,c,x),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.split(c,Some(d))).collect()}
1364 }else{
1365 let v=self.into_multi();
1366 v.chunks(chunksize).map(|c|Value::from(c.to_vec())).collect()
1367 }
1368 }
1369 pub fn squeeze(self)->Self{self.squeeze_dim(0)}
1371 pub fn squeeze_dim(self,d:i32)->Self{
1373 fn f<B:Backend,K:BasicOps<B>+TensorKind<B>,const D:usize,const N:usize>(x:Tensor<B,D,K>,d:i32)->Result<Tensor<B,N,K>,String>{
1374 let d=if d<0{D-((-d) as usize)}else{d as usize};
1375
1376 if d>=D{return Err(format!("dim {d} must be less than {D}"))}
1377 let xdim=x.dims()[d];
1378
1379 if xdim==1{Ok(x.squeeze(d))}else{Err(format!("cannot squeeze a dim of size not equal to 1. dim {d} was {xdim}"))}
1380 }
1381 match match self{B1(_x)=>Err("currently cannot decrease the number of tensor dimensions below 1".into()),B2(x)=>f(x,d).map(B1),B3(x)=>f(x,d).map(B2),B4(x)=>f(x,d).map(B3),B5(x)=>f(x,d).map(B4),B6(x)=>f(x,d).map(B5),B7(x)=>f(x,d).map(B6),B8(x)=>f(x,d).map(B7),F1(_x)=>Err("currently cannot decrease the number of tensor dimensions below 1".into()),F2(x)=>f(x,d).map(F1),F3(x)=>f(x,d).map(F2),F4(x)=>f(x,d).map(F3),F5(x)=>f(x,d).map(F4),F6(x)=>f(x,d).map(F5),F7(x)=>f(x,d).map(F6),F8(x)=>f(x,d).map(F7),I1(_x)=>Err("currently cannot decrease the number of tensor dimensions below 1".into()),I2(x)=>f(x,d).map(I1),I3(x)=>f(x,d).map(I2),I4(x)=>f(x,d).map(I3),I5(x)=>f(x,d).map(I4),I6(x)=>f(x,d).map(I5),I7(x)=>f(x,d).map(I6),I8(x)=>f(x,d).map(I7),Value::Incompatible(e)=>Err(e),Value::Multi(v)=>Ok(v.into_iter().map(|x|x.squeeze_dim(d)).collect())}{Err(e)=>e.into(),Ok(x)=>x}
1382 }
1383 pub fn try_incompatible(self)->Result<String,Self>{
1385 if let Value::Incompatible(x)=self{Ok(x)}else{Err(self)}
1386 }
1387 pub fn try_multi(self)->Result<Vec<Value<B>>,Self>{
1389 if let Value::Multi(v)=self{Ok(v)}else{Err(self)}
1390 }
1391 pub fn unsqueeze(self)->Self{self.unsqueeze_dim(0)}
1393 pub fn unsqueeze_dim(self,d:i32)->Self{
1395 fn f<B:Backend,K:BasicOps<B>+TensorKind<B>,const D:usize,const N:usize>(x:Tensor<B,D,K>,d:i32)->Tensor<B,N,K>{
1396 x.unsqueeze_dim(if d<0{D-((-d) as usize)+1}else{d as usize})
1397 }
1398 if let Some(r)=self.rank(){
1399 let e=if d<0{r-((-d) as usize)+1}else{d as usize};
1400 if e>r{return format!("dim {e} must be less than or equal to rank {r}").into()}
1401 }
1402 match self{B1(x)=>B2(f(x,d)),B2(x)=>B3(f(x,d)),B3(x)=>B4(f(x,d)),B4(x)=>B5(f(x,d)),B5(x)=>B6(f(x,d)),B6(x)=>B7(f(x,d)),B7(x)=>B8(f(x,d)),B8(_x)=>"currently can't increase number of tensor dimensions above 8".into(),F1(x)=>F2(f(x,d)),F2(x)=>F3(f(x,d)),F3(x)=>F4(f(x,d)),F4(x)=>F5(f(x,d)),F5(x)=>F6(f(x,d)),F6(x)=>F7(f(x,d)),F7(x)=>F8(f(x,d)),F8(_x)=>"currently can't increase number of tensor dimensions above 8".into(),I1(x)=>I2(f(x,d)),I2(x)=>I3(f(x,d)),I3(x)=>I4(f(x,d)),I4(x)=>I5(f(x,d)),I5(x)=>I6(f(x,d)),I6(x)=>I7(f(x,d)),I7(x)=>I8(f(x,d)),I8(_x)=>"currently can't increase number of tensor dimensions above 8".into(),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.unsqueeze_dim(d)).collect()}
1403 }
1404 #[track_caller]
1405 pub fn unwrap_incompatible(self)->String{self.try_incompatible().unwrap()}
1407 #[track_caller]
1408 pub fn unwrap_multi(self)->Vec<Value<B>>{self.try_multi().unwrap()}
1410 pub fn zeros_like(&self)->Value<B>{match self{B1(x)=>B1(x.clone().int().zeros_like().bool()),B2(x)=>B2(x.clone().int().zeros_like().bool()),B3(x)=>B3(x.clone().int().zeros_like().bool()),B4(x)=>B4(x.clone().int().zeros_like().bool()),B5(x)=>B5(x.clone().int().zeros_like().bool()),B6(x)=>B6(x.clone().int().zeros_like().bool()),B7(x)=>B7(x.clone().int().zeros_like().bool()),B8(x)=>B8(x.clone().int().zeros_like().bool()),F1(x)=>F1(x.zeros_like()),F2(x)=>F2(x.zeros_like()),F3(x)=>F3(x.zeros_like()),F4(x)=>F4(x.zeros_like()),F5(x)=>F5(x.zeros_like()),F6(x)=>F6(x.zeros_like()),F7(x)=>F7(x.zeros_like()),F8(x)=>F8(x.zeros_like()),I1(x)=>I1(x.zeros_like()),I2(x)=>I2(x.zeros_like()),I3(x)=>I3(x.zeros_like()),I4(x)=>I4(x.zeros_like()),I5(x)=>I5(x.zeros_like()),I6(x)=>I6(x.zeros_like()),I7(x)=>I7(x.zeros_like()),I8(x)=>I8(x.zeros_like()),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.iter().map(Value::zeros_like).collect()}
1413 }
1414 pub fn zip(self)->Self{fn build<B:Backend>(input:Vec<Value<B>>)->Value<B>{
1417 if input.len()==0{return input.into()}
1418 if !input.iter().all(Value::is_multi){return input.into()}
1419
1420 let len=input.iter().map(Value::len).max().unwrap();
1421 let mut iters:Vec<_>=input.into_iter().map(Value::into_iter).collect();
1422 (0..len).map(|_|build(iters.iter_mut().map(|v|v.next().unwrap_or_default()).collect())).collect()
1423 }
1424 match self{Value::Multi(v)=>build(v),x=>x}
1425 }
1426 try_unwrap!(Tensor<B,1,Bool>,try_b1,unwrap_b1);
1427 try_unwrap!(Tensor<B,2,Bool>,try_b2,unwrap_b2);
1428 try_unwrap!(Tensor<B,3,Bool>,try_b3,unwrap_b3);
1429 try_unwrap!(Tensor<B,4,Bool>,try_b4,unwrap_b4);
1430 try_unwrap!(Tensor<B,5,Bool>,try_b5,unwrap_b5);
1431 try_unwrap!(Tensor<B,6,Bool>,try_b6,unwrap_b6);
1432 try_unwrap!(Tensor<B,7,Bool>,try_b7,unwrap_b7);
1433 try_unwrap!(Tensor<B,8,Bool>,try_b8,unwrap_b8);
1434 try_unwrap!(Tensor<B,1,Float>,try_f1,unwrap_f1);
1435 try_unwrap!(Tensor<B,2,Float>,try_f2,unwrap_f2);
1436 try_unwrap!(Tensor<B,3,Float>,try_f3,unwrap_f3);
1437 try_unwrap!(Tensor<B,4,Float>,try_f4,unwrap_f4);
1438 try_unwrap!(Tensor<B,5,Float>,try_f5,unwrap_f5);
1439 try_unwrap!(Tensor<B,6,Float>,try_f6,unwrap_f6);
1440 try_unwrap!(Tensor<B,7,Float>,try_f7,unwrap_f7);
1441 try_unwrap!(Tensor<B,8,Float>,try_f8,unwrap_f8);
1442 try_unwrap!(Tensor<B,1,Int>,try_i1,unwrap_i1);
1443 try_unwrap!(Tensor<B,2,Int>,try_i2,unwrap_i2);
1444 try_unwrap!(Tensor<B,3,Int>,try_i3,unwrap_i3);
1445 try_unwrap!(Tensor<B,4,Int>,try_i4,unwrap_i4);
1446 try_unwrap!(Tensor<B,5,Int>,try_i5,unwrap_i5);
1447 try_unwrap!(Tensor<B,6,Int>,try_i6,unwrap_i6);
1448 try_unwrap!(Tensor<B,7,Int>,try_i7,unwrap_i7);
1449 try_unwrap!(Tensor<B,8,Int>,try_i8,unwrap_i8);
1450}
1451macro_rules! bicop_num{
1452 ($trait:ident,$traitfn:ident,$traitscalar:ident)=>(
1453 impl<B:Backend,E:Copy+ElementConversion> $trait<E> for &Value<B>{
1454 fn $traitfn(self,rhs:E)->Value<B>{self.clone().$traitfn(rhs)}
1455 type Output=Value<B>;
1456 }
1457 impl<B:Backend,E:Copy+ElementConversion> $trait<E> for Value<B>{
1458 fn $traitfn(self,rhs:E)->Value<B>{
1459 match self{B1(x)=>I1(x.int().$traitscalar(rhs)),B2(x)=>I2(x.int().$traitscalar(rhs)),B3(x)=>I3(x.int().$traitscalar(rhs)),B4(x)=>I4(x.int().$traitscalar(rhs)),B5(x)=>I5(x.int().$traitscalar(rhs)),B6(x)=>I6(x.int().$traitscalar(rhs)),B7(x)=>I7(x.int().$traitscalar(rhs)),B8(x)=>I8(x.int().$traitscalar(rhs)),F1(x)=>F1(x.$traitscalar(rhs)),F2(x)=>F2(x.$traitscalar(rhs)),F3(x)=>F3(x.$traitscalar(rhs)),F4(x)=>F4(x.$traitscalar(rhs)),F5(x)=>F5(x.$traitscalar(rhs)),F6(x)=>F6(x.$traitscalar(rhs)),F7(x)=>F7(x.$traitscalar(rhs)),F8(x)=>F8(x.$traitscalar(rhs)),I1(x)=>I1(x.$traitscalar(rhs)),I2(x)=>I2(x.$traitscalar(rhs)),I3(x)=>I3(x.$traitscalar(rhs)),I4(x)=>I4(x.$traitscalar(rhs)),I5(x)=>I5(x.$traitscalar(rhs)),I6(x)=>I6(x.$traitscalar(rhs)),I7(x)=>I7(x.$traitscalar(rhs)),I8(x)=>I8(x.$traitscalar(rhs)),Value::Incompatible(e)=>e.into(),Value::Multi(v)=>v.into_iter().map(|x|x.$traitfn(rhs)).collect()}
1460 }
1461 type Output=Value<B>;
1462 }
1463 impl<B:Backend> $trait<&Value<B>> for &Value<B>{
1464 fn $traitfn(self,rhs:&Value<B>)->Value<B>{self.clone().$traitfn(rhs.clone())}
1465 type Output=Value<B>;
1466 }
1467 impl<B:Backend> $trait<&Value<B>> for Value<B>{
1468 fn $traitfn(self,rhs:&Value<B>)->Value<B>{self.$traitfn(rhs.clone())}
1469 type Output=Value<B>;
1470 }
1471 impl<B:Backend> $trait<Value<B>> for &Value<B>{
1472 fn $traitfn(self,rhs:Value<B>)->Value<B>{self.clone().$traitfn(rhs)}
1473 type Output=Value<B>;
1474 }
1475 impl<B:Backend> $trait<Value<B>> for Value<B>{
1476 fn $traitfn(self,rhs:Value<B>)->Value<B>{match self.promote(rhs){(B1(l),B1(r))=>I1(l.int().$traitfn(r.int())),(B2(l),B2(r))=>I2(l.int().$traitfn(r.int())),(B3(l),B3(r))=>I3(l.int().$traitfn(r.int())),(B4(l),B4(r))=>I4(l.int().$traitfn(r.int())),(B5(l),B5(r))=>I5(l.int().$traitfn(r.int())),(B6(l),B6(r))=>I6(l.int().$traitfn(r.int())),(B7(l),B7(r))=>I7(l.int().$traitfn(r.int())),(B8(l),B8(r))=>I8(l.int().$traitfn(r.int())),(F1(l),F1(r))=>F1(l.$traitfn(r)),(F2(l),F2(r))=>F2(l.$traitfn(r)),(F3(l),F3(r))=>F3(l.$traitfn(r)),(F4(l),F4(r))=>F4(l.$traitfn(r)),(F5(l),F5(r))=>F5(l.$traitfn(r)),(F6(l),F6(r))=>F6(l.$traitfn(r)),(F7(l),F7(r))=>F7(l.$traitfn(r)),(F8(l),F8(r))=>F8(l.$traitfn(r)),(I1(l),I1(r))=>I1(l.$traitfn(r)),(I2(l),I2(r))=>I2(l.$traitfn(r)),(I3(l),I3(r))=>I3(l.$traitfn(r)),(I4(l),I4(r))=>I4(l.$traitfn(r)),(I5(l),I5(r))=>I5(l.$traitfn(r)),(I6(l),I6(r))=>I6(l.$traitfn(r)),(I7(l),I7(r))=>I7(l.$traitfn(r)),(I8(l),I8(r))=>I8(l.$traitfn(r)),(Value::Incompatible(e),_)=>e.into(),(_,Value::Incompatible(e))=>e.into(),(Value::Multi(l),r)=>broadcast_multi(l,r.into_multi(),$trait::$traitfn),(l,Value::Multi(r))=>broadcast_multi(l.into_multi(),r,$trait::$traitfn),_=>panic!("couldn't promote types for $traitfn")}
1478 }
1479 type Output=Value<B>;
1480 }
1481 );
1482}
1483macro_rules! try_unwrap{
1484 ($tensor:ty,$try_unwrap:ident,$unwrap:ident)=>{
1485 pub fn $try_unwrap(self)->Result<$tensor,Self>{self.try_into()}
1487 #[track_caller]
1488 pub fn $unwrap(self)->$tensor{self.try_into().unwrap()}
1490 }
1491}
1492#[derive(Clone,Debug)]
1493pub enum Value<B:Backend>{B1(Tensor<B,1,Bool>),B2(Tensor<B,2,Bool>),B3(Tensor<B,3,Bool>),B4(Tensor<B,4,Bool>),B5(Tensor<B,5,Bool>),B6(Tensor<B,6,Bool>),B7(Tensor<B,7,Bool>),B8(Tensor<B,8,Bool>),F1(Tensor<B,1,Float>),F2(Tensor<B,2,Float>),F3(Tensor<B,3,Float>),F4(Tensor<B,4,Float>),F5(Tensor<B,5,Float>),F6(Tensor<B,6,Float>),F7(Tensor<B,7,Float>),F8(Tensor<B,8,Float>),I1(Tensor<B,1,Int>),I2(Tensor<B,2,Int>),I3(Tensor<B,3,Int>),I4(Tensor<B,4,Int>),I5(Tensor<B,5,Int>),I6(Tensor<B,6,Int>),I7(Tensor<B,7,Int>),I8(Tensor<B,8,Int>),Incompatible(String),Multi(Vec<Self>)}
1495#[derive(Clone,Debug,Deserialize,Serialize)]
1496pub enum ValueData{BX(TensorData),FX(TensorData),IX(TensorData),Incompatible(String),Multi(Vec<ValueData>)}
1498#[derive(Clone,Debug,Deserialize,Serialize)]
1499#[serde(bound="")]
1500pub struct LossOutput<B:Backend>{loss:Value<B>,output:Value<B>,target:Value<B>}
1502use {bicop_num,try_unwrap};
1503use Bound::{Excluded,Included,Unbounded};
1504use Reshape::{R1,R2,R3,R4,R5,R6,R7,R8};
1505use Shape::{X1,X2,X3,X4,X5,X6,X7,X8};
1506use Value::{B1,B2,B3,B4,B5,B6,B7,B8,F1,F2,F3,F4,F5,F6,F7,F8,I1,I2,I3,I4,I5,I6,I7,I8};
1507use ValueData::{BX,FX,IX};
1508use burn::{
1509 module::{AutodiffModule,ConstantRecord,Content,DisplaySettings,ModuleDisplay,ModuleDisplayDefault,ModuleMapper,ModuleVisitor,Quantizer},
1510 nn::{
1511 BatchNorm,Dropout,Embedding,LayerNorm,Linear,Relu,RotaryEncoding,Tanh,conv::Conv2d,loss::{CrossEntropyLoss,MseLoss},pool::MaxPool2d
1512 },
1513 prelude::{Backend,Bool,Float,Int,Module,Tensor,TensorData},
1514 record::{FileRecorder,RecorderError},
1515 tensor::{
1516 BasicOps,ElementConversion,Numeric,TensorKind,activation::{log_softmax,softmax},backend::AutodiffBackend,cast::ToElement
1517 }
1518};
1519use crate::{
1520 AI,Decompose,Merge,Op,
1521 builtin::{
1522 Alignment,ReductionMode,math::{MeanLayer,SquaredErrorLayer,SumLayer},reinforcement::AccQLayer,soft::{ChooseLayer,CrossEntropyLayer,SoftmaxLayer}
1523 },
1524 ops::{Abs,Cat,Flatten,Reshape as OpsReshape,Stack,Squeeze,Unsqueeze}
1525};
1526use rand::random;
1527use serde::{Deserialize,Deserializer,Serialize,Serializer};
1528use std::{
1529 any::TypeId,fmt::{Display,Result as FmtResult},iter::{FromIterator,once},mem,ops::{Add,Bound,Div,Mul,RangeBounds,Range,Rem,Sub},path::PathBuf,slice::{Iter as SliceIter,self},vec::IntoIter as VecIntoIter
1530};
1531use super::{Kind,Reshape,Shape,apply_depthwise};