1use std::{collections::HashMap, marker::PhantomData, sync::Arc};
2
3#[cfg(not(target_arch = "wasm32"))]
4use futures::future::BoxFuture;
5#[cfg(target_arch = "wasm32")]
6use futures::future::LocalBoxFuture;
7use half::f16;
8use itertools::Itertools;
9use serde::{Deserialize, Serialize};
10use web_rwkv_derive::DeserializeSeed;
11use wgpu::CommandBuffer;
12
13use super::{
14 infer::{RnnChunk, RnnInfo, RnnInput, RnnOutput, RnnOutputBatch, RnnRedirect, Token},
15 loader::{Loader, LoaderError, Reader},
16 model::{AsAny, ModelBuilder, ModelCustomInfo, ModelInfo, State as _},
17 Dispatcher, Job, RuntimeError,
18};
19use crate::{
20 context::Context,
21 num::Float,
22 runtime::model::Quant,
23 tensor::{
24 cache::ResourceCache,
25 kind::ReadWrite,
26 matrix::Matrix,
27 ops::{Activation, TensorCommand, TensorOp},
28 serialization::Seed,
29 shape::{Shape, TensorDimension},
30 DeepClone, IntoPackedCursors, TensorCpu, TensorError, TensorGpu, TensorGpuView, TensorInit,
31 TensorReshape, TensorShape, TensorStack,
32 },
33};
34
35#[derive(Debug, Clone, Serialize, DeserializeSeed)]
36#[serde_seed(seed = "Seed", context = "Context")]
37pub struct Model {
38 pub context: Context,
39 pub info: ModelInfo,
40 pub rescale: usize,
41 pub sep: usize,
42 pub tensor: ModelTensor,
43}
44
45impl Model {
46 pub const L2_EPS: f32 = 1.0e-12;
47 pub const LN_EPS: f32 = 1.0e-5;
48 pub const GN_EPS: f32 = 64.0e-5;
49
50 pub const DEFAULT_RESCALE: usize = 1024;
51 pub const DEFAULT_SEP: usize = 1024;
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
55pub struct CustomInfo {
56 pub w: usize,
57 pub a: usize,
58 pub g: usize,
59 pub v: usize,
60}
61
62#[derive(Debug, Clone, Serialize, DeserializeSeed)]
63#[serde_seed(seed = "Seed", context = "Context")]
64pub struct ModelTensor {
65 pub embed: Embed,
66 pub head: Head,
67 pub layers: Vec<Layer>,
68}
69
70#[derive(Debug, Clone, Serialize, DeserializeSeed)]
71#[serde_seed(seed = "Seed", context = "Context")]
72pub struct LayerNorm {
73 pub w: TensorGpu<f16, ReadWrite>,
74 pub b: TensorGpu<f16, ReadWrite>,
75}
76
77#[derive(Debug, Clone, Serialize, DeserializeSeed)]
78#[serde_seed(seed = "Seed", context = "Context")]
79pub struct Att {
80 pub x_r: TensorGpu<f16, ReadWrite>,
81 pub x_w: TensorGpu<f16, ReadWrite>,
82 pub x_k: TensorGpu<f16, ReadWrite>,
83 pub x_v: TensorGpu<f16, ReadWrite>,
84 pub x_a: TensorGpu<f16, ReadWrite>,
85 pub x_g: TensorGpu<f16, ReadWrite>,
86
87 pub w0: TensorGpu<f16, ReadWrite>,
88 pub a0: TensorGpu<f16, ReadWrite>,
89 pub v0: TensorGpu<f16, ReadWrite>,
90
91 pub w1: Matrix,
92 pub w2: Matrix,
93 pub a1: Matrix,
94 pub a2: Matrix,
95 pub g1: Matrix,
96 pub g2: Matrix,
97 pub v1: Matrix,
98 pub v2: Matrix,
99
100 pub r_k: TensorGpu<f16, ReadWrite>,
101 pub k_k: TensorGpu<f16, ReadWrite>,
102 pub k_a: TensorGpu<f16, ReadWrite>,
103
104 pub w_k: Matrix,
105 pub w_v: Matrix,
106 pub w_r: Matrix,
107 pub w_o: Matrix,
108
109 pub gn: LayerNorm,
110}
111
112#[derive(Debug, Clone, Serialize, DeserializeSeed)]
113#[serde_seed(seed = "Seed", context = "Context")]
114pub struct Ffn {
115 pub x_k: TensorGpu<f16, ReadWrite>,
116
117 pub w_k: Matrix,
118 pub w_v: Matrix,
119}
120
121#[derive(Debug, Clone, Serialize, DeserializeSeed)]
122#[serde_seed(seed = "Seed", context = "Context")]
123pub struct Layer {
124 pub att_ln: LayerNorm,
125 pub ffn_ln: LayerNorm,
126 pub att: Att,
127 pub ffn: Ffn,
128}
129
130#[derive(Debug, Clone, Serialize, DeserializeSeed)]
131#[serde_seed(seed = "Seed", context = "Context")]
132pub struct Embed {
133 pub ln: LayerNorm,
134 pub w: TensorCpu<f16>,
135}
136
137#[derive(Debug, Clone, Serialize, DeserializeSeed)]
138#[serde_seed(seed = "Seed", context = "Context")]
139pub struct Head {
140 pub ln: LayerNorm,
141 pub w: Matrix,
142}
143
144#[derive(Debug, Clone, Serialize, DeserializeSeed)]
145#[serde_seed(seed = "Seed", context = "Context")]
146pub struct State {
147 pub context: Context,
148 pub info: ModelInfo,
149 pub data: Vec<TensorGpu<f32, ReadWrite>>,
150}
151
152impl State {
153 async fn back(&self, batch: usize) -> Result<TensorCpu<f32>, TensorError> {
154 let context = &self.context;
155 let mut tensors = Vec::with_capacity(self.info.num_layer);
156 let mut encoder = context.device.create_command_encoder(&Default::default());
157 for data in self.data.iter() {
158 let shape = data.shape();
159 let destination = context.tensor_init([shape[0], shape[1], 1, 1]);
160 encoder.copy_tensor_batch(data, &destination, batch, 0)?;
161 tensors.push(destination);
162 }
163 context.queue.submit(Some(encoder.finish()));
164
165 let mut backed = Vec::with_capacity(tensors.len());
166 for tensor in tensors {
167 backed.push(tensor.back().await);
168 }
169 TensorCpu::stack(backed, 2)
170 }
171}
172
173impl AsAny for State {
174 fn as_any(&self) -> &dyn std::any::Any {
175 self
176 }
177}
178
179impl super::model::State for State {
180 #[inline]
181 fn num_batch(&self) -> usize {
182 self.data[0].shape()[2]
183 }
184
185 #[inline]
186 fn init_shape(&self) -> Shape {
187 let info = &self.info;
188 let head_size = info.num_emb / info.num_head;
189 [info.num_emb, head_size + 2, info.num_layer, 1].into()
190 }
191
192 fn init(&self) -> TensorCpu<f32> {
193 let shape = self.init_shape();
194 let data = vec![0.0; shape.len()];
195 TensorCpu::from_data(shape, data).unwrap()
196 }
197
198 fn att(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError> {
199 let head_size = self.info.num_emb / self.info.num_head;
200 let end = head_size + 1;
201 self.data[layer].view(.., 0..end, .., ..)
202 }
203
204 fn ffn(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError> {
205 let head_size = self.info.num_emb / self.info.num_head;
206 let start = head_size + 1;
207 self.data[layer].view(.., start, .., ..)
208 }
209
210 fn load(&self, tensor: TensorCpu<f32>, batch: usize) -> Result<(), TensorError> {
211 let head_size = self.info.num_emb / self.info.num_head;
212 tensor.check_shape([self.info.num_emb, head_size + 2, self.info.num_layer, 1])?;
213 for (data, source) in self.data.iter().zip(tensor.split(2)?.into_iter()) {
214 data.load_batch(&source, batch)?;
215 }
216 Ok(())
217 }
218
219 #[cfg(not(target_arch = "wasm32"))]
220 fn back(&self, batch: usize) -> BoxFuture<'_, Result<TensorCpu<f32>, TensorError>> {
221 Box::pin(self.back(batch))
222 }
223
224 #[cfg(target_arch = "wasm32")]
225 fn back(&self, batch: usize) -> LocalBoxFuture<'_, Result<TensorCpu<f32>, TensorError>> {
226 Box::pin(self.back(batch))
227 }
228
229 fn write(&self, tensor: TensorGpu<f32, ReadWrite>, batch: usize) -> Result<(), TensorError> {
230 let head_size = self.info.num_emb / self.info.num_head;
231 tensor.check_shape([self.info.num_emb, head_size + 2, self.info.num_layer, 1])?;
232
233 let context = &self.context;
234 let mut ops = Vec::with_capacity(self.data.len());
235 for (layer, data) in self.data.iter().enumerate() {
236 ops.push(TensorOp::blit(
237 tensor.view(.., .., layer, ..)?,
238 data.view(.., .., batch, ..)?,
239 )?);
240 }
241 context.queue.submit(context.encode(&TensorOp::List(ops)));
242
243 Ok(())
244 }
245
246 fn read(&self, batch: usize) -> Result<TensorGpu<f32, ReadWrite>, TensorError> {
247 let context = &self.context;
248 let head_size = self.info.num_emb / self.info.num_head;
249 let shape = [self.info.num_emb, head_size + 2, self.info.num_layer, 1];
250 let tensor: TensorGpu<_, _> = context.tensor_init(shape);
251
252 let mut ops = Vec::with_capacity(self.data.len());
253 for (layer, data) in self.data.iter().enumerate() {
254 ops.push(TensorOp::blit(
255 data.view(.., .., batch, ..)?,
256 tensor.view(.., .., layer, ..)?,
257 )?);
258 }
259 context.queue.submit(context.encode(&TensorOp::List(ops)));
260
261 Ok(tensor)
262 }
263
264 fn embed(&self, layer: usize, backed: TensorCpu<f32>) -> Result<TensorCpu<f32>, TensorError> {
265 backed.slice(.., 0, layer, ..)
266 }
267}
268
269impl DeepClone for State {
270 fn deep_clone(&self) -> Self {
271 let data = self.data.iter().map(|tensor| tensor.deep_clone()).collect();
272 Self {
273 data,
274 ..self.clone()
275 }
276 }
277}
278
279#[derive(Debug, Clone, Serialize, DeserializeSeed)]
280#[serde_seed(seed = "Seed", context = "Context")]
281pub struct Runtime<F: Float> {
282 pub cursors: TensorGpu<u32, ReadWrite>,
283 pub input: TensorGpu<f16, ReadWrite>,
284
285 pub x: TensorGpu<F, ReadWrite>,
286
287 pub att_x: TensorGpu<F, ReadWrite>,
288 pub att_v0: TensorGpu<F, ReadWrite>,
289
290 pub att_rx: TensorGpu<F, ReadWrite>,
291 pub att_wx: TensorGpu<F, ReadWrite>,
292 pub att_kx: TensorGpu<F, ReadWrite>,
293 pub att_vx: TensorGpu<F, ReadWrite>,
294 pub att_ax: TensorGpu<F, ReadWrite>,
295 pub att_gx: TensorGpu<F, ReadWrite>,
296
297 pub att_r: TensorGpu<F, ReadWrite>,
298 pub att_w: TensorGpu<F, ReadWrite>,
299 pub att_k: TensorGpu<F, ReadWrite>,
300 pub att_v: TensorGpu<F, ReadWrite>,
301 pub att_a: TensorGpu<F, ReadWrite>,
302 pub att_g: TensorGpu<F, ReadWrite>,
303 pub att_o: TensorGpu<F, ReadWrite>,
304
305 pub att_kk: TensorGpu<F, ReadWrite>,
306 pub att_vv: TensorGpu<F, ReadWrite>,
307
308 pub att_n: TensorGpu<F, ReadWrite>,
309
310 pub aux_w: TensorGpu<F, ReadWrite>,
312 pub aux_a: TensorGpu<F, ReadWrite>,
313 pub aux_g: TensorGpu<F, ReadWrite>,
314 pub aux_v: TensorGpu<F, ReadWrite>,
315
316 pub ffn_x: TensorGpu<F, ReadWrite>,
317 pub ffn_kx: TensorGpu<F, ReadWrite>,
318 pub ffn_k: TensorGpu<F, ReadWrite>,
319 pub ffn_v: TensorGpu<F, ReadWrite>,
320}
321
322impl<F: Float> Runtime<F> {
323 pub fn new(context: &Context, info: &ModelInfo, num_token: usize) -> Self {
324 let ModelCustomInfo::V7(custom) = info.custom else {
325 unreachable!()
326 };
327
328 let shape = Shape::new(info.num_emb, num_token, 1, 1);
329 let cursors_shape = Shape::new(num_token, 1, 1, 1);
330 let hidden_shape = Shape::new(info.num_hidden, num_token, 1, 1);
331
332 Self {
333 cursors: context.tensor_init(cursors_shape),
334 input: context.tensor_init(shape),
335 x: context.tensor_init(shape),
336 att_x: context.tensor_init(shape),
337 att_v0: context.tensor_init(shape),
338 att_rx: context.tensor_init(shape),
339 att_wx: context.tensor_init(shape),
340 att_kx: context.tensor_init(shape),
341 att_vx: context.tensor_init(shape),
342 att_ax: context.tensor_init(shape),
343 att_gx: context.tensor_init(shape),
344 att_r: context.tensor_init(shape),
345 att_w: context.tensor_init(shape),
346 att_k: context.tensor_init(shape),
347 att_v: context.tensor_init(shape),
348 att_a: context.tensor_init(shape),
349 att_g: context.tensor_init(shape),
350 att_o: context.tensor_init(shape),
351 att_kk: context.tensor_init(shape),
352 att_vv: context.tensor_init(shape),
353 att_n: context.tensor_init([shape[0], shape[1], 4, 1]),
354 aux_w: context.tensor_init([custom.w, shape[1], 1, 1]),
355 aux_a: context.tensor_init([custom.a, shape[1], 1, 1]),
356 aux_g: context.tensor_init([custom.g, shape[1], 1, 1]),
357 aux_v: context.tensor_init([custom.v, shape[1], 1, 1]),
358 ffn_x: context.tensor_init(shape),
359 ffn_kx: context.tensor_init(shape),
360 ffn_k: context.tensor_init(hidden_shape),
361 ffn_v: context.tensor_init(shape),
362 }
363 }
364}
365
366#[derive(Debug, Clone, Serialize, DeserializeSeed)]
367#[serde_seed(seed = "Seed", context = "Context")]
368pub struct Header<F: Float> {
369 pub head_x: TensorGpu<F, ReadWrite>,
370 pub head_o: TensorGpu<f32, ReadWrite>,
371}
372
373impl<F: Float> Header<F> {
374 pub fn new(context: &Context, info: &ModelInfo, num_header: usize) -> Self {
375 let head_shape = Shape::new(info.num_emb, num_header, 1, 1);
376 let output_shape = Shape::new(info.num_vocab_padded(), num_header, 1, 1);
377
378 Self {
379 head_x: context.tensor_init(head_shape),
380 head_o: context.tensor_init(output_shape),
381 }
382 }
383}
384
385#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
386pub enum Hook {
387 PostEmbedLoaded,
388 PostEmbedLayerNorm,
389 PreAtt(usize),
390 PostAttLayerNorm(usize),
391 PreAttTokenShift(usize),
392 PostAttTokenShift(usize),
393 PreAttLinear(usize),
394 PostAttLinear(usize),
395 PreAttAdapt(usize),
396 PostAttAdapt(usize),
397 PreAttControl(usize),
398 PostAttControl(usize),
399 PreAttValueResidual(usize),
400 PostAttValueResidual(usize),
401 PreAttTimeMix(usize),
402 PostAttTimeMix(usize),
403 PreAttGate(usize),
404 PostAttGate(usize),
405 PreAttOut(usize),
406 PostAttOut(usize),
407 PostAtt(usize),
408 PreFfn(usize),
409 PostFfnLayerNorm(usize),
410 PreFfnTokenShift(usize),
411 PostFfnTokenShift(usize),
412 PreFfnLinear(usize),
413 PostFfnLinear(usize),
414 PostFfnActivate(usize),
415 PreFfnChannelMix(usize),
416 PostFfnChannelMix(usize),
417 PostFfn(usize),
418 PreHead,
419 PostHeadLayerNorm,
420 PostHead,
421}
422
423pub struct RnnJob {
424 commands: Vec<CommandBuffer>,
425 redirect: RnnRedirect,
426
427 embed: TensorCpu<f16>,
428
429 cursors: TensorGpu<u32, ReadWrite>,
430 input: TensorGpu<f16, ReadWrite>,
431 output: TensorGpu<f32, ReadWrite>,
432}
433
434impl Job for RnnJob {
435 type Input = RnnInput;
436 type Output = RnnOutput;
437
438 fn load(&self, input: &RnnChunk) -> Result<(), RuntimeError> {
439 if input.num_token() == 0 {
440 return Ok(());
441 }
442
443 let stack: Vec<TensorCpu<f16>> = input
444 .iter()
445 .map(|chunk| {
446 let num_emb = self.embed.shape()[0];
447 let data = self.embed.data();
448 let data = chunk
449 .iter()
450 .map(|token| match token {
451 &Token::Token(token) => {
452 let start = num_emb * token as usize;
453 let end = start + num_emb;
454 let data = data[start..end].to_vec();
455 TensorCpu::from_data_1d(data)
456 }
457 Token::Embed(tensor) => tensor.clone(),
458 })
459 .collect_vec();
460 match TensorCpu::stack(data, 1) {
461 Ok(tensor) => tensor,
462 Err(_) => TensorCpu::init([num_emb, 0, 1, 1]),
463 }
464 })
465 .collect();
466 let stack = TensorStack::try_from(stack)?;
467
468 let cursors = stack.cursors.clone().into_cursors();
469 let cursors = TensorCpu::from_data(self.cursors.shape(), cursors)?;
470 self.cursors.load(&cursors)?;
471 self.input.load(&stack.tensor)?;
472
473 Ok(())
474 }
475
476 fn submit(&mut self) {
477 let commands = std::mem::take(&mut self.commands);
478 self.output.context.queue.submit(commands);
479 }
480
481 async fn back(self) -> Result<Self::Output, RuntimeError> {
482 let output = self.output.back().await;
483 let batches: Vec<_> = self
484 .redirect
485 .outputs
486 .into_iter()
487 .map(|(start, end)| output.slice(.., start..end, .., ..))
488 .try_collect()?;
489 let batches = batches.into_iter().map(RnnOutputBatch).collect();
490 Ok(RnnOutput(batches))
491 }
492}
493
494#[derive(Debug, Clone)]
495pub struct Frame<F: Float> {
496 pub state: State,
497 pub buffer: Arc<Runtime<F>>,
498 pub header: Arc<Header<F>>,
499}
500
501pub type HookFn<F> = Box<dyn Fn(Frame<F>) -> Result<TensorOp, TensorError> + Send + Sync>;
502pub type HookMap<F> = HashMap<Hook, HookFn<F>>;
503
504#[derive(Clone)]
505pub struct Bundle<F: Float> {
506 model: Model,
507 state: State,
508 hooks: Arc<HookMap<F>>,
509 buffers: ResourceCache<usize, Runtime<F>>,
510 headers: ResourceCache<usize, Header<F>>,
511 phantom: PhantomData<F>,
512}
513
514impl<F: Float> Bundle<F> {
515 pub fn new(model: Model, num_batch: usize) -> Self {
516 let context = model.context.clone();
517 let info = model.info.clone();
518 let state = {
519 let head_size = info.num_emb / info.num_head;
520 let shape = Shape::new(info.num_emb, head_size + 2, num_batch, 1);
521 let data = (0..info.num_layer).map(|_| context.zeros(shape)).collect();
522 State {
523 context,
524 info,
525 data,
526 }
527 };
528 Self {
529 model,
530 state,
531 hooks: Default::default(),
532 buffers: ResourceCache::new(4),
533 headers: ResourceCache::new(4),
534 phantom: PhantomData,
535 }
536 }
537
538 pub fn new_with_hooks(model: Model, num_batch: usize, hooks: HookMap<F>) -> Self {
539 Self {
540 hooks: Arc::new(hooks),
541 ..Self::new(model, num_batch)
542 }
543 }
544
545 fn checkout_buffer(
546 &self,
547 context: &Context,
548 info: &ModelInfo,
549 num_token: usize,
550 ) -> Arc<Runtime<F>> {
551 self.buffers
552 .checkout(num_token, || Runtime::new(context, info, num_token))
553 }
554
555 fn checkout_header(
556 &self,
557 context: &Context,
558 info: &ModelInfo,
559 num_header: usize,
560 ) -> Arc<Header<F>> {
561 self.headers
562 .checkout(num_header, || Header::new(context, info, num_header))
563 }
564}
565
566impl<F: Float> super::model::Bundle for Bundle<F> {
567 #[inline]
568 fn info(&self) -> ModelInfo {
569 self.model.info.clone()
570 }
571
572 #[inline]
573 fn state(&self) -> impl super::model::State + AsAny + 'static {
574 self.state.clone()
575 }
576
577 #[inline]
578 fn model(&self) -> impl Serialize + 'static {
579 self.model.clone()
580 }
581}
582
583fn turbo(num_token: usize) -> bool {
584 num_token.is_multiple_of(super::infer::rnn::MIN_TOKEN_CHUNK_SIZE)
585}
586
587fn hook_op<F: Float>(
588 hooks: &HookMap<F>,
589 hook: &Hook,
590 frame: &Frame<F>,
591) -> Result<TensorOp, TensorError> {
592 match hooks.get(hook) {
593 Some(f) => f(frame.clone()),
594 None => Ok(TensorOp::empty()),
595 }
596}
597
598impl<F: Float> Dispatcher<RnnJob> for Bundle<F> {
599 type Info = RnnInfo;
600
601 fn dispatch(&self, seed: Self::Info) -> Result<RnnJob, RuntimeError> {
602 let model = &self.model;
603 let state = &self.state;
604 let context = &model.context;
605 let info = &model.info;
606 let tensor = &model.tensor;
607
608 let num_token = seed.num_token();
609 let head_size = info.num_emb / info.num_head;
610
611 let redirect = seed.redirect();
612 let num_header = redirect.headers.len();
613
614 let buffer = self.checkout_buffer(context, info, num_token);
615 let header = self.checkout_header(context, info, num_header);
616 let frame = Frame {
617 state: state.clone(),
618 buffer: buffer.clone(),
619 header: header.clone(),
620 };
621
622 context.maintain();
623 self.buffers.maintain();
624 self.headers.maintain();
625
626 if num_token == 0 {
627 return Ok(RnnJob {
628 commands: vec![],
629 redirect,
630 embed: model.tensor.embed.w.clone(),
631 cursors: buffer.cursors.clone(),
632 input: buffer.input.clone(),
633 output: header.head_o.clone(),
634 });
635 }
636
637 #[cfg(feature = "trace")]
638 let _span = tracing::trace_span!("build").entered();
639
640 let (head_op, head_x) = redirect.op(&buffer.x, &header.head_x)?;
641
642 let hook_op = |hook: Hook| hook_op(&self.hooks, &hook, &frame);
643 let mut ops = vec![];
644
645 {
646 #[cfg(feature = "trace")]
647 let _span = tracing::trace_span!("embed").entered();
648
649 ops.extend([
650 hook_op(Hook::PostEmbedLoaded)?,
651 TensorOp::layer_norm(
652 &tensor.embed.ln.w,
653 &tensor.embed.ln.b,
654 &buffer.input,
655 Model::LN_EPS,
656 )?,
657 TensorOp::blit(&buffer.input, &buffer.x)?,
658 hook_op(Hook::PostEmbedLayerNorm)?,
659 ]);
660 };
661
662 for (index, layer) in tensor.layers.iter().enumerate() {
663 #[cfg(feature = "trace")]
664 let _span = tracing::trace_span!("layer", index).entered();
665
666 let hooks = self.hooks.clone();
667 let frame = frame.clone();
668 let layer = layer.clone();
669
670 let op = dispatch_layer(
671 hooks,
672 frame,
673 layer,
674 index,
675 num_token,
676 head_size,
677 model.rescale,
678 )?;
679 ops.push(op);
680
681 if (index + 1) % model.sep == 0 {
682 ops.push(TensorOp::Sep);
683 }
684 }
685
686 {
687 #[cfg(feature = "trace")]
688 let _span = tracing::trace_span!("header").entered();
689
690 let hooks = self.hooks.clone();
691 let frame = frame.clone();
692 let head = model.tensor.head.clone();
693
694 let op = dispatch_header(hooks, frame, head, head_x, num_header, head_op)?;
695 ops.push(op);
696 }
697
698 let commands = {
699 #[cfg(feature = "trace")]
700 let _span = tracing::trace_span!("encode").entered();
701 context.encode(&TensorOp::List(ops))
702 };
703
704 Ok(RnnJob {
705 commands,
706 redirect,
707 embed: model.tensor.embed.w.clone(),
708 cursors: buffer.cursors.clone(),
709 input: buffer.input.clone(),
710 output: header.head_o.clone(),
711 })
712 }
713}
714
715#[allow(clippy::too_many_arguments)]
716fn dispatch_layer<F: Float>(
717 hooks: Arc<HookMap<F>>,
718 frame: Frame<F>,
719 layer: Layer,
720 index: usize,
721 num_token: usize,
722 head_size: usize,
723 rescale: usize,
724) -> Result<TensorOp, TensorError> {
725 let hook_op = |hook: Hook| hook_op(&hooks, &hook, &frame);
726 let Frame { state, buffer, .. } = &frame;
727
728 let att_kk = buffer.att_kk.reshape(
729 TensorDimension::Size(head_size),
730 TensorDimension::Auto,
731 TensorDimension::Size(num_token),
732 TensorDimension::Size(1),
733 )?;
734 let att_x = buffer.att_x.reshape(
735 TensorDimension::Size(head_size),
736 TensorDimension::Auto,
737 TensorDimension::Size(num_token),
738 TensorDimension::Size(1),
739 )?;
740 let att_r = buffer.att_r.reshape(
741 TensorDimension::Size(head_size),
742 TensorDimension::Auto,
743 TensorDimension::Size(num_token),
744 TensorDimension::Size(1),
745 )?;
746 let att_w = buffer.att_w.reshape(
747 TensorDimension::Size(head_size),
748 TensorDimension::Auto,
749 TensorDimension::Size(num_token),
750 TensorDimension::Size(1),
751 )?;
752 let att_n = buffer.att_n.reshape(
753 TensorDimension::Size(head_size),
754 TensorDimension::Auto,
755 TensorDimension::Size(num_token),
756 TensorDimension::Size(4),
757 )?;
758
759 let mut ops = vec![];
760
761 ops.extend([
762 TensorOp::blit(&buffer.x, &buffer.att_x)?,
763 hook_op(Hook::PreAtt(index))?,
764 TensorOp::layer_norm(
765 &layer.att_ln.w,
766 &layer.att_ln.b,
767 &buffer.att_x,
768 Model::LN_EPS,
769 )?,
770 hook_op(Hook::PostAttLayerNorm(index))?,
771 hook_op(Hook::PreAttTokenShift(index))?,
772 TensorOp::token_shift(
773 &buffer.cursors,
774 &layer.att.x_r,
775 state.att(index)?,
776 &buffer.att_x,
777 &buffer.att_rx,
778 true,
779 )?,
780 TensorOp::token_shift(
781 &buffer.cursors,
782 &layer.att.x_w,
783 state.att(index)?,
784 &buffer.att_x,
785 &buffer.att_wx,
786 true,
787 )?,
788 TensorOp::token_shift(
789 &buffer.cursors,
790 &layer.att.x_k,
791 state.att(index)?,
792 &buffer.att_x,
793 &buffer.att_kx,
794 true,
795 )?,
796 TensorOp::token_shift(
797 &buffer.cursors,
798 &layer.att.x_v,
799 state.att(index)?,
800 &buffer.att_x,
801 &buffer.att_vx,
802 true,
803 )?,
804 TensorOp::token_shift(
805 &buffer.cursors,
806 &layer.att.x_a,
807 state.att(index)?,
808 &buffer.att_x,
809 &buffer.att_ax,
810 true,
811 )?,
812 TensorOp::token_shift(
813 &buffer.cursors,
814 &layer.att.x_g,
815 state.att(index)?,
816 &buffer.att_x,
817 &buffer.att_gx,
818 true,
819 )?,
820 hook_op(Hook::PostAttTokenShift(index))?,
821 hook_op(Hook::PreAttLinear(index))?,
822 layer.att.w_r.matmul_op(
823 &buffer.att_rx,
824 &buffer.att_r,
825 Activation::None,
826 turbo(num_token),
827 )?,
828 layer.att.w_k.matmul_op(
829 &buffer.att_kx,
830 &buffer.att_k,
831 Activation::None,
832 turbo(num_token),
833 )?,
834 layer.att.w_v.matmul_op(
835 &buffer.att_vx,
836 &buffer.att_v,
837 Activation::None,
838 turbo(num_token),
839 )?,
840 hook_op(Hook::PostAttLinear(index))?,
841 hook_op(Hook::PreAttAdapt(index))?,
842 layer.att.w1.matmul_op(
843 &buffer.att_wx,
844 &buffer.aux_w,
845 Activation::Tanh,
846 turbo(num_token),
847 )?,
848 layer.att.w2.matmul_op(
849 &buffer.aux_w,
850 &buffer.att_w,
851 Activation::None,
852 turbo(num_token),
853 )?,
854 TensorOp::add(&layer.att.w0, &buffer.att_w)?,
855 layer.att.a1.matmul_op(
856 &buffer.att_ax,
857 &buffer.aux_a,
858 Activation::None,
859 turbo(num_token),
860 )?,
861 layer.att.a2.matmul_op(
862 &buffer.aux_a,
863 &buffer.att_a,
864 Activation::None,
865 turbo(num_token),
866 )?,
867 TensorOp::add_activate(
868 &layer.att.a0,
869 &buffer.att_a,
870 Activation::None,
871 Activation::None,
872 Activation::Sigmoid,
873 )?,
874 layer.att.g1.matmul_op(
875 &buffer.att_gx,
876 &buffer.aux_g,
877 Activation::Sigmoid,
878 turbo(num_token),
879 )?,
880 layer.att.g2.matmul_op(
881 &buffer.aux_g,
882 &buffer.att_g,
883 Activation::None,
884 turbo(num_token),
885 )?,
886 hook_op(Hook::PostAttAdapt(index))?,
887 hook_op(Hook::PreAttControl(index))?,
888 TensorOp::blit(&buffer.att_k, &buffer.att_kk)?,
889 TensorOp::mul(&layer.att.k_k, &buffer.att_kk)?,
890 TensorOp::l2_norm(&att_kk, Model::L2_EPS)?,
891 TensorOp::control_k_v7(&layer.att.k_a, &buffer.att_a, &buffer.att_k)?,
892 hook_op(Hook::PostAttControl(index))?,
893 ]);
894
895 ops.push(hook_op(Hook::PreAttValueResidual(index))?);
896 match index {
897 0 => ops.push(TensorOp::blit(&buffer.att_v, &buffer.att_v0)?),
898 _ => ops.extend([
899 layer.att.v1.matmul_op(
900 &buffer.att_vx,
901 &buffer.aux_v,
902 Activation::None,
903 turbo(num_token),
904 )?,
905 layer.att.v2.matmul_op(
906 &buffer.aux_v,
907 &buffer.att_vv,
908 Activation::None,
909 turbo(num_token),
910 )?,
911 TensorOp::add_activate(
912 &layer.att.v0,
913 &buffer.att_vv,
914 Activation::None,
915 Activation::None,
916 Activation::Sigmoid,
917 )?,
918 TensorOp::lerp(&buffer.att_v0, &buffer.att_v, &buffer.att_vv, true)?,
919 ]),
920 };
921 ops.push(hook_op(Hook::PostAttValueResidual(index))?);
922
923 ops.extend([
924 hook_op(Hook::PreAttTimeMix(index))?,
925 TensorOp::blit(&buffer.att_k, buffer.att_n.view(.., .., 0, ..)?)?,
926 TensorOp::blit(&buffer.att_v, buffer.att_n.view(.., .., 1, ..)?)?,
927 TensorOp::blit(&buffer.att_a, buffer.att_n.view(.., .., 2, ..)?)?,
928 TensorOp::blit(&buffer.att_kk, buffer.att_n.view(.., .., 3, ..)?)?,
929 TensorOp::time_mix_v7(
930 &buffer.cursors,
931 state.att(index)?,
932 &att_r,
933 &att_w,
934 &att_n,
935 &att_x,
936 )?,
937 TensorOp::group_norm(&layer.att.gn.w, &layer.att.gn.b, &att_x, Model::GN_EPS)?,
938 TensorOp::time_first_v7(&layer.att.r_k, &att_r, &att_n, &att_x)?,
939 hook_op(Hook::PostAttTimeMix(index))?,
940 hook_op(Hook::PreAttGate(index))?,
941 TensorOp::mul(&buffer.att_g, &buffer.att_x)?,
942 hook_op(Hook::PostAttGate(index))?,
943 hook_op(Hook::PreAttOut(index))?,
944 layer.att.w_o.matmul_op(
945 &buffer.att_x,
946 &buffer.att_o,
947 Activation::None,
948 turbo(num_token),
949 )?,
950 hook_op(Hook::PostAttOut(index))?,
951 TensorOp::add(&buffer.att_o, &buffer.x)?,
952 hook_op(Hook::PostAtt(index))?,
953 ]);
954
955 ops.extend([
956 TensorOp::blit(&buffer.x, &buffer.ffn_x)?,
957 hook_op(Hook::PreFfn(index))?,
958 TensorOp::layer_norm(
959 &layer.ffn_ln.w,
960 &layer.ffn_ln.b,
961 &buffer.ffn_x,
962 Model::LN_EPS,
963 )?,
964 hook_op(Hook::PostFfnLayerNorm(index))?,
965 hook_op(Hook::PreFfnTokenShift(index))?,
966 TensorOp::token_shift(
967 &buffer.cursors,
968 &layer.ffn.x_k,
969 state.ffn(index)?,
970 &buffer.ffn_x,
971 &buffer.ffn_kx,
972 true,
973 )?,
974 hook_op(Hook::PostFfnTokenShift(index))?,
975 hook_op(Hook::PreFfnLinear(index))?,
976 layer.ffn.w_k.matmul_op(
977 &buffer.ffn_kx,
978 &buffer.ffn_k,
979 Activation::SquaredRelu,
980 turbo(num_token),
981 )?,
982 hook_op(Hook::PostFfnActivate(index))?,
983 layer.ffn.w_v.matmul_op_sparse(
984 &buffer.ffn_k,
985 &buffer.ffn_v,
986 Activation::None,
987 turbo(num_token),
988 )?,
989 hook_op(Hook::PostFfnLinear(index))?,
990 hook_op(Hook::PreFfnChannelMix(index))?,
991 TensorOp::channel_mix_v7(
992 &buffer.cursors,
993 state.ffn(index)?,
994 &buffer.ffn_v,
995 &buffer.ffn_x,
996 )?,
997 hook_op(Hook::PostFfnChannelMix(index))?,
998 TensorOp::add(&buffer.ffn_x, &buffer.x)?,
999 hook_op(Hook::PostFfn(index))?,
1000 ]);
1001
1002 if (index + 1).is_multiple_of(rescale) {
1003 ops.push(TensorOp::affine(&buffer.x, 0.5, 0.0)?);
1004 }
1005
1006 Ok(TensorOp::List(ops))
1007}
1008
1009fn dispatch_header<F: Float>(
1010 hooks: Arc<HookMap<F>>,
1011 frame: Frame<F>,
1012 head: Head,
1013 head_x: TensorGpu<F, ReadWrite>,
1014 num_header: usize,
1015 head_op: TensorOp,
1016) -> Result<TensorOp, TensorError> {
1017 let hook_op = |hook: Hook| hook_op(&hooks, &hook, &frame);
1018 let header = &frame.header;
1019 let mut ops = vec![head_op];
1020
1021 if num_header > 0 {
1022 ops.extend([
1023 hook_op(Hook::PreHead)?,
1024 TensorOp::layer_norm(&head.ln.w, &head.ln.b, &head_x, Model::LN_EPS)?,
1025 hook_op(Hook::PostHeadLayerNorm)?,
1026 head.w.matmul_op(
1027 head_x.view(.., .., .., ..)?,
1028 header.head_o.view(.., .., .., ..)?,
1029 Activation::None,
1030 turbo(num_header),
1031 )?,
1032 hook_op(Hook::PostHead)?,
1033 ]);
1034 }
1035 Ok(TensorOp::List(ops))
1036}
1037
1038impl<R: Reader> ModelBuilder<R> {
1039 pub async fn build_v7(self) -> Result<Model, LoaderError> {
1040 let ModelBuilder {
1041 context,
1042 model,
1043 rescale,
1044 sep,
1045 lora,
1046 quant,
1047 ..
1048 } = self;
1049
1050 let rescale = rescale.unwrap_or(Model::DEFAULT_RESCALE);
1051 let sep = sep.unwrap_or(Model::DEFAULT_SEP);
1052
1053 let info = Loader::info(&model)?;
1054 let loader = Loader {
1055 context: context.clone(),
1056 model,
1057 lora,
1058 };
1059
1060 let embed = Embed {
1061 ln: LayerNorm {
1062 w: loader.load_vector_f16("blocks.0.ln0.weight")?,
1063 b: loader.load_vector_f16("blocks.0.ln0.bias")?,
1064 },
1065 w: loader.load_matrix_f16_padded_cpu("emb.weight")?,
1066 };
1067
1068 let head = Head {
1069 ln: LayerNorm {
1070 w: loader.load_vector_f16("ln_out.weight")?,
1071 b: loader.load_vector_f16("ln_out.bias")?,
1072 },
1073 w: Matrix::Fp16(loader.load_matrix_f16_padded("head.weight")?),
1074 };
1075
1076 let submission_index = Some(context.queue.submit(None));
1077 _ = context.device.poll(wgpu::PollType::Wait {
1078 submission_index,
1079 timeout: None,
1080 });
1081
1082 let load_matrix = |name: String, quant: Quant| loader.load_matrix(name, quant);
1083 let load_matrix_discount = |name: String, quant: Quant, discount: f32| {
1084 loader.load_matrix_discount(name, quant, discount)
1085 };
1086
1087 let mut layers = vec![];
1088 for layer in 0..info.num_layer {
1089 let quant = quant.get(&layer).copied().unwrap_or_default();
1090 let discount = 2.0_f32.powi(-((layer / rescale) as i32));
1091
1092 let att_ln = LayerNorm {
1093 w: loader.load_vector_f16(format!("blocks.{layer}.ln1.weight"))?,
1094 b: loader.load_vector_f16(format!("blocks.{layer}.ln1.bias"))?,
1095 };
1096
1097 let att = format!("blocks.{layer}.att");
1098 let x_r = loader.load_vector_f16(format!("{att}.x_r"))?;
1099 let x_w = loader.load_vector_f16(format!("{att}.x_w"))?;
1100 let x_k = loader.load_vector_f16(format!("{att}.x_k"))?;
1101 let x_v = loader.load_vector_f16(format!("{att}.x_v"))?;
1102 let x_a = loader.load_vector_f16(format!("{att}.x_a"))?;
1103 let x_g = loader.load_vector_f16(format!("{att}.x_g"))?;
1104
1105 let w0 = loader.load_vector_f16(format!("{att}.w0"))?;
1106 let a0 = loader.load_vector_f16(format!("{att}.a0"))?;
1107
1108 let w1 = Matrix::Fp16(loader.load_matrix_f16(format!("{att}.w1"))?);
1109 let w2 = Matrix::Fp16(loader.load_matrix_f16(format!("{att}.w2"))?);
1110 let a1 = Matrix::Fp16(loader.load_matrix_f16(format!("{att}.a1"))?);
1111 let a2 = Matrix::Fp16(loader.load_matrix_f16(format!("{att}.a2"))?);
1112 let g1 = Matrix::Fp16(loader.load_matrix_f16(format!("{att}.g1"))?);
1113 let g2 = Matrix::Fp16(loader.load_matrix_f16(format!("{att}.g2"))?);
1114
1115 let (v0, v1, v2) = match layer {
1116 0 => (a0.clone(), a1.clone(), a2.clone()), _ => (
1118 loader.load_vector_f16(format!("{att}.v0"))?,
1119 Matrix::Fp16(loader.load_matrix_f16(format!("{att}.v1"))?),
1120 Matrix::Fp16(loader.load_matrix_f16(format!("{att}.v2"))?),
1121 ),
1122 };
1123
1124 let r_k = loader.load_matrix_f16(format!("{att}.r_k"))?;
1125 let k_k = loader.load_vector_f16(format!("{att}.k_k"))?;
1126 let k_a = loader.load_vector_f16(format!("{att}.k_a"))?;
1127
1128 let gn = LayerNorm {
1129 w: loader
1130 .load_vector_f16(format!("{att}.ln_x.weight"))?
1131 .reshape(
1132 TensorDimension::Auto,
1133 TensorDimension::Size(info.num_head),
1134 TensorDimension::Size(1),
1135 TensorDimension::Size(1),
1136 )?,
1137 b: loader
1138 .load_vector_f16(format!("{att}.ln_x.bias"))?
1139 .reshape(
1140 TensorDimension::Auto,
1141 TensorDimension::Size(info.num_head),
1142 TensorDimension::Size(1),
1143 TensorDimension::Size(1),
1144 )?,
1145 };
1146
1147 let att = Att {
1148 x_r,
1149 x_w,
1150 x_k,
1151 x_v,
1152 x_a,
1153 x_g,
1154 w0,
1155 a0,
1156 v0,
1157 w1,
1158 w2,
1159 a1,
1160 a2,
1161 g1,
1162 g2,
1163 v1,
1164 v2,
1165 r_k,
1166 k_k,
1167 k_a,
1168 w_k: load_matrix(format!("{att}.key.weight"), quant)?,
1169 w_v: load_matrix(format!("{att}.value.weight"), quant)?,
1170 w_r: load_matrix(format!("{att}.receptance.weight"), quant)?,
1171 w_o: load_matrix_discount(format!("{att}.output.weight"), quant, discount)?,
1172 gn,
1173 };
1174
1175 let ffn_ln = LayerNorm {
1176 w: loader.load_vector_f16(format!("blocks.{layer}.ln2.weight"))?,
1177 b: loader.load_vector_f16(format!("blocks.{layer}.ln2.bias"))?,
1178 };
1179
1180 let ffn = format!("blocks.{layer}.ffn");
1181 let x_k = loader.load_vector_f16(format!("{ffn}.x_k"))?;
1182
1183 let ffn = Ffn {
1184 x_k,
1185 w_k: load_matrix(format!("{ffn}.key.weight"), quant)?,
1186 w_v: load_matrix_discount(format!("{ffn}.value.weight"), quant, discount)?,
1187 };
1188
1189 let submission_index = Some(context.queue.submit(None));
1190 _ = context.device.poll(wgpu::PollType::Wait {
1191 submission_index,
1192 timeout: None,
1193 });
1194
1195 layers.push(Layer {
1196 att_ln,
1197 ffn_ln,
1198 att,
1199 ffn,
1200 })
1201 }
1202
1203 let submission_index = Some(context.queue.submit(None));
1204 _ = context.device.poll(wgpu::PollType::Wait {
1205 submission_index,
1206 timeout: None,
1207 });
1208
1209 let tensor = ModelTensor {
1210 embed,
1211 head,
1212 layers,
1213 };
1214 let model = {
1215 let context = context.clone();
1216 let info = info.clone();
1217 Model {
1218 context,
1219 info,
1220 rescale,
1221 sep,
1222 tensor,
1223 }
1224 };
1225 Ok(model)
1226 }
1227}
1228
1229pub async fn read_state<R: Reader>(
1231 context: &Context,
1232 info: &ModelInfo,
1233 model: R,
1234) -> Result<TensorCpu<f32>, LoaderError> {
1235 let loader = Loader {
1236 context: context.clone(),
1237 model,
1238 lora: vec![],
1239 };
1240
1241 let head_size = info.num_emb / info.num_head;
1242 let data: TensorGpu<f32, _> = context.zeros([info.num_emb, head_size + 2, info.num_layer, 1]);
1243
1244 let mut ops = vec![];
1245 for layer in 0..info.num_layer {
1246 let matrix = loader.load_matrix_f16(format!("blocks.{layer}.att.time_state"))?;
1247 let state: TensorGpu<_, _> = context.tensor_init([head_size, info.num_head, head_size, 1]);
1248 let reshaped: TensorGpu<f16, _> = state.reshape(
1249 TensorDimension::Size(info.num_emb),
1250 TensorDimension::Size(head_size),
1251 TensorDimension::Size(1),
1252 TensorDimension::Auto,
1253 )?;
1254 ops.extend([
1255 TensorOp::transpose(&matrix, &state)?,
1256 TensorOp::blit(&reshaped, data.view(.., 1..head_size + 1, layer, ..)?)?,
1257 ]);
1258 }
1259 context.queue.submit(context.encode(&TensorOp::List(ops)));
1260
1261 Ok(data.back().await)
1262}