Skip to main content

cake_core/models/qwen2/
qwen.rs

1use anyhow::Result;
2use async_trait::async_trait;
3
4use super::QwenHistory;
5use crate::models::common::text_model::TextModelBase;
6use crate::models::common::Transformer;
7use crate::models::TextGenerator;
8use crate::{
9    cake::Context,
10    models::{chat::Message, Generator, Token},
11};
12
13/// Default end of stream token if not found in configuration.
14const DEFAULT_EOS_TOKEN: &str = "<|endoftext|>";
15
16/// Qwen2/Qwen2.5 main class.
17pub struct Qwen2 {
18    base: TextModelBase,
19    history: QwenHistory,
20}
21
22#[async_trait]
23impl Generator for Qwen2 {
24    type Shardable = Transformer;
25    const MODEL_NAME: &'static str = "qwen2";
26
27    /// Load this model from the context.
28    async fn load(ctx: &mut Context) -> Result<Option<Box<Self>>> {
29        let base = TextModelBase::load::<Transformer>(ctx, DEFAULT_EOS_TOKEN).await?;
30        let history = QwenHistory::new();
31        Ok(Some(Box::new(Self { base, history })))
32    }
33}
34
35#[async_trait]
36impl TextGenerator for Qwen2 {
37    /// Add a message to the chat history.
38    fn add_message(&mut self, message: Message) -> Result<()> {
39        self.history.push(message);
40        Ok(())
41    }
42
43    /// Reset the chat pipeline state.
44    fn reset(&mut self) -> Result<()> {
45        self.history.clear();
46        self.base.reset();
47        Ok(())
48    }
49
50    async fn goodbye(&mut self) -> Result<()> {
51        self.base.goodbye().await
52    }
53
54    /// Return the next token.
55    async fn next_token(&mut self, index: usize) -> Result<Token> {
56        // Prefill tokens with chat history the first time.
57        if self.base.generated == 0 {
58            let dialog = self.history.encode_dialog_to_prompt();
59            self.base.prepare_prompt(&dialog)?;
60        }
61        self.base.next_token(index).await
62    }
63
64    /// Return the number of generated tokens so far.
65    fn generated_tokens(&self) -> usize {
66        self.base.generated
67    }
68}