1use anyhow::Result;
2use mistralrs::{
3 AnyMoeConfig, AnyMoeExpertType, AnyMoeModelBuilder, IsqType, PagedAttentionMetaBuilder,
4 TextMessageRole, TextMessages, TextModelBuilder,
5};
6
7#[tokio::main]
8async fn main() -> Result<()> {
9 let text_builder = TextModelBuilder::new("mistralai/Mistral-7B-Instruct-v0.1")
10 .with_isq(IsqType::Q8_0)
11 .with_logging()
12 .with_paged_attn(|| PagedAttentionMetaBuilder::default().build())?;
13
14 let model = AnyMoeModelBuilder::from_text_builder(
15 text_builder,
16 AnyMoeConfig {
17 hidden_size: 4096,
18 lr: 1e-3,
19 epochs: 100,
20 batch_size: 4,
21 expert_type: AnyMoeExpertType::FineTuned,
22 gate_model_id: None, training: true,
24 loss_csv_path: None,
25 },
26 "model.layers",
27 "mlp",
28 "examples/amoe.json",
29 vec!["typeof/zephyr-7b-beta-lora"],
30 vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
31 )
32 .build()
33 .await?;
34
35 let messages = TextMessages::new()
36 .add_message(
37 TextMessageRole::System,
38 "You are an AI agent with a specialty in programming.",
39 )
40 .add_message(
41 TextMessageRole::User,
42 "Hello! How are you? Please write generic binary search function in Rust.",
43 );
44
45 let response = model.send_chat_request(messages).await?;
46
47 println!("{}", response.choices[0].message.content.as_ref().unwrap());
48 dbg!(
49 response.usage.avg_prompt_tok_per_sec,
50 response.usage.avg_compl_tok_per_sec
51 );
52
53 Ok(())
54}