1use crate::cli::{ModelsCommands, ModelsPathCommands, ModelsTagsCommands};
4use crate::{chat, config, debug_log};
5use anyhow::Result;
6use colored::Colorize;
7
8pub async fn handle(
10 command: Option<ModelsCommands>,
11 query: Option<String>,
12 tags: Option<String>,
13 context_length: Option<u64>,
14 input_length: Option<u64>,
15 output_length: Option<u64>,
16 input_price: Option<f64>,
17 output_price: Option<f64>,
18) -> Result<()> {
19 let context_length_str = context_length.map(|v| v.to_string());
21 let input_length_str = input_length.map(|v| v.to_string());
22 let output_length_str = output_length.map(|v| v.to_string());
23
24 handle_models_command(
25 command,
26 query,
27 tags,
28 context_length_str,
29 input_length_str,
30 output_length_str,
31 input_price,
32 output_price,
33 )
34 .await
35}
36
37async fn handle_models_command(
39 command: Option<ModelsCommands>,
40 query: Option<String>,
41 tags: Option<String>,
42 context_length: Option<String>,
43 input_length: Option<String>,
44 output_length: Option<String>,
45 input_price: Option<f64>,
46 output_price: Option<f64>,
47) -> Result<()> {
48 match command {
49 Some(ModelsCommands::Refresh) => {
50 crate::unified_cache::UnifiedCache::refresh_all_providers().await?;
51 }
52 Some(ModelsCommands::Info) => {
53 debug_log!("Handling models info command");
54
55 let models_dir = crate::unified_cache::UnifiedCache::models_dir()?;
56 debug_log!("Models cache directory: {}", models_dir.display());
57
58 println!("\n{}", "Models Cache Information:".bold().blue());
59 println!("Cache Directory: {}", models_dir.display());
60
61 if !models_dir.exists() {
62 debug_log!("Cache directory does not exist");
63 println!("Status: No cache directory found");
64 return Ok(());
65 }
66
67 let entries = std::fs::read_dir(&models_dir)?;
68 let mut provider_count = 0;
69 let mut total_models = 0;
70
71 debug_log!("Reading cache directory entries");
72
73 let mut provider_info = Vec::new();
75 for entry in entries {
76 let entry = entry?;
77 let path = entry.path();
78
79 if let Some(extension) = path.extension() {
80 if extension == "json" {
81 if let Some(provider_name) = path.file_stem().and_then(|s| s.to_str()) {
82 debug_log!("Processing cache file for provider: {}", provider_name);
83 provider_count += 1;
84 match crate::unified_cache::UnifiedCache::load_provider_models(
85 provider_name,
86 )
87 .await
88 {
89 Ok(models) => {
90 let count = models.len();
91 total_models += count;
92 debug_log!(
93 "Provider '{}' has {} cached models",
94 provider_name,
95 count
96 );
97
98 let age_display =
99 crate::unified_cache::UnifiedCache::get_cache_age_display(
100 provider_name,
101 )
102 .await
103 .unwrap_or_else(|_| "Unknown".to_string());
104 let is_fresh =
105 crate::unified_cache::UnifiedCache::is_cache_fresh(
106 provider_name,
107 )
108 .await
109 .unwrap_or(false);
110 debug_log!(
111 "Provider '{}' cache age: {}, fresh: {}",
112 provider_name,
113 age_display,
114 is_fresh
115 );
116
117 let status = if is_fresh {
118 age_display.green()
119 } else {
120 format!("{} (expired)", age_display).red()
121 };
122 provider_info.push((provider_name.to_string(), count, status));
123 }
124 Err(e) => {
125 debug_log!(
126 "Error loading cache for provider '{}': {}",
127 provider_name,
128 e
129 );
130 provider_info.push((
131 provider_name.to_string(),
132 0,
133 "Error loading cache".red(),
134 ));
135 }
136 }
137 }
138 }
139 }
140 }
141
142 debug_log!("Sorting {} providers alphabetically", provider_info.len());
143
144 provider_info.sort_by(|a, b| a.0.cmp(&b.0));
146
147 println!("\nCached Providers:");
148 for (provider_name, count, status) in provider_info {
149 if count > 0 {
150 println!(
151 " {} {} - {} models ({})",
152 "•".blue(),
153 provider_name.bold(),
154 count,
155 status
156 );
157 } else {
158 println!(" {} {} - {}", "•".blue(), provider_name.bold(), status);
159 }
160 }
161
162 debug_log!(
163 "Cache summary: {} providers, {} total models",
164 provider_count,
165 total_models
166 );
167
168 println!("\nSummary:");
169 println!(" Providers: {}", provider_count);
170 println!(" Total Models: {}", total_models);
171 }
172 Some(ModelsCommands::Dump) => {
173 dump_models_data().await?;
174 }
175 Some(ModelsCommands::Embed) => {
176 debug_log!("Handling embedding models command");
177
178 debug_log!("Loading all cached models from unified cache");
180 let enhanced_models =
181 crate::unified_cache::UnifiedCache::load_all_cached_models().await?;
182
183 debug_log!("Loaded {} models from cache", enhanced_models.len());
184
185 if enhanced_models.is_empty() {
187 debug_log!("No cached models found, refreshing all providers");
188 println!("No cached models found. Refreshing all providers...");
189 crate::unified_cache::UnifiedCache::refresh_all_providers().await?;
190 let enhanced_models =
191 crate::unified_cache::UnifiedCache::load_all_cached_models().await?;
192
193 debug_log!("After refresh, loaded {} models", enhanced_models.len());
194
195 if enhanced_models.is_empty() {
196 debug_log!("Still no models found after refresh");
197 println!("No models found after refresh.");
198 return Ok(());
199 }
200 }
201
202 debug_log!("Filtering for embedding models");
203
204 let embedding_models: Vec<_> = enhanced_models
206 .into_iter()
207 .filter(|model| {
208 matches!(
209 model.model_type,
210 crate::model_metadata::ModelType::Embedding
211 )
212 })
213 .collect();
214
215 debug_log!("Found {} embedding models", embedding_models.len());
216
217 if embedding_models.is_empty() {
218 println!("No embedding models found.");
219 return Ok(());
220 }
221
222 debug_log!("Displaying {} embedding models", embedding_models.len());
224 display_embedding_models(&embedding_models)?;
225 }
226 Some(ModelsCommands::Path { command }) => match command {
227 ModelsPathCommands::List => {
228 crate::model_metadata::list_model_paths()?;
229 }
230 ModelsPathCommands::Add { path } => {
231 crate::model_metadata::add_model_path(path)?;
232 }
233 ModelsPathCommands::Delete { path } => {
234 crate::model_metadata::remove_model_path(path)?;
235 }
236 },
237 Some(ModelsCommands::Tags { command }) => {
238 match command {
239 ModelsTagsCommands::List => {
240 crate::model_metadata::list_tags()?;
241 }
242 ModelsTagsCommands::Add { tag, rule } => {
243 crate::model_metadata::add_tag(tag, vec![rule], "string".to_string(), None)?;
245 }
246 }
247 }
248 Some(ModelsCommands::Filter { tags: filter_tags }) => {
249 let models = crate::unified_cache::UnifiedCache::load_all_cached_models().await?;
251
252 let required_tags: Vec<&str> = filter_tags.split(',').map(|s| s.trim()).collect();
254
255 let filtered: Vec<_> = models
257 .into_iter()
258 .filter(|model| {
259 for tag in &required_tags {
260 match *tag {
261 "tools" => {
262 if !model.supports_tools && !model.supports_function_calling {
263 return false;
264 }
265 }
266 "vision" => {
267 if !model.supports_vision {
268 return false;
269 }
270 }
271 "audio" => {
272 if !model.supports_audio {
273 return false;
274 }
275 }
276 "reasoning" => {
277 if !model.supports_reasoning {
278 return false;
279 }
280 }
281 "code" => {
282 if !model.supports_code {
283 return false;
284 }
285 }
286 _ => {
287 if tag.starts_with("ctx") {
289 if let Some(ctx) = model.context_length {
290 if tag.contains('>') {
291 if let Some(min_str) = tag.split('>').nth(1) {
292 if let Ok(min_ctx) = parse_token_count(min_str) {
293 if ctx < min_ctx {
294 return false;
295 }
296 }
297 }
298 }
299 }
300 }
301 }
302 }
303 }
304 true
305 })
306 .collect();
307
308 if filtered.is_empty() {
309 println!("No models found with tags: {}", filter_tags);
310 } else {
311 println!(
312 "\n{} Models with tags [{}] ({} found):",
313 "Filtered Results:".bold().blue(),
314 filter_tags,
315 filtered.len()
316 );
317
318 let mut current_provider = String::new();
319 for model in filtered {
320 if model.provider != current_provider {
321 current_provider = model.provider.clone();
322 println!("\n{}", format!("{}:", current_provider).bold().green());
323 }
324
325 print!(" {} {}", "•".blue(), model.id.bold());
326
327 let mut capabilities = Vec::new();
329 if model.supports_tools || model.supports_function_calling {
330 capabilities.push("🔧 tools".blue());
331 }
332 if model.supports_vision {
333 capabilities.push("👁 vision".magenta());
334 }
335 if model.supports_audio {
336 capabilities.push("🔊 audio".yellow());
337 }
338 if model.supports_reasoning {
339 capabilities.push("🧠 reasoning".cyan());
340 }
341 if model.supports_code {
342 capabilities.push("💻 code".green());
343 }
344
345 if !capabilities.is_empty() {
346 let capability_strings: Vec<String> =
347 capabilities.iter().map(|c| c.to_string()).collect();
348 print!(" [{}]", capability_strings.join(" "));
349 }
350
351 if let Some(ctx) = model.context_length {
353 if ctx >= 1000 {
354 print!(" ({}k ctx)", ctx / 1000);
355 } else {
356 print!(" ({} ctx)", ctx);
357 }
358 }
359
360 println!();
361 }
362 }
363 }
364 None => {
365 debug_log!("Handling global models command");
366
367 debug_log!("Loading all cached models from unified cache");
369 let enhanced_models =
370 crate::unified_cache::UnifiedCache::load_all_cached_models().await?;
371
372 debug_log!("Loaded {} models from cache", enhanced_models.len());
373
374 if enhanced_models.is_empty() {
376 debug_log!("No cached models found, refreshing all providers");
377 println!("No cached models found. Refreshing all providers...");
378 crate::unified_cache::UnifiedCache::refresh_all_providers().await?;
379 let enhanced_models =
380 crate::unified_cache::UnifiedCache::load_all_cached_models().await?;
381
382 debug_log!("After refresh, loaded {} models", enhanced_models.len());
383
384 if enhanced_models.is_empty() {
385 debug_log!("Still no models found after refresh");
386 println!("No models found after refresh.");
387 return Ok(());
388 }
389 }
390
391 debug_log!("Applying filters to {} models", enhanced_models.len());
392
393 let tag_filters = if let Some(ref tag_str) = tags {
395 let tags_vec: Vec<String> =
396 tag_str.split(',').map(|s| s.trim().to_string()).collect();
397 Some(tags_vec)
398 } else {
399 None
400 };
401
402 let filtered_models = apply_model_filters_with_tags(
404 enhanced_models,
405 &query,
406 tag_filters,
407 &context_length,
408 &input_length,
409 &output_length,
410 input_price,
411 output_price,
412 )?;
413
414 debug_log!("After filtering, {} models remain", filtered_models.len());
415
416 if filtered_models.is_empty() {
417 debug_log!("No models match the specified criteria");
418 println!("No models found matching the specified criteria.");
419 return Ok(());
420 }
421
422 debug_log!("Displaying {} filtered models", filtered_models.len());
424 display_enhanced_models(&filtered_models, &query)?;
425 }
426 }
427
428 Ok(())
429}
430
431async fn dump_models_data() -> Result<()> {
433 println!("{} Dumping /models for each provider...", "🔍".blue());
434
435 let config = config::Config::load()?;
437
438 std::fs::create_dir_all("models")?;
440
441 let mut successful_dumps = 0;
442 let mut total_providers = 0;
443
444 for (provider_name, provider_config) in &config.providers {
445 total_providers += 1;
446
447 if provider_config.api_key.is_none() {
449 println!("{} Skipping {} (no API key)", "⚠️".yellow(), provider_name);
450 continue;
451 }
452
453 println!("{} Fetching models from {}...", "📡".blue(), provider_name);
454
455 let mut config_mut = config.clone();
457 match chat::create_authenticated_client(&mut config_mut, provider_name).await {
458 Ok(client) => {
459 match fetch_raw_models_response(&client, provider_config).await {
461 Ok(raw_response) => {
462 let filename = format!("models/{}.json", provider_name);
464 match std::fs::write(&filename, &raw_response) {
465 Ok(_) => {
466 println!(
467 "{} Saved {} models data to {}",
468 "✅".green(),
469 provider_name,
470 filename
471 );
472 successful_dumps += 1;
473 }
474 Err(e) => {
475 println!(
476 "{} Failed to save {} models data: {}",
477 "❌".red(),
478 provider_name,
479 e
480 );
481 }
482 }
483 }
484 Err(e) => {
485 println!(
486 "{} Failed to fetch models from {}: {}",
487 "❌".red(),
488 provider_name,
489 e
490 );
491 }
492 }
493 }
494 Err(e) => {
495 println!(
496 "{} Failed to create client for {}: {}",
497 "❌".red(),
498 provider_name,
499 e
500 );
501 }
502 }
503 }
504
505 println!("\n{} Summary:", "📊".blue());
506 println!(" Total providers: {}", total_providers);
507 println!(" Successful dumps: {}", successful_dumps);
508 println!(" Models data saved to: ./models/");
509
510 if successful_dumps > 0 {
511 println!("\n{} Model data collection complete!", "🎉".green());
512 println!(" Next step: Analyze the JSON files to extract metadata patterns");
513 }
514
515 Ok(())
516}
517
518fn apply_model_filters_with_tags(
519 models: Vec<crate::model_metadata::ModelMetadata>,
520 query: &Option<String>,
521 tag_filters: Option<Vec<String>>,
522 context_length: &Option<String>,
523 input_length: &Option<String>,
524 output_length: &Option<String>,
525 input_price: Option<f64>,
526 output_price: Option<f64>,
527) -> Result<Vec<crate::model_metadata::ModelMetadata>> {
528 let mut filtered = models;
529
530 if let Some(ref search_query) = query {
532 let query_lower = search_query.to_lowercase();
533 filtered.retain(|model| {
534 model.id.to_lowercase().contains(&query_lower)
535 || model
536 .display_name
537 .as_ref()
538 .map_or(false, |name| name.to_lowercase().contains(&query_lower))
539 || model
540 .description
541 .as_ref()
542 .map_or(false, |desc| desc.to_lowercase().contains(&query_lower))
543 });
544 }
545
546 if let Some(tags) = tag_filters {
548 for tag in tags {
549 match tag.as_str() {
550 "tools" => {
551 filtered
552 .retain(|model| model.supports_tools || model.supports_function_calling);
553 }
554 "reasoning" => {
555 filtered.retain(|model| model.supports_reasoning);
556 }
557 "vision" => {
558 filtered.retain(|model| model.supports_vision);
559 }
560 "audio" => {
561 filtered.retain(|model| model.supports_audio);
562 }
563 "code" => {
564 filtered.retain(|model| model.supports_code);
565 }
566 _ => {
567 }
569 }
570 }
571 }
572
573 if let Some(ref ctx_str) = context_length {
575 let min_ctx = parse_token_count(ctx_str)?;
576 filtered.retain(|model| model.context_length.map_or(false, |ctx| ctx >= min_ctx));
577 }
578
579 if let Some(ref input_str) = input_length {
581 let min_input = parse_token_count(input_str)?;
582 filtered.retain(|model| {
583 model
584 .max_input_tokens
585 .map_or(false, |input| input >= min_input)
586 || model.context_length.map_or(false, |ctx| ctx >= min_input)
587 });
588 }
589
590 if let Some(ref output_str) = output_length {
592 let min_output = parse_token_count(output_str)?;
593 filtered.retain(|model| {
594 model
595 .max_output_tokens
596 .map_or(false, |output| output >= min_output)
597 });
598 }
599
600 if let Some(max_input_price) = input_price {
602 filtered.retain(|model| {
603 model
604 .input_price_per_m
605 .map_or(true, |price| price <= max_input_price)
606 });
607 }
608
609 if let Some(max_output_price) = output_price {
610 filtered.retain(|model| {
611 model
612 .output_price_per_m
613 .map_or(true, |price| price <= max_output_price)
614 });
615 }
616
617 filtered.sort_by(|a, b| a.provider.cmp(&b.provider).then(a.id.cmp(&b.id)));
619
620 Ok(filtered)
621}
622
623fn parse_token_count(input: &str) -> Result<u32> {
624 let input = input.to_lowercase();
625 if let Some(num_str) = input.strip_suffix('k') {
626 let num: f32 = num_str
627 .parse()
628 .map_err(|_| anyhow::anyhow!("Invalid token count format: '{}'", input))?;
629 Ok((num * 1000.0) as u32)
630 } else if let Some(num_str) = input.strip_suffix('m') {
631 let num: f32 = num_str
632 .parse()
633 .map_err(|_| anyhow::anyhow!("Invalid token count format: '{}'", input))?;
634 Ok((num * 1000000.0) as u32)
635 } else {
636 input
637 .parse()
638 .map_err(|_| anyhow::anyhow!("Invalid token count format: '{}'", input))
639 }
640}
641
642fn display_enhanced_models(
643 models: &[crate::model_metadata::ModelMetadata],
644 query: &Option<String>,
645) -> Result<()> {
646 if let Some(ref search_query) = query {
647 println!(
648 "\n{} Models matching '{}' ({} found):",
649 "Search Results:".bold().blue(),
650 search_query,
651 models.len()
652 );
653 } else {
654 println!(
655 "\n{} Available models ({} total):",
656 "Models:".bold().blue(),
657 models.len()
658 );
659 }
660
661 let mut current_provider = String::new();
662 for model in models {
663 if model.provider != current_provider {
664 current_provider = model.provider.clone();
665 println!("\n{}", format!("{}:", current_provider).bold().green());
666 }
667
668 let mut capabilities = Vec::new();
670 if model.supports_tools || model.supports_function_calling {
671 capabilities.push("🔧 tools".blue());
672 }
673 if model.supports_vision {
674 capabilities.push("👁 vision".magenta());
675 }
676 if model.supports_audio {
677 capabilities.push("🔊 audio".yellow());
678 }
679 if model.supports_reasoning {
680 capabilities.push("🧠 reasoning".cyan());
681 }
682 if model.supports_code {
683 capabilities.push("💻 code".green());
684 }
685
686 let mut context_info = Vec::new();
688 if let Some(ctx) = model.context_length {
689 context_info.push(format!("{}k ctx", ctx / 1000));
690 }
691 if let Some(max_out) = model.max_output_tokens {
692 context_info.push(format!("{}k out", max_out / 1000));
693 }
694
695 let model_display = if let Some(ref display_name) = model.display_name {
697 format!("{} ({})", model.id, display_name)
698 } else {
699 model.id.clone()
700 };
701
702 print!(" {} {}", "•".blue(), model_display.bold());
703
704 if !capabilities.is_empty() {
705 let capability_strings: Vec<String> =
706 capabilities.iter().map(|c| c.to_string()).collect();
707 print!(" [{}]", capability_strings.join(" "));
708 }
709
710 if !context_info.is_empty() {
711 print!(" ({})", context_info.join(", ").dimmed());
712 }
713
714 println!();
715 }
716
717 Ok(())
718}
719
720async fn fetch_raw_models_response(
721 _client: &crate::chat::LLMClient,
722 provider_config: &crate::config::ProviderConfig,
723) -> Result<String> {
724 use serde_json::Value;
725
726 let http_client = reqwest::Client::builder()
728 .pool_max_idle_per_host(10)
729 .pool_idle_timeout(std::time::Duration::from_secs(90))
730 .tcp_keepalive(std::time::Duration::from_secs(60))
731 .timeout(std::time::Duration::from_secs(60))
732 .connect_timeout(std::time::Duration::from_secs(10))
733 .build()?;
734
735 let url = provider_config.get_models_url();
736
737 debug_log!("Making API request to: {}", url);
738 debug_log!("Request timeout: 60 seconds");
739
740 let mut req = http_client
741 .get(&url)
742 .header("Content-Type", "application/json");
743
744 debug_log!("Added Content-Type: application/json header");
745
746 let mut has_custom_headers = false;
748 for (name, value) in &provider_config.headers {
749 debug_log!("Adding custom header: {}: {}", name, value);
750 req = req.header(name, value);
751 has_custom_headers = true;
752 }
753
754 if !has_custom_headers {
756 if let Some(api_key) = provider_config.api_key.as_ref() {
757 req = req.header("Authorization", format!("Bearer {}", api_key));
758 debug_log!("Added Authorization header with API key");
759 } else {
760 debug_log!("No API key configured and no custom headers provided; cannot add Authorization header");
761 anyhow::bail!("No API key configured and no custom headers set for models request");
763 }
764 } else {
765 debug_log!("Skipping Authorization header due to custom headers present");
766 }
767
768 debug_log!("Sending HTTP GET request...");
769 let response = req.send().await?;
770
771 let status = response.status();
772 debug_log!("Received response with status: {}", status);
773
774 if !status.is_success() {
775 let text = response.text().await.unwrap_or_default();
776 debug_log!("API request failed with error response: {}", text);
777 anyhow::bail!("API request failed with status {}: {}", status, text);
778 }
779
780 let response_text = response.text().await?;
781 debug_log!("Received response body ({} bytes)", response_text.len());
782
783 match serde_json::from_str::<Value>(&response_text) {
785 Ok(json_value) => {
786 debug_log!("Response is valid JSON, pretty-printing");
787 Ok(serde_json::to_string_pretty(&json_value)?)
788 }
789 Err(_) => {
790 debug_log!("Response is not valid JSON, returning as-is");
791 Ok(response_text)
793 }
794 }
795}
796
797fn display_embedding_models(models: &[crate::model_metadata::ModelMetadata]) -> Result<()> {
799 println!(
800 "\n{} Available embedding models ({} total):",
801 "Embedding Models:".bold().blue(),
802 models.len()
803 );
804
805 let mut current_provider = String::new();
806 for model in models {
807 if model.provider != current_provider {
808 current_provider = model.provider.clone();
809 println!("\n{}", format!("{}:", current_provider).bold().green());
810 }
811
812 let mut info_parts = Vec::new();
814 if let Some(ctx) = model.context_length {
815 if ctx >= 1000000 {
816 info_parts.push(format!("{}m ctx", ctx / 1000000));
817 } else if ctx >= 1000 {
818 info_parts.push(format!("{}k ctx", ctx / 1000));
819 } else {
820 info_parts.push(format!("{} ctx", ctx));
821 }
822 }
823 if let Some(input_price) = model.input_price_per_m {
824 info_parts.push(format!("${:.2}/M", input_price));
825 }
826
827 let model_display = if let Some(ref display_name) = model.display_name {
829 if display_name != &model.id {
830 format!("{} ({})", model.id, display_name)
831 } else {
832 model.id.clone()
833 }
834 } else {
835 model.id.clone()
836 };
837
838 print!(" {} {}", "•".blue(), model_display.bold());
839
840 if !info_parts.is_empty() {
841 print!(" ({})", info_parts.join(", ").dimmed());
842 }
843
844 println!();
845 }
846
847 Ok(())
848}