use flyllm::{
ProviderType, LlmManager, GenerationRequest, LlmManagerResponse, TaskDefinition, LlmResult,
use_logging, ModelDiscovery, ModelInfo
};
use std::env;
use std::path::PathBuf;
use std::time::Instant;
use std::collections::HashMap;
use futures::future::join_all;
use log::info;
#[tokio::main]
async fn main() -> LlmResult<()> {
env::set_var("RUST_LOG", "debug"); use_logging();
info!("Starting Task Routing Example");
let anthropic_api_key = env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set");
let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
let mistral_api_key = env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
let google_api_key = env::var("GOOGLE_API_KEY").expect("GOOGLE_API_KEY not set");
print_available_models(&anthropic_api_key, &openai_api_key, &mistral_api_key, &google_api_key).await;
let manager = LlmManager::builder()
.define_task(
TaskDefinition::new("summary")
.with_max_tokens(500) .with_param("temperature", 0.3) )
.define_task(
TaskDefinition::new("creative_writing")
.with_max_tokens(1500)
.with_temperature(0.9)
)
.define_task(
TaskDefinition::new("code_generation")
)
.define_task(
TaskDefinition::new("short_poem")
.with_max_tokens(100)
.with_temperature(0.8)
)
.add_instance(ProviderType::Mistral, "mistral-large-latest", &mistral_api_key)
.supports("summary")
.supports("code_generation")
.add_instance(ProviderType::Anthropic, "claude-3-sonnet-20240229", &anthropic_api_key)
.supports("summary")
.supports("creative_writing")
.supports("code_generation")
.add_instance(ProviderType::Anthropic, "claude-3-opus-20240229", &anthropic_api_key)
.supports_many(&["creative_writing", "short_poem"])
.add_instance(ProviderType::Google, "gemini-2.0-flash", &google_api_key)
.supports("short_poem")
.add_instance(ProviderType::OpenAI, "gpt-3.5-turbo", &openai_api_key)
.supports("summary")
.debug_folder(PathBuf::from("debug_folder"))
.build().await?;
let provider_count = manager.get_provider_count().await;
info!("LlmManager configured with {} providers.", provider_count);
let requests = vec![
GenerationRequest::builder(
"Summarize the following text: Climate change refers to long-term shifts...",
)
.task("summary")
.build(),
GenerationRequest::builder("Write a short story about a robot discovering emotions.")
.task("creative_writing")
.build(),
GenerationRequest::builder(
"Write a Python function that calculates the Fibonacci sequence up to n terms.",
)
.task("code_generation")
.build(),
GenerationRequest::builder("Write a VERY short poem about the rain.")
.task("creative_writing") .max_tokens(50) .build(),
GenerationRequest::builder("Write a rust program to sum two input numbers via console.")
.task("code_generation")
.build(),
GenerationRequest::builder("Craft a haiku about a silent dawn.")
.task("short_poem")
.build(),
];
info!("Defined {} requests using builder pattern.", requests.len());
println!("\n=== Running requests sequentially... ===");
let sequential_start = Instant::now();
let sequential_results = manager.generate_sequentially(requests.clone()).await;
let sequential_duration = sequential_start.elapsed();
println!("Sequential processing completed in {:?}", sequential_duration);
print_results(&sequential_results);
println!("\n=== Running requests in parallel... ===");
let parallel_start = Instant::now();
let parallel_results = manager.batch_generate(requests).await; let parallel_duration = parallel_start.elapsed();
println!("Parallel processing completed in {:?}", parallel_duration);
print_results(¶llel_results);
info!("Task Routing Example Finished.");
println!("\n--- Comparison ---");
println!("Sequential Duration: {:?}", sequential_duration);
println!("Parallel Duration: {:?}", parallel_duration);
if parallel_duration < sequential_duration && parallel_duration.as_nanos() > 0 {
let speedup = sequential_duration.as_secs_f64() / parallel_duration.as_secs_f64();
println!("Parallel execution was roughly {:.2}x faster.", speedup);
} else if parallel_duration >= sequential_duration {
println!("Parallel execution was not faster (or was equal) in this run.");
} else {
println!("Parallel execution finished too quickly to measure speedup reliably.");
}
manager.print_token_usage().await;
Ok(())
}
async fn print_available_models(
anthropic_api_key: &str,
openai_api_key: &str,
mistral_api_key: &str,
google_api_key: &str
) {
println!("\n=== AVAILABLE MODELS ===");
let anthropic_key = anthropic_api_key.to_string();
let openai_key = openai_api_key.to_string();
let mistral_key = mistral_api_key.to_string();
let google_key = google_api_key.to_string();
let futures = vec![
tokio::spawn(async move { ModelDiscovery::list_anthropic_models(&anthropic_key).await }),
tokio::spawn(async move { ModelDiscovery::list_openai_models(&openai_key).await }),
tokio::spawn(async move { ModelDiscovery::list_mistral_models(&mistral_key).await }),
tokio::spawn(async move { ModelDiscovery::list_google_models(&google_key).await }),
tokio::spawn(async { ModelDiscovery::list_ollama_models(None).await }),
];
let results = join_all(futures).await;
let mut models_by_provider: HashMap<ProviderType, Vec<ModelInfo>> = HashMap::new();
let providers = [
ProviderType::Anthropic,
ProviderType::OpenAI,
ProviderType::Mistral,
ProviderType::Google,
ProviderType::Ollama
];
for (i, result) in results.into_iter().enumerate() {
if i >= providers.len() { continue; }
let provider = providers[i];
match result {
Ok(Ok(models)) => {
models_by_provider.insert(provider, models);
},
Ok(Err(e)) => {
println!("Error fetching {} models: {}", provider, e);
},
Err(e) => {
println!("Task error fetching {} models: {}", provider, e);
}
}
}
println!("\n{:<15} {:<40}", "PROVIDER", "MODEL NAME");
println!("{}", "=".repeat(55));
for provider in providers.iter() {
if let Some(models) = models_by_provider.get(provider) {
for model in models {
println!("{:<15} {:<40}", provider.to_string(), model.name);
}
println!("{}", "-".repeat(55));
}
}
}
fn print_results(results: &[LlmManagerResponse]) {
println!("\n--- Request Results ---");
let task_labels = [
"Summary Request",
"Creative Writing Request",
"Code Generation Request",
"Short Poem Request (Override)",
"Rust Code Request",
"Haiku Request"
];
for (i, result) in results.iter().enumerate() {
let task_label = task_labels.get(i).map_or_else(|| "Unknown Task", |&name| name);
println!("{}:", task_label);
if result.success {
let content_preview = result.content.chars().take(150).collect::<String>();
let ellipsis = if result.content.chars().count() > 150 { "..." } else { "" };
println!(" Success: {}{}\n", content_preview, ellipsis);
} else {
println!(" Error: {}\n", result.error.as_ref().unwrap_or(&"Unknown error".to_string()));
}
}
}