Skip to main content

cake_core/models/qwen3_5/
model.rs

1use anyhow::Result;
2use async_trait::async_trait;
3
4use crate::models::common::chatml_history::ChatMLHistory;
5use crate::models::common::text_model::TextModelBase;
6use crate::models::TextGenerator;
7use crate::{
8    cake::Context,
9    models::{chat::Message, Generator, Token},
10};
11
12use super::block::Qwen3_5Block;
13
14/// Default end of stream token if not found in configuration.
15const DEFAULT_EOS_TOKEN: &str = "<|im_end|>";
16
17/// Qwen3.5 main class.
18pub struct Qwen3_5 {
19    base: TextModelBase,
20    history: ChatMLHistory,
21}
22
23#[async_trait]
24impl Generator for Qwen3_5 {
25    type Shardable = Qwen3_5Block;
26    const MODEL_NAME: &'static str = "qwen3_5";
27
28    /// Load this model from the context.
29    async fn load(ctx: &mut Context) -> Result<Option<Box<Self>>> {
30        let base = TextModelBase::load::<Qwen3_5Block>(ctx, DEFAULT_EOS_TOKEN).await?;
31        let history = ChatMLHistory::new();
32        Ok(Some(Box::new(Self { base, history })))
33    }
34}
35
36#[async_trait]
37impl TextGenerator for Qwen3_5 {
38    /// Add a message to the chat history.
39    fn add_message(&mut self, message: Message) -> Result<()> {
40        self.history.push(message);
41        Ok(())
42    }
43
44    /// Reset the chat pipeline state.
45    fn reset(&mut self) -> Result<()> {
46        self.history.clear();
47        self.base.reset();
48        Ok(())
49    }
50
51    async fn goodbye(&mut self) -> Result<()> {
52        self.base.goodbye().await
53    }
54
55    /// Return the next token.
56    async fn next_token(&mut self, index: usize) -> Result<Token> {
57        // Prefill tokens with chat history the first time.
58        if self.base.generated == 0 {
59            let dialog = self.history.encode_dialog_to_prompt();
60            self.base.prepare_prompt(&dialog)?;
61        }
62        self.base.next_token(index).await
63    }
64
65    /// Return the number of generated tokens so far.
66    fn generated_tokens(&self) -> usize {
67        self.base.generated
68    }
69}