block_graph/burn/
layer.rs

1fn derror<D:Display,E:Derror>(msg:D)->E{E::custom(msg)}
2fn deserialize_batch_norm<'a,B:Backend,D:Deserializer<'a>>(deserializer:D)->Result<BatchNorm<B,1>,D::Error>{
3	let record:BatchNormRecord<B>=BatchNormRecord::deserialize(deserializer)?;
4
5	let (beta,epsilon,gamma,mean,momentum,variance)=(record.beta,record.epsilon,record.gamma,record.mean,record.momentum,record.variance);
6	let (beta,gamma)=if let (Ok(b),Ok(g))=(beta.try_into(),gamma.try_into()){(Param::from_tensor(b),Param::from_tensor(g))}else{return Err(derror("batch norm beta and gamma parameters must be rank 1 floats"))};
7	let (mean,variance)=if let (Ok(m),Ok(v))=(mean.try_into(),variance.try_into()){(RunningState::new(m),RunningState::new(v))}else{return Err(derror("batch norm mean and variance states must be rank 1 floats"))};
8
9	Ok(BatchNorm{beta,epsilon,gamma,momentum,running_mean:mean,running_var:variance})
10}
11fn deserialize_conv2d<'a,B:Backend,D:Deserializer<'a>>(deserializer:D)->Result<Conv2d<B>,D::Error>{
12	let record=Conv2dRecord::deserialize(deserializer)?;
13
14	let (dilation,groups,kernelsize,stride)=(record.dilation,record.groups,record.kernelsize,record.stride);
15	let bias=if let Some(b)=record.bias{
16		if let Ok(b)=b.try_into(){Some(Param::from_tensor(b))}else{return Err(derror("linear bias parameter must be a rank 1 float"))}
17	}else{
18		None
19	};
20	let padding=record.padding.clone();
21	let weight=Param::from_tensor(if let Ok(w)=record.weight.try_into(){w}else{return Err(derror("linear weight parameter must be a rank 2 float"))});
22
23	Ok(Conv2d{bias,dilation,groups,kernel_size:kernelsize,padding,stride,weight})
24}
25fn deserialize_cross_entropy<'a,B:Backend,D:Deserializer<'a>>(deserializer:D)->Result<CrossEntropyLoss<B>,D::Error>{
26	let record=CrossEntropyRecord::deserialize(deserializer)?;
27
28	let (logits,pad,smoothing)=(record.logits,record.pad,record.smoothing);
29	let weights=if let Some(s)=record.weights{
30		if let Ok(s)=s.try_into(){Some(s)}else{return Err(derror("cross entropy weights parameter must be a rank 1 float"))}
31	}else{
32		None
33	};
34
35	Ok(CrossEntropyLoss{logits,pad_tokens:pad,smoothing,weights})
36}
37fn deserialize_dropout<'a,D:Deserializer<'a>>(deserializer:D)->Result<Dropout,D::Error>{
38	Ok(Dropout{prob:f64::deserialize(deserializer)?})
39}
40fn deserialize_embedding<'a,B:Backend,D:Deserializer<'a>>(deserializer:D)->Result<Embedding<B>,D::Error>{
41	let weight=deserialize_param(deserializer)?;
42	Ok(Embedding{weight})
43}
44fn deserialize_ignored<'a,D:Deserializer<'a>,T:Deserialize<'a>>(deserializer:D)->Result<Ignored<T>,D::Error>{
45	let data:T=T::deserialize(deserializer)?;
46	Ok(Ignored(data))
47}
48fn deserialize_layer_norm<'a,B:Backend,D:Deserializer<'a>>(deserializer:D)->Result<LayerNorm<B>,D::Error>{
49	let mut layer=LayerNormConfig::new(1).init(&Default::default());
50	let record=LayerNormRecord::deserialize(deserializer)?;
51
52	if let Ok(b)=record.beta.try_into(){layer.beta=Param::from_tensor(b)}else{return Err(derror("beta parameter must be a rank 1 float"))}
53	if let Ok(g)=record.gamma.try_into(){layer.gamma=Param::from_tensor(g)}else{return Err(derror("gamma parameter must be a rank 1 float"))}
54
55	Ok(layer)
56}
57fn deserialize_linear<'a,B:Backend,D:Deserializer<'a>>(deserializer:D)->Result<Linear<B>,D::Error>{
58	let record=LinearRecord::deserialize(deserializer)?;
59
60	let bias=if let Some(b)=record.bias{
61		if let Ok(b)=b.try_into(){Some(Param::from_tensor(b))}else{return Err(derror("linear bias parameter must be a rank 1 float"))}
62	}else{
63		None
64	};
65	let weight=Param::from_tensor(if let Ok(w)=record.weight.try_into(){w}else{return Err(derror("linear weight parameter must be a rank 2 float"))});
66
67	Ok(Linear{bias,weight})
68}
69fn deserialize_max_pool_2d<'a,D:Deserializer<'a>>(deserializer:D)->Result<MaxPool2d,D::Error>{
70	let config=MaxPool2dConfig::deserialize(deserializer)?;
71	Ok(config.init())
72}
73fn deserialize_nothing<'a,D:Deserializer<'a>,T:Default>(_deserializer:D)->Result<T,D::Error>{Ok(T::default())}
74fn deserialize_param<'a,B:Backend,D:Deserializer<'a>,const N:usize>(deserializer:D)->Result<Param<Tensor<B,N>>,D::Error>{
75	let data:Value<B>=Value::deserialize(deserializer)?;
76	if let Ok(t)=data.try_into(){Ok(Param::from_tensor(t))}else{Err(derror(format!("expected parameter to be a rank {N} float")))}
77}
78fn deserialize_rotary<'a,B:Backend,D:Deserializer<'a>>(deserializer:D)->Result<RotaryEncoding<B>,D::Error>{Ok(RotaryEncodingConfig::deserialize(deserializer)?.init(&Default::default()))}
79fn serialize_batch_norm<B:Backend,S:Serializer>(layer:&BatchNorm<B,1>,serializer:S)->Result<S::Ok,S::Error>{
80	let (beta,gamma)=(Value::from(layer.beta.val()),Value::from(layer.gamma.val()));
81	let (epsilon,momentum)=(layer.epsilon,layer.momentum);
82	let (mean,variance)=(Value::from(layer.running_mean.value()),Value::from(layer.running_var.value()));
83
84	BatchNormRecord{beta,epsilon,gamma,mean,momentum,variance}.serialize(serializer)
85}
86fn serialize_conv2d<B:Backend,S:Serializer>(layer:&Conv2d<B>,serializer:S)->Result<S::Ok,S::Error>{
87	let (dilation,groups,kernelsize,stride)=(layer.dilation,layer.groups,layer.kernel_size,layer.stride);
88	let bias=layer.bias.as_ref().map(|b|b.val().into());
89	let padding=layer.padding.clone();
90	let weight=layer.weight.val().into();
91
92	Conv2dRecord{bias,dilation,groups,kernelsize,padding,stride,weight}.serialize(serializer)
93}
94fn serialize_cross_entropy<'a,B:Backend,S:Serializer>(layer:&CrossEntropyLoss<B>,serializer:S)->Result<S::Ok,S::Error>{
95	let (logits,pad,smoothing)=(layer.logits.clone(),layer.pad_tokens.clone(),layer.smoothing.clone());
96	let weights=layer.weights.clone().map(Into::into);
97
98	CrossEntropyRecord{logits,pad,smoothing,weights}.serialize(serializer)
99}
100fn serialize_dropout<S:Serializer>(data:&Dropout,serializer:S)->Result<S::Ok,S::Error>{data.prob.serialize(serializer)}
101fn serialize_embedding<B:Backend,S:Serializer>(layer:&Embedding<B>,serializer:S)->Result<S::Ok,S::Error>{serialize_param(&layer.weight,serializer)}
102fn serror<D:Display,E:Serror>(msg:D)->E{E::custom(msg)}
103fn serialize_ignored<S:Serializer,T:Serialize>(data:&Ignored<T>,serializer:S)->Result<S::Ok,S::Error>{
104	let data:&T=data;
105	data.serialize(serializer)
106}
107fn serialize_layer_norm<B:Backend,S:Serializer>(layer:&LayerNorm<B>,serializer:S)->Result<S::Ok,S::Error>{
108	LayerNormRecord{beta:layer.beta.val().into(),gamma:layer.gamma.val().into()}.serialize(serializer)
109}
110fn serialize_linear<B:Backend,S:Serializer>(layer:&Linear<B>,serializer:S)->Result<S::Ok,S::Error>{
111	let bias=layer.bias.as_ref().map(|b|b.val().into());
112	let weight=layer.weight.val().into();
113
114	LinearRecord{bias,weight}.serialize(serializer)
115}
116fn serialize_max_pool_2d<S:Serializer>(layer:&MaxPool2d,serializer:S)->Result<S::Ok,S::Error>{
117	MaxPool2dConfig{kernel_size:layer.kernel_size,strides:layer.stride,padding:layer.padding.0.clone(),dilation:layer.dilation}.serialize(serializer)
118}
119fn serialize_nothing<S:Serializer,T:Default>(_data:&T,serializer:S)->Result<S::Ok,S::Error>{().serialize(serializer)}
120fn serialize_param<B:Backend,S:Serializer,const N:usize>(data:&Param<Tensor<B,N>>,serializer:S)->Result<S::Ok,S::Error>{
121	if N>8{return Err(serror("tensor rank greater than 8 is not currently supported"))}
122	let data:Value<B>=data.val().into();
123	data.serialize(serializer)
124}
125fn serialize_rotary<B:Backend,S:Serializer>(data:&RotaryEncoding<B>,serializer:S)->Result<S::Ok,S::Error>{
126	let [distance,head,_2]=data.freq_complex.dims();
127	let theta:f32=data.theta.clone().into_scalar().elem();
128
129	RotaryEncodingConfig::new(distance,head).with_theta(theta).serialize(serializer)
130}
131impl AttentionConfig{
132	pub fn init<B:Backend>(&self,_device:&B::Device)->Attention<B>{
133		let (dropout,heads,mask)=(self.dropout,self.heads,self.mask);
134		let mask=Ignored(mask);
135		let phantom=PhantomData;
136
137		Attention{dropout,heads,mask,phantom}
138	}
139}
140impl BiasConfig{
141	pub fn init<B:Backend>(&self,device:&B::Device)->Bias<B>{
142		let dim=self.dim;
143		let shape=[dim];
144
145		Bias{bias:self.initializer.init_with(shape,None,Some(dim),device)}
146	}
147}
148impl Config{
149	/// creates an attention config
150	pub fn attention(heads:usize,mask:AttentionMask)->Self{Self::Attention(AttentionConfig::new(heads,mask))}
151	/// creates a batch norm config
152	pub fn batch_norm(countfeatures:usize,epsilon:f32,momentum:f32)->Self{Self::BatchNorm(BatchNormConfig::new(countfeatures).with_epsilon(epsilon as f64).with_momentum(momentum as f64))}
153	/// creates a bias config
154	pub fn bias(dim:usize)->Self{Self::Bias(BiasConfig::new(dim))}
155	/// creates a embedding config
156	pub fn embedding(input:usize,output:usize)->Self{Self::Embedding(EmbeddingConfig::new(input,output))}
157	/// initializes the layer
158	pub fn init<B:Backend>(&self,device:&B::Device)->Layer<B>{
159		match self{Config::Attention(c)=>Layer::Attention(c.init(device)),Config::BatchNorm(c)=>Layer::BatchNorm(c.init(device)),Config::Bias(c)=>Layer::Bias(c.init(device)),Config::CacheKV=>Layer::CacheKV(CacheKV::default()),Config::Cat(c)=>Layer::Cat(Ignored(*c)),Config::Conv2d(c)=>Layer::Conv2d(c.init(device)),Config::Dropout(c)=>Layer::Dropout(c.init()),Config::Embedding(c)=>Layer::Embedding(c.init(device)),Config::LayerNorm(c)=>Layer::LayerNorm(c.init(device)),Config::Linear(c)=>Layer::Linear(c.init(device)),Config::KQV(c)=>Layer::KQV(c.init(device)),Config::CrossEntropy(c)=>Layer::CrossEntropy(c.init(device)),Config::MaxPool2d(c)=>Layer::MaxPool2d(c.init()),Config::Mse=>Layer::Mse(MseLoss),Config::Relu=>Layer::Relu(Relu::new()),Config::Rotary(c)=>Layer::Rotary(c.init(device)),Config::ScaleShift(c)=>Layer::ScaleShift(c.init(device)),Config::Stack(d)=>Layer::Stack(Ignored(*d)),Config::Squeeze(c)=>Layer::Squeeze(Ignored(*c)),Config::Sum(c)=>Layer::Sum(Ignored(*c)),Config::Tanh=>Layer::Tanh(Tanh::new()),Config::Unsqueeze(c)=>Layer::Unsqueeze(Ignored(*c))}
160	}
161	/// creates a layer norm config
162	pub fn layer_norm(dim:usize)->Self{Self::LayerNorm(LayerNormConfig::new(dim))}
163	/// creates a linear config
164	pub fn linear(bias:bool,input:usize,output:usize)->Self{Self::Linear(LinearConfig::new(input,output).with_bias(bias))}
165	/// creates a max pool 2d config
166	pub fn max_pool_2d(kernel:[usize;2],strides:[usize;2])->Self{MaxPool2dConfig::new(kernel).with_strides(strides).into()}
167	/// creates a relu config
168	pub fn relu()->Self{Self::Relu}
169	/// creates a rotary config
170	pub fn rotary(distance:usize,head:usize)->Self{Self::Rotary(RotaryEncodingConfig::new(distance,head))}
171	/// creates a scale shift config
172	pub fn scale_shift()->Self{Self::ScaleShift(ScaleShiftConfig::new())}
173	/// creates a tanh config
174	pub fn tanh()->Self{Self::Tanh}
175	/// scales the initializer
176	pub fn w_scale(mut self,r:f32)->Self{
177		match &mut self{Config::Attention(_c)=>(),Config::BatchNorm(_c)=>(),Config::Bias(c)=>w_scale_mut(&mut c.initializer,r),Config::CacheKV=>(),Config::Cat(_c)=>(),Config::Conv2d(c)=>w_scale_mut(&mut c.initializer,r),Config::CrossEntropy(_c)=>(),Config::Dropout(_c)=>(),Config::Embedding(c)=>w_scale_mut(&mut c.initializer,r),Config::KQV(c)=>w_scale_mut(&mut c.initializer,r),Config::LayerNorm(_c)=>(),Config::Linear(c)=>w_scale_mut(&mut c.initializer,r),Config::MaxPool2d(_c)=>(),Config::Mse=>(),Config::Relu=>(),Config::Rotary(_c)=>(),Config::ScaleShift(c)=>c.initializer.as_mut().into_iter().for_each(|i|w_scale_mut(i,r)),Config::Squeeze(_d)=>(),Config::Stack(_d)=>(),Config::Sum(_c)=>(),Config::Tanh=>(),Config::Unsqueeze(_c)=>()}
178		self
179	}
180}
181impl Decompose for Config{
182	fn compose(decomposition:Self::Decomposition)->Self{decomposition}
183	fn decompose(self)->Self::Decomposition{self}
184	fn decompose_cloned(&self)->Self::Decomposition{self.clone()}
185	type Decomposition=Self;
186}
187impl From<AttentionConfig> for Config{
188	fn from(value:AttentionConfig)->Self{Self::Attention(value)}
189}
190impl From<BatchNormConfig> for Config{
191	fn from(value:BatchNormConfig)->Self{Self::BatchNorm(value)}
192}
193impl From<BiasConfig> for Config{
194	fn from(value:BiasConfig)->Self{Self::Bias(value)}
195}
196impl From<CatLayer> for Config{
197	fn from(value:CatLayer)->Self{Config::Cat(value)}
198}
199impl From<CrossEntropyLossConfig> for Config{
200	fn from(value:CrossEntropyLossConfig)->Self{Config::CrossEntropy(value)}
201}
202impl From<DropoutConfig> for Config{
203	fn from(value:DropoutConfig)->Self{Config::Dropout(value)}
204}
205impl From<EmbeddingConfig> for Config{
206	fn from(value:EmbeddingConfig)->Self{Config::Embedding(value)}
207}
208impl From<LayerNormConfig> for Config{
209	fn from(value:LayerNormConfig)->Self{Config::LayerNorm(value)}
210}
211impl From<LinearConfig> for Config{
212	fn from(value:LinearConfig)->Self{Config::Linear(value)}
213}
214impl From<MaxPool2dConfig> for Config{
215	fn from(value:MaxPool2dConfig)->Self{Config::MaxPool2d(value)}
216}
217impl From<MseLoss> for Config{
218	fn from(_value:MseLoss)->Self{Config::Mse}
219}
220impl From<Relu> for Config{
221	fn from(_value:Relu)->Self{Config::Relu}
222}
223impl From<RotaryEncodingConfig> for Config{
224	fn from(value:RotaryEncodingConfig)->Self{Config::Rotary(value)}
225}
226impl From<SqueezeLayer> for Config{
227	fn from(value:SqueezeLayer)->Self{Config::Squeeze(value)}
228}
229impl From<StackLayer> for Config{
230	fn from(value:StackLayer)->Self{Config::Stack(value)}
231}
232impl From<SumLayer> for Config{
233	fn from(value:SumLayer)->Self{Config::Sum(value)}
234}
235impl From<Tanh> for Config{
236	fn from(_value:Tanh)->Self{Config::Tanh}
237}
238impl From<UnsqueezeLayer> for Config{
239	fn from(value:UnsqueezeLayer)->Self{Config::Unsqueeze(value)}
240}
241impl KQVConfig{
242	pub fn init<B:Backend>(&self,device:&B::Device)->KQV<B>{
243		let (embed,initializer,kdim,vdim)=(self.embed.clone(),self.initializer.clone(),self.kdim.clone(),self.vdim.clone());
244		let (key,value)=(LinearConfig::new(embed,kdim).with_initializer(initializer.clone()).init(device),LinearConfig::new(embed,vdim).with_initializer(initializer.clone()).init(device));
245		let query=LinearConfig::new(embed,kdim).with_initializer(initializer).init(device);
246
247		KQV{key,query,value}
248	}
249}
250impl ScaleShiftConfig{
251	pub fn init<B:Backend>(&self,device:&B::Device)->ScaleShift<B>{
252		let initializer=&self.initializer;
253
254		let a=if let Some(i)=initializer{i.init_with([1],None,None,device)}else{Initializer::Constant{value:1.0}.init_with([1],None,None,device)};
255		let b=if let Some(i)=initializer{i.init_with([1],None,None,device)}else{Initializer::Constant{value:0.0}.init_with([1],None,None,device)};
256		ScaleShift{a,b}
257	}
258}
259impl<B:Backend,M:AI<M::Output,M::Output>+Op> IntoSequence<M> for Layer<B> where Layer<B>:Into<M>{
260	fn into_sequence(self)->Sequential<Vec<M>>{vec![self.into()].sequential()}
261}
262impl<B:Backend,const N:usize> From<BatchNorm<B,N>> for Layer<B>{
263	fn from(value:BatchNorm<B,N>)->Self{
264		Self::BatchNorm(BatchNorm{beta:value.beta,epsilon:value.epsilon,gamma:value.gamma,momentum:value.momentum,running_mean:value.running_mean,running_var:value.running_var})
265	}
266}
267impl<B:Backend> AI<(Value<B>,Value<B>),(Value<B>,Value<B>)> for CacheKV<B>{
268	fn forward(&self,(k,v):(Value<B>,Value<B>))->(Value<B>,Value<B>){
269		let (keys,values)=(self.keys.clone(),self.values.clone());
270		(if keys.is_empty(){k}else{Value::from(vec![keys,k]).cat(1)},if values.is_empty(){v}else{Value::from(vec![values,v]).cat(1)})
271	}
272	fn forward_mut(&mut self,(k,v):(Value<B>,Value<B>))->(Value<B>,Value<B>){
273		let (keys,values)=(mem::take(&mut self.keys),mem::take(&mut self.values));
274
275		let (keys,values)=(if keys.is_empty(){k}else{Value::from(vec![keys,k]).cat(1)},if values.is_empty(){v}else{Value::from(vec![values,v]).cat(1)});
276		(self.keys,self.values)=if keys.is_incompatible()||values.is_incompatible(){Default::default()}else{(keys.clone(),values.clone())};
277
278		(keys,values)
279	}
280}
281impl<B:Backend> AI<(Value<B>,Value<B>,Value<B>),Value<B>> for Attention<B>{
282	fn forward(&self,(k,q,v):(Value<B>,Value<B>,Value<B>))->Value<B>{// TODO support for other numbers of dimensions
283		fn apply_mask<B:Backend,const D:usize>(a:Tensor<B,D>,mask:AttentionMask,value:f32)->Tensor<B,D>{
284			match mask{AttentionMask::Causal=>mask_causal(a,value as f64),AttentionMask::None=>a,AttentionMask::Window(n)=>mask_window(a,n,value as f64)}
285		}
286		fn f_3d<B:Backend>(dropout:f32,heads:usize,mask:AttentionMask,k:Tensor<B,3>,q:Tensor<B,3>,v:Tensor<B,3>)->Result<Tensor<B,3>,String>{
287			let (kdims,qdims,vdims)=(k.dims(),q.dims(),v.dims());
288
289			if kdims!=qdims{return Err("mismatched dims".into())}
290			if kdims!=vdims{return Err("mismatched dims".into())}
291			let [batch,sequence,embed]=kdims;
292			let dropout=Dropout{prob:dropout as f64};
293			let head=if embed%heads==0{embed/heads}else{return Err("embed must be a multiple of heads".into())};
294
295			let (k,q,v)=(k.reshape([batch,sequence,heads,head]).swap_dims(1,2),q.reshape([batch,sequence,heads,head]).swap_dims(1,2),v.reshape([batch,sequence,heads,head]).swap_dims(1,2));
296			let a=activation::softmax(apply_mask(q.matmul(k.transpose())/(head as f32).sqrt(),mask,-9999.0),3);
297			let a=dropout.forward(a);
298			let s=a.matmul(v).swap_dims(1,2).reshape([0,0,-1]);
299
300			Ok(s)
301		}
302		fn mask_causal<B:Backend,const D:usize>(a:Tensor<B,D>,value:f64)->Tensor<B,D>{
303			if D<2{return mask_causal::<B,2>(a.unsqueeze(),value).squeeze(0)}									// shouldn't actually happen but if the dimension is less than 2 we can just treat it like it has a second dimension of size 1
304
305			let (device,dims)=(a.device(),a.dims());
306			let (key,query)=(dims[D-1],dims[D-2]);
307			let extrakeys=key.saturating_sub(query);															// due to caching, there might be more keys than queries
308
309			let causal:Tensor<B,2,Bool>=Tensor::tril_mask([query,key],extrakeys as i64,&device);
310			let a=a.mask_fill(causal.unsqueeze(),value);
311			a
312		}
313		/// fills the attention tensor with the value where the query position is less than the key position minus length, or greater than the key position. Assumes attention dimensions are [.., query, key]
314		fn mask_window<B:Backend,const D:usize>(a:Tensor<B,D>,length:usize,value:f64)->Tensor<B,D>{
315			if D<2{return mask_window::<B,2>(a.unsqueeze(),length,value).squeeze(0)}							// shouldn't actually happen but if the dimension is less than 2 we can just treat it like it has a second dimension of size 1
316
317			let (device,dims)=(a.device(),a.dims());
318			let (key,query)=(dims[D-1],dims[D-2]);
319			let extrakeys=key.saturating_sub(query);															// due to caching, there might be more keys than queries
320
321			let causal:Tensor<B,2,Bool>=Tensor::tril_mask([query,key],extrakeys as i64,&device);
322			let window:Tensor<B,2,Bool>=Tensor::triu_mask([query,key],extrakeys as i64-length as i64,&device);
323			let a=a.mask_fill(causal.unsqueeze(),value).mask_fill(window.unsqueeze(),value);
324			a
325		}
326		let (dropout,heads,mask)=(self.dropout,self.heads,self.mask.0);
327
328		match match (k.float(),q.float(),v.float()){
329			(Value::F3(k),Value::F3(q),Value::F3(v))=>f_3d(dropout,heads,mask,k,q,v).map(Into::into),
330			(Value::Multi(k),Value::Multi(q),Value::Multi(v))=>if k.len()==q.len()&&q.len()==v.len(){Ok(k.into_iter().zip(q).zip(v).map(|((k,q),v)|self.forward((k,q,v))).collect())}else{Err("incompatible lengths".into())}
331			_=>Err("attention is currently only supported for 3d float inputs [batch, seq, embed]".into())
332		}{
333			Err(e)=>e.into(),
334			Ok(x)=>x
335		}
336	}
337}
338impl<B:Backend> AI<Value<B>,(Value<B>,Value<B>,Value<B>)> for KQV<B>{
339	fn forward(&self,input:Value<B>)->(Value<B>,Value<B>,Value<B>){
340		let (k,q)=(input.clone(),input.clone());
341		let v=input;
342
343		(AI::forward(&self.key,k),AI::forward(&self.query,q),AI::forward(&self.value,v))
344	}
345}
346impl<B:Backend> AI<Value<B>,Value<B>> for Attention<B>{
347	fn forward(&self,input:Value<B>)->Value<B>{
348		match input{
349			Value::Incompatible(e)=>e.into(),
350			Value::Multi(v) if v.len()>=3=>if v.len()==3{
351				let [k,q,v]=v.try_into().unwrap();
352				self.forward((k,q,v))
353			}else{
354				v.into_iter().map(|x|self.forward(x)).collect()
355			},
356			_=>"attention inputs must be in triples".into()
357		}
358	}
359}
360impl<B:Backend> AI<Value<B>,Value<B>> for Bias<B>{
361	fn forward(&self,input:Value<B>)->Value<B>{input+Value::from(self.bias.val())}
362}
363impl<B:Backend> AI<Value<B>,Value<B>> for CacheKV<B>{
364	fn forward(&self,input:Value<B>)->Value<B>{
365		match input{
366			Value::Incompatible(e)=>e.into(),
367			Value::Multi(v) if v.len()>=2=>match v.len(){
368				2=>{
369					let [k,v]=v.try_into().unwrap();
370
371					let (k,v)=self.forward((k,v));
372					vec![k,v].into()
373				},
374				3=>{
375					let [k,q,v]=v.try_into().unwrap();
376
377					let (k,v)=self.forward((k,v));
378					vec![k,q,v].into()
379				},
380				_=>{
381					v.into_iter().map(|x|self.forward(x)).collect()
382				}
383			},
384			_=>"cache kv inputs must be in pairs or triples".into()
385		}
386	}
387	fn forward_mut(&mut self,input:Value<B>)->Value<B>{
388		match input{
389			Value::Incompatible(e)=>e.into(),
390			Value::Multi(v) if v.len()>=2=>match v.len(){
391				2=>{
392					let [k,v]=v.try_into().unwrap();
393
394					let (k,v)=self.forward_mut((k,v));
395					vec![k,v].into()
396				},
397				3=>{
398					let [k,q,v]=v.try_into().unwrap();
399
400					let (k,v)=self.forward_mut((k,v));
401					vec![k,q,v].into()
402				},
403				_=>{
404					v.into_iter().map(|x|self.forward_mut(x)).collect()
405				}
406			},
407			_=>"cache kv inputs must be in pairs or triples".into()
408		}
409	}
410}
411impl<B:Backend> AI<Value<B>,Value<B>> for KQV<B>{
412	fn forward(&self,input:Value<B>)->Value<B>{
413		let (k,q,v)=self.forward(input);
414		vec![k,q,v].into()
415	}
416}
417impl<B:Backend> AI<Value<B>,Value<B>> for Layer<B>{
418	fn forward(&self,input:Value<B>)->Value<B>{
419		match self{
420			Layer::Attention(f)=>f.forward(input),
421			Layer::BatchNorm(f)=>AI::forward(f,input),
422			Layer::Bias(f)=>f.forward(input),
423			Layer::CacheKV(f)=>f.forward(input),
424			Layer::Cat(f)=>f.forward(input),
425			Layer::Conv2d(f)=>AI::forward(f,input),
426			Layer::CrossEntropy(f)=>AI::forward(f,input),
427			Layer::Dropout(f)=>AI::forward(f,input),
428			Layer::Embedding(f)=>AI::forward(f,input),
429			Layer::KQV(f)=>f.forward(input),
430			Layer::LayerNorm(f)=>AI::forward(f,input),
431			Layer::Linear(f)=>AI::forward(f,input),
432			Layer::MaxPool2d(f)=>AI::forward(f,input),
433			Layer::Mse(f)=>AI::forward(f,input),
434			Layer::Relu(f)=>AI::forward(f,input),
435			Layer::Rotary(f)=>AI::forward(f,input),
436			Layer::ScaleShift(f)=>f.forward(input),
437			Layer::Squeeze(f)=>f.forward(input),
438			Layer::Stack(f)=>f.forward(input),
439			Layer::Sum(f)=>f.forward(input),
440			Layer::Tanh(f)=>AI::forward(f,input),
441			Layer::Unsqueeze(f)=>f.forward(input),
442		}
443	}
444	fn forward_mut(&mut self,input:Value<B>)->Value<B>{
445		match self{
446			Layer::Attention(f)=>f.forward_mut(input),
447			Layer::BatchNorm(f)=>AI::forward_mut(f,input),
448			Layer::Bias(f)=>f.forward_mut(input),
449			Layer::CacheKV(f)=>f.forward_mut(input),
450			Layer::Cat(f)=>f.0.forward_mut(input),
451			Layer::Conv2d(f)=>f.forward_mut(input),
452			Layer::CrossEntropy(f)=>AI::forward_mut(f,input),
453			Layer::Dropout(f)=>AI::forward_mut(f,input),
454			Layer::Embedding(f)=>AI::forward_mut(f,input),
455			Layer::KQV(f)=>f.forward_mut(input),
456			Layer::LayerNorm(f)=>AI::forward_mut(f,input),
457			Layer::Linear(f)=>AI::forward_mut(f,input),
458			Layer::MaxPool2d(f)=>AI::forward_mut(f,input),
459			Layer::Mse(f)=>AI::forward_mut(f,input),
460			Layer::Relu(f)=>AI::forward_mut(f,input),
461			Layer::Rotary(f)=>AI::forward_mut(f,input),
462			Layer::ScaleShift(f)=>f.forward_mut(input),
463			Layer::Squeeze(f)=>f.0.forward_mut(input),
464			Layer::Stack(f)=>f.0.forward_mut(input),
465			Layer::Sum(f)=>f.0.forward_mut(input),
466			Layer::Tanh(f)=>AI::forward_mut(f,input),
467			Layer::Unsqueeze(f)=>f.0.forward_mut(input),
468		}
469	}
470}
471impl<B:Backend> AI<Value<B>,Value<B>> for ScaleShift<B>{
472	fn forward(&self,input:Value<B>)->Value<B>{
473		let (a,b)=(Value::from(self.a.val()),Value::from(self.b.val()));
474		input*a+b
475	}
476}
477impl<B:Backend> Decompose for Layer<B>{
478	fn compose(decomposition:Self::Decomposition)->Self{decomposition}
479	fn decompose(self)->Self::Decomposition{self}
480	fn decompose_cloned(&self)->Self::Decomposition{self.clone()}
481	type Decomposition=Self;
482}
483impl<B:Backend> From<CatLayer> for Layer<B>{
484	fn from(value:CatLayer)->Self{Layer::Cat(Ignored(value))}
485}
486impl<B:Backend> From<CrossEntropyLoss<B>> for Layer<B>{
487	fn from(value:CrossEntropyLoss<B>)->Self{Layer::CrossEntropy(value)}
488}
489impl<B:Backend> From<Dropout> for Layer<B>{
490	fn from(value:Dropout)->Self{Layer::Dropout(value)}
491}
492impl<B:Backend> From<Embedding<B>> for Layer<B>{
493	fn from(value:Embedding<B>)->Self{Layer::Embedding(value)}
494}
495impl<B:Backend> From<LayerNorm<B>> for Layer<B>{
496	fn from(value:LayerNorm<B>)->Self{Layer::LayerNorm(value)}
497}
498impl<B:Backend> From<Linear<B>> for Layer<B>{
499	fn from(value:Linear<B>)->Self{Layer::Linear(value)}
500}
501impl<B:Backend> From<MaxPool2d> for Layer<B>{
502	fn from(value:MaxPool2d)->Self{Layer::MaxPool2d(value)}
503}
504impl<B:Backend> From<MseLoss> for Layer<B>{
505	fn from(value:MseLoss)->Self{Layer::Mse(value)}
506}
507impl<B:Backend> From<Relu> for Layer<B>{
508	fn from(value:Relu)->Self{Layer::Relu(value)}
509}
510impl<B:Backend> From<RotaryEncoding<B>> for Layer<B>{
511	fn from(value:RotaryEncoding<B>)->Self{Layer::Rotary(value)}
512}
513impl<B:Backend> From<SqueezeLayer> for Layer<B>{
514	fn from(value:SqueezeLayer)->Self{Layer::Squeeze(Ignored(value))}
515}
516impl<B:Backend> From<StackLayer> for Layer<B>{
517	fn from(value:StackLayer)->Self{Layer::Stack(Ignored(value))}
518}
519impl<B:Backend> From<SumLayer> for Layer<B>{
520	fn from(value:SumLayer)->Self{Layer::Sum(Ignored(value))}
521}
522impl<B:Backend> From<Tanh> for Layer<B>{
523	fn from(value:Tanh)->Self{Layer::Tanh(value)}
524}
525impl<B:Backend> From<UnsqueezeLayer> for Layer<B>{
526	fn from(value:UnsqueezeLayer)->Self{Layer::Unsqueeze(Ignored(value))}
527}
528impl<B:Backend> Layer<B>{
529	/// creates a batch norm layer
530	pub fn batch_norm(countfeatures:usize,epsilon:f32,momentum:f32)->Self{Config::batch_norm(countfeatures,epsilon,momentum).init(&Default::default())}
531	/// creates a embedding layer
532	pub fn embedding(input:usize,output:usize,wscale:f32)->Self{
533		let mut l=EmbeddingConfig::new(input,output);
534		if wscale!=1.0{l.initializer=w_scale(l.initializer,wscale)}
535		let l=l.init(&Default::default());
536		Self::Embedding(l)
537	}
538	/// creates a layer norm layer
539	pub fn layer_norm(dim:usize)->Self{Self::LayerNorm(LayerNormConfig::new(dim).init(&Default::default()))}
540	/// creates a linear layer
541	pub fn linear(bias:bool,input:usize,output:usize,wscale:f32)->Self{
542		let mut l=LinearConfig::new(input,output).with_bias(bias);
543		if wscale!=1.0{l.initializer=w_scale(l.initializer,wscale)}
544		let l=l.init(&Default::default());
545		Self::Linear(l)
546	}
547	/// creates a max pool 2d layer
548	pub fn max_pool_2d(kernel:[usize;2],strides:[usize;2])->Self{MaxPool2dConfig::new(kernel).with_strides(strides).init().into()}
549	/// creates a relu layer
550	pub fn relu()->Self{Self::Relu(Relu)}
551	/// creates a rotary layer
552	pub fn rotary(distance:usize,head:usize)->Self{Self::Rotary(RotaryEncodingConfig::new(distance,head).init(&Default::default()))}
553	/// creates a scale shift layer
554	pub fn scale_shift()->Self{Self::ScaleShift(ScaleShiftConfig::new().init(&Default::default()))}
555	/// creates a tanh layer
556	pub fn tanh()->Self{Self::Tanh(Tanh)}
557}
558impl<B:Backend> Op for Layer<B>{
559	type Output=Value<B>;
560}
561#[derive(Clone,Copy,Debug,Deserialize,Serialize)]
562pub enum AttentionMask{Causal,None,Window(usize)}
563#[derive(Config)]
564/// enumerates config for some burn layers
565pub enum Config{Attention(AttentionConfig),BatchNorm(BatchNormConfig),Bias(BiasConfig),CacheKV,Cat(CatLayer),Conv2d(Conv2dConfig),CrossEntropy(CrossEntropyLossConfig),Dropout(DropoutConfig),Embedding(EmbeddingConfig),KQV(KQVConfig),LayerNorm(LayerNormConfig),Linear(LinearConfig),MaxPool2d(MaxPool2dConfig),Mse,Relu,Rotary(RotaryEncodingConfig),ScaleShift(ScaleShiftConfig),Squeeze(SqueezeLayer),Stack(StackLayer),Sum(SumLayer),Tanh,Unsqueeze(UnsqueezeLayer)}
566#[derive(Debug,Deserialize,Module,Serialize)]//TODO more layers
567#[serde(bound="")]
568/// enumerates some burn layers
569pub enum Layer<B:Backend>{
570	Attention(Attention<B>),
571	Bias(Bias<B>),
572	#[serde(deserialize_with="deserialize_batch_norm")]
573	#[serde(serialize_with="serialize_batch_norm")]
574	BatchNorm(BatchNorm<B,1>),
575	CacheKV(CacheKV<B>),
576	#[serde(deserialize_with="deserialize_ignored")]
577	#[serde(serialize_with="serialize_ignored")]
578	Cat(Ignored<CatLayer>),
579	#[serde(deserialize_with="deserialize_conv2d")]
580	#[serde(serialize_with="serialize_conv2d")]
581	Conv2d(Conv2d<B>),
582	#[serde(deserialize_with="deserialize_cross_entropy")]
583	#[serde(serialize_with="serialize_cross_entropy")]
584	CrossEntropy(CrossEntropyLoss<B>),
585	#[serde(deserialize_with="deserialize_dropout")]
586	#[serde(serialize_with="serialize_dropout")]
587	Dropout(Dropout),
588	#[serde(deserialize_with="deserialize_embedding")]
589	#[serde(serialize_with="serialize_embedding")]
590	Embedding(Embedding<B>),
591	KQV(KQV<B>),
592	#[serde(deserialize_with="deserialize_layer_norm")]
593	#[serde(serialize_with="serialize_layer_norm")]
594	LayerNorm(LayerNorm<B>),
595	#[serde(deserialize_with="deserialize_linear")]
596	#[serde(serialize_with="serialize_linear")]
597	Linear(Linear<B>),
598	#[serde(deserialize_with="deserialize_max_pool_2d")]
599	#[serde(serialize_with="serialize_max_pool_2d")]
600	MaxPool2d(MaxPool2d),
601	#[serde(deserialize_with="deserialize_nothing")]
602	#[serde(serialize_with="serialize_nothing")]
603	Mse(MseLoss),
604	#[serde(deserialize_with="deserialize_nothing")]
605	#[serde(serialize_with="serialize_nothing")]
606	Relu(Relu),
607	#[serde(deserialize_with="deserialize_rotary")]
608	#[serde(serialize_with="serialize_rotary")]
609	Rotary(RotaryEncoding<B>),
610	ScaleShift(ScaleShift<B>),
611	#[serde(deserialize_with="deserialize_ignored")]
612	#[serde(serialize_with="serialize_ignored")]
613	Squeeze(Ignored<SqueezeLayer>),
614	#[serde(deserialize_with="deserialize_ignored")]
615	#[serde(serialize_with="serialize_ignored")]
616	Stack(Ignored<StackLayer>),
617	#[serde(deserialize_with="deserialize_ignored")]
618	#[serde(serialize_with="serialize_ignored")]
619	Sum(Ignored<SumLayer>),
620	#[serde(deserialize_with="deserialize_nothing")]
621	#[serde(serialize_with="serialize_nothing")]
622	Tanh(Tanh),
623	#[serde(deserialize_with="deserialize_ignored")]
624	#[serde(serialize_with="serialize_ignored")]
625	Unsqueeze(Ignored<UnsqueezeLayer>),
626}
627/// scales the initializer
628pub fn w_scale(initializer:Initializer,r:f32)->Initializer{
629	let r=r as f64;// apparently
630	match initializer{
631		Initializer::Constant{value}=>Initializer::Constant{value:value*r},
632		Initializer::KaimingNormal{gain,fan_out_only}=>Initializer::KaimingNormal{gain:gain*r,fan_out_only},
633		Initializer::KaimingUniform{gain,fan_out_only}=>Initializer::KaimingUniform{gain:gain*r,fan_out_only},
634		Initializer::Normal{mean,std}=>Initializer::Normal{mean:mean*r,std:std*r},
635		Initializer::Ones=>Initializer::Constant{value:r},
636		Initializer::Orthogonal{gain}=>Initializer::Orthogonal{gain:gain*r},
637		Initializer::Uniform{min,max}=>Initializer::Uniform{min:min*r,max:max*r},
638		Initializer::XavierNormal{gain}=>Initializer::XavierNormal{gain:gain*r},
639		Initializer::XavierUniform{gain}=>Initializer::XavierUniform{gain:gain*r},
640		Initializer::Zeros=>Initializer::Zeros
641	}
642}
643/// scales the initializer
644pub fn w_scale_mut(initializer:&mut Initializer,r:f32){*initializer=w_scale(initializer.clone(),r)}
645#[derive(Config,Debug)]
646/// layer for computing attention from [key,query,value] inputs
647pub struct AttentionConfig{
648	#[config(default="0.2")]
649	dropout:f32,
650	heads:usize,
651	mask:AttentionMask
652}
653#[derive(Debug,Deserialize,Module,Serialize)]
654#[serde(bound="")]
655/// layer for computing attention from [key,query,value] inputs
656pub struct Attention<B:Backend>{
657	dropout:f32,
658	heads:usize,
659	#[serde(deserialize_with="deserialize_ignored")]
660	#[serde(serialize_with="serialize_ignored")]
661	mask:Ignored<AttentionMask>,
662	phantom:PhantomData<B>
663}
664#[derive(Config,Debug)]
665/// layer for adding bias somewhere
666pub struct BiasConfig{
667	dim:usize,
668	#[config(default="Initializer::Normal{mean:0.0,std:1.0}")]
669	initializer:Initializer
670}
671#[derive(Config,Debug)]
672/// layer for linear splitting into [key,query,value] for attention purposes
673pub struct KQVConfig{
674	embed:usize,
675	#[config(default="Initializer::XavierNormal{gain:1.0}")]
676	initializer:Initializer,
677	kdim:usize,
678	vdim:usize
679}
680#[derive(Debug,Deserialize,Module,Serialize)]
681#[serde(bound="")]
682/// layer for adding bias anywhere
683pub struct Bias<B:Backend>{
684	#[serde(deserialize_with="deserialize_param")]
685	#[serde(serialize_with="serialize_param")]
686	bias:Param<Tensor<B,1>>
687}
688#[derive(Debug,Default,Deserialize,Module,Serialize)]
689#[serde(bound="")]
690/// layer for caching kv values from kqv when run mutably. cats along d1 and outputs the concatenated keys and values. clears cache on forward_mut when new data is incompatible for concatenation
691pub struct CacheKV<B:Backend>{keys:Value<B>,values:Value<B>}
692#[derive(Debug,Deserialize,Module,Serialize)]
693#[serde(bound="")]
694/// layer for linear splitting into [key,query,value] for attention purposes
695pub struct KQV<B:Backend>{
696	#[serde(deserialize_with="deserialize_linear")]
697	#[serde(serialize_with="serialize_linear")]
698	key:Linear<B>,
699	#[serde(deserialize_with="deserialize_linear")]
700	#[serde(serialize_with="serialize_linear")]
701	query:Linear<B>,
702	#[serde(deserialize_with="deserialize_linear")]
703	#[serde(serialize_with="serialize_linear")]
704	value:Linear<B>
705}
706#[derive(Debug,Deserialize,Module,Serialize)]
707#[serde(bound="")]
708/// layer that applies a componentwise scalar affine transformation: f(x)=ax+b where a and b are tunable scalars
709pub struct ScaleShift<B:Backend>{
710	#[serde(deserialize_with="deserialize_param")]
711	#[serde(serialize_with="serialize_param")]
712	a:Param<Tensor<B,1>>,
713	#[serde(deserialize_with="deserialize_param")]
714	#[serde(serialize_with="serialize_param")]
715	b:Param<Tensor<B,1>>
716}
717#[derive(Config,Debug)]
718/// scale shift config
719pub struct ScaleShiftConfig{
720	#[config(default="None")]
721	initializer:Option<Initializer>
722}
723#[derive(Deserialize,Serialize)]
724#[serde(bound="")]
725struct Conv2dRecord<B:Backend>{
726	bias:Option<Value<B>>,
727	dilation:[usize;2],
728	groups:usize,
729	kernelsize:[usize;2],
730	#[serde(deserialize_with="deserialize_ignored")]
731	#[serde(serialize_with="serialize_ignored")]
732	padding:Ignored<PaddingConfig2d>,
733	stride:[usize;2],
734	weight:Value<B>
735}
736#[derive(Deserialize,Serialize)]
737#[serde(bound="")]
738struct BatchNormRecord<B:Backend>{beta:Value<B>,epsilon:f64,gamma:Value<B>,mean:Value<B>,momentum:f64,variance:Value<B>}
739#[derive(Deserialize,Serialize)]
740#[serde(bound="")]
741struct CrossEntropyRecord<B:Backend>{logits:bool,pad:Option<Vec<usize>>,weights:Option<Value<B>>,smoothing:Option<f32>}
742#[derive(Deserialize,Serialize)]
743#[serde(bound="")]
744struct LayerNormRecord<B:Backend>{beta:Value<B>,gamma:Value<B>}
745#[derive(Deserialize,Serialize)]
746#[serde(bound="")]
747struct LinearRecord<B:Backend>{bias:Option<Value<B>>,weight:Value<B>}
748use burn::{
749	module::{Ignored,Param,RunningState},
750	nn::{
751		BatchNorm,BatchNormConfig,Dropout,DropoutConfig,Embedding,EmbeddingConfig,Initializer,LayerNorm,LayerNormConfig,Linear,LinearConfig,PaddingConfig2d,Relu,RotaryEncoding,RotaryEncodingConfig,Tanh,conv::{Conv2d,Conv2dConfig},loss::{CrossEntropyLoss,CrossEntropyLossConfig,MseLoss},pool::{MaxPool2d,MaxPool2dConfig}
752	},
753	prelude::*,
754	tensor::activation
755};
756use crate::{
757	ai::{AI,Decompose,IntoSequence,Op},
758	builtin::{
759		Sequential,math::SumLayer,structural::{CatLayer,SqueezeLayer,StackLayer,UnsqueezeLayer}
760	},
761	burn::Value,
762	ops::Cat as OpsCat
763};
764use serde::{Deserialize,Deserializer,Serialize,Serializer,de::Error as Derror,ser::Error as Serror};
765use std::{fmt::Display,marker::PhantomData,mem};