block_graph/burn/
layer.rs

1fn derror<D:Display,E:Derror>(msg:D)->E{E::custom(msg)}
2fn deserialize_batch_norm<'a,B:Backend,D:Deserializer<'a>>(deserializer:D)->Result<BatchNorm<B>,D::Error>{
3	let record:BatchNormRecord<B>=BatchNormRecord::deserialize(deserializer)?;
4
5	let (beta,epsilon,gamma,mean,momentum,variance)=(record.beta,record.epsilon,record.gamma,record.mean,record.momentum,record.variance);
6	let (beta,gamma)=if let (Ok(b),Ok(g))=(beta.try_into(),gamma.try_into()){(Param::from_tensor(b),Param::from_tensor(g))}else{return Err(derror("batch norm beta and gamma parameters must be rank 1 floats"))};
7	let (mean,variance)=if let (Ok(m),Ok(v))=(mean.try_into(),variance.try_into()){(RunningState::new(m),RunningState::new(v))}else{return Err(derror("batch norm mean and variance states must be rank 1 floats"))};
8
9	Ok(BatchNorm{beta,epsilon,gamma,momentum,running_mean:mean,running_var:variance})
10}
11fn deserialize_conv2d<'a,B:Backend,D:Deserializer<'a>>(deserializer:D)->Result<Conv2d<B>,D::Error>{
12	let record=Conv2dRecord::deserialize(deserializer)?;
13
14	let (dilation,groups,kernelsize,stride)=(record.dilation,record.groups,record.kernelsize,record.stride);
15	let bias=if let Some(b)=record.bias{
16		if let Ok(b)=b.try_into(){Some(Param::from_tensor(b))}else{return Err(derror("linear bias parameter must be a rank 1 float"))}
17	}else{
18		None
19	};
20	let padding=record.padding.clone();
21	let weight=Param::from_tensor(if let Ok(w)=record.weight.try_into(){w}else{return Err(derror("linear weight parameter must be a rank 2 float"))});
22
23	Ok(Conv2d{bias,dilation,groups,kernel_size:kernelsize,padding,stride,weight})
24}
25fn deserialize_cross_entropy<'a,B:Backend,D:Deserializer<'a>>(deserializer:D)->Result<CrossEntropyLoss<B>,D::Error>{
26	let record=CrossEntropyRecord::deserialize(deserializer)?;
27
28	let (logits,pad,smoothing)=(record.logits,record.pad,record.smoothing);
29	let weights=if let Some(s)=record.weights{
30		if let Ok(s)=s.try_into(){Some(s)}else{return Err(derror("cross entropy weights parameter must be a rank 1 float"))}
31	}else{
32		None
33	};
34
35	Ok(CrossEntropyLoss{logits,pad_tokens:pad,smoothing,weights})
36}
37fn deserialize_dropout<'a,D:Deserializer<'a>>(deserializer:D)->Result<Dropout,D::Error>{
38	Ok(Dropout{prob:f64::deserialize(deserializer)?})
39}
40fn deserialize_embedding<'a,B:Backend,D:Deserializer<'a>>(deserializer:D)->Result<Embedding<B>,D::Error>{
41	let weight=deserialize_param(deserializer)?;
42	Ok(Embedding{weight})
43}
44fn deserialize_ignored<'a,D:Deserializer<'a>,T:Deserialize<'a>>(deserializer:D)->Result<Ignored<T>,D::Error>{
45	let data:T=T::deserialize(deserializer)?;
46	Ok(Ignored(data))
47}
48fn deserialize_layer_norm<'a,B:Backend,D:Deserializer<'a>>(deserializer:D)->Result<LayerNorm<B>,D::Error>{
49	let mut layer=LayerNormConfig::new(1).init(&Default::default());
50	let record=LayerNormRecord::deserialize(deserializer)?;
51
52	if let Ok(b)=record.beta.try_into(){layer.beta=Param::from_tensor(b)}else{return Err(derror("beta parameter must be a rank 1 float"))}
53	if let Ok(g)=record.gamma.try_into(){layer.gamma=Param::from_tensor(g)}else{return Err(derror("gamma parameter must be a rank 1 float"))}
54
55	Ok(layer)
56}
57fn deserialize_linear<'a,B:Backend,D:Deserializer<'a>>(deserializer:D)->Result<Linear<B>,D::Error>{
58	let record=LinearRecord::deserialize(deserializer)?;
59
60	let bias=if let Some(b)=record.bias{
61		if let Ok(b)=b.try_into(){Some(Param::from_tensor(b))}else{return Err(derror("linear bias parameter must be a rank 1 float"))}
62	}else{
63		None
64	};
65	let weight=Param::from_tensor(if let Ok(w)=record.weight.try_into(){w}else{return Err(derror("linear weight parameter must be a rank 2 float"))});
66
67	Ok(Linear{bias,weight})
68}
69fn deserialize_max_pool_2d<'a,D:Deserializer<'a>>(deserializer:D)->Result<MaxPool2d,D::Error>{
70	let config=MaxPool2dConfig::deserialize(deserializer)?;
71	Ok(config.init())
72}
73fn deserialize_nothing<'a,D:Deserializer<'a>,T:Default>(deserializer:D)->Result<T,D::Error>{
74	let _x:()=Deserialize::deserialize(deserializer)?;
75	Ok(T::default())
76}
77fn deserialize_param<'a,B:Backend,D:Deserializer<'a>,const N:usize>(deserializer:D)->Result<Param<Tensor<B,N>>,D::Error>{
78	let data:Value<B>=Value::deserialize(deserializer)?;
79	if let Ok(t)=data.try_into(){Ok(Param::from_tensor(t))}else{Err(derror(format!("expected parameter to be a rank {N} float")))}
80}
81fn deserialize_rotary<'a,B:Backend,D:Deserializer<'a>>(deserializer:D)->Result<RotaryEncoding<B>,D::Error>{Ok(RotaryEncodingConfig::deserialize(deserializer)?.init(&Default::default()))}
82fn serialize_batch_norm<B:Backend,S:Serializer>(layer:&BatchNorm<B>,serializer:S)->Result<S::Ok,S::Error>{
83	let (beta,gamma)=(Value::from(layer.beta.val()),Value::from(layer.gamma.val()));
84	let (epsilon,momentum)=(layer.epsilon,layer.momentum);
85	let (mean,variance)=(Value::from(layer.running_mean.value()),Value::from(layer.running_var.value()));
86
87	BatchNormRecord{beta,epsilon,gamma,mean,momentum,variance}.serialize(serializer)
88}
89fn serialize_conv2d<B:Backend,S:Serializer>(layer:&Conv2d<B>,serializer:S)->Result<S::Ok,S::Error>{
90	let (dilation,groups,kernelsize,stride)=(layer.dilation,layer.groups,layer.kernel_size,layer.stride);
91	let bias=layer.bias.as_ref().map(|b|b.val().into());
92	let padding=layer.padding.clone();
93	let weight=layer.weight.val().into();
94
95	Conv2dRecord{bias,dilation,groups,kernelsize,padding,stride,weight}.serialize(serializer)
96}
97fn serialize_cross_entropy<'a,B:Backend,S:Serializer>(layer:&CrossEntropyLoss<B>,serializer:S)->Result<S::Ok,S::Error>{
98	let (logits,pad,smoothing)=(layer.logits.clone(),layer.pad_tokens.clone(),layer.smoothing.clone());
99	let weights=layer.weights.clone().map(Into::into);
100
101	CrossEntropyRecord{logits,pad,smoothing,weights}.serialize(serializer)
102}
103fn serialize_dropout<S:Serializer>(data:&Dropout,serializer:S)->Result<S::Ok,S::Error>{data.prob.serialize(serializer)}
104fn serialize_embedding<B:Backend,S:Serializer>(layer:&Embedding<B>,serializer:S)->Result<S::Ok,S::Error>{serialize_param(&layer.weight,serializer)}
105fn serror<D:Display,E:Serror>(msg:D)->E{E::custom(msg)}
106fn serialize_ignored<S:Serializer,T:Serialize>(data:&Ignored<T>,serializer:S)->Result<S::Ok,S::Error>{
107	let data:&T=data;
108	data.serialize(serializer)
109}
110fn serialize_layer_norm<B:Backend,S:Serializer>(layer:&LayerNorm<B>,serializer:S)->Result<S::Ok,S::Error>{
111	LayerNormRecord{beta:layer.beta.val().into(),gamma:layer.gamma.val().into()}.serialize(serializer)
112}
113fn serialize_linear<B:Backend,S:Serializer>(layer:&Linear<B>,serializer:S)->Result<S::Ok,S::Error>{
114	let bias=layer.bias.as_ref().map(|b|b.val().into());
115	let weight=layer.weight.val().into();
116
117	LinearRecord{bias,weight}.serialize(serializer)
118}
119fn serialize_max_pool_2d<S:Serializer>(layer:&MaxPool2d,serializer:S)->Result<S::Ok,S::Error>{
120	MaxPool2dConfig{kernel_size:layer.kernel_size,strides:layer.stride,padding:layer.padding.0.clone(),dilation:layer.dilation}.serialize(serializer)
121}
122fn serialize_nothing<S:Serializer,T:Default>(_data:&T,serializer:S)->Result<S::Ok,S::Error>{().serialize(serializer)}
123fn serialize_param<B:Backend,S:Serializer,const N:usize>(data:&Param<Tensor<B,N>>,serializer:S)->Result<S::Ok,S::Error>{
124	if N>8{return Err(serror("tensor rank greater than 8 is not currently supported"))}
125	let data:Value<B>=data.val().into();
126	data.serialize(serializer)
127}
128fn serialize_rotary<B:Backend,S:Serializer>(data:&RotaryEncoding<B>,serializer:S)->Result<S::Ok,S::Error>{
129	let [distance,head,_2]=data.freq_complex.dims();
130	//let theta:f32=data.theta.clone().into_scalar().elem();// TODO determine theta somehow
131
132	//RotaryEncodingConfig::new(distance,head).with_theta(theta).serialize(serializer)
133	RotaryEncodingConfig::new(distance,head).serialize(serializer)
134}
135impl AttentionConfig{
136	pub fn init<B:Backend>(&self,_device:&B::Device)->Attention<B>{
137		let (dropout,heads,mask)=(self.dropout,self.heads,self.mask);
138		let mask=Ignored(mask);
139		let phantom=PhantomData;
140
141		Attention{dropout,heads,mask,phantom}
142	}
143}
144impl BiasConfig{
145	pub fn init<B:Backend>(&self,device:&B::Device)->Bias<B>{
146		let dim=self.dim;
147		let shape=[dim];
148
149		Bias{bias:self.initializer.init_with(shape,None,Some(dim),device)}
150	}
151}
152impl Config{
153	/// creates an attention config
154	pub fn attention(heads:usize,mask:AttentionMask)->Self{Self::Attention(AttentionConfig::new(heads,mask))}
155	/// creates a batch norm config
156	pub fn batch_norm(countfeatures:usize,epsilon:f32,momentum:f32)->Self{Self::BatchNorm(BatchNormConfig::new(countfeatures).with_epsilon(epsilon as f64).with_momentum(momentum as f64))}
157	/// creates a bias config
158	pub fn bias(dim:usize)->Self{Self::Bias(BiasConfig::new(dim))}
159	/// creates a cache config
160	pub fn cache(limit:usize)->Self{Self::Cache(CacheConfig::new(limit))}
161	/// creates a dropout config
162	pub fn dropout(chance:f32)->Self{Self::Dropout(DropoutConfig::new(chance as f64))}
163	/// creates a embedding config
164	pub fn embedding(input:usize,output:usize)->Self{Self::Embedding(EmbeddingConfig::new(input,output))}
165	/// creates a flatten config
166	pub fn flatten<R:RangeBounds<isize>>(dims:R)->Self{
167		let a=match dims.start_bound(){Excluded(&n)=>n+1,Included(&n)=>n,Unbounded=>0};
168		let b=match dims.end_bound(){Excluded(&n)=>n,Included(n)=>n+1,Unbounded=>0};
169		Self::Flatten(FlattenLayer::new(a..b))
170	}
171	/// initializes the layer
172	pub fn init<B:Backend>(&self,device:&B::Device)->Layer<B>{
173		match self{Config::Attention(c)=>Layer::Attention(c.init(device)),Config::BatchNorm(c)=>Layer::BatchNorm(c.init(device)),Config::Bias(c)=>Layer::Bias(c.init(device)),Config::Cache(c)=>Layer::Cache(Cache::new(c.limit)),Config::Cat(c)=>Layer::Cat(Ignored(*c)),Config::Conv2d(c)=>Layer::Conv2d(c.init(device)),Config::Dropout(c)=>Layer::Dropout(c.init()),Config::Embedding(c)=>Layer::Embedding(c.init(device)),Config::Flatten(c)=>Layer::Flatten(Ignored(c.clone())),Config::LayerNorm(c)=>Layer::LayerNorm(c.init(device)),Config::Linear(c)=>Layer::Linear(c.init(device)),Config::KQV(c)=>Layer::KQV(c.init(device)),Config::CrossEntropy(c)=>Layer::CrossEntropy(c.init(device)),Config::MaxPool2d(c)=>Layer::MaxPool2d(c.init()),Config::Mse=>Layer::Mse(MseLoss),Config::Relu=>Layer::Relu(Relu::new()),Config::Reshape(c)=>Layer::Reshape(Ignored(c.clone())),Config::Rotary(c)=>Layer::Rotary(c.init(device)),Config::ScaleShift(c)=>Layer::ScaleShift(c.init(device)),Config::Stack(d)=>Layer::Stack(Ignored(*d)),Config::Squeeze(c)=>Layer::Squeeze(Ignored(*c)),Config::Sum(c)=>Layer::Sum(Ignored(*c)),Config::Tanh=>Layer::Tanh(Tanh::new()),Config::Unsqueeze(c)=>Layer::Unsqueeze(Ignored(*c))}
174	}
175	/// creates a layer norm config
176	pub fn layer_norm(dim:usize)->Self{Self::LayerNorm(LayerNormConfig::new(dim))}
177	/// creates a linear config
178	pub fn linear(bias:bool,input:usize,output:usize)->Self{Self::Linear(LinearConfig::new(input,output).with_bias(bias))}
179	/// creates a max pool 2d config
180	pub fn max_pool_2d(kernel:[usize;2],strides:[usize;2])->Self{MaxPool2dConfig::new(kernel).with_strides(strides).into()}
181	/// creates a relu config
182	pub fn relu()->Self{Self::Relu}
183	/// creates a reshape config
184	pub fn reshape<R:Into<Reshape>>(args:R)->Self{Self::Reshape(ReshapeLayer::new(args.into()))}
185	/// creates a rotary config
186	pub fn rotary(distance:usize,head:usize)->Self{Self::Rotary(RotaryEncodingConfig::new(distance,head))}
187	/// creates a scale shift config
188	pub fn scale_shift()->Self{Self::ScaleShift(ScaleShiftConfig::new())}
189	/// sets the dropout if this is an attention layer
190	pub fn set_attention_dropout(&mut self,dropout:f32)->bool{
191		if let Config::Attention(c)=self{
192			c.dropout=dropout;
193			true
194		}else{
195			false
196		}
197	}
198	/// creates a tanh config
199	pub fn tanh()->Self{Self::Tanh}
200	/// scales the initializer
201	pub fn w_scale(mut self,r:f32)->Self{
202		match &mut self{Config::Attention(_c)=>(),Config::BatchNorm(_c)=>(),Config::Bias(c)=>w_scale_mut(&mut c.initializer,r),Config::Cache(_c)=>(),Config::Cat(_c)=>(),Config::Conv2d(c)=>w_scale_mut(&mut c.initializer,r),Config::CrossEntropy(_c)=>(),Config::Dropout(_c)=>(),Config::Embedding(c)=>w_scale_mut(&mut c.initializer,r),Config::Flatten(_c)=>(),Config::KQV(c)=>w_scale_mut(&mut c.initializer,r),Config::LayerNorm(_c)=>(),Config::Linear(c)=>w_scale_mut(&mut c.initializer,r),Config::MaxPool2d(_c)=>(),Config::Mse=>(),Config::Relu=>(),Config::Reshape(_c)=>(),Config::Rotary(_c)=>(),Config::ScaleShift(c)=>c.initializer.as_mut().into_iter().for_each(|i|w_scale_mut(i,r)),Config::Squeeze(_d)=>(),Config::Stack(_d)=>(),Config::Sum(_c)=>(),Config::Tanh=>(),Config::Unsqueeze(_c)=>()}
203		self
204	}
205}
206impl Decompose for Config{
207	fn compose(decomposition:Self::Decomposition)->Self{decomposition}
208	fn decompose(self)->Self::Decomposition{self}
209	fn decompose_cloned(&self)->Self::Decomposition{self.clone()}
210	type Decomposition=Self;
211}
212impl From<AttentionConfig> for Config{
213	fn from(value:AttentionConfig)->Self{Self::Attention(value)}
214}
215impl From<BatchNormConfig> for Config{
216	fn from(value:BatchNormConfig)->Self{Self::BatchNorm(value)}
217}
218impl From<BiasConfig> for Config{
219	fn from(value:BiasConfig)->Self{Self::Bias(value)}
220}
221impl<B:Backend> From<Cache<B>> for Layer<B>{
222	fn from(value:Cache<B>)->Self{Self::Cache(value)}
223}
224impl From<CatLayer> for Config{
225	fn from(value:CatLayer)->Self{Config::Cat(value)}
226}
227impl From<CrossEntropyLossConfig> for Config{
228	fn from(value:CrossEntropyLossConfig)->Self{Config::CrossEntropy(value)}
229}
230impl From<DropoutConfig> for Config{
231	fn from(value:DropoutConfig)->Self{Config::Dropout(value)}
232}
233impl From<EmbeddingConfig> for Config{
234	fn from(value:EmbeddingConfig)->Self{Config::Embedding(value)}
235}
236impl From<FlattenLayer<Range<isize>>> for Config{
237	fn from(value:FlattenLayer<Range<isize>>)->Self{Config::Flatten(value)}
238}
239impl From<LayerNormConfig> for Config{
240	fn from(value:LayerNormConfig)->Self{Config::LayerNorm(value)}
241}
242impl From<LinearConfig> for Config{
243	fn from(value:LinearConfig)->Self{Config::Linear(value)}
244}
245impl From<MaxPool2dConfig> for Config{
246	fn from(value:MaxPool2dConfig)->Self{Config::MaxPool2d(value)}
247}
248impl From<MseLoss> for Config{
249	fn from(_value:MseLoss)->Self{Config::Mse}
250}
251impl From<Relu> for Config{
252	fn from(_value:Relu)->Self{Config::Relu}
253}
254impl From<ReshapeLayer<Reshape>> for Config{
255	fn from(value:ReshapeLayer<Reshape>)->Self{Config::Reshape(value)}
256}
257impl From<RotaryEncodingConfig> for Config{
258	fn from(value:RotaryEncodingConfig)->Self{Config::Rotary(value)}
259}
260impl From<ScaleShiftConfig> for Config{
261	fn from(value:ScaleShiftConfig)->Self{Config::ScaleShift(value)}
262}
263impl From<SqueezeLayer> for Config{
264	fn from(value:SqueezeLayer)->Self{Config::Squeeze(value)}
265}
266impl From<StackLayer> for Config{
267	fn from(value:StackLayer)->Self{Config::Stack(value)}
268}
269impl From<SumLayer> for Config{
270	fn from(value:SumLayer)->Self{Config::Sum(value)}
271}
272impl From<Tanh> for Config{
273	fn from(_value:Tanh)->Self{Config::Tanh}
274}
275impl From<UnsqueezeLayer> for Config{
276	fn from(value:UnsqueezeLayer)->Self{Config::Unsqueeze(value)}
277}
278impl KQVConfig{
279	pub fn init<B:Backend>(&self,device:&B::Device)->KQV<B>{
280		let (embed,initializer,kdim,vdim)=(self.embed.clone(),self.initializer.clone(),self.kdim.clone(),self.vdim.clone());
281		let (key,value)=(LinearConfig::new(embed,kdim).with_initializer(initializer.clone()).init(device),LinearConfig::new(embed,vdim).with_initializer(initializer.clone()).init(device));
282		let query=LinearConfig::new(embed,kdim).with_initializer(initializer).init(device);
283
284		KQV{key,query,value}
285	}
286}
287impl ScaleShiftConfig{
288	pub fn init<B:Backend>(&self,device:&B::Device)->ScaleShift<B>{
289		let initializer=&self.initializer;
290
291		let a=if let Some(i)=initializer{i.init_with([1],None,None,device)}else{Initializer::Constant{value:1.0}.init_with([1],None,None,device)};
292		let b=if let Some(i)=initializer{i.init_with([1],None,None,device)}else{Initializer::Constant{value:0.0}.init_with([1],None,None,device)};
293		ScaleShift{a,b}
294	}
295}
296impl<B:Backend,M:AI<M::Output,M::Output>+Op> IntoSequence<M> for Layer<B> where Layer<B>:Into<M>{
297	fn into_sequence(self)->Sequential<Vec<M>>{vec![self.into()].sequential()}
298}
299impl<B:Backend> From<BatchNorm<B>> for Layer<B>{
300	fn from(value:BatchNorm<B>)->Self{
301		Self::BatchNorm(BatchNorm{beta:value.beta,epsilon:value.epsilon,gamma:value.gamma,momentum:value.momentum,running_mean:value.running_mean,running_var:value.running_var})
302	}
303}
304impl<B:Backend> AI<(Value<B>,Value<B>,Value<B>),Value<B>> for Attention<B>{
305	fn forward(&self,(k,q,v):(Value<B>,Value<B>,Value<B>))->Value<B>{// TODO support for other numbers of dimensions
306		fn apply_mask<B:Backend,const D:usize>(a:Tensor<B,D>,mask:AttentionMask,value:f32)->Tensor<B,D>{
307			match mask{AttentionMask::Causal=>mask_causal(a,value as f64),AttentionMask::None=>a,AttentionMask::Power(n)=>mask_power(a,n,value as f64),AttentionMask::Window(n)=>mask_window(a,n,value as f64)}
308		}
309		fn f_3d<B:Backend>(dropout:f32,heads:usize,mask:AttentionMask,k:Tensor<B,3>,q:Tensor<B,3>,v:Tensor<B,3>)->Result<Tensor<B,3>,String>{
310			let (kdims,qdims,vdims)=(k.dims(),q.dims(),v.dims());
311			let (kseq,qseq,vseq)=(kdims[1],qdims[1],vdims[1]);
312
313			if kdims[0]!=qdims[0]{return Err("mismatched dims".into())}
314			if kdims[2]!=qdims[2]{return Err("mismatched dims".into())}
315			if kdims!=vdims{return Err("mismatched dims".into())}
316			let [batch,_sequence,embed]=kdims;
317			let dropout=Dropout{prob:dropout as f64};
318			let head=if embed%heads==0{embed/heads}else{return Err("embed must be a multiple of heads".into())};
319
320			let (k,q,v)=(k.reshape([batch,kseq,heads,head]).swap_dims(1,2),q.reshape([batch,qseq,heads,head]).swap_dims(1,2),v.reshape([batch,vseq,heads,head]).swap_dims(1,2));
321			let a=activation::softmax(apply_mask(q.matmul(k.transpose())/(head as f32).sqrt(),mask,-9999.0),3);
322			let a=dropout.forward(a);
323			let s=a.matmul(v).swap_dims(1,2).reshape([0,0,-1]);
324
325			Ok(s)
326		}
327		fn mask_causal<B:Backend,const D:usize>(a:Tensor<B,D>,value:f64)->Tensor<B,D>{
328			if D<2{return mask_causal::<B,2>(a.unsqueeze(),value).squeeze_dim(0)}									// shouldn't actually happen but if the dimension is less than 2 we can just treat it like it has a second dimension of size 1
329
330			let (device,dims)=(a.device(),a.dims());
331			let (key,query)=(dims[D-1],dims[D-2]);
332			let extrakeys=key.saturating_sub(query);															// due to caching, there might be more keys than queries
333
334			let causal:Tensor<B,2,Bool>=Tensor::tril_mask([query,key],extrakeys as i64,&device);
335			let a=a.mask_fill(causal.unsqueeze(),value);
336			a
337		}
338		fn mask_power<B:Backend,const D:usize>(a:Tensor<B,D>,info:PowerMaskInfo,value:f64)->Tensor<B,D>{
339			if D<2{return mask_power::<B,2>(a.unsqueeze(),info,value).squeeze_dim(0)}
340			let (block,window)=(info.block,info.window);
341			let dims=a.dims();
342
343			let (key,query)=(dims[D-1],dims[D-2]);
344			let mask=generate_power_attention_mask(block,key,query,window);
345
346			a.mask_fill(mask.unsqueeze(),value)
347		}
348		/// fills the attention tensor with the value where the query position is less than the key position minus length, or greater than the key position. Assumes attention dimensions are [.., query, key]
349		fn mask_window<B:Backend,const D:usize>(a:Tensor<B,D>,length:usize,value:f64)->Tensor<B,D>{
350			if D<2{return mask_window::<B,2>(a.unsqueeze(),length,value).squeeze_dim(0)}							// shouldn't actually happen but if the dimension is less than 2 we can just treat it like it has a second dimension of size 1
351
352			let (device,dims)=(a.device(),a.dims());
353			let (key,query)=(dims[D-1],dims[D-2]);
354			let extrakeys=key.saturating_sub(query);															// due to caching, there might be more keys than queries
355
356			let causal:Tensor<B,2,Bool>=Tensor::tril_mask([query,key],extrakeys as i64,&device);
357			let window:Tensor<B,2,Bool>=Tensor::triu_mask([query,key],extrakeys as i64-length as i64,&device);
358			let a=a.mask_fill(causal.unsqueeze(),value).mask_fill(window.unsqueeze(),value);
359			a
360		}
361		let (dropout,heads,mask)=(self.dropout,self.heads,self.mask.0);
362
363		match match (k.float(),q.float(),v.float()){
364			(Value::F3(k),Value::F3(q),Value::F3(v))=>f_3d(dropout,heads,mask,k,q,v).map(Into::into),
365			(Value::Multi(k),Value::Multi(q),Value::Multi(v))=>if k.len()==q.len()&&q.len()==v.len(){Ok(k.into_iter().zip(q).zip(v).map(|((k,q),v)|self.forward((k,q,v))).collect())}else{Err("incompatible lengths".into())}
366			(k,q,v)=>Err(format!("attention is currently only supported for 3d float inputs [batch, seq, embed], k: {:?}, q: {:?}, v: {:?}",k.shape_recursive(),q.shape_recursive(),v.shape_recursive()))
367		}{
368			Err(e)=>e.into(),
369			Ok(x)=>x
370		}
371	}
372}
373impl<B:Backend> AI<Value<B>,(Value<B>,Value<B>,Value<B>)> for KQV<B>{
374	fn forward(&self,input:Value<B>)->(Value<B>,Value<B>,Value<B>){
375		let (k,q)=(input.clone(),input.clone());
376		let v=input;
377
378		(AI::forward(&self.key,k),AI::forward(&self.query,q),AI::forward(&self.value,v))
379	}
380}
381impl<B:Backend> AI<Value<B>,Value<B>> for Attention<B>{
382	fn forward(&self,input:Value<B>)->Value<B>{
383		input.map_multi(1,|input|match input{
384			Value::Incompatible(e)=>e.into(),
385			Value::Multi(v) if v.len()>=3=>if v.len()==3{
386				let [k,q,v]=v.try_into().unwrap();
387				self.forward((k,q,v))
388			}else{
389				v.into_iter().map(|x|self.forward(x)).collect()
390			},
391			_=>"attention inputs must be in triples".into()
392		})
393	}
394}
395impl<B:Backend> AI<Value<B>,Value<B>> for Bias<B>{
396	fn forward(&self,input:Value<B>)->Value<B>{input+Value::from(self.bias.val())}
397}
398impl<B:Backend> AI<Value<B>,Value<B>> for Cache<B>{
399	fn forward(&self,input:Value<B>)->Value<B>{self.clone().forward_mut(input)}
400	fn forward_mut(&mut self,input:Value<B>)->Value<B>{
401		let limit=self.limit;
402		if self.cache.is_empty(){
403			self.cache=input.clone();//slice([0..,0..self.limit]); TODO start slice from -limit?
404			return input
405		}
406
407		match (mem::take(&mut self.cache),input){
408			(Value::Multi(mut cache),Value::Multi(input))=>{
409				if cache.len()<input.len(){cache.resize_with(input.len(),Default::default)}
410				let (cache,output):(Vec<Value<B>>,Vec<Value<B>>)=cache.into_iter().zip(input).map(|(cache,input)|{
411					let mut c=Cache{cache,limit};
412					let o=c.forward_mut(input);
413
414					(c.cache,o)
415				}).unzip();
416
417				self.cache=cache.into();
418				output.into()
419			},
420			(cache,input)=>{// TODO what if one is multi and the other isn't
421				let seq=cache.shape().to_array(Default::default())[1]+input.shape().to_array(Default::default())[1];
422
423				let cacheinput=Value::from(vec![cache,input]);
424				let cacheoutput=cacheinput.cat(1);
425				let cacheoutput=if seq>limit{cacheoutput.slice([0..,seq-limit..])}else{cacheoutput};
426
427				self.cache=cacheoutput.clone();
428				cacheoutput
429			}
430		}
431	}
432}
433impl<B:Backend> AI<Value<B>,Value<B>> for KQV<B>{
434	fn forward(&self,input:Value<B>)->Value<B>{
435		input.map_values(|input|{
436			let (k,q,v)=self.forward(input);
437			vec![k,q,v].into()
438		})
439	}
440}
441impl<B:Backend> AI<Value<B>,Value<B>> for Layer<B>{
442	fn forward(&self,input:Value<B>)->Value<B>{
443		match self{
444			Layer::Attention(f)=>f.forward(input),
445			Layer::BatchNorm(f)=>AI::forward(f,input),
446			Layer::Bias(f)=>f.forward(input),
447			Layer::Cache(f)=>f.forward(input),
448			Layer::Cat(f)=>f.forward(input),
449			Layer::Conv2d(f)=>AI::forward(f,input),
450			Layer::CrossEntropy(f)=>AI::forward(f,input),
451			Layer::Dropout(f)=>AI::forward(f,input),
452			Layer::Embedding(f)=>AI::forward(f,input),
453			Layer::Flatten(f)=>f.0.forward(input),
454			Layer::KQV(f)=>f.forward(input),
455			Layer::LayerNorm(f)=>AI::forward(f,input),
456			Layer::Linear(f)=>AI::forward(f,input),
457			Layer::MaxPool2d(f)=>AI::forward(f,input),
458			Layer::Mse(f)=>AI::forward(f,input),
459			Layer::Relu(f)=>AI::forward(f,input),
460			Layer::Reshape(f)=>f.0.forward(input),
461			Layer::Rotary(f)=>AI::forward(f,input),
462			Layer::ScaleShift(f)=>f.forward(input),
463			Layer::Squeeze(f)=>f.forward(input),
464			Layer::Stack(f)=>f.forward(input),
465			Layer::Sum(f)=>f.forward(input),
466			Layer::Tanh(f)=>AI::forward(f,input),
467			Layer::Unsqueeze(f)=>f.forward(input),
468		}
469	}
470	fn forward_mut(&mut self,input:Value<B>)->Value<B>{
471		match self{
472			Layer::Attention(f)=>f.forward_mut(input),
473			Layer::BatchNorm(f)=>AI::forward_mut(f,input),
474			Layer::Bias(f)=>f.forward_mut(input),
475			Layer::Cache(f)=>f.forward_mut(input),
476			Layer::Cat(f)=>f.0.forward_mut(input),
477			Layer::Conv2d(f)=>f.forward_mut(input),
478			Layer::CrossEntropy(f)=>AI::forward_mut(f,input),
479			Layer::Dropout(f)=>AI::forward_mut(f,input),
480			Layer::Embedding(f)=>AI::forward_mut(f,input),
481			Layer::Flatten(f)=>f.0.forward_mut(input),
482			Layer::KQV(f)=>f.forward_mut(input),
483			Layer::LayerNorm(f)=>AI::forward_mut(f,input),
484			Layer::Linear(f)=>AI::forward_mut(f,input),
485			Layer::MaxPool2d(f)=>AI::forward_mut(f,input),
486			Layer::Mse(f)=>AI::forward_mut(f,input),
487			Layer::Relu(f)=>AI::forward_mut(f,input),
488			Layer::Reshape(f)=>f.0.forward_mut(input),
489			Layer::Rotary(f)=>AI::forward_mut(f,input),
490			Layer::ScaleShift(f)=>f.forward_mut(input),
491			Layer::Squeeze(f)=>f.0.forward_mut(input),
492			Layer::Stack(f)=>f.0.forward_mut(input),
493			Layer::Sum(f)=>f.0.forward_mut(input),
494			Layer::Tanh(f)=>AI::forward_mut(f,input),
495			Layer::Unsqueeze(f)=>f.0.forward_mut(input),
496		}
497	}
498}
499impl<B:Backend> AI<Value<B>,Value<B>> for ScaleShift<B>{
500	fn forward(&self,input:Value<B>)->Value<B>{
501		let (a,b)=(Value::from(self.a.val()),Value::from(self.b.val()));
502		input*a+b
503	}
504}
505impl<B:Backend> From<Attention<B>> for Layer<B>{
506	fn from(value:Attention<B>)->Self{Self::Attention(value)}
507}
508impl<B:Backend> Cache<B>{
509	fn new(limit:usize)->Self{
510		let cache=Value::default();
511		Self{cache,limit}
512	}
513}
514impl<B:Backend> Decompose for Layer<B>{
515	fn compose(decomposition:Self::Decomposition)->Self{decomposition}
516	fn decompose(self)->Self::Decomposition{self}
517	fn decompose_cloned(&self)->Self::Decomposition{self.clone()}
518	type Decomposition=Self;
519}
520impl<B:Backend> From<CatLayer> for Layer<B>{
521	fn from(value:CatLayer)->Self{Layer::Cat(Ignored(value))}
522}
523impl<B:Backend> From<CrossEntropyLoss<B>> for Layer<B>{
524	fn from(value:CrossEntropyLoss<B>)->Self{Layer::CrossEntropy(value)}
525}
526impl<B:Backend> From<Dropout> for Layer<B>{
527	fn from(value:Dropout)->Self{Layer::Dropout(value)}
528}
529impl<B:Backend> From<Embedding<B>> for Layer<B>{
530	fn from(value:Embedding<B>)->Self{Layer::Embedding(value)}
531}
532impl<B:Backend> From<FlattenLayer<Range<isize>>> for Layer<B>{
533	fn from(value:FlattenLayer<Range<isize>>)->Self{Layer::Flatten(Ignored(value))}
534}
535impl<B:Backend> From<LayerNorm<B>> for Layer<B>{
536	fn from(value:LayerNorm<B>)->Self{Layer::LayerNorm(value)}
537}
538impl<B:Backend> From<Linear<B>> for Layer<B>{
539	fn from(value:Linear<B>)->Self{Layer::Linear(value)}
540}
541impl<B:Backend> From<MaxPool2d> for Layer<B>{
542	fn from(value:MaxPool2d)->Self{Layer::MaxPool2d(value)}
543}
544impl<B:Backend> From<MseLoss> for Layer<B>{
545	fn from(value:MseLoss)->Self{Layer::Mse(value)}
546}
547impl<B:Backend> From<Relu> for Layer<B>{
548	fn from(value:Relu)->Self{Layer::Relu(value)}
549}
550impl<B:Backend> From<ReshapeLayer<Reshape>> for Layer<B>{
551	fn from(value:ReshapeLayer<Reshape>)->Self{Layer::Reshape(Ignored(value))}
552}
553impl<B:Backend> From<RotaryEncoding<B>> for Layer<B>{
554	fn from(value:RotaryEncoding<B>)->Self{Layer::Rotary(value)}
555}
556impl<B:Backend> From<ScaleShift<B>> for Layer<B>{
557	fn from(value:ScaleShift<B>)->Self{Layer::ScaleShift(value)}
558}
559impl<B:Backend> From<SqueezeLayer> for Layer<B>{
560	fn from(value:SqueezeLayer)->Self{Layer::Squeeze(Ignored(value))}
561}
562impl<B:Backend> From<StackLayer> for Layer<B>{
563	fn from(value:StackLayer)->Self{Layer::Stack(Ignored(value))}
564}
565impl<B:Backend> From<SumLayer> for Layer<B>{
566	fn from(value:SumLayer)->Self{Layer::Sum(Ignored(value))}
567}
568impl<B:Backend> From<Tanh> for Layer<B>{
569	fn from(value:Tanh)->Self{Layer::Tanh(value)}
570}
571impl<B:Backend> From<UnsqueezeLayer> for Layer<B>{
572	fn from(value:UnsqueezeLayer)->Self{Layer::Unsqueeze(Ignored(value))}
573}
574
575/*
576#[cfg(test)]
577mod tests {
578    use super::*;
579
580    #[test]
581    fn test_small_mask() {
582		type B=burn::backend::NdArray;
583
584        let mask = generate_power_attention_mask::<B>(1,15,10,2).int();
585
586        // Print mask for debugging
587        println!("{:?}", mask.clone());
588		assert!(false);
589    }
590}*/
591
592
593pub fn generate_power_attention_mask<B:Backend>(block:usize,k:usize,q:usize,window:usize)->Tensor<B,2,Bool>{
594	let device=Default::default();
595	let kx:Tensor<B,1,Int>=Tensor::arange(0..k as i64,&device);
596	let qx:Tensor<B,1,Int>=Tensor::arange(0..q as i64,&device);
597
598	let kx:Tensor<B,2,Int>=kx.unsqueeze_dim(0).repeat_dim(0,q);
599	let qx:Tensor<B,2,Int>=qx.unsqueeze_dim(1).repeat_dim(1,k);
600
601	let bx=qx.clone()/block as i64-kx.clone()/block as i64;
602	let causal=qx.greater_equal(kx.clone());
603	let power=bx.clone().bitwise_and(bx.clone()-1).equal_elem(0);
604	let sink=kx.lower_elem(block as i64);
605	let window=bx.lower_elem((window/block) as i64);
606
607	//causal.bool_and(power.bool_or(sink).bool_or(window)).bool_not()
608
609	(causal.int()*((power.int()+sink.int()+window.int()+2)/3)).bool().bool_not()
610}
611
612
613
614impl<B:Backend> Layer<B>{
615	/// creates an attention config
616	pub fn attention(heads:usize,mask:AttentionMask)->Self{Config::attention(heads,mask).init(&Default::default())}
617	/// creates a batch norm layer
618	pub fn batch_norm(countfeatures:usize,epsilon:f32,momentum:f32)->Self{Config::batch_norm(countfeatures,epsilon,momentum).init(&Default::default())}
619	/// creates a bias config
620	pub fn bias(dim:usize)->Self{Config::bias(dim).init(&Default::default())}
621	/// creates a cache layer
622	pub fn cache(limit:usize)->Self{Self::Cache(Cache::new(limit))}
623	/// clears the cache if the layer has one
624	pub fn clear_cache(&mut self)->bool{
625		match self{
626			Self::Cache(c)=>{
627				c.cache=Default::default();
628				true
629			},
630			_=>false
631		}
632	}
633	/// creates a dropout layer
634	pub fn dropout(chance:f32)->Self{Config::dropout(chance).init(&Default::default())}
635	/// creates a embedding layer
636	pub fn embedding(input:usize,output:usize,wscale:f32)->Self{
637		let mut l=EmbeddingConfig::new(input,output);
638		if wscale!=1.0{l.initializer=w_scale(l.initializer,wscale)}
639		let l=l.init(&Default::default());
640		Self::Embedding(l)
641	}
642	/// creates a flatten layer
643	pub fn flatten<R:RangeBounds<isize>>(dims:R)->Self{
644		let a=match dims.start_bound(){Excluded(&n)=>n+1,Included(&n)=>n,Unbounded=>0};
645		let b=match dims.end_bound(){Excluded(&n)=>n,Included(n)=>n+1,Unbounded=>0};
646		Self::Flatten(Ignored(FlattenLayer::new(a..b)))
647	}
648	/// creates a layer norm layer
649	pub fn layer_norm(dim:usize)->Self{Self::LayerNorm(LayerNormConfig::new(dim).init(&Default::default()))}
650	/// creates a linear layer
651	pub fn linear(bias:bool,input:usize,output:usize,wscale:f32)->Self{
652		let mut l=LinearConfig::new(input,output).with_bias(bias);
653		if wscale!=1.0{l.initializer=w_scale(l.initializer,wscale)}
654		let l=l.init(&Default::default());
655		Self::Linear(l)
656	}
657	/// creates a max pool 2d layer
658	pub fn max_pool_2d(kernel:[usize;2],strides:[usize;2])->Self{MaxPool2dConfig::new(kernel).with_strides(strides).init().into()}
659	/// creates a relu layer
660	pub fn relu()->Self{Self::Relu(Relu)}
661	/// creates a reshape layer
662	pub fn reshape<R:Into<Reshape>>(args:R)->Self{Self::Reshape(Ignored(ReshapeLayer::new(args.into())))}
663	/// creates a rotary layer
664	pub fn rotary(distance:usize,head:usize)->Self{Self::Rotary(RotaryEncodingConfig::new(distance,head).init(&Default::default()))}
665	/// creates a scale shift layer
666	pub fn scale_shift()->Self{Self::ScaleShift(ScaleShiftConfig::new().init(&Default::default()))}
667	/// creates a tanh layer
668	pub fn tanh()->Self{Self::Tanh(Tanh)}
669}
670impl<B:Backend> Op for Layer<B>{
671	type Output=Value<B>;
672}
673#[derive(Clone,Copy,Debug,Deserialize,Serialize)]
674pub enum AttentionMask{Causal,None,Power(PowerMaskInfo),Window(usize)}
675#[derive(Config,Debug)]
676/// enumerates config for some burn layers
677pub enum Config{Attention(AttentionConfig),BatchNorm(BatchNormConfig),Bias(BiasConfig),Cache(CacheConfig),Cat(CatLayer),Conv2d(Conv2dConfig),CrossEntropy(CrossEntropyLossConfig),Dropout(DropoutConfig),Embedding(EmbeddingConfig),Flatten(FlattenLayer<Range<isize>>),KQV(KQVConfig),LayerNorm(LayerNormConfig),Linear(LinearConfig),MaxPool2d(MaxPool2dConfig),Mse,Relu,Reshape(ReshapeLayer<Reshape>),Rotary(RotaryEncodingConfig),ScaleShift(ScaleShiftConfig),Squeeze(SqueezeLayer),Stack(StackLayer),Sum(SumLayer),Tanh,Unsqueeze(UnsqueezeLayer)}
678#[derive(Debug,Deserialize,Module,Serialize)]//TODO more layers
679#[serde(bound="")]
680/// enumerates some burn layers
681pub enum Layer<B:Backend>{
682	Attention(Attention<B>),
683	Bias(Bias<B>),
684	#[serde(deserialize_with="deserialize_batch_norm")]
685	#[serde(serialize_with="serialize_batch_norm")]
686	BatchNorm(BatchNorm<B>),
687	Cache(Cache<B>),
688	#[serde(deserialize_with="deserialize_ignored")]
689	#[serde(serialize_with="serialize_ignored")]
690	Cat(Ignored<CatLayer>),
691	#[serde(deserialize_with="deserialize_conv2d")]
692	#[serde(serialize_with="serialize_conv2d")]
693	Conv2d(Conv2d<B>),
694	#[serde(deserialize_with="deserialize_cross_entropy")]
695	#[serde(serialize_with="serialize_cross_entropy")]
696	CrossEntropy(CrossEntropyLoss<B>),
697	#[serde(deserialize_with="deserialize_dropout")]
698	#[serde(serialize_with="serialize_dropout")]
699	Dropout(Dropout),
700	#[serde(deserialize_with="deserialize_embedding")]
701	#[serde(serialize_with="serialize_embedding")]
702	Embedding(Embedding<B>),
703	#[serde(deserialize_with="deserialize_ignored")]
704	#[serde(serialize_with="serialize_ignored")]
705	Flatten(Ignored<FlattenLayer<Range<isize>>>),
706	KQV(KQV<B>),
707	#[serde(deserialize_with="deserialize_layer_norm")]
708	#[serde(serialize_with="serialize_layer_norm")]
709	LayerNorm(LayerNorm<B>),
710	#[serde(deserialize_with="deserialize_linear")]
711	#[serde(serialize_with="serialize_linear")]
712	Linear(Linear<B>),
713	#[serde(deserialize_with="deserialize_max_pool_2d")]
714	#[serde(serialize_with="serialize_max_pool_2d")]
715	MaxPool2d(MaxPool2d),
716	#[serde(deserialize_with="deserialize_nothing")]
717	#[serde(serialize_with="serialize_nothing")]
718	Mse(MseLoss),
719	#[serde(deserialize_with="deserialize_nothing")]
720	#[serde(serialize_with="serialize_nothing")]
721	Relu(Relu),
722	#[serde(deserialize_with="deserialize_ignored")]
723	#[serde(serialize_with="serialize_ignored")]
724	Reshape(Ignored<ReshapeLayer<Reshape>>),
725	#[serde(deserialize_with="deserialize_rotary")]
726	#[serde(serialize_with="serialize_rotary")]
727	Rotary(RotaryEncoding<B>),
728	ScaleShift(ScaleShift<B>),
729	#[serde(deserialize_with="deserialize_ignored")]
730	#[serde(serialize_with="serialize_ignored")]
731	Squeeze(Ignored<SqueezeLayer>),
732	#[serde(deserialize_with="deserialize_ignored")]
733	#[serde(serialize_with="serialize_ignored")]
734	Stack(Ignored<StackLayer>),
735	#[serde(deserialize_with="deserialize_ignored")]
736	#[serde(serialize_with="serialize_ignored")]
737	Sum(Ignored<SumLayer>),
738	#[serde(deserialize_with="deserialize_nothing")]
739	#[serde(serialize_with="serialize_nothing")]
740	Tanh(Tanh),
741	#[serde(deserialize_with="deserialize_ignored")]
742	#[serde(serialize_with="serialize_ignored")]
743	Unsqueeze(Ignored<UnsqueezeLayer>),
744}
745/// scales the initializer
746pub fn w_scale(initializer:Initializer,r:f32)->Initializer{
747	let r=r as f64;// apparently
748	match initializer{
749		Initializer::Constant{value}=>Initializer::Constant{value:value*r},
750		Initializer::KaimingNormal{gain,fan_out_only}=>Initializer::KaimingNormal{gain:gain*r,fan_out_only},
751		Initializer::KaimingUniform{gain,fan_out_only}=>Initializer::KaimingUniform{gain:gain*r,fan_out_only},
752		Initializer::Normal{mean,std}=>Initializer::Normal{mean:mean*r,std:std*r},
753		Initializer::Ones=>Initializer::Constant{value:r},
754		Initializer::Orthogonal{gain}=>Initializer::Orthogonal{gain:gain*r},
755		Initializer::Uniform{min,max}=>Initializer::Uniform{min:min*r,max:max*r},
756		Initializer::XavierNormal{gain}=>Initializer::XavierNormal{gain:gain*r},
757		Initializer::XavierUniform{gain}=>Initializer::XavierUniform{gain:gain*r},
758		Initializer::Zeros=>Initializer::Zeros
759	}
760}
761/// scales the initializer
762pub fn w_scale_mut(initializer:&mut Initializer,r:f32){*initializer=w_scale(initializer.clone(),r)}
763#[derive(Config,Debug)]
764/// layer for computing attention from [key,query,value] inputs
765pub struct AttentionConfig{
766	#[config(default="0.2")]
767	dropout:f32,
768	heads:usize,
769	mask:AttentionMask
770}
771#[derive(Debug,Deserialize,Module,Serialize)]
772#[serde(bound="")]
773/// layer for computing attention from [key,query,value] inputs
774pub struct Attention<B:Backend>{
775	dropout:f32,
776	heads:usize,
777	#[serde(deserialize_with="deserialize_ignored")]
778	#[serde(serialize_with="serialize_ignored")]
779	mask:Ignored<AttentionMask>,
780	phantom:PhantomData<B>
781}
782#[derive(Config,Debug)]
783/// layer for adding bias somewhere
784pub struct BiasConfig{
785	dim:usize,
786	#[config(default="Initializer::Normal{mean:0.0,std:1.0}")]
787	initializer:Initializer
788}
789#[derive(Config,Debug)]
790/// layer for linear splitting into [key,query,value] for attention purposes
791pub struct KQVConfig{
792	embed:usize,
793	#[config(default="Initializer::XavierNormal{gain:1.0}")]
794	initializer:Initializer,
795	kdim:usize,
796	vdim:usize
797}
798#[derive(Debug,Deserialize,Module,Serialize)]
799#[serde(bound="")]
800/// layer for adding bias anywhere
801pub struct Bias<B:Backend>{
802	#[serde(deserialize_with="deserialize_param")]
803	#[serde(serialize_with="serialize_param")]
804	bias:Param<Tensor<B,1>>
805}
806#[derive(Debug,Default,Deserialize,Module,Serialize)]// TODe a layer level functionO clear cache should b
807#[serde(bound="")]
808/// layer for caching kv values from kqv when run mutably. cats along d1 and outputs the concatenated keys and values.
809pub struct Cache<B:Backend>{cache:Value<B>,limit:usize}
810#[derive(Config,Debug)]
811pub struct CacheConfig{limit:usize}
812#[derive(Debug,Deserialize,Module,Serialize)]
813#[serde(bound="")]
814/// layer for linear splitting into [key,query,value] for attention purposes
815pub struct KQV<B:Backend>{
816	#[serde(deserialize_with="deserialize_linear")]
817	#[serde(serialize_with="serialize_linear")]
818	key:Linear<B>,
819	#[serde(deserialize_with="deserialize_linear")]
820	#[serde(serialize_with="serialize_linear")]
821	query:Linear<B>,
822	#[serde(deserialize_with="deserialize_linear")]
823	#[serde(serialize_with="serialize_linear")]
824	value:Linear<B>
825}
826#[derive(Clone,Copy,Debug,Deserialize,Serialize)]
827/// power mask information
828pub struct PowerMaskInfo{pub block:usize,pub window:usize}
829#[derive(Debug,Deserialize,Module,Serialize)]
830#[serde(bound="")]
831/// layer that applies a componentwise scalar affine transformation: f(x)=ax+b where a and b are tunable scalars
832pub struct ScaleShift<B:Backend>{
833	#[serde(deserialize_with="deserialize_param")]
834	#[serde(serialize_with="serialize_param")]
835	a:Param<Tensor<B,1>>,
836	#[serde(deserialize_with="deserialize_param")]
837	#[serde(serialize_with="serialize_param")]
838	b:Param<Tensor<B,1>>
839}
840#[derive(Config,Debug)]
841/// scale shift config
842pub struct ScaleShiftConfig{
843	#[config(default="None")]
844	initializer:Option<Initializer>
845}
846#[derive(Deserialize,Serialize)]
847#[serde(bound="")]
848struct Conv2dRecord<B:Backend>{
849	bias:Option<Value<B>>,
850	dilation:[usize;2],
851	groups:usize,
852	kernelsize:[usize;2],
853	#[serde(deserialize_with="deserialize_ignored")]
854	#[serde(serialize_with="serialize_ignored")]
855	padding:Ignored<PaddingConfig2d>,
856	stride:[usize;2],
857	weight:Value<B>
858}
859#[derive(Deserialize,Serialize)]
860#[serde(bound="")]
861struct BatchNormRecord<B:Backend>{beta:Value<B>,epsilon:f64,gamma:Value<B>,mean:Value<B>,momentum:f64,variance:Value<B>}
862#[derive(Deserialize,Serialize)]
863#[serde(bound="")]
864struct CrossEntropyRecord<B:Backend>{logits:bool,pad:Option<Vec<usize>>,weights:Option<Value<B>>,smoothing:Option<f32>}
865#[derive(Deserialize,Serialize)]
866#[serde(bound="")]
867struct LayerNormRecord<B:Backend>{beta:Value<B>,gamma:Value<B>}
868#[derive(Deserialize,Serialize)]
869#[serde(bound="")]
870struct LinearRecord<B:Backend>{bias:Option<Value<B>>,weight:Value<B>}
871use Bound::{Excluded,Included,Unbounded};
872use burn::{
873	module::{Ignored,Param,RunningState},
874	nn::{
875		BatchNorm,BatchNormConfig,Dropout,DropoutConfig,Embedding,EmbeddingConfig,Initializer,LayerNorm,LayerNormConfig,Linear,LinearConfig,PaddingConfig2d,Relu,RotaryEncoding,RotaryEncodingConfig,Tanh,conv::{Conv2d,Conv2dConfig},loss::{CrossEntropyLoss,CrossEntropyLossConfig,MseLoss},pool::{MaxPool2d,MaxPool2dConfig}
876	},
877	prelude::*,
878	tensor::activation
879};
880use crate::{
881	ai::{AI,Decompose,IntoSequence,Op},
882	builtin::{
883		Sequential,math::SumLayer,structural::{FlattenLayer,CatLayer,ReshapeLayer,SqueezeLayer,StackLayer,UnsqueezeLayer}
884	},
885	burn::{Reshape,Value},
886	ops::Cat as OpsCat
887};
888use serde::{Deserialize,Deserializer,Serialize,Serializer,de::Error as Derror,ser::Error as Serror};
889use std::{
890	fmt::Display,marker::PhantomData,mem,ops::{Bound,Range,RangeBounds}
891};