Skip to main content

anymoe_lora/
main.rs

1use anyhow::Result;
2use mistralrs::{
3    AnyMoeConfig, AnyMoeExpertType, AnyMoeModelBuilder, IsqType, PagedAttentionMetaBuilder,
4    TextMessageRole, TextMessages, TextModelBuilder,
5};
6
7#[tokio::main]
8async fn main() -> Result<()> {
9    let text_builder = TextModelBuilder::new("mistralai/Mistral-7B-Instruct-v0.1")
10        .with_isq(IsqType::Q8_0)
11        .with_logging()
12        .with_paged_attn(|| PagedAttentionMetaBuilder::default().build())?;
13
14    let model = AnyMoeModelBuilder::from_text_builder(
15        text_builder,
16        AnyMoeConfig {
17            hidden_size: 4096,
18            lr: 1e-3,
19            epochs: 100,
20            batch_size: 4,
21            expert_type: AnyMoeExpertType::FineTuned,
22            gate_model_id: None, // Set this to Some("path/to/model/id") for the pretrained gating model id
23            training: true,
24            loss_csv_path: None,
25        },
26        "model.layers",
27        "mlp",
28        "examples/amoe.json",
29        vec!["typeof/zephyr-7b-beta-lora"],
30        vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
31    )
32    .build()
33    .await?;
34
35    let messages = TextMessages::new()
36        .add_message(
37            TextMessageRole::System,
38            "You are an AI agent with a specialty in programming.",
39        )
40        .add_message(
41            TextMessageRole::User,
42            "Hello! How are you? Please write generic binary search function in Rust.",
43        );
44
45    let response = model.send_chat_request(messages).await?;
46
47    println!("{}", response.choices[0].message.content.as_ref().unwrap());
48    dbg!(
49        response.usage.avg_prompt_tok_per_sec,
50        response.usage.avg_compl_tok_per_sec
51    );
52
53    Ok(())
54}