1use std::{
2 cmp::Ordering,
3 collections::{HashMap, VecDeque},
4 error::Error,
5 ops::Deref,
6 sync::{Arc, Weak},
7 time::Duration,
8};
9
10use anyhow::{bail, Result};
11use derivative::Derivative;
12use flume::{Receiver, Sender, TryRecvError};
13use itertools::Itertools;
14use memmap2::Mmap;
15use qp_trie::Trie;
16use safetensors::SafeTensors;
17use tokio::{
18 sync::{Mutex, RwLock},
19 task::JoinHandle,
20 time::Instant,
21};
22use web_rwkv::{
23 context::Context,
24 runtime::{
25 infer::{Rnn, RnnInput, RnnInputBatch, RnnOption, RnnOutputBatch},
26 model::{ModelInfo, State},
27 Runtime,
28 },
29 tensor::{kind::ReadWrite, TensorCpu, TensorGpu, TensorShape},
30 tokenizer::Tokenizer,
31};
32
33use crate::{
34 load_model_state,
35 sampler::{bnf::BnfSampler, Formatter, Sampler},
36 FinishReason, GenerateKind, GenerateRequest, InitState, InputState, ReloadRequest, RuntimeInfo,
37 StateId, Token, TokenCounter,
38};
39
40const MIN_PROMPT_CACHE_TOKENS: usize = 32;
41const MAX_CACHE_ITEMS: usize = 256;
42
43#[repr(transparent)]
44#[derive(Debug, Default, Clone)]
45pub struct Tokens(pub Vec<u32>);
46
47impl std::ops::Deref for Tokens {
48 type Target = TokenSlice;
49
50 fn deref(&self) -> &Self::Target {
51 self.0.as_token_slice()
52 }
53}
54
55impl std::borrow::Borrow<[u8]> for Tokens {
56 fn borrow(&self) -> &[u8] {
57 bytemuck::cast_slice(&self.0)
58 }
59}
60
61impl std::borrow::Borrow<[u32]> for Tokens {
62 fn borrow(&self) -> &[u32] {
63 &self.0
64 }
65}
66
67impl std::borrow::Borrow<TokenSlice> for Tokens {
68 fn borrow(&self) -> &TokenSlice {
69 self.0[..].as_token_slice()
70 }
71}
72
73impl qp_trie::Break for Tokens {
74 type Split = TokenSlice;
75
76 fn empty<'a>() -> &'a Self::Split {
77 Default::default()
78 }
79
80 fn find_break(&self, loc: usize) -> &Self::Split {
81 self.0[..loc >> 2].as_token_slice()
82 }
83}
84
85#[repr(transparent)]
86pub struct TokenSlice([u32]);
87
88impl std::ops::Deref for TokenSlice {
89 type Target = [u32];
90
91 fn deref(&self) -> &Self::Target {
92 &self.0
93 }
94}
95
96impl std::borrow::Borrow<[u8]> for TokenSlice {
97 fn borrow(&self) -> &[u8] {
98 bytemuck::cast_slice(&self.0)
99 }
100}
101
102impl Default for &TokenSlice {
103 fn default() -> Self {
104 <&[u32]>::default().as_token_slice()
105 }
106}
107
108pub trait AsTokenSlice {
109 fn as_token_slice(&self) -> &TokenSlice;
110}
111
112impl AsTokenSlice for [u32] {
113 fn as_token_slice(&self) -> &TokenSlice {
114 let ptr = self as *const [u32] as *const TokenSlice;
115 unsafe { &*ptr }
116 }
117}
118
119#[derive(Derivative, Clone)]
120#[derivative(Debug)]
121pub struct GenerateContext {
122 pub prompt_tokens: Vec<u32>,
124 pub prompt_cached: CachedPrompt,
126 pub prefix: Tokens,
128 pub suffix: Tokens,
130 pub output: Option<TensorCpu<f32>>,
132 pub choices: Vec<Tokens>,
134 pub model_text: Vec<u8>,
136 pub buffer: Vec<u8>,
138 pub model_tokens: Vec<u32>,
140 #[derivative(Debug = "ignore")]
142 pub formatters: Vec<Arc<RwLock<dyn Formatter + Send + Sync>>>,
143 pub instant: Option<Instant>,
145 pub request: GenerateRequest,
147 pub sender: Sender<Token>,
149}
150
151impl GenerateContext {
152 pub async fn new(
153 request: GenerateRequest,
154 sender: Sender<Token>,
155 tokenizer: &Tokenizer,
156 ) -> Result<Self> {
157 let tokens = Tokens(tokenizer.encode(request.prompt.as_bytes())?);
158 let model_tokens = Tokens(tokenizer.encode(request.model_text.as_bytes())?);
159
160 request.sampler.write().await.init(&model_tokens);
162
163 let choices = match &request.kind {
164 GenerateKind::Choose { choices, .. } => {
165 let choices: Vec<_> = choices
166 .iter()
167 .map(|prompt| tokenizer.encode(prompt.as_bytes()))
168 .try_collect()?;
169 choices.into_iter().map(Tokens).collect()
170 }
171 _ => Vec::new(),
172 };
173 Ok(Self {
174 prompt_tokens: tokens.to_vec(),
175 prompt_cached: Default::default(),
176 prefix: Default::default(),
177 suffix: tokens,
178 output: None,
179 choices,
180 model_text: Vec::new(),
181 buffer: Vec::new(),
182 model_tokens: Vec::new(),
183 formatters: Vec::new(),
184 instant: None,
185 request,
186 sender,
187 })
188 }
189}
190
191#[derive(Debug, Default, Clone)]
192pub enum CachedPrompt {
193 #[default]
194 None,
195 Future(tokio::sync::watch::Sender<Option<CachedItem>>),
196 Done,
197}
198
199#[derive(Debug, Clone)]
201pub struct CachedItem {
202 state: TensorCpu<f32>,
203 output: TensorCpu<f32>,
204 instant: Instant,
205}
206
207impl CachedItem {
208 pub fn new(state: TensorCpu<f32>, output: TensorCpu<f32>) -> Self {
209 Self {
210 state,
211 output,
212 instant: Instant::now(),
213 }
214 }
215
216 pub fn update(cached: CachedItem) -> Self {
218 Self {
219 instant: Instant::now(),
220 ..cached
221 }
222 }
223}
224
225struct CacheCheckout {
226 prefix: Vec<u32>,
227 state: TensorCpu<f32>,
228 output: Option<TensorCpu<f32>>,
229}
230
231#[derive(Debug, Default)]
232struct Cache {
233 state: Option<InitState>,
234 cache: Trie<Tokens, tokio::sync::watch::Sender<Option<CachedItem>>>,
235}
236
237impl Cache {
238 fn maintain(&mut self) {
239 let cache = &mut self.cache;
240 if cache.count() <= MAX_CACHE_ITEMS {
241 return;
242 }
243
244 let mut remove = vec![];
245 for (tokens, _) in cache
246 .iter()
247 .filter_map(|(tokens, item)| item.borrow().clone().map(|item| (tokens, item)))
248 .sorted_unstable_by_key(|(_, item)| item.instant.elapsed())
249 .skip(MAX_CACHE_ITEMS)
250 {
251 remove.push(tokens.to_owned());
252 }
253
254 for tokens in remove.into_iter() {
255 cache.remove(&tokens);
256 }
257 }
258}
259
260#[derive(Debug, Default)]
261struct CacheHub {
262 backed: HashMap<StateId, Cache>,
263 default: Cache,
264}
265
266impl CacheHub {
267 fn fetch(&mut self, id: StateId) -> &mut Cache {
268 match self.backed.get_mut(&id) {
269 Some(item) => item,
270 None => &mut self.default,
271 }
272 }
273}
274
275#[derive(Debug)]
277enum SlotResult {
278 Success(usize),
280 Fault(usize),
282 Failure(Box<GenerateContext>),
284 Error(Box<dyn Error>),
286}
287
288#[derive(Debug)]
289enum SlotState {
290 Idle(Tokens, Instant),
292 Busy(JoinHandle<Result<GenerateContext>>),
294 Locked,
296}
297
298impl Default for SlotState {
299 fn default() -> Self {
300 Self::Idle(Default::default(), Instant::now())
301 }
302}
303
304#[derive(Debug, PartialEq, Eq)]
305enum SlotChoice {
306 Continue(usize, usize),
307 Back(usize),
308 Empty(usize),
309}
310
311impl std::cmp::Ord for SlotChoice {
312 fn cmp(&self, other: &Self) -> Ordering {
313 use SlotChoice::{Back, Continue, Empty};
315 match (self, other) {
316 (Continue(_, x), Continue(_, y)) => x.cmp(y),
317 (Continue(_, _), _) => Ordering::Greater,
318 (_, Continue(_, _)) => Ordering::Less,
319 (Empty(_), Empty(_)) => Ordering::Equal,
320 (Empty(_), Back(_)) => Ordering::Greater,
321 (Back(_), Empty(_)) => Ordering::Less,
322 (Back(_), Back(_)) => Ordering::Equal,
323 }
324 }
325}
326
327impl std::cmp::PartialOrd for SlotChoice {
328 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
329 Some(self.cmp(other))
330 }
331}
332
333#[derive(Debug, Clone)]
334enum InferBatch {
335 Run {
336 batch: usize,
337 tokens: Vec<u32>,
338 option: RnnOption,
339 sender: Sender<TensorCpu<f32>>,
340 },
341 Load {
342 batch: usize,
343 tensor: TensorCpu<f32>,
344 },
345 Back {
346 batch: usize,
347 sender: Sender<TensorCpu<f32>>,
348 },
349 Write {
350 batch: usize,
351 tensor: TensorGpu<f32, ReadWrite>,
352 },
353 Read {
354 batch: usize,
355 sender: Sender<TensorGpu<f32, ReadWrite>>,
356 },
357}
358
359#[derive(Debug, Clone)]
360struct SoftmaxBatch {
361 input: TensorCpu<f32>,
362 sender: Sender<TensorCpu<f32>>,
363}
364
365#[derive(Debug, Clone)]
366struct RuntimeSender {
367 infer: Sender<InferBatch>,
368 softmax: Sender<SoftmaxBatch>,
369}
370
371#[derive(Derivative, Clone)]
372#[derivative(Debug)]
373struct CoreRuntime {
374 context: Context,
375 info: ModelInfo,
376 reload: Arc<ReloadRequest>,
377 #[derivative(Debug = "ignore")]
378 state: Arc<dyn State + Send + Sync>,
379 sender: RuntimeSender,
380 tokenizer: Arc<Tokenizer>,
381 slots: Arc<Mutex<Vec<SlotState>>>,
382 caches: Arc<Mutex<CacheHub>>,
383}
384
385impl CoreRuntime {
386 async fn check_in_state(&self, state: &InputState) -> Result<StateId> {
388 match state {
389 InputState::Key(id) => Ok(*id),
390 InputState::Value(value) => {
391 let id = value.id;
392 let state = InitState::try_from(value.clone())?;
393 let mut caches = self.caches.lock().await;
394 caches.backed.insert(
395 id,
396 Cache {
397 state: Some(state),
398 cache: Trie::new(),
399 },
400 );
401 Ok(id)
402 }
403 InputState::File(file) => {
404 let name = file.name.clone();
405 let id = file.id;
406 let default = false;
407
408 let file = tokio::fs::File::open(&file.path).await?;
409 let data = unsafe { Mmap::map(&file) }?;
410
411 let st = SafeTensors::deserialize(&data);
412 let prefab = cbor4ii::serde::from_slice::<InitState>(&data);
413 let state = match (st, prefab) {
414 (Ok(model), _) => {
415 let data = load_model_state(&self.context, &self.info, model).await?;
416 InitState {
417 name,
418 id,
419 default,
420 data,
421 }
422 }
423 (_, Ok(state)) => state,
424 _ => bail!("failed to load init state"),
425 };
426
427 let mut caches = self.caches.lock().await;
428 caches.backed.insert(
429 id,
430 Cache {
431 state: Some(state),
432 cache: Trie::new(),
433 },
434 );
435
436 Ok(id)
437 }
438 }
439 }
440
441 async fn checkout(&self, id: StateId, tokens: &[u32]) -> CacheCheckout {
444 let mut caches = self.caches.lock().await;
445
446 let Cache { state, cache } = caches.fetch(id);
447 let prefix = cache.longest_common_prefix(tokens.as_token_slice());
448 let len = (1..=prefix.len())
449 .rev()
450 .find(|len| cache.contains_key(prefix[0..*len].as_token_slice()))
451 .unwrap_or_default();
452 let prefix = prefix[0..len].to_vec();
453
454 let state = state.clone().map(|state| state.data);
455 let item = cache.get(prefix[..].as_token_slice()).cloned();
456 drop(caches);
457
458 match item {
459 Some(sender) => {
460 let mut receiver = sender.subscribe();
461 let item = loop {
462 if let Some(item) = receiver.borrow_and_update().deref().clone() {
463 break item;
464 }
465 let _ = receiver.changed().await;
466 };
467 let item = CachedItem::update(item);
468 sender.send_replace(Some(item.clone()));
469 CacheCheckout {
470 prefix,
471 state: item.state,
472 output: Some(item.output),
473 }
474 }
475 None => {
476 let prefix = vec![];
477 let state = state.unwrap_or_else(|| self.state.init());
478 CacheCheckout {
479 prefix,
480 state,
481 output: None,
482 }
483 }
484 }
485 }
486
487 async fn queue(&self, context: GenerateContext) -> SlotResult {
489 let tokens = match [context.prefix, context.suffix].concat() {
490 tokens if tokens.is_empty() => vec![0u32],
491 tokens => tokens,
492 };
493
494 let mut formatters = Vec::<Arc<RwLock<dyn Formatter + Send + Sync>>>::new();
496 if let Some(schema) = context.request.bnf_schema.clone() {
497 match BnfSampler::new(&self.tokenizer, &schema) {
498 Ok(bnf) => formatters.push(Arc::new(RwLock::new(bnf))),
499 Err(err) => return SlotResult::Error(err.into()),
500 }
501 }
502
503 let choice = {
508 let mut slots = self.slots.lock().await;
509 let choice = slots
510 .iter()
511 .enumerate()
512 .filter_map(|(batch, slot)| match slot {
513 SlotState::Idle(content, instant) => {
514 let delta = instant.elapsed();
515 match (content.is_empty(), tokens.starts_with(content)) {
516 (true, _) => Some((SlotChoice::Empty(batch), delta)),
517 (_, true) => Some((SlotChoice::Continue(batch, content.len()), delta)),
518 (_, false) => Some((SlotChoice::Back(batch), delta)),
519 }
520 }
521 _ => None,
522 })
523 .max_by(|lhs, rhs| lhs.0.cmp(&rhs.0).then(lhs.1.cmp(&rhs.1)))
524 .map(|(x, _)| x);
525 match choice {
526 None => (),
527 Some(SlotChoice::Empty(batch))
528 | Some(SlotChoice::Back(batch))
529 | Some(SlotChoice::Continue(batch, _)) => slots[batch] = SlotState::Locked,
530 }
531 choice
532 };
533
534 let check_in_state = |state: Arc<InputState>, batch: usize| async move {
535 match self.check_in_state(&state).await {
536 Ok(state) => state,
537 Err(err) => {
538 log::error!("[queue][state][slot: {batch}] error: {err:#?}");
539 Default::default()
540 }
541 }
542 };
543
544 match choice {
545 None => SlotResult::Failure(
548 GenerateContext {
549 prefix: Default::default(),
550 suffix: Tokens(tokens),
551 formatters,
552 ..context
553 }
554 .into(),
555 ),
556 Some(SlotChoice::Back(batch)) => {
558 log::info!("[queue][back][slot: {batch}]");
559 let state = check_in_state(context.request.state.clone(), batch).await;
560 let checkout = self.checkout(state, &tokens).await;
561 self.load(batch, checkout.state).await;
562
563 let len = checkout.prefix.len();
564 assert!(len == 0 || (len > 0 && checkout.output.is_some()));
565 log::info!("[cache][checkout[[slot: {batch}][len: {len}]");
566
567 let context = GenerateContext {
568 prefix: Tokens(tokens[..len].to_vec()),
569 suffix: Tokens(tokens[len..].to_vec()),
570 output: checkout.output,
571 formatters,
572 ..context
573 };
574 let handle = tokio::spawn(self.clone().process(batch, context));
575 let mut slots = self.slots.lock().await;
576 slots[batch] = SlotState::Busy(handle);
577 SlotResult::Fault(batch)
578 }
579 Some(SlotChoice::Empty(batch)) => {
581 log::info!("[queue][empty][slot: {batch}]");
582 let state = check_in_state(context.request.state.clone(), batch).await;
583 let checkout = self.checkout(state, &tokens).await;
584 self.load(batch, checkout.state).await;
585
586 let len = checkout.prefix.len();
587 assert!(len == 0 || (len > 0 && checkout.output.is_some()));
588 log::info!("[cache][checkout][slot: {batch}][len: {len}]");
589
590 let context = GenerateContext {
591 prefix: Tokens(tokens[..len].to_vec()),
592 suffix: Tokens(tokens[len..].to_vec()),
593 output: checkout.output,
594 formatters,
595 ..context
596 };
597 let handle = tokio::spawn(self.clone().process(batch, context));
598 let mut slots = self.slots.lock().await;
599 slots[batch] = SlotState::Busy(handle);
600 SlotResult::Success(batch)
601 }
602 Some(SlotChoice::Continue(batch, ..)) => {
603 log::info!("[queue][continue][slot: {batch}]");
604 let state = check_in_state(context.request.state.clone(), batch).await;
605 let checkout = self.checkout(state, &tokens).await;
606 self.load(batch, checkout.state).await;
607
608 let len = checkout.prefix.len();
609 assert!(len == 0 || (len > 0 && checkout.output.is_some()));
610 log::info!("[cache][checkout[[slot: {batch}][len: {len}]");
611
612 let context = GenerateContext {
613 prefix: Tokens(tokens[..len].to_vec()),
614 suffix: Tokens(tokens[len..].to_vec()),
615 output: checkout.output,
616 formatters,
617 ..context
618 };
619 let handle = tokio::spawn(self.clone().process(batch, context));
620 let mut slots = self.slots.lock().await;
621 slots[batch] = SlotState::Busy(handle);
622 SlotResult::Success(batch)
623 }
624 }
625 }
626
627 async fn update(&self) {
629 let update = |handle: JoinHandle<_>| async move {
630 if !handle.is_finished() {
631 return Ok(SlotState::Busy(handle));
632 }
633
634 let context = handle.await??;
635 Ok::<_, Box<dyn Error + Send + Sync>>(SlotState::Idle(context.prefix, Instant::now()))
636 };
637
638 for batch in 0..self.reload.max_batch {
639 let handle = {
640 let mut slots = self.slots.lock().await;
641 let slot = std::mem::replace(&mut slots[batch], SlotState::Locked);
642 let SlotState::Busy(handle) = slot else {
643 slots[batch] = slot;
644 continue;
645 };
646 handle
647 };
648
649 let updated = match update(handle).await {
650 Ok(updated) => updated,
651 Err(err) => {
652 log::error!("[update][error][slot: {batch}] {err:#?}");
653 let mut slots = self.slots.lock().await;
654 slots[batch] = Default::default();
655 continue;
656 }
657 };
658
659 let mut slots = self.slots.lock().await;
660 slots[batch] = updated;
661 }
662 }
663
664 async fn sample(
665 &self,
666 output: TensorCpu<f32>,
667 sampler: Arc<RwLock<dyn Sampler + Send + Sync>>,
668 formatters: Vec<Arc<RwLock<dyn Formatter + Send + Sync>>>,
669 bias: Arc<HashMap<u32, f32>>,
670 ) -> Result<(u32, TensorCpu<f32>)> {
671 let num_vocab = self.info.num_vocab;
673 let input = {
674 let mut data = output.to_vec();
675 assert_eq!(data.len(), num_vocab);
676
677 sampler.read().await.transform(&mut data);
678 for formatter in formatters {
679 formatter.read().await.transform(&mut data);
680 }
681 for (token, bias) in bias.iter() {
682 data[*token as usize] += *bias;
683 }
684
685 self.context.tensor_from_data([num_vocab, 1, 1, 1], data)?
686 };
687
688 let (sender, receiver) = flume::bounded(1);
690 let _ = self.sender.softmax.send(SoftmaxBatch { input, sender });
691 let output = receiver.recv_async().await?;
692
693 assert_eq!(output.len(), num_vocab);
695 let token = sampler.write().await.sample(&output);
696 Ok((token, output))
697 }
698
699 async fn perplexity(&self, batch: usize, tokens: &[u32], head: Option<f32>) -> Result<f32> {
700 let mut p = Vec::with_capacity(tokens.len().max(1));
701 let len = tokens.len();
702 let tokens = match head {
703 Some(head) => {
704 p.push(head);
705 tokens.to_vec()
706 }
707 None => [&[0], tokens].concat(),
708 };
709
710 let (sender, receiver) = flume::unbounded();
711 let _ = self
712 .sender
713 .infer
714 .send_async({
715 let tokens = tokens.clone();
716 let option = RnnOption::Full;
717 InferBatch::Run {
718 batch,
719 tokens,
720 option,
721 sender,
722 }
723 })
724 .await;
725
726 let index = Arc::new(Mutex::new(1));
727 while p.len() < len {
728 let tokens = tokens.clone();
729 let output = receiver.recv_async().await?;
730 let output = output.split(1)?;
731 let f = {
732 let index = index.clone();
733 move || {
734 let mut index = index.blocking_lock();
735 let mut p = Vec::with_capacity(output.len());
736 for data in output {
737 if *index < tokens.len() {
738 let data = data.map(|x| x.exp()).to_vec();
739 let sum: f32 = data.iter().sum();
740 let token = tokens[*index] as usize;
741 p.push(data[token] / sum);
742 }
743 *index += 1;
744 }
745 p
746 }
747 };
748 let mut q = tokio::task::spawn_blocking(f).await?;
749 p.append(&mut q);
750 }
751
752 let ppl: f32 = p.into_iter().map(|x| x.ln()).sum();
753 let ppl = -ppl / tokens.len() as f32;
754 Ok(ppl)
755 }
756
757 async fn load(&self, batch: usize, tensor: TensorCpu<f32>) {
758 let _ = self
759 .sender
760 .infer
761 .send_async(InferBatch::Load { batch, tensor })
762 .await;
763 }
764
765 async fn back(&self, batch: usize) -> Result<TensorCpu<f32>> {
766 let (sender, receiver) = flume::bounded(1);
767 let _ = self.sender.infer.send(InferBatch::Back { batch, sender });
768 let tensor = receiver.recv_async().await?;
769 Ok(tensor)
770 }
771
772 async fn write(&self, batch: usize, tensor: TensorGpu<f32, ReadWrite>) {
773 let _ = self
774 .sender
775 .infer
776 .send_async(InferBatch::Write { batch, tensor })
777 .await;
778 }
779
780 async fn read(&self, batch: usize) -> Result<TensorGpu<f32, ReadWrite>> {
781 let (sender, receiver) = flume::bounded(1);
782 let _ = self.sender.infer.send(InferBatch::Read { batch, sender });
783 let tensor = receiver.recv_async().await?;
784 Ok(tensor)
785 }
786
787 async fn process(self, batch: usize, mut context: GenerateContext) -> Result<GenerateContext> {
789 {
791 let mut caches = self.caches.lock().await;
792 let cache = &mut caches.fetch(context.request.state.id()).cache;
793
794 let enable = context.prompt_tokens.len() > MIN_PROMPT_CACHE_TOKENS;
795 let enable = enable && !cache.contains_key(context.prompt_tokens.as_token_slice());
796 if enable {
797 let (sender, _) = tokio::sync::watch::channel(None);
798 context.prompt_cached = CachedPrompt::Future(sender.clone());
799 cache.insert(Tokens(context.prompt_tokens.clone()), sender);
800
801 let len = context.prompt_tokens.len();
802 log::info!("[cache][future][slot: {batch}][len: {len}]");
803 }
804 }
805
806 let _ = context.sender.send(Token::Start);
807
808 loop {
809 let output = match (context.suffix.len(), context.output.clone()) {
810 (0, Some(output)) => output,
811 _ => {
812 let (sender, receiver) = flume::bounded(1);
813 let _ = self
814 .sender
815 .infer
816 .send_async(InferBatch::Run {
817 batch,
818 tokens: context.suffix.to_vec(),
819 option: RnnOption::Last,
820 sender,
821 })
822 .await;
823
824 let prefix = std::mem::take(&mut context.prefix);
825 let suffix = std::mem::take(&mut context.suffix);
826
827 context.prefix = Tokens([prefix.0, suffix.0].concat());
828 context.suffix = Tokens(vec![]);
829
830 receiver.recv_async().await?
831 }
832 };
833
834 if let CachedPrompt::Future(sender) = context.prompt_cached.clone() {
836 assert_eq!(context.prefix.len(), context.prompt_tokens.len());
837
838 let backed = self.back(batch).await?;
839 let output = output.clone();
840 sender.send_replace(Some(CachedItem::new(backed, output)));
841 context.prompt_cached = CachedPrompt::Done;
842
843 let len = context.prefix.len();
844 log::info!("[cache][insert][slot: {batch}][len: {len}]");
845 }
846
847 let (token, output) = {
848 let output = output.clone();
849 let sampler = context.request.sampler.clone();
850 let formatters = context.formatters.clone();
851 let bias = context.request.bias.clone();
852 self.sample(output, sampler, formatters, bias).await?
853 };
854
855 let mut stop_token = token == 0;
856 let mut word = match self.tokenizer.decode(&[token]) {
857 Ok(word) => word,
858 Err(err) => {
859 log::warn!("[process][error] {err:#?}");
860 stop_token = true;
861 Vec::new()
862 }
863 };
864
865 context.output = Some(output.clone());
866 context.suffix.0.push(token);
867 context.model_tokens.push(token);
868 context.model_text.extend(&word);
869 context.buffer.append(&mut word);
870
871 let instant = context.instant.get_or_insert(Instant::now());
872 let mut done = false;
873 let mut stop = |reason| {
874 let counter = {
875 let prompt = context.prompt_tokens.len();
876 let completion = context.model_tokens.len();
877 let total = prompt + completion;
878 let duration = instant.elapsed();
879 TokenCounter {
880 prompt,
881 completion,
882 total,
883 duration,
884 }
885 };
886
887 let _ = context.sender.send(Token::Stop(reason, counter));
888 let _ = context.sender.send(Token::Done);
889 done = true;
890 };
891
892 let mut halt = false;
894 for formatter in context.formatters.iter() {
895 let mut formatter = formatter.write().await;
896 halt |= formatter.update(token);
897 }
898
899 let ((head, tail), stop_matched) = context
901 .request
902 .stop
903 .iter()
904 .map(|stop| {
905 let stop = stop.as_bytes();
906 let mut index_safe = 0;
907 let mut index_unsafe = 0;
908 while index_unsafe < context.buffer.len() {
909 let index_stop = index_unsafe - index_safe;
911 if index_stop >= stop.len() {
912 return (index_safe, true);
914 }
915
916 let output = context.buffer[index_unsafe];
917 let stop = stop[index_stop];
918
919 index_unsafe += 1;
920 if output != stop {
921 index_safe = index_unsafe;
922 }
923 }
924 (index_safe, index_unsafe - index_safe >= stop.len())
925 })
926 .min_by(|x, y| match (x.1, y.1) {
927 (true, false) => Ordering::Less,
928 (false, true) => Ordering::Greater,
929 _ => x.0.cmp(&y.0),
930 })
931 .map(|(mid, matched)| (context.buffer.split_at(mid), matched))
932 .unwrap_or(((&context.buffer[..], &[]), false));
933
934 if context.sender.is_disconnected() {
935 done = true;
936 } else if let GenerateKind::Choose { calibrate, .. } = context.request.kind {
937 let backed = self.read(batch).await?;
938 let mut ppl = vec![f32::INFINITY; context.choices.len()];
939
940 if calibrate {
941 let init = {
943 let id = context.request.state.id();
944 let mut caches = self.caches.lock().await;
945 caches
946 .fetch(id)
947 .state
948 .clone()
949 .map(|state| state.data)
950 .unwrap_or_else(|| self.state.init())
951 };
952 for (index, choice) in context
953 .choices
954 .iter()
955 .enumerate()
956 .filter(|(_, choice)| !choice.is_empty())
957 {
958 self.load(batch, init.clone()).await;
959 ppl[index] = -self.perplexity(batch, choice, None).await?;
960 }
961 self.write(batch, backed.clone()).await;
963 }
964
965 for (index, choice) in context
966 .choices
967 .iter()
968 .enumerate()
969 .filter(|(_, choice)| !choice.is_empty())
970 {
971 let output = output.clone().to_vec();
972 let head = Some(output[choice[0] as usize]);
973 let p = self.perplexity(batch, choice, head).await?;
974 ppl[index] = match calibrate {
975 true => ppl[index] + p,
976 false => p,
977 };
978 self.write(batch, backed.clone()).await;
980 }
981
982 let _ = context.sender.send(Token::Choose(ppl));
983 done = true;
984 } else if let GenerateKind::State = context.request.kind {
985 let backed = self.back(batch).await?;
986 let embed = backed.to_vec();
987 let shape = backed.shape().into();
988 let _ = context.sender.send(Token::Embed(embed, shape));
989 done = true;
990 } else if halt || stop_matched || stop_token {
991 let output = String::from_utf8_lossy(head);
992 let _ = context.sender.send(Token::Content(output.into()));
993 stop(FinishReason::Stop);
994
995 if let Some(output) = context.output.clone() {
996 let backed = self.back(batch).await?;
997 let mut caches = self.caches.lock().await;
998 let cache = &mut caches.fetch(context.request.state.id()).cache;
999 let item = CachedItem::new(backed, output);
1000 let (item, _) = tokio::sync::watch::channel(Some(item));
1001 cache.insert(context.prefix.clone(), item);
1002
1003 let len = context.prefix.len();
1004 log::info!("[cache][insert][slot: {batch}][len: {len}]");
1005 }
1006 } else if context.model_tokens.len() >= context.request.max_tokens {
1007 stop(FinishReason::Length);
1008 } else if let Ok(word) = String::from_utf8(head.to_vec()) {
1009 let _ = context.sender.send(Token::Content(word));
1010 context.buffer = tail.to_vec();
1011 }
1012
1013 if done {
1014 log::info!("[process][done][slot: {batch}]");
1015 break;
1016 }
1017 }
1018
1019 Ok(context)
1020 }
1021
1022 async fn maintain_cache(&self) {
1024 let mut caches = self.caches.lock().await;
1025 caches.default.maintain();
1026 caches.backed.iter_mut().for_each(|(_, x)| x.maintain());
1027 }
1028}
1029
1030async fn enqueue(runtime: CoreRuntime, receiver: Receiver<GenerateContext>, timer: Duration) {
1031 let mut queue = Vec::<GenerateContext>::new();
1032
1033 'outer: while let Ok(context) = receiver.recv_async().await {
1034 queue.push(context);
1035
1036 'inner: loop {
1037 runtime.maintain_cache().await;
1038 runtime.update().await;
1039
1040 let mut temp = Vec::new();
1041 for context in queue.drain(..) {
1042 match runtime.queue(context).await {
1043 SlotResult::Failure(context) => temp.push(*context),
1044 SlotResult::Success(batch) => log::info!("[enqueue][ok][slot: {batch}]"),
1045 SlotResult::Fault(batch) => log::info!("[enqueue][fault][slot: {batch}]"),
1046 SlotResult::Error(err) => log::error!("[enqueue][error] {err:#?}"),
1047 }
1048 }
1049 std::mem::swap(&mut queue, &mut temp);
1050
1051 if queue.is_empty() {
1052 break 'inner;
1053 }
1054
1055 match receiver.try_recv() {
1056 Ok(context) => queue.push(context),
1057 Err(TryRecvError::Empty) => tokio::time::sleep(timer).await,
1058 Err(TryRecvError::Disconnected) => break 'outer,
1059 }
1060 }
1061 }
1062}
1063
1064async fn finalize(runtime: CoreRuntime, receiver: Receiver<GenerateContext>, timer: Duration) {
1065 while !receiver.is_disconnected() {
1066 runtime.maintain_cache().await;
1067 runtime.update().await;
1068 tokio::time::sleep(timer).await;
1069 }
1070}
1071
1072async fn infer(
1073 reload: Arc<ReloadRequest>,
1074 runtime: Weak<dyn Runtime<Rnn> + Send + Sync>,
1075 state: Arc<dyn State + Send + Sync>,
1076 receiver: Receiver<InferBatch>,
1077) -> Result<()> {
1078 type Batch = (Vec<u32>, RnnOption, Sender<TensorCpu<f32>>);
1079 let mut batches: HashMap<usize, VecDeque<Batch>> = HashMap::new();
1080
1081 async fn schedule(
1082 batches: &mut HashMap<usize, VecDeque<Batch>>,
1083 state: Arc<dyn State + Send + Sync>,
1084 batch: InferBatch,
1085 ) -> Result<()> {
1086 match batch {
1087 InferBatch::Run {
1088 batch,
1089 tokens,
1090 option,
1091 sender,
1092 } => match batches.get_mut(&batch) {
1093 Some(batches) => batches.push_back((tokens, option, sender)),
1094 None => {
1095 let deque = VecDeque::from_iter([(tokens, option, sender)]);
1096 batches.insert(batch, deque);
1097 }
1098 },
1099 InferBatch::Load { batch, tensor } => state.load(tensor, batch)?,
1100 InferBatch::Back { batch, sender } => {
1101 let tensor = state.back(batch).await?;
1102 let _ = sender.send_async(tensor).await;
1103 }
1104 InferBatch::Write { batch, tensor } => state.write(tensor, batch)?,
1105 InferBatch::Read { batch, sender } => {
1106 let tensor = state.read(batch)?;
1107 let _ = sender.send_async(tensor).await;
1108 }
1109 }
1110 Ok(())
1111 }
1112
1113 'outer: while let Ok(batch) = receiver.recv_async().await {
1114 schedule(&mut batches, state.clone(), batch).await?;
1115
1116 for batch in receiver.drain() {
1117 schedule(&mut batches, state.clone(), batch).await?;
1118 }
1119
1120 while batches.values().map(|x| x.len()).sum::<usize>() > 0 {
1121 let mut inference = vec![Default::default(); reload.max_batch];
1122 let mut senders = HashMap::new();
1123
1124 for (&batch, deque) in batches.iter_mut() {
1125 let Some((tokens, option, sender)) = deque.pop_front() else {
1126 continue;
1127 };
1128 inference[batch] = RnnInputBatch::new(tokens, option);
1129 senders.insert(batch, sender);
1130 }
1131
1132 let mut inference = Some(RnnInput::new(inference, reload.token_chunk_size));
1133
1134 while inference
1135 .as_ref()
1136 .map(|input| input.num_token() > 0)
1137 .expect("inference must not be `None`")
1138 {
1139 let Some(runtime) = runtime.upgrade() else {
1140 break 'outer;
1141 };
1142 let input = inference.take().expect("inference must not be `None`");
1143 let (input, output) = runtime.infer(input).await?;
1144 inference.replace(input);
1145
1146 for (batch, RnnOutputBatch(output)) in output
1147 .iter()
1148 .enumerate()
1149 .filter(|(_, output)| !output.is_empty())
1150 {
1151 let Some(sender) = senders.get(&batch) else {
1152 continue;
1153 };
1154 let _ = sender.send(output.clone());
1155 }
1156 }
1157 }
1158 }
1159
1160 log::info!("[infer] exit");
1161 Ok(())
1162}
1163
1164async fn softmax(
1165 reload: Arc<ReloadRequest>,
1166 context: Context,
1167 receiver: Receiver<SoftmaxBatch>,
1168) -> Result<()> {
1169 let mut batches = Vec::with_capacity(reload.max_batch);
1170
1171 while let Ok(batch) = receiver.recv_async().await {
1172 batches.push(batch);
1173
1174 for batch in receiver.drain() {
1175 batches.push(batch);
1176 }
1177
1178 let input = batches.iter().map(|batch| batch.input.clone()).collect();
1179 let output = web_rwkv::runtime::softmax::softmax(&context, input).await?;
1180
1181 for (batch, tensor) in batches.iter().zip_eq(output.into_iter()) {
1182 let _ = batch.sender.send(tensor);
1183 }
1184
1185 batches.clear();
1186 }
1187
1188 log::info!("[softmax] exit");
1189 Ok(())
1190}
1191
1192pub async fn run(
1193 context: Context,
1194 runtime: Weak<dyn Runtime<Rnn> + Send + Sync>,
1195 state: Arc<dyn State + Send + Sync>,
1196 receiver: Receiver<GenerateContext>,
1197 RuntimeInfo {
1198 reload,
1199 info,
1200 states,
1201 tokenizer,
1202 ..
1203 }: RuntimeInfo,
1204) {
1205 let slots = std::iter::repeat_with(Default::default)
1206 .take(reload.max_batch)
1207 .collect();
1208 let slots = Arc::new(Mutex::new(slots));
1209
1210 let caches = {
1211 let mut caches = CacheHub::default();
1212 if let Some(state) = states.iter().find(|state| state.default) {
1214 caches.default.state = Some(state.clone());
1215 }
1216 for state in states {
1218 let id = state.id;
1219 let item = Cache {
1220 state: Some(state),
1221 cache: Trie::new(),
1222 };
1223 caches.backed.insert(id, item);
1224 }
1225 Arc::new(Mutex::new(caches))
1226 };
1227
1228 let max_batch = reload.max_batch;
1229 let runtime = {
1230 let infer = {
1231 let (sender, receiver) = flume::unbounded();
1232 tokio::spawn(infer(reload.clone(), runtime, state.clone(), receiver));
1233 sender
1234 };
1235 let softmax = {
1236 let (sender, receiver) = flume::unbounded();
1237 tokio::spawn(softmax(reload.clone(), context.clone(), receiver));
1238 sender
1239 };
1240 let sender = RuntimeSender { infer, softmax };
1241 CoreRuntime {
1242 context,
1243 info,
1244 reload,
1245 state,
1246 sender,
1247 tokenizer,
1248 slots,
1249 caches,
1250 }
1251 };
1252 let timer = Duration::from_secs_f32(1.0);
1253 for _ in 0..max_batch {
1254 tokio::spawn(enqueue(runtime.clone(), receiver.clone(), timer));
1255 }
1256 tokio::spawn(finalize(runtime, receiver, timer));
1257}