ai00_core/
run.rs

1use std::{
2    cmp::Ordering,
3    collections::{HashMap, VecDeque},
4    error::Error,
5    ops::Deref,
6    sync::{Arc, Weak},
7    time::Duration,
8};
9
10use anyhow::{bail, Result};
11use derivative::Derivative;
12use flume::{Receiver, Sender, TryRecvError};
13use itertools::Itertools;
14use memmap2::Mmap;
15use qp_trie::Trie;
16use safetensors::SafeTensors;
17use tokio::{
18    sync::{Mutex, RwLock},
19    task::JoinHandle,
20    time::Instant,
21};
22use web_rwkv::{
23    context::Context,
24    runtime::{
25        infer::{Rnn, RnnInput, RnnInputBatch, RnnOption, RnnOutputBatch},
26        model::{ModelInfo, State},
27        Runtime,
28    },
29    tensor::{kind::ReadWrite, TensorCpu, TensorGpu, TensorShape},
30    tokenizer::Tokenizer,
31};
32
33use crate::{
34    load_model_state,
35    sampler::{bnf::BnfSampler, Formatter, Sampler},
36    FinishReason, GenerateKind, GenerateRequest, InitState, InputState, ReloadRequest, RuntimeInfo,
37    StateId, Token, TokenCounter,
38};
39
40const MIN_PROMPT_CACHE_TOKENS: usize = 32;
41const MAX_CACHE_ITEMS: usize = 256;
42
43#[repr(transparent)]
44#[derive(Debug, Default, Clone)]
45pub struct Tokens(pub Vec<u32>);
46
47impl std::ops::Deref for Tokens {
48    type Target = TokenSlice;
49
50    fn deref(&self) -> &Self::Target {
51        self.0.as_token_slice()
52    }
53}
54
55impl std::borrow::Borrow<[u8]> for Tokens {
56    fn borrow(&self) -> &[u8] {
57        bytemuck::cast_slice(&self.0)
58    }
59}
60
61impl std::borrow::Borrow<[u32]> for Tokens {
62    fn borrow(&self) -> &[u32] {
63        &self.0
64    }
65}
66
67impl std::borrow::Borrow<TokenSlice> for Tokens {
68    fn borrow(&self) -> &TokenSlice {
69        self.0[..].as_token_slice()
70    }
71}
72
73impl qp_trie::Break for Tokens {
74    type Split = TokenSlice;
75
76    fn empty<'a>() -> &'a Self::Split {
77        Default::default()
78    }
79
80    fn find_break(&self, loc: usize) -> &Self::Split {
81        self.0[..loc >> 2].as_token_slice()
82    }
83}
84
85#[repr(transparent)]
86pub struct TokenSlice([u32]);
87
88impl std::ops::Deref for TokenSlice {
89    type Target = [u32];
90
91    fn deref(&self) -> &Self::Target {
92        &self.0
93    }
94}
95
96impl std::borrow::Borrow<[u8]> for TokenSlice {
97    fn borrow(&self) -> &[u8] {
98        bytemuck::cast_slice(&self.0)
99    }
100}
101
102impl Default for &TokenSlice {
103    fn default() -> Self {
104        <&[u32]>::default().as_token_slice()
105    }
106}
107
108pub trait AsTokenSlice {
109    fn as_token_slice(&self) -> &TokenSlice;
110}
111
112impl AsTokenSlice for [u32] {
113    fn as_token_slice(&self) -> &TokenSlice {
114        let ptr = self as *const [u32] as *const TokenSlice;
115        unsafe { &*ptr }
116    }
117}
118
119#[derive(Derivative, Clone)]
120#[derivative(Debug)]
121pub struct GenerateContext {
122    /// Tokens that are provided at first.
123    pub prompt_tokens: Vec<u32>,
124    /// Whether the prompt has already been processed and cached.
125    pub prompt_cached: CachedPrompt,
126    /// Tokens that have been computed and cached.
127    pub prefix: Tokens,
128    /// Tokens to be computed.
129    pub suffix: Tokens,
130    /// The output of the model from the last run.
131    pub output: Option<TensorCpu<f32>>,
132    /// Tokens to be chosen if this is a choose request.
133    pub choices: Vec<Tokens>,
134    /// Texts that are output by the model.
135    pub model_text: Vec<u8>,
136    /// Model may output partial utf-8. This makes sure the output is always valid.
137    pub buffer: Vec<u8>,
138    /// Tokens that are output by the model.
139    pub model_tokens: Vec<u32>,
140    /// Compiled BNF schema, if any.
141    #[derivative(Debug = "ignore")]
142    pub formatters: Vec<Arc<RwLock<dyn Formatter + Send + Sync>>>,
143    /// For measuring time used.
144    pub instant: Option<Instant>,
145    /// Generate request provided by the caller.
146    pub request: GenerateRequest,
147    /// To send back generated tokens.
148    pub sender: Sender<Token>,
149}
150
151impl GenerateContext {
152    pub async fn new(
153        request: GenerateRequest,
154        sender: Sender<Token>,
155        tokenizer: &Tokenizer,
156    ) -> Result<Self> {
157        let tokens = Tokens(tokenizer.encode(request.prompt.as_bytes())?);
158        let model_tokens = Tokens(tokenizer.encode(request.model_text.as_bytes())?);
159
160        // init sampler state here
161        request.sampler.write().await.init(&model_tokens);
162
163        let choices = match &request.kind {
164            GenerateKind::Choose { choices, .. } => {
165                let choices: Vec<_> = choices
166                    .iter()
167                    .map(|prompt| tokenizer.encode(prompt.as_bytes()))
168                    .try_collect()?;
169                choices.into_iter().map(Tokens).collect()
170            }
171            _ => Vec::new(),
172        };
173        Ok(Self {
174            prompt_tokens: tokens.to_vec(),
175            prompt_cached: Default::default(),
176            prefix: Default::default(),
177            suffix: tokens,
178            output: None,
179            choices,
180            model_text: Vec::new(),
181            buffer: Vec::new(),
182            model_tokens: Vec::new(),
183            formatters: Vec::new(),
184            instant: None,
185            request,
186            sender,
187        })
188    }
189}
190
191#[derive(Debug, Default, Clone)]
192pub enum CachedPrompt {
193    #[default]
194    None,
195    Future(tokio::sync::watch::Sender<Option<CachedItem>>),
196    Done,
197}
198
199/// An item that a cache slot holds, including a state, last model output and a timestamp.
200#[derive(Debug, Clone)]
201pub struct CachedItem {
202    state: TensorCpu<f32>,
203    output: TensorCpu<f32>,
204    instant: Instant,
205}
206
207impl CachedItem {
208    pub fn new(state: TensorCpu<f32>, output: TensorCpu<f32>) -> Self {
209        Self {
210            state,
211            output,
212            instant: Instant::now(),
213        }
214    }
215
216    /// Update an existing cache item's timestamp.
217    pub fn update(cached: CachedItem) -> Self {
218        Self {
219            instant: Instant::now(),
220            ..cached
221        }
222    }
223}
224
225struct CacheCheckout {
226    prefix: Vec<u32>,
227    state: TensorCpu<f32>,
228    output: Option<TensorCpu<f32>>,
229}
230
231#[derive(Debug, Default)]
232struct Cache {
233    state: Option<InitState>,
234    cache: Trie<Tokens, tokio::sync::watch::Sender<Option<CachedItem>>>,
235}
236
237impl Cache {
238    fn maintain(&mut self) {
239        let cache = &mut self.cache;
240        if cache.count() <= MAX_CACHE_ITEMS {
241            return;
242        }
243
244        let mut remove = vec![];
245        for (tokens, _) in cache
246            .iter()
247            .filter_map(|(tokens, item)| item.borrow().clone().map(|item| (tokens, item)))
248            .sorted_unstable_by_key(|(_, item)| item.instant.elapsed())
249            .skip(MAX_CACHE_ITEMS)
250        {
251            remove.push(tokens.to_owned());
252        }
253
254        for tokens in remove.into_iter() {
255            cache.remove(&tokens);
256        }
257    }
258}
259
260#[derive(Debug, Default)]
261struct CacheHub {
262    backed: HashMap<StateId, Cache>,
263    default: Cache,
264}
265
266impl CacheHub {
267    fn fetch(&mut self, id: StateId) -> &mut Cache {
268        match self.backed.get_mut(&id) {
269            Some(item) => item,
270            None => &mut self.default,
271        }
272    }
273}
274
275/// The result of trying to queuing a task.
276#[derive(Debug)]
277enum SlotResult {
278    /// There is an idle slot ready to be picked up.
279    Success(usize),
280    /// An idle slot is swapped.
281    Fault(usize),
282    /// There is no idle slot left.
283    Failure(Box<GenerateContext>),
284    /// An error occurred.
285    Error(Box<dyn Error>),
286}
287
288#[derive(Debug)]
289enum SlotState {
290    /// The slot might be either picked up or swapped.
291    Idle(Tokens, Instant),
292    /// The slot is currently under processing.
293    Busy(JoinHandle<Result<GenerateContext>>),
294    /// The slot is locked for updating.
295    Locked,
296}
297
298impl Default for SlotState {
299    fn default() -> Self {
300        Self::Idle(Default::default(), Instant::now())
301    }
302}
303
304#[derive(Debug, PartialEq, Eq)]
305enum SlotChoice {
306    Continue(usize, usize),
307    Back(usize),
308    Empty(usize),
309}
310
311impl std::cmp::Ord for SlotChoice {
312    fn cmp(&self, other: &Self) -> Ordering {
313        // priority: continue > empty > back
314        use SlotChoice::{Back, Continue, Empty};
315        match (self, other) {
316            (Continue(_, x), Continue(_, y)) => x.cmp(y),
317            (Continue(_, _), _) => Ordering::Greater,
318            (_, Continue(_, _)) => Ordering::Less,
319            (Empty(_), Empty(_)) => Ordering::Equal,
320            (Empty(_), Back(_)) => Ordering::Greater,
321            (Back(_), Empty(_)) => Ordering::Less,
322            (Back(_), Back(_)) => Ordering::Equal,
323        }
324    }
325}
326
327impl std::cmp::PartialOrd for SlotChoice {
328    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
329        Some(self.cmp(other))
330    }
331}
332
333#[derive(Debug, Clone)]
334enum InferBatch {
335    Run {
336        batch: usize,
337        tokens: Vec<u32>,
338        option: RnnOption,
339        sender: Sender<TensorCpu<f32>>,
340    },
341    Load {
342        batch: usize,
343        tensor: TensorCpu<f32>,
344    },
345    Back {
346        batch: usize,
347        sender: Sender<TensorCpu<f32>>,
348    },
349    Write {
350        batch: usize,
351        tensor: TensorGpu<f32, ReadWrite>,
352    },
353    Read {
354        batch: usize,
355        sender: Sender<TensorGpu<f32, ReadWrite>>,
356    },
357}
358
359#[derive(Debug, Clone)]
360struct SoftmaxBatch {
361    input: TensorCpu<f32>,
362    sender: Sender<TensorCpu<f32>>,
363}
364
365#[derive(Debug, Clone)]
366struct RuntimeSender {
367    infer: Sender<InferBatch>,
368    softmax: Sender<SoftmaxBatch>,
369}
370
371#[derive(Derivative, Clone)]
372#[derivative(Debug)]
373struct CoreRuntime {
374    context: Context,
375    info: ModelInfo,
376    reload: Arc<ReloadRequest>,
377    #[derivative(Debug = "ignore")]
378    state: Arc<dyn State + Send + Sync>,
379    sender: RuntimeSender,
380    tokenizer: Arc<Tokenizer>,
381    slots: Arc<Mutex<Vec<SlotState>>>,
382    caches: Arc<Mutex<CacheHub>>,
383}
384
385impl CoreRuntime {
386    /// Check in an input state into the cache.
387    async fn check_in_state(&self, state: &InputState) -> Result<StateId> {
388        match state {
389            InputState::Key(id) => Ok(*id),
390            InputState::Value(value) => {
391                let id = value.id;
392                let state = InitState::try_from(value.clone())?;
393                let mut caches = self.caches.lock().await;
394                caches.backed.insert(
395                    id,
396                    Cache {
397                        state: Some(state),
398                        cache: Trie::new(),
399                    },
400                );
401                Ok(id)
402            }
403            InputState::File(file) => {
404                let name = file.name.clone();
405                let id = file.id;
406                let default = false;
407
408                let file = tokio::fs::File::open(&file.path).await?;
409                let data = unsafe { Mmap::map(&file) }?;
410
411                let st = SafeTensors::deserialize(&data);
412                let prefab = cbor4ii::serde::from_slice::<InitState>(&data);
413                let state = match (st, prefab) {
414                    (Ok(model), _) => {
415                        let data = load_model_state(&self.context, &self.info, model).await?;
416                        InitState {
417                            name,
418                            id,
419                            default,
420                            data,
421                        }
422                    }
423                    (_, Ok(state)) => state,
424                    _ => bail!("failed to load init state"),
425                };
426
427                let mut caches = self.caches.lock().await;
428                caches.backed.insert(
429                    id,
430                    Cache {
431                        state: Some(state),
432                        cache: Trie::new(),
433                    },
434                );
435
436                Ok(id)
437            }
438        }
439    }
440
441    /// Search for the longest common prefix in the memory cache and checkout the state from that point.
442    /// Should there be a cache miss, an initial state is returned.
443    async fn checkout(&self, id: StateId, tokens: &[u32]) -> CacheCheckout {
444        let mut caches = self.caches.lock().await;
445
446        let Cache { state, cache } = caches.fetch(id);
447        let prefix = cache.longest_common_prefix(tokens.as_token_slice());
448        let len = (1..=prefix.len())
449            .rev()
450            .find(|len| cache.contains_key(prefix[0..*len].as_token_slice()))
451            .unwrap_or_default();
452        let prefix = prefix[0..len].to_vec();
453
454        let state = state.clone().map(|state| state.data);
455        let item = cache.get(prefix[..].as_token_slice()).cloned();
456        drop(caches);
457
458        match item {
459            Some(sender) => {
460                let mut receiver = sender.subscribe();
461                let item = loop {
462                    if let Some(item) = receiver.borrow_and_update().deref().clone() {
463                        break item;
464                    }
465                    let _ = receiver.changed().await;
466                };
467                let item = CachedItem::update(item);
468                sender.send_replace(Some(item.clone()));
469                CacheCheckout {
470                    prefix,
471                    state: item.state,
472                    output: Some(item.output),
473                }
474            }
475            None => {
476                let prefix = vec![];
477                let state = state.unwrap_or_else(|| self.state.init());
478                CacheCheckout {
479                    prefix,
480                    state,
481                    output: None,
482                }
483            }
484        }
485    }
486
487    /// Queue an inference task.
488    async fn queue(&self, context: GenerateContext) -> SlotResult {
489        let tokens = match [context.prefix, context.suffix].concat() {
490            tokens if tokens.is_empty() => vec![0u32],
491            tokens => tokens,
492        };
493
494        // compile the BNF schema.
495        let mut formatters = Vec::<Arc<RwLock<dyn Formatter + Send + Sync>>>::new();
496        if let Some(schema) = context.request.bnf_schema.clone() {
497            match BnfSampler::new(&self.tokenizer, &schema) {
498                Ok(bnf) => formatters.push(Arc::new(RwLock::new(bnf))),
499                Err(err) => return SlotResult::Error(err.into()),
500            }
501        }
502
503        // find the best idle slot by:
504        // 1. find the slot that matches the context (continue)
505        // 2. find an empty slot
506        // 3. find the oldest non-empty slot
507        let choice = {
508            let mut slots = self.slots.lock().await;
509            let choice = slots
510                .iter()
511                .enumerate()
512                .filter_map(|(batch, slot)| match slot {
513                    SlotState::Idle(content, instant) => {
514                        let delta = instant.elapsed();
515                        match (content.is_empty(), tokens.starts_with(content)) {
516                            (true, _) => Some((SlotChoice::Empty(batch), delta)),
517                            (_, true) => Some((SlotChoice::Continue(batch, content.len()), delta)),
518                            (_, false) => Some((SlotChoice::Back(batch), delta)),
519                        }
520                    }
521                    _ => None,
522                })
523                .max_by(|lhs, rhs| lhs.0.cmp(&rhs.0).then(lhs.1.cmp(&rhs.1)))
524                .map(|(x, _)| x);
525            match choice {
526                None => (),
527                Some(SlotChoice::Empty(batch))
528                | Some(SlotChoice::Back(batch))
529                | Some(SlotChoice::Continue(batch, _)) => slots[batch] = SlotState::Locked,
530            }
531            choice
532        };
533
534        let check_in_state = |state: Arc<InputState>, batch: usize| async move {
535            match self.check_in_state(&state).await {
536                Ok(state) => state,
537                Err(err) => {
538                    log::error!("[queue][state][slot: {batch}] error: {err:#?}");
539                    Default::default()
540                }
541            }
542        };
543
544        match choice {
545            // we cannot find a slot because all slots are occupied
546            // in this case, we hand the request back to the caller
547            None => SlotResult::Failure(
548                GenerateContext {
549                    prefix: Default::default(),
550                    suffix: Tokens(tokens),
551                    formatters,
552                    ..context
553                }
554                .into(),
555            ),
556            // back a non-relative and non-empty slot and use it for our new context
557            Some(SlotChoice::Back(batch)) => {
558                log::info!("[queue][back][slot: {batch}]");
559                let state = check_in_state(context.request.state.clone(), batch).await;
560                let checkout = self.checkout(state, &tokens).await;
561                self.load(batch, checkout.state).await;
562
563                let len = checkout.prefix.len();
564                assert!(len == 0 || (len > 0 && checkout.output.is_some()));
565                log::info!("[cache][checkout[[slot: {batch}][len: {len}]");
566
567                let context = GenerateContext {
568                    prefix: Tokens(tokens[..len].to_vec()),
569                    suffix: Tokens(tokens[len..].to_vec()),
570                    output: checkout.output,
571                    formatters,
572                    ..context
573                };
574                let handle = tokio::spawn(self.clone().process(batch, context));
575                let mut slots = self.slots.lock().await;
576                slots[batch] = SlotState::Busy(handle);
577                SlotResult::Fault(batch)
578            }
579            // directly occupy an empty slot so no need backing
580            Some(SlotChoice::Empty(batch)) => {
581                log::info!("[queue][empty][slot: {batch}]");
582                let state = check_in_state(context.request.state.clone(), batch).await;
583                let checkout = self.checkout(state, &tokens).await;
584                self.load(batch, checkout.state).await;
585
586                let len = checkout.prefix.len();
587                assert!(len == 0 || (len > 0 && checkout.output.is_some()));
588                log::info!("[cache][checkout][slot: {batch}][len: {len}]");
589
590                let context = GenerateContext {
591                    prefix: Tokens(tokens[..len].to_vec()),
592                    suffix: Tokens(tokens[len..].to_vec()),
593                    output: checkout.output,
594                    formatters,
595                    ..context
596                };
597                let handle = tokio::spawn(self.clone().process(batch, context));
598                let mut slots = self.slots.lock().await;
599                slots[batch] = SlotState::Busy(handle);
600                SlotResult::Success(batch)
601            }
602            Some(SlotChoice::Continue(batch, ..)) => {
603                log::info!("[queue][continue][slot: {batch}]");
604                let state = check_in_state(context.request.state.clone(), batch).await;
605                let checkout = self.checkout(state, &tokens).await;
606                self.load(batch, checkout.state).await;
607
608                let len = checkout.prefix.len();
609                assert!(len == 0 || (len > 0 && checkout.output.is_some()));
610                log::info!("[cache][checkout[[slot: {batch}][len: {len}]");
611
612                let context = GenerateContext {
613                    prefix: Tokens(tokens[..len].to_vec()),
614                    suffix: Tokens(tokens[len..].to_vec()),
615                    output: checkout.output,
616                    formatters,
617                    ..context
618                };
619                let handle = tokio::spawn(self.clone().process(batch, context));
620                let mut slots = self.slots.lock().await;
621                slots[batch] = SlotState::Busy(handle);
622                SlotResult::Success(batch)
623            }
624        }
625    }
626
627    /// Reset finished slots to `idle`. Cache current states of finished slots.
628    async fn update(&self) {
629        let update = |handle: JoinHandle<_>| async move {
630            if !handle.is_finished() {
631                return Ok(SlotState::Busy(handle));
632            }
633
634            let context = handle.await??;
635            Ok::<_, Box<dyn Error + Send + Sync>>(SlotState::Idle(context.prefix, Instant::now()))
636        };
637
638        for batch in 0..self.reload.max_batch {
639            let handle = {
640                let mut slots = self.slots.lock().await;
641                let slot = std::mem::replace(&mut slots[batch], SlotState::Locked);
642                let SlotState::Busy(handle) = slot else {
643                    slots[batch] = slot;
644                    continue;
645                };
646                handle
647            };
648
649            let updated = match update(handle).await {
650                Ok(updated) => updated,
651                Err(err) => {
652                    log::error!("[update][error][slot: {batch}] {err:#?}");
653                    let mut slots = self.slots.lock().await;
654                    slots[batch] = Default::default();
655                    continue;
656                }
657            };
658
659            let mut slots = self.slots.lock().await;
660            slots[batch] = updated;
661        }
662    }
663
664    async fn sample(
665        &self,
666        output: TensorCpu<f32>,
667        sampler: Arc<RwLock<dyn Sampler + Send + Sync>>,
668        formatters: Vec<Arc<RwLock<dyn Formatter + Send + Sync>>>,
669        bias: Arc<HashMap<u32, f32>>,
670    ) -> Result<(u32, TensorCpu<f32>)> {
671        // process raw model outputs
672        let num_vocab = self.info.num_vocab;
673        let input = {
674            let mut data = output.to_vec();
675            assert_eq!(data.len(), num_vocab);
676
677            sampler.read().await.transform(&mut data);
678            for formatter in formatters {
679                formatter.read().await.transform(&mut data);
680            }
681            for (token, bias) in bias.iter() {
682                data[*token as usize] += *bias;
683            }
684
685            self.context.tensor_from_data([num_vocab, 1, 1, 1], data)?
686        };
687
688        // compute probabilities
689        let (sender, receiver) = flume::bounded(1);
690        let _ = self.sender.softmax.send(SoftmaxBatch { input, sender });
691        let output = receiver.recv_async().await?;
692
693        // sample tokens
694        assert_eq!(output.len(), num_vocab);
695        let token = sampler.write().await.sample(&output);
696        Ok((token, output))
697    }
698
699    async fn perplexity(&self, batch: usize, tokens: &[u32], head: Option<f32>) -> Result<f32> {
700        let mut p = Vec::with_capacity(tokens.len().max(1));
701        let len = tokens.len();
702        let tokens = match head {
703            Some(head) => {
704                p.push(head);
705                tokens.to_vec()
706            }
707            None => [&[0], tokens].concat(),
708        };
709
710        let (sender, receiver) = flume::unbounded();
711        let _ = self
712            .sender
713            .infer
714            .send_async({
715                let tokens = tokens.clone();
716                let option = RnnOption::Full;
717                InferBatch::Run {
718                    batch,
719                    tokens,
720                    option,
721                    sender,
722                }
723            })
724            .await;
725
726        let index = Arc::new(Mutex::new(1));
727        while p.len() < len {
728            let tokens = tokens.clone();
729            let output = receiver.recv_async().await?;
730            let output = output.split(1)?;
731            let f = {
732                let index = index.clone();
733                move || {
734                    let mut index = index.blocking_lock();
735                    let mut p = Vec::with_capacity(output.len());
736                    for data in output {
737                        if *index < tokens.len() {
738                            let data = data.map(|x| x.exp()).to_vec();
739                            let sum: f32 = data.iter().sum();
740                            let token = tokens[*index] as usize;
741                            p.push(data[token] / sum);
742                        }
743                        *index += 1;
744                    }
745                    p
746                }
747            };
748            let mut q = tokio::task::spawn_blocking(f).await?;
749            p.append(&mut q);
750        }
751
752        let ppl: f32 = p.into_iter().map(|x| x.ln()).sum();
753        let ppl = -ppl / tokens.len() as f32;
754        Ok(ppl)
755    }
756
757    async fn load(&self, batch: usize, tensor: TensorCpu<f32>) {
758        let _ = self
759            .sender
760            .infer
761            .send_async(InferBatch::Load { batch, tensor })
762            .await;
763    }
764
765    async fn back(&self, batch: usize) -> Result<TensorCpu<f32>> {
766        let (sender, receiver) = flume::bounded(1);
767        let _ = self.sender.infer.send(InferBatch::Back { batch, sender });
768        let tensor = receiver.recv_async().await?;
769        Ok(tensor)
770    }
771
772    async fn write(&self, batch: usize, tensor: TensorGpu<f32, ReadWrite>) {
773        let _ = self
774            .sender
775            .infer
776            .send_async(InferBatch::Write { batch, tensor })
777            .await;
778    }
779
780    async fn read(&self, batch: usize) -> Result<TensorGpu<f32, ReadWrite>> {
781        let (sender, receiver) = flume::bounded(1);
782        let _ = self.sender.infer.send(InferBatch::Read { batch, sender });
783        let tensor = receiver.recv_async().await?;
784        Ok(tensor)
785    }
786
787    /// Read in the prompt of a batch and continuously sample it until it is done.
788    async fn process(self, batch: usize, mut context: GenerateContext) -> Result<GenerateContext> {
789        // schedule a future cache slot for the prompt
790        {
791            let mut caches = self.caches.lock().await;
792            let cache = &mut caches.fetch(context.request.state.id()).cache;
793
794            let enable = context.prompt_tokens.len() > MIN_PROMPT_CACHE_TOKENS;
795            let enable = enable && !cache.contains_key(context.prompt_tokens.as_token_slice());
796            if enable {
797                let (sender, _) = tokio::sync::watch::channel(None);
798                context.prompt_cached = CachedPrompt::Future(sender.clone());
799                cache.insert(Tokens(context.prompt_tokens.clone()), sender);
800
801                let len = context.prompt_tokens.len();
802                log::info!("[cache][future][slot: {batch}][len: {len}]");
803            }
804        }
805
806        let _ = context.sender.send(Token::Start);
807
808        loop {
809            let output = match (context.suffix.len(), context.output.clone()) {
810                (0, Some(output)) => output,
811                _ => {
812                    let (sender, receiver) = flume::bounded(1);
813                    let _ = self
814                        .sender
815                        .infer
816                        .send_async(InferBatch::Run {
817                            batch,
818                            tokens: context.suffix.to_vec(),
819                            option: RnnOption::Last,
820                            sender,
821                        })
822                        .await;
823
824                    let prefix = std::mem::take(&mut context.prefix);
825                    let suffix = std::mem::take(&mut context.suffix);
826
827                    context.prefix = Tokens([prefix.0, suffix.0].concat());
828                    context.suffix = Tokens(vec![]);
829
830                    receiver.recv_async().await?
831                }
832            };
833
834            // cache the prompt if being asked
835            if let CachedPrompt::Future(sender) = context.prompt_cached.clone() {
836                assert_eq!(context.prefix.len(), context.prompt_tokens.len());
837
838                let backed = self.back(batch).await?;
839                let output = output.clone();
840                sender.send_replace(Some(CachedItem::new(backed, output)));
841                context.prompt_cached = CachedPrompt::Done;
842
843                let len = context.prefix.len();
844                log::info!("[cache][insert][slot: {batch}][len: {len}]");
845            }
846
847            let (token, output) = {
848                let output = output.clone();
849                let sampler = context.request.sampler.clone();
850                let formatters = context.formatters.clone();
851                let bias = context.request.bias.clone();
852                self.sample(output, sampler, formatters, bias).await?
853            };
854
855            let mut stop_token = token == 0;
856            let mut word = match self.tokenizer.decode(&[token]) {
857                Ok(word) => word,
858                Err(err) => {
859                    log::warn!("[process][error] {err:#?}");
860                    stop_token = true;
861                    Vec::new()
862                }
863            };
864
865            context.output = Some(output.clone());
866            context.suffix.0.push(token);
867            context.model_tokens.push(token);
868            context.model_text.extend(&word);
869            context.buffer.append(&mut word);
870
871            let instant = context.instant.get_or_insert(Instant::now());
872            let mut done = false;
873            let mut stop = |reason| {
874                let counter = {
875                    let prompt = context.prompt_tokens.len();
876                    let completion = context.model_tokens.len();
877                    let total = prompt + completion;
878                    let duration = instant.elapsed();
879                    TokenCounter {
880                        prompt,
881                        completion,
882                        total,
883                        duration,
884                    }
885                };
886
887                let _ = context.sender.send(Token::Stop(reason, counter));
888                let _ = context.sender.send(Token::Done);
889                done = true;
890            };
891
892            // update the formatter (BNF) state
893            let mut halt = false;
894            for formatter in context.formatters.iter() {
895                let mut formatter = formatter.write().await;
896                halt |= formatter.update(token);
897            }
898
899            // here we detect if there is a stop word in our buffer
900            let ((head, tail), stop_matched) = context
901                .request
902                .stop
903                .iter()
904                .map(|stop| {
905                    let stop = stop.as_bytes();
906                    let mut index_safe = 0;
907                    let mut index_unsafe = 0;
908                    while index_unsafe < context.buffer.len() {
909                        // the maximum match of the current stop string
910                        let index_stop = index_unsafe - index_safe;
911                        if index_stop >= stop.len() {
912                            // we have a total match
913                            return (index_safe, true);
914                        }
915
916                        let output = context.buffer[index_unsafe];
917                        let stop = stop[index_stop];
918
919                        index_unsafe += 1;
920                        if output != stop {
921                            index_safe = index_unsafe;
922                        }
923                    }
924                    (index_safe, index_unsafe - index_safe >= stop.len())
925                })
926                .min_by(|x, y| match (x.1, y.1) {
927                    (true, false) => Ordering::Less,
928                    (false, true) => Ordering::Greater,
929                    _ => x.0.cmp(&y.0),
930                })
931                .map(|(mid, matched)| (context.buffer.split_at(mid), matched))
932                .unwrap_or(((&context.buffer[..], &[]), false));
933
934            if context.sender.is_disconnected() {
935                done = true;
936            } else if let GenerateKind::Choose { calibrate, .. } = context.request.kind {
937                let backed = self.read(batch).await?;
938                let mut ppl = vec![f32::INFINITY; context.choices.len()];
939
940                if calibrate {
941                    // compute perplexities of the choices themselves and calibrate their effects
942                    let init = {
943                        let id = context.request.state.id();
944                        let mut caches = self.caches.lock().await;
945                        caches
946                            .fetch(id)
947                            .state
948                            .clone()
949                            .map(|state| state.data)
950                            .unwrap_or_else(|| self.state.init())
951                    };
952                    for (index, choice) in context
953                        .choices
954                        .iter()
955                        .enumerate()
956                        .filter(|(_, choice)| !choice.is_empty())
957                    {
958                        self.load(batch, init.clone()).await;
959                        ppl[index] = -self.perplexity(batch, choice, None).await?;
960                    }
961                    // recover the state
962                    self.write(batch, backed.clone()).await;
963                }
964
965                for (index, choice) in context
966                    .choices
967                    .iter()
968                    .enumerate()
969                    .filter(|(_, choice)| !choice.is_empty())
970                {
971                    let output = output.clone().to_vec();
972                    let head = Some(output[choice[0] as usize]);
973                    let p = self.perplexity(batch, choice, head).await?;
974                    ppl[index] = match calibrate {
975                        true => ppl[index] + p,
976                        false => p,
977                    };
978                    // recover the state
979                    self.write(batch, backed.clone()).await;
980                }
981
982                let _ = context.sender.send(Token::Choose(ppl));
983                done = true;
984            } else if let GenerateKind::State = context.request.kind {
985                let backed = self.back(batch).await?;
986                let embed = backed.to_vec();
987                let shape = backed.shape().into();
988                let _ = context.sender.send(Token::Embed(embed, shape));
989                done = true;
990            } else if halt || stop_matched || stop_token {
991                let output = String::from_utf8_lossy(head);
992                let _ = context.sender.send(Token::Content(output.into()));
993                stop(FinishReason::Stop);
994
995                if let Some(output) = context.output.clone() {
996                    let backed = self.back(batch).await?;
997                    let mut caches = self.caches.lock().await;
998                    let cache = &mut caches.fetch(context.request.state.id()).cache;
999                    let item = CachedItem::new(backed, output);
1000                    let (item, _) = tokio::sync::watch::channel(Some(item));
1001                    cache.insert(context.prefix.clone(), item);
1002
1003                    let len = context.prefix.len();
1004                    log::info!("[cache][insert][slot: {batch}][len: {len}]");
1005                }
1006            } else if context.model_tokens.len() >= context.request.max_tokens {
1007                stop(FinishReason::Length);
1008            } else if let Ok(word) = String::from_utf8(head.to_vec()) {
1009                let _ = context.sender.send(Token::Content(word));
1010                context.buffer = tail.to_vec();
1011            }
1012
1013            if done {
1014                log::info!("[process][done][slot: {batch}]");
1015                break;
1016            }
1017        }
1018
1019        Ok(context)
1020    }
1021
1022    /// Keep the items in the cache less then [`MAX_CACHE_ITEMS`].
1023    async fn maintain_cache(&self) {
1024        let mut caches = self.caches.lock().await;
1025        caches.default.maintain();
1026        caches.backed.iter_mut().for_each(|(_, x)| x.maintain());
1027    }
1028}
1029
1030async fn enqueue(runtime: CoreRuntime, receiver: Receiver<GenerateContext>, timer: Duration) {
1031    let mut queue = Vec::<GenerateContext>::new();
1032
1033    'outer: while let Ok(context) = receiver.recv_async().await {
1034        queue.push(context);
1035
1036        'inner: loop {
1037            runtime.maintain_cache().await;
1038            runtime.update().await;
1039
1040            let mut temp = Vec::new();
1041            for context in queue.drain(..) {
1042                match runtime.queue(context).await {
1043                    SlotResult::Failure(context) => temp.push(*context),
1044                    SlotResult::Success(batch) => log::info!("[enqueue][ok][slot: {batch}]"),
1045                    SlotResult::Fault(batch) => log::info!("[enqueue][fault][slot: {batch}]"),
1046                    SlotResult::Error(err) => log::error!("[enqueue][error] {err:#?}"),
1047                }
1048            }
1049            std::mem::swap(&mut queue, &mut temp);
1050
1051            if queue.is_empty() {
1052                break 'inner;
1053            }
1054
1055            match receiver.try_recv() {
1056                Ok(context) => queue.push(context),
1057                Err(TryRecvError::Empty) => tokio::time::sleep(timer).await,
1058                Err(TryRecvError::Disconnected) => break 'outer,
1059            }
1060        }
1061    }
1062}
1063
1064async fn finalize(runtime: CoreRuntime, receiver: Receiver<GenerateContext>, timer: Duration) {
1065    while !receiver.is_disconnected() {
1066        runtime.maintain_cache().await;
1067        runtime.update().await;
1068        tokio::time::sleep(timer).await;
1069    }
1070}
1071
1072async fn infer(
1073    reload: Arc<ReloadRequest>,
1074    runtime: Weak<dyn Runtime<Rnn> + Send + Sync>,
1075    state: Arc<dyn State + Send + Sync>,
1076    receiver: Receiver<InferBatch>,
1077) -> Result<()> {
1078    type Batch = (Vec<u32>, RnnOption, Sender<TensorCpu<f32>>);
1079    let mut batches: HashMap<usize, VecDeque<Batch>> = HashMap::new();
1080
1081    async fn schedule(
1082        batches: &mut HashMap<usize, VecDeque<Batch>>,
1083        state: Arc<dyn State + Send + Sync>,
1084        batch: InferBatch,
1085    ) -> Result<()> {
1086        match batch {
1087            InferBatch::Run {
1088                batch,
1089                tokens,
1090                option,
1091                sender,
1092            } => match batches.get_mut(&batch) {
1093                Some(batches) => batches.push_back((tokens, option, sender)),
1094                None => {
1095                    let deque = VecDeque::from_iter([(tokens, option, sender)]);
1096                    batches.insert(batch, deque);
1097                }
1098            },
1099            InferBatch::Load { batch, tensor } => state.load(tensor, batch)?,
1100            InferBatch::Back { batch, sender } => {
1101                let tensor = state.back(batch).await?;
1102                let _ = sender.send_async(tensor).await;
1103            }
1104            InferBatch::Write { batch, tensor } => state.write(tensor, batch)?,
1105            InferBatch::Read { batch, sender } => {
1106                let tensor = state.read(batch)?;
1107                let _ = sender.send_async(tensor).await;
1108            }
1109        }
1110        Ok(())
1111    }
1112
1113    'outer: while let Ok(batch) = receiver.recv_async().await {
1114        schedule(&mut batches, state.clone(), batch).await?;
1115
1116        for batch in receiver.drain() {
1117            schedule(&mut batches, state.clone(), batch).await?;
1118        }
1119
1120        while batches.values().map(|x| x.len()).sum::<usize>() > 0 {
1121            let mut inference = vec![Default::default(); reload.max_batch];
1122            let mut senders = HashMap::new();
1123
1124            for (&batch, deque) in batches.iter_mut() {
1125                let Some((tokens, option, sender)) = deque.pop_front() else {
1126                    continue;
1127                };
1128                inference[batch] = RnnInputBatch::new(tokens, option);
1129                senders.insert(batch, sender);
1130            }
1131
1132            let mut inference = Some(RnnInput::new(inference, reload.token_chunk_size));
1133
1134            while inference
1135                .as_ref()
1136                .map(|input| input.num_token() > 0)
1137                .expect("inference must not be `None`")
1138            {
1139                let Some(runtime) = runtime.upgrade() else {
1140                    break 'outer;
1141                };
1142                let input = inference.take().expect("inference must not be `None`");
1143                let (input, output) = runtime.infer(input).await?;
1144                inference.replace(input);
1145
1146                for (batch, RnnOutputBatch(output)) in output
1147                    .iter()
1148                    .enumerate()
1149                    .filter(|(_, output)| !output.is_empty())
1150                {
1151                    let Some(sender) = senders.get(&batch) else {
1152                        continue;
1153                    };
1154                    let _ = sender.send(output.clone());
1155                }
1156            }
1157        }
1158    }
1159
1160    log::info!("[infer] exit");
1161    Ok(())
1162}
1163
1164async fn softmax(
1165    reload: Arc<ReloadRequest>,
1166    context: Context,
1167    receiver: Receiver<SoftmaxBatch>,
1168) -> Result<()> {
1169    let mut batches = Vec::with_capacity(reload.max_batch);
1170
1171    while let Ok(batch) = receiver.recv_async().await {
1172        batches.push(batch);
1173
1174        for batch in receiver.drain() {
1175            batches.push(batch);
1176        }
1177
1178        let input = batches.iter().map(|batch| batch.input.clone()).collect();
1179        let output = web_rwkv::runtime::softmax::softmax(&context, input).await?;
1180
1181        for (batch, tensor) in batches.iter().zip_eq(output.into_iter()) {
1182            let _ = batch.sender.send(tensor);
1183        }
1184
1185        batches.clear();
1186    }
1187
1188    log::info!("[softmax] exit");
1189    Ok(())
1190}
1191
1192pub async fn run(
1193    context: Context,
1194    runtime: Weak<dyn Runtime<Rnn> + Send + Sync>,
1195    state: Arc<dyn State + Send + Sync>,
1196    receiver: Receiver<GenerateContext>,
1197    RuntimeInfo {
1198        reload,
1199        info,
1200        states,
1201        tokenizer,
1202        ..
1203    }: RuntimeInfo,
1204) {
1205    let slots = std::iter::repeat_with(Default::default)
1206        .take(reload.max_batch)
1207        .collect();
1208    let slots = Arc::new(Mutex::new(slots));
1209
1210    let caches = {
1211        let mut caches = CacheHub::default();
1212        // set up default initial state
1213        if let Some(state) = states.iter().find(|state| state.default) {
1214            caches.default.state = Some(state.clone());
1215        }
1216        // set up other initial states with ids
1217        for state in states {
1218            let id = state.id;
1219            let item = Cache {
1220                state: Some(state),
1221                cache: Trie::new(),
1222            };
1223            caches.backed.insert(id, item);
1224        }
1225        Arc::new(Mutex::new(caches))
1226    };
1227
1228    let max_batch = reload.max_batch;
1229    let runtime = {
1230        let infer = {
1231            let (sender, receiver) = flume::unbounded();
1232            tokio::spawn(infer(reload.clone(), runtime, state.clone(), receiver));
1233            sender
1234        };
1235        let softmax = {
1236            let (sender, receiver) = flume::unbounded();
1237            tokio::spawn(softmax(reload.clone(), context.clone(), receiver));
1238            sender
1239        };
1240        let sender = RuntimeSender { infer, softmax };
1241        CoreRuntime {
1242            context,
1243            info,
1244            reload,
1245            state,
1246            sender,
1247            tokenizer,
1248            slots,
1249            caches,
1250        }
1251    };
1252    let timer = Duration::from_secs_f32(1.0);
1253    for _ in 0..max_batch {
1254        tokio::spawn(enqueue(runtime.clone(), receiver.clone(), timer));
1255    }
1256    tokio::spawn(finalize(runtime, receiver, timer));
1257}