Skip to main content

multi_model/
main.rs

1use anyhow::{anyhow, Result};
2use mistralrs::{
3    IsqType, MultiModelBuilder, TextMessageRole, TextMessages, TextModelBuilder, VisionModelBuilder,
4};
5
6// Model IDs - these are the actual HuggingFace model paths
7const GEMMA_MODEL_ID: &str = "google/gemma-3-4b-it";
8const QWEN_MODEL_ID: &str = "Qwen/Qwen3-4B";
9// Aliases - these are the short IDs used in API requests
10const GEMMA_ALIAS: &str = "gemma-vision";
11const QWEN_ALIAS: &str = "qwen-text";
12
13#[tokio::main]
14async fn main() -> Result<()> {
15    println!("Loading multiple models...");
16
17    let model = MultiModelBuilder::new()
18        .add_model_with_alias(
19            GEMMA_ALIAS,
20            VisionModelBuilder::new(GEMMA_MODEL_ID)
21                .with_isq(IsqType::Q4K)
22                .with_logging(),
23        )
24        .add_model_with_alias(
25            QWEN_ALIAS,
26            TextModelBuilder::new(QWEN_MODEL_ID).with_isq(IsqType::Q4K),
27        )
28        .with_default_model(GEMMA_ALIAS)
29        .build()
30        .await?;
31
32    // List available models
33    println!("\n=== Available Models ===");
34    let models = model.list_models().map_err(|e| anyhow!(e))?;
35    for model_id in &models {
36        println!("  - {}", model_id);
37    }
38
39    // Get the default model
40    let default_model = model.get_default_model_id().map_err(|e| anyhow!(e))?;
41    println!("\nDefault model: {:?}", default_model);
42
43    // List models with their status
44    println!("\n=== Model Status ===");
45    let status = model.list_models_with_status()?;
46    for (model_id, status) in &status {
47        println!("  {} -> {:?}", model_id, status);
48    }
49
50    // Send a request to the default model (Gemma - vision model)
51    println!("\n=== Request to Default Model ({}) ===", GEMMA_ALIAS);
52    let messages =
53        TextMessages::new().add_message(TextMessageRole::User, "What is 2 + 2? Answer briefly.");
54
55    let response = model.send_chat_request(messages).await?;
56    println!(
57        "Response: {}",
58        response.choices[0].message.content.as_ref().unwrap()
59    );
60
61    // Send a request to a specific model (Qwen - text model)
62    println!("\n=== Request to Specific Model ({}) ===", QWEN_ALIAS);
63    let messages = TextMessages::new().add_message(TextMessageRole::User, "Say hello in one word.");
64
65    let response = model
66        .send_chat_request_with_model(messages, Some(QWEN_ALIAS))
67        .await?;
68    println!(
69        "Response: {}",
70        response.choices[0].message.content.as_ref().unwrap()
71    );
72
73    // Change the default model
74    println!("\n=== Changing Default Model ===");
75    model
76        .set_default_model_id(QWEN_ALIAS)
77        .map_err(|e| anyhow!(e))?;
78    let new_default = model.get_default_model_id().map_err(|e| anyhow!(e))?;
79    println!("New default model: {:?}", new_default);
80
81    // Now requests without model_id go to Qwen
82    let messages =
83        TextMessages::new().add_message(TextMessageRole::User, "What is your name? Be brief.");
84
85    let response = model.send_chat_request(messages).await?;
86    println!(
87        "Response from new default: {}",
88        response.choices[0].message.content.as_ref().unwrap()
89    );
90
91    // Model unloading/reloading demonstration
92    println!("\n=== Model Unloading/Reloading ===");
93
94    // Check if Gemma is loaded
95    let is_gemma_loaded = model.is_model_loaded(GEMMA_ALIAS)?;
96    println!("Is '{}' loaded? {}", GEMMA_ALIAS, is_gemma_loaded);
97
98    // Unload Gemma to free memory
99    println!("Unloading '{}' model...", GEMMA_ALIAS);
100    model.unload_model(GEMMA_ALIAS)?;
101
102    // Check status after unload
103    let status = model.list_models_with_status()?;
104    println!("Status after unload:");
105    for (model_id, status) in &status {
106        println!("  {} -> {:?}", model_id, status);
107    }
108
109    // Reload Gemma when needed
110    println!("Reloading '{}' model...", GEMMA_ALIAS);
111    model.reload_model(GEMMA_ALIAS).await?;
112
113    // Check status after reload
114    let is_gemma_loaded = model.is_model_loaded(GEMMA_ALIAS)?;
115    println!(
116        "Is '{}' loaded after reload? {}",
117        GEMMA_ALIAS, is_gemma_loaded
118    );
119
120    // Use the reloaded model
121    let messages =
122        TextMessages::new().add_message(TextMessageRole::User, "Hi! Respond with just 'Hello'.");
123
124    let response = model
125        .send_chat_request_with_model(messages, Some(GEMMA_ALIAS))
126        .await?;
127    println!(
128        "Response from reloaded {}: {}",
129        GEMMA_ALIAS,
130        response.choices[0].message.content.as_ref().unwrap()
131    );
132
133    println!("\n=== Multi-Model Example Complete ===");
134
135    Ok(())
136}