cake_core/models/qwen2/
qwen.rs1use anyhow::Result;
2use async_trait::async_trait;
3
4use super::QwenHistory;
5use crate::models::common::text_model::TextModelBase;
6use crate::models::common::Transformer;
7use crate::models::TextGenerator;
8use crate::{
9 cake::Context,
10 models::{chat::Message, Generator, Token},
11};
12
13const DEFAULT_EOS_TOKEN: &str = "<|endoftext|>";
15
16pub struct Qwen2 {
18 base: TextModelBase,
19 history: QwenHistory,
20}
21
22#[async_trait]
23impl Generator for Qwen2 {
24 type Shardable = Transformer;
25 const MODEL_NAME: &'static str = "qwen2";
26
27 async fn load(ctx: &mut Context) -> Result<Option<Box<Self>>> {
29 let base = TextModelBase::load::<Transformer>(ctx, DEFAULT_EOS_TOKEN).await?;
30 let history = QwenHistory::new();
31 Ok(Some(Box::new(Self { base, history })))
32 }
33}
34
35#[async_trait]
36impl TextGenerator for Qwen2 {
37 fn add_message(&mut self, message: Message) -> Result<()> {
39 self.history.push(message);
40 Ok(())
41 }
42
43 fn reset(&mut self) -> Result<()> {
45 self.history.clear();
46 self.base.reset();
47 Ok(())
48 }
49
50 async fn goodbye(&mut self) -> Result<()> {
51 self.base.goodbye().await
52 }
53
54 async fn next_token(&mut self, index: usize) -> Result<Token> {
56 if self.base.generated == 0 {
58 let dialog = self.history.encode_dialog_to_prompt();
59 self.base.prepare_prompt(&dialog)?;
60 }
61 self.base.next_token(index).await
62 }
63
64 fn generated_tokens(&self) -> usize {
66 self.base.generated
67 }
68}