1use crate::provider::Provider;
2use anyhow::{Context, Result};
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6use std::fs;
7use std::path::PathBuf;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct ModelMetadata {
11 pub id: String,
12 pub provider: String,
13 pub display_name: Option<String>,
14 pub description: Option<String>,
15 pub owned_by: Option<String>,
16 pub created: Option<i64>,
17
18 pub context_length: Option<u32>,
20 pub max_input_tokens: Option<u32>,
21 pub max_output_tokens: Option<u32>,
22
23 pub input_price_per_m: Option<f64>,
25 pub output_price_per_m: Option<f64>,
26
27 pub supports_tools: bool,
30 pub supports_vision: bool,
32 pub supports_audio: bool,
34 pub supports_reasoning: bool,
36 pub supports_code: bool,
38 pub supports_function_calling: bool,
40 pub supports_json_mode: bool,
42 pub supports_streaming: bool,
44
45 pub model_type: ModelType,
47 pub is_deprecated: bool,
49 pub is_fine_tunable: bool,
51
52 pub raw_data: serde_json::Value,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub enum ModelType {
58 Chat,
59 Completion,
60 Embedding,
61 ImageGeneration,
62 AudioGeneration,
63 Moderation,
64 Other(String),
65}
66
67impl Default for ModelMetadata {
68 fn default() -> Self {
69 Self {
70 id: String::new(),
71 provider: String::new(),
72 display_name: None,
73 description: None,
74 owned_by: None,
75 created: None,
76 context_length: None,
77 max_input_tokens: None,
78 max_output_tokens: None,
79 input_price_per_m: None,
80 output_price_per_m: None,
81 supports_tools: false,
82 supports_vision: false,
83 supports_audio: false,
84 supports_reasoning: false,
85 supports_code: false,
86 supports_function_calling: false,
87 supports_json_mode: false,
88 supports_streaming: false,
89 model_type: ModelType::Chat,
90 is_deprecated: false,
91 is_fine_tunable: false,
92 raw_data: serde_json::Value::Null,
93 }
94 }
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct ModelPaths {
100 pub paths: Vec<String>,
101 #[serde(default)]
102 pub field_mappings: FieldMappings,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct FieldMappings {
107 pub id_fields: Vec<String>,
109 pub name_fields: Vec<String>,
111}
112
113impl Default for FieldMappings {
114 fn default() -> Self {
115 Self {
116 id_fields: vec![
117 "id".to_string(),
118 "modelId".to_string(),
119 "name".to_string(),
120 "modelName".to_string(),
121 ],
122 name_fields: vec![
123 "display_name".to_string(),
124 "name".to_string(),
125 "modelName".to_string(),
126 ],
127 }
128 }
129}
130
131impl Default for ModelPaths {
132 fn default() -> Self {
133 Self {
134 paths: vec![
135 ".data[]".to_string(),
136 ".models[]".to_string(),
137 ".".to_string(),
138 ],
139 field_mappings: FieldMappings::default(),
140 }
141 }
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct TagConfig {
146 pub tags: HashMap<String, TagRule>,
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct TagRule {
151 pub paths: Vec<String>,
152 pub value_type: String,
153 pub transform: Option<String>,
154}
155
156impl Default for TagConfig {
157 fn default() -> Self {
158 let mut tags = HashMap::new();
159
160 tags.insert(
162 "context_length".to_string(),
163 TagRule {
164 paths: vec![
165 ".context_length".to_string(),
166 ".context_window".to_string(),
167 ".context_size".to_string(),
168 ".max_context_length".to_string(),
169 ".input_token_limit".to_string(),
170 ".inputTokenLimit".to_string(),
171 ".limits.max_input_tokens".to_string(),
172 ".top_provider.context_length".to_string(),
173 ],
174 value_type: "u32".to_string(),
175 transform: None,
176 },
177 );
178
179 tags.insert(
181 "output".to_string(),
182 TagRule {
183 paths: vec![
184 ".max_completion_tokens".to_string(),
185 ".outputTokenLimit".to_string(),
186 ".max_output_tokens".to_string(),
187 ".limits.max_output_tokens".to_string(),
188 ".top_provider.max_completion_tokens".to_string(),
189 ".max_tokens".to_string(),
190 ],
191 value_type: "u32".to_string(),
192 transform: None,
193 },
194 );
195
196 tags.insert(
198 "input_price_per_m".to_string(),
199 TagRule {
200 paths: vec![
201 ".pricing.prompt".to_string(),
202 ".pricing.input.usd".to_string(),
203 ".input_price".to_string(),
204 ],
205 value_type: "f64".to_string(),
206 transform: Some("multiply_million".to_string()),
207 },
208 );
209
210 tags.insert(
212 "input_price_per_m_direct".to_string(),
213 TagRule {
214 paths: vec![".input_token_price_per_m".to_string()],
215 value_type: "f64".to_string(),
216 transform: None,
217 },
218 );
219
220 tags.insert(
222 "output_price_per_m".to_string(),
223 TagRule {
224 paths: vec![
225 ".pricing.completion".to_string(),
226 ".pricing.output.usd".to_string(),
227 ".output_price".to_string(),
228 ],
229 value_type: "f64".to_string(),
230 transform: Some("multiply_million".to_string()),
231 },
232 );
233
234 tags.insert(
236 "output_price_per_m_direct".to_string(),
237 TagRule {
238 paths: vec![".output_token_price_per_m".to_string()],
239 value_type: "f64".to_string(),
240 transform: None,
241 },
242 );
243
244 tags.insert(
246 "supports_vision".to_string(),
247 TagRule {
248 paths: vec![
249 ".supports_vision".to_string(),
250 ".supports_image_input".to_string(),
251 ".capabilities.vision".to_string(),
252 ".architecture.input_modalities[] | select(. == \"image\")".to_string(),
253 ".architecture.output_modalities[] | select(. == \"image\")".to_string(),
254 "@name_contains(\"image\")".to_string(),
255 "@name_contains(\"flux\")".to_string(),
256 "@name_contains(\"dall-e\")".to_string(),
257 "@name_contains(\"midjourney\")".to_string(),
258 "@name_contains(\"stable\")".to_string(),
259 "@name_contains(\"diffusion\")".to_string(),
260 "@name_contains(\"vision\")".to_string(),
261 "@name_contains(\"visual\")".to_string(),
262 "@name_contains(\"photo\")".to_string(),
263 "@name_contains(\"picture\")".to_string(),
264 "@name_contains(\"draw\")".to_string(),
265 "@name_contains(\"paint\")".to_string(),
266 "@name_contains(\"art\")".to_string(),
267 "@name_contains(\"generate\")".to_string(),
268 ],
269 value_type: "bool".to_string(),
270 transform: None,
271 },
272 );
273
274 tags.insert(
276 "supports_tools".to_string(),
277 TagRule {
278 paths: vec![
279 ".supports_tools".to_string(),
280 ".capabilities.function_calling".to_string(),
281 ".features[] | select(. == \"tools\")".to_string(),
282 ".features[] | select(. == \"function-calling\")".to_string(),
283 ".capabilities[] | select(. == \"tool-calling\")".to_string(),
284 ".supported_parameters[] | select(. == \"tools\")".to_string(),
285 ],
286 value_type: "bool".to_string(),
287 transform: None,
288 },
289 );
290
291 tags.insert(
293 "supports_audio".to_string(),
294 TagRule {
295 paths: vec![
296 ".supports_audio".to_string(),
297 "@name_contains(\"audio\")".to_string(),
298 ".features[] | select(. == \"audio\")".to_string(),
299 ".capabilities[] | select(. == \"audio\")".to_string(),
300 ".supported_input_modalities[] | select(. == \"audio\")".to_string(),
301 ".supported_output_modalities[] | select(. == \"audio\")".to_string(),
302 ".architecture.input_modalities[] | select(. == \"audio\")".to_string(),
303 ".architecture.output_modalities[] | select(. == \"audio\")".to_string(),
304 ],
305 value_type: "bool".to_string(),
306 transform: None,
307 },
308 );
309
310 tags.insert(
312 "supports_reasoning".to_string(),
313 TagRule {
314 paths: vec![
315 ".supports_reasoning".to_string(),
316 ".features[] | select(. == \"think\")".to_string(),
317 ".features[] | select(. == \"reasoning\")".to_string(),
318 ".capabilities[] | select(. == \"reasoning\")".to_string(),
319 ".supported_input_modalities[] | select(. == \"reasoning\")".to_string(),
320 ".supported_output_modalities[] | select(. == \"reasoning\")".to_string(),
321 ".architecture.input_modalities[] | select(. == \"reasoning\")".to_string(),
322 ".architecture.output_modalities[] | select(. == \"reasoning\")".to_string(),
323 ],
324 value_type: "bool".to_string(),
325 transform: None,
326 },
327 );
328
329 Self { tags }
330 }
331}
332
333pub struct ModelMetadataExtractor {
335 model_paths: ModelPaths,
336 tag_config: TagConfig,
337}
338
339impl ModelMetadataExtractor {
340 pub fn new() -> Result<Self> {
341 if let Err(e) = Self::ensure_config_files_exist() {
343 eprintln!(
344 "Warning: Failed to ensure model metadata config files exist: {}",
345 e
346 );
347 }
348
349 let model_paths = Self::load_model_paths()?;
350 let tag_config = Self::load_tag_config()?;
351
352 Ok(Self {
353 model_paths,
354 tag_config,
355 })
356 }
357
358 fn ensure_config_files_exist() -> Result<()> {
360 let config_dir = Self::get_config_dir()?;
361
362 fs::create_dir_all(&config_dir)?;
364
365 let model_paths_file = config_dir.join("model_paths.toml");
367 if !model_paths_file.exists() {
368 let default_paths = ModelPaths::default();
369 let content = toml::to_string_pretty(&default_paths)?;
370 fs::write(&model_paths_file, content)?;
371 }
372
373 let tags_file = config_dir.join("tags.toml");
375 if !tags_file.exists() {
376 let default_tags = TagConfig::default();
377 let content = toml::to_string_pretty(&default_tags)?;
378 fs::write(&tags_file, content)?;
379 }
380
381 Ok(())
382 }
383
384 fn get_config_dir() -> Result<PathBuf> {
385 if let Ok(xdg_config) = std::env::var("XDG_CONFIG_HOME") {
387 return Ok(std::path::PathBuf::from(xdg_config).join("lc"));
388 }
389
390 if let Ok(home) = std::env::var("HOME") {
391 if home.contains("tmp") || home.contains("temp") {
393 return Ok(std::path::PathBuf::from(home).join(".config").join("lc"));
394 }
395 }
396
397 let config_dir = dirs::config_dir()
399 .context("Failed to get config directory")?
400 .join("lc");
401 Ok(config_dir)
402 }
403
404 fn load_model_paths() -> Result<ModelPaths> {
405 let config_dir = Self::get_config_dir()?;
406 let path = config_dir.join("model_paths.toml");
407
408 fs::create_dir_all(&config_dir)?;
410
411 if path.exists() {
412 let content = fs::read_to_string(&path)?;
413 toml::from_str(&content).context("Failed to parse model_paths.toml")
414 } else {
415 let default = ModelPaths::default();
417 let content = toml::to_string_pretty(&default)?;
418 fs::write(&path, content)?;
419 Ok(default)
420 }
421 }
422
423 fn load_tag_config() -> Result<TagConfig> {
424 let config_dir = Self::get_config_dir()?;
425 let path = config_dir.join("tags.toml");
426
427 fs::create_dir_all(&config_dir)?;
429
430 if path.exists() {
431 let content = fs::read_to_string(&path)?;
432 toml::from_str(&content).context("Failed to parse tags.toml")
433 } else {
434 let default = TagConfig::default();
436 let content = toml::to_string_pretty(&default)?;
437 fs::write(&path, content)?;
438 Ok(default)
439 }
440 }
441
442 pub fn extract_models(&self, provider: &Provider, response: &Value) -> Result<Vec<Value>> {
443 let mut models = Vec::new();
444
445 for path in &self.model_paths.paths {
446 if let Ok(extracted) = self.extract_with_jq_path(response, path) {
447 match &extracted {
448 Value::Array(arr) => models.extend(arr.clone()),
449 Value::Object(obj) => {
450 let has_model_field = self
452 .model_paths
453 .field_mappings
454 .id_fields
455 .iter()
456 .any(|field| obj.contains_key(field))
457 || obj.contains_key("model"); if has_model_field {
460 models.push(extracted);
461 }
462 }
463 _ => {}
464 }
465 }
466 }
467
468 if provider.provider == "hf" || provider.provider == "huggingface" {
470 models = self.expand_huggingface_models(models)?;
471 }
472
473 Ok(models)
474 }
475
476 pub fn extract_with_jq_path(&self, data: &Value, path: &str) -> Result<Value> {
477 if path == "." {
479 return Ok(data.clone());
480 }
481
482 if path.contains(" | ") {
484 return self.extract_with_jq_filter(data, path);
485 }
486
487 let parts: Vec<&str> = path.split('.').filter(|s| !s.is_empty()).collect();
488 let mut current = data;
489
490 for part in parts {
491 if part.ends_with("[]") {
492 let field = &part[..part.len() - 2];
493 current = current
494 .get(field)
495 .context(format!("Field {} not found", field))?;
496 if !current.is_array() {
497 anyhow::bail!("Expected array at {}", field);
498 }
499 } else {
500 current = current
501 .get(part)
502 .context(format!("Field {} not found", part))?;
503 }
504 }
505
506 Ok(current.clone())
507 }
508
509 fn extract_with_jq_filter(&self, data: &Value, path: &str) -> Result<Value> {
510 let parts: Vec<&str> = path.split(" | ").collect();
511 if parts.len() != 2 {
512 anyhow::bail!("Complex JQ filters not supported: {}", path);
513 }
514
515 let array_path = parts[0].trim();
516 let filter = parts[1].trim();
517
518 let array_value = self.extract_with_jq_path(data, array_path)?;
520
521 if filter.starts_with("select(") && filter.ends_with(")") {
523 let condition = &filter[7..filter.len() - 1]; if let Value::Array(arr) = array_value {
526 for item in arr {
528 if self.evaluate_select_condition(&item, condition)? {
529 return Ok(Value::Bool(true));
530 }
531 }
532 return Ok(Value::Bool(false));
533 } else {
534 if self.evaluate_select_condition(&array_value, condition)? {
536 return Ok(array_value);
537 } else {
538 return Ok(Value::Null);
539 }
540 }
541 }
542
543 anyhow::bail!("Unsupported JQ filter: {}", filter)
544 }
545
546 fn evaluate_select_condition(&self, value: &Value, condition: &str) -> Result<bool> {
547 if condition.starts_with(". == ") {
549 let expected = condition[5..].trim();
550
551 let expected = if expected.starts_with('"') && expected.ends_with('"') {
553 &expected[1..expected.len() - 1]
554 } else {
555 expected
556 };
557
558 match value {
559 Value::String(s) => Ok(s == expected),
560 Value::Number(n) => {
561 if let Ok(num) = expected.parse::<f64>() {
562 Ok(n.as_f64() == Some(num))
563 } else {
564 Ok(false)
565 }
566 }
567 Value::Bool(b) => {
568 if let Ok(bool_val) = expected.parse::<bool>() {
569 Ok(*b == bool_val)
570 } else {
571 Ok(false)
572 }
573 }
574 _ => Ok(false),
575 }
576 } else {
577 anyhow::bail!("Unsupported select condition: {}", condition)
578 }
579 }
580
581 fn expand_huggingface_models(&self, models: Vec<Value>) -> Result<Vec<Value>> {
582 let mut expanded = Vec::new();
583
584 for model in models {
585 if let Some(providers) = model.get("providers").and_then(|p| p.as_array()) {
586 for provider in providers {
587 let mut new_model = model.clone();
588 if let Some(obj) = new_model.as_object_mut() {
589 obj.insert("provider".to_string(), provider.clone());
590 obj.remove("providers");
591 }
592 expanded.push(new_model);
593 }
594 } else {
595 expanded.push(model);
596 }
597 }
598
599 Ok(expanded)
600 }
601
602 pub fn extract_metadata(&self, provider: &Provider, model: &Value) -> Result<ModelMetadata> {
603 let mut metadata = ModelMetadata::default();
604
605 let base_id = self
607 .model_paths
608 .field_mappings
609 .id_fields
610 .iter()
611 .find_map(|field| model.get(field).and_then(|v| v.as_str()))
612 .map(|s| s.to_string())
613 .ok_or_else(|| {
614 let fields = self.model_paths.field_mappings.id_fields.join(", ");
615 anyhow::anyhow!(
616 "Model missing required ID field. Checked fields: {}",
617 fields
618 )
619 })?;
620
621 if (provider.provider == "hf" || provider.provider == "huggingface")
623 && model.get("provider").is_some()
624 {
625 if let Some(provider_obj) = model.get("provider") {
626 if let Some(provider_name) = provider_obj.get("provider").and_then(|v| v.as_str()) {
627 metadata.id = format!("{}:{}", base_id, provider_name);
628 } else {
629 metadata.id = base_id;
630 }
631 } else {
632 metadata.id = base_id;
633 }
634 } else {
635 metadata.id = base_id;
636 }
637
638 metadata.provider = provider.provider.clone();
639 metadata.raw_data = model.clone();
640
641 if let Some(name) = self
643 .model_paths
644 .field_mappings
645 .name_fields
646 .iter()
647 .find_map(|field| model.get(field).and_then(|v| v.as_str()))
648 {
649 metadata.display_name = Some(name.to_string());
650 }
651
652 if let Some(desc) = model.get("description").and_then(|v| v.as_str()) {
653 metadata.description = Some(desc.to_string());
654 }
655
656 if let Some(owner) = model.get("owned_by").and_then(|v| v.as_str()) {
657 metadata.owned_by = Some(owner.to_string());
658 }
659
660 if let Some(created) = model.get("created").and_then(|v| v.as_i64()) {
661 metadata.created = Some(created);
662 }
663
664 for (tag_name, rule) in &self.tag_config.tags {
666 if let Some(value) = self.extract_tag_value(model, rule) {
667 self.apply_tag_value(&mut metadata, tag_name, value, &rule.value_type)?;
668 }
669 }
670
671 metadata.model_type = self.determine_model_type(&metadata.id, metadata.display_name.as_deref());
673
674 Ok(metadata)
675 }
676
677 fn extract_tag_value(&self, model: &Value, rule: &TagRule) -> Option<Value> {
678 let is_bool_field = rule.value_type == "bool";
680 let mut found_false = false;
681
682 for path in &rule.paths {
683 if path.starts_with("@name_contains(") && path.ends_with(")") {
685 let pattern = &path[15..path.len() - 1]; let pattern = pattern.trim_matches('"'); if let Some(result) = self.check_name_contains(model, pattern) {
689 if is_bool_field && result {
690 return Some(Value::Bool(true));
691 } else if !is_bool_field {
692 return Some(Value::Bool(result));
693 } else if result == false {
694 found_false = true;
695 }
696 }
697 continue;
698 }
699
700 if path.starts_with("@name_matches(") && path.ends_with(")") {
701 let pattern = &path[14..path.len() - 1]; let pattern = pattern.trim_matches('"'); if let Some(result) = self.check_name_matches(model, pattern) {
705 if is_bool_field && result {
706 return Some(Value::Bool(true));
707 } else if !is_bool_field {
708 return Some(Value::Bool(result));
709 } else if result == false {
710 found_false = true;
711 }
712 }
713 continue;
714 }
715
716 if let Ok(value) = self.extract_with_jq_path(model, path) {
718 if !value.is_null() {
719 if is_bool_field {
721 if let Some(bool_val) = value.as_bool() {
722 if bool_val {
723 if let Some(transform) = &rule.transform {
725 return self.apply_transform(value, transform);
726 }
727 return Some(value);
728 } else {
729 found_false = true;
731 }
732 }
733 } else {
734 if let Some(transform) = &rule.transform {
736 return self.apply_transform(value, transform);
737 }
738 return Some(value);
739 }
740 }
741 }
742 }
743
744 if is_bool_field && found_false {
747 Some(Value::Bool(false))
748 } else {
749 None
750 }
751 }
752
753 fn apply_transform(&self, value: Value, transform: &str) -> Option<Value> {
754 match transform {
755 "multiply_million" => {
756 if let Some(num) = value.as_f64() {
757 Some(Value::from(num * 1_000_000.0))
758 } else {
759 None
760 }
761 }
762 _ => Some(value),
763 }
764 }
765
766 fn apply_tag_value(
767 &self,
768 metadata: &mut ModelMetadata,
769 tag_name: &str,
770 value: Value,
771 value_type: &str,
772 ) -> Result<()> {
773 match tag_name {
774 "context_length" => {
775 if let Some(v) = self.parse_value_as_u32(&value, value_type)? {
776 metadata.context_length = Some(v);
777 }
778 }
779 "max_input_tokens" => {
780 if let Some(v) = self.parse_value_as_u32(&value, value_type)? {
781 metadata.max_input_tokens = Some(v);
782 }
783 }
784 "max_output_tokens" | "output" => {
785 if let Some(v) = self.parse_value_as_u32(&value, value_type)? {
786 metadata.max_output_tokens = Some(v);
787 }
788 }
789 "input_price_per_m" | "input_price_per_m_direct" => {
790 if let Some(v) = self.parse_value_as_f64(&value, value_type)? {
791 metadata.input_price_per_m = Some(v);
792 }
793 }
794 "output_price_per_m" | "output_price_per_m_direct" => {
795 if let Some(v) = self.parse_value_as_f64(&value, value_type)? {
796 metadata.output_price_per_m = Some(v);
797 }
798 }
799 "supports_tools" => {
800 if let Some(v) = self.parse_value_as_bool(&value, value_type)? {
801 metadata.supports_tools = v;
802 }
803 }
804 "supports_vision" => {
805 if let Some(v) = self.parse_value_as_bool(&value, value_type)? {
806 metadata.supports_vision = v;
807 }
808 }
809 "supports_audio" => {
810 if let Some(v) = self.parse_value_as_bool(&value, value_type)? {
811 metadata.supports_audio = v;
812 }
813 }
814 "supports_reasoning" => {
815 if let Some(v) = self.parse_value_as_bool(&value, value_type)? {
816 metadata.supports_reasoning = v;
817 }
818 }
819 "supports_code" => {
820 if let Some(v) = self.parse_value_as_bool(&value, value_type)? {
821 metadata.supports_code = v;
822 }
823 }
824 "supports_function_calling" => {
825 if let Some(v) = self.parse_value_as_bool(&value, value_type)? {
826 metadata.supports_function_calling = v;
827 }
828 }
829 "supports_json_mode" => {
830 if let Some(v) = self.parse_value_as_bool(&value, value_type)? {
831 metadata.supports_json_mode = v;
832 }
833 }
834 "supports_streaming" => {
835 if let Some(v) = self.parse_value_as_bool(&value, value_type)? {
836 metadata.supports_streaming = v;
837 }
838 }
839 "is_deprecated" => {
840 if let Some(v) = self.parse_value_as_bool(&value, value_type)? {
841 metadata.is_deprecated = v;
842 }
843 }
844 "is_fine_tunable" => {
845 if let Some(v) = self.parse_value_as_bool(&value, value_type)? {
846 metadata.is_fine_tunable = v;
847 }
848 }
849 _ => {
850 }
852 }
853 Ok(())
854 }
855
856 fn parse_value_as_bool(&self, value: &Value, _value_type: &str) -> Result<Option<bool>> {
857 match value {
858 Value::Bool(b) => Ok(Some(*b)),
859 Value::String(s) => Ok(Some(s == "true" || s == "yes" || s == "1")),
860 Value::Number(n) => Ok(Some(n.as_i64().unwrap_or(0) != 0)),
861 _ => Ok(None),
862 }
863 }
864
865 fn parse_value_as_u32(&self, value: &Value, _value_type: &str) -> Result<Option<u32>> {
866 match value {
867 Value::Number(n) => {
868 if let Some(v) = n.as_u64() {
869 Ok(Some(v as u32))
870 } else if let Some(v) = n.as_i64() {
871 Ok(Some(v as u32))
872 } else {
873 Ok(None)
874 }
875 }
876 Value::String(s) => Ok(s.parse::<u32>().ok()),
877 _ => Ok(None),
878 }
879 }
880
881 fn parse_value_as_f64(&self, value: &Value, _value_type: &str) -> Result<Option<f64>> {
882 match value {
883 Value::Number(n) => Ok(n.as_f64()),
884 Value::String(s) => Ok(s.parse::<f64>().ok()),
885 _ => Ok(None),
886 }
887 }
888
889 fn check_name_contains(&self, model: &Value, pattern: &str) -> Option<bool> {
891 let pattern_lower = pattern.to_lowercase();
892
893 for field in &self.model_paths.field_mappings.id_fields {
895 if let Some(value) = model.get(field).and_then(|v| v.as_str()) {
896 if value.to_lowercase().contains(&pattern_lower) {
897 return Some(true);
898 }
899 }
900 }
901
902 for field in &self.model_paths.field_mappings.name_fields {
904 if let Some(value) = model.get(field).and_then(|v| v.as_str()) {
905 if value.to_lowercase().contains(&pattern_lower) {
906 return Some(true);
907 }
908 }
909 }
910
911 Some(false)
912 }
913
914 fn determine_model_type(&self, model_id: &str, display_name: Option<&str>) -> ModelType {
916 let id_lower = model_id.to_lowercase();
917 let name_lower = display_name.map(|n| n.to_lowercase());
918
919 let embedding_patterns = [
921 "embed",
922 "embedding",
923 "text-embedding",
924 "text_embedding",
925 "ada",
926 "similarity",
927 "bge",
928 "e5",
929 "gte",
930 "instructor",
931 "voyage",
932 "titan-embed",
933 "embedding-gecko",
934 "embed-english",
935 "embed-multilingual",
936 ];
937
938 for pattern in &embedding_patterns {
939 if id_lower.contains(pattern) {
940 return ModelType::Embedding;
941 }
942 if let Some(ref name) = name_lower {
943 if name.contains(pattern) {
944 return ModelType::Embedding;
945 }
946 }
947 }
948
949 let image_patterns = [
951 "dall-e",
952 "dalle",
953 "stable-diffusion",
954 "midjourney",
955 "imagen",
956 "image",
957 ];
958
959 for pattern in &image_patterns {
960 if id_lower.contains(pattern) {
961 return ModelType::ImageGeneration;
962 }
963 if let Some(ref name) = name_lower {
964 if name.contains(pattern) {
965 return ModelType::ImageGeneration;
966 }
967 }
968 }
969
970 let audio_patterns = [
972 "whisper",
973 "tts",
974 "audio",
975 "speech",
976 "voice",
977 ];
978
979 for pattern in &audio_patterns {
980 if id_lower.contains(pattern) {
981 return ModelType::AudioGeneration;
982 }
983 if let Some(ref name) = name_lower {
984 if name.contains(pattern) {
985 return ModelType::AudioGeneration;
986 }
987 }
988 }
989
990 let moderation_patterns = [
992 "moderation",
993 "moderate",
994 "safety",
995 ];
996
997 for pattern in &moderation_patterns {
998 if id_lower.contains(pattern) {
999 return ModelType::Moderation;
1000 }
1001 if let Some(ref name) = name_lower {
1002 if name.contains(pattern) {
1003 return ModelType::Moderation;
1004 }
1005 }
1006 }
1007
1008 let completion_patterns = [
1010 "davinci",
1011 "curie",
1012 "babbage",
1013 "ada-001",
1014 "text-davinci",
1015 "text-curie",
1016 "text-babbage",
1017 "code-davinci",
1018 "code-cushman",
1019 ];
1020
1021 for pattern in &completion_patterns {
1022 if id_lower.contains(pattern) && !id_lower.contains("embed") {
1023 return ModelType::Completion;
1024 }
1025 if let Some(ref name) = name_lower {
1026 if name.contains(pattern) && !name.contains("embed") {
1027 return ModelType::Completion;
1028 }
1029 }
1030 }
1031
1032 ModelType::Chat
1034 }
1035
1036 fn check_name_matches(&self, model: &Value, pattern: &str) -> Option<bool> {
1038 use regex::RegexBuilder;
1039
1040 let regex = match RegexBuilder::new(pattern).case_insensitive(true).build() {
1042 Ok(r) => r,
1043 Err(_) => return Some(false), };
1045
1046 for field in &self.model_paths.field_mappings.id_fields {
1048 if let Some(value) = model.get(field).and_then(|v| v.as_str()) {
1049 if regex.is_match(value) {
1050 return Some(true);
1051 }
1052 }
1053 }
1054
1055 for field in &self.model_paths.field_mappings.name_fields {
1057 if let Some(value) = model.get(field).and_then(|v| v.as_str()) {
1058 if regex.is_match(value) {
1059 return Some(true);
1060 }
1061 }
1062 }
1063
1064 Some(false)
1065 }
1066}
1067
1068pub fn extract_models_from_provider(
1070 provider: &Provider,
1071 raw_json: &str,
1072) -> Result<Vec<ModelMetadata>> {
1073 let response: Value = serde_json::from_str(raw_json)?;
1074 let extractor = ModelMetadataExtractor::new()?;
1075
1076 let models = extractor.extract_models(provider, &response)?;
1077 let mut metadata_list = Vec::new();
1078
1079 for model in models {
1080 match extractor.extract_metadata(provider, &model) {
1081 Ok(metadata) => metadata_list.push(metadata),
1082 Err(e) => {
1083 eprintln!("Warning: Failed to extract metadata for model: {}", e);
1084 }
1085 }
1086 }
1087
1088 Ok(metadata_list)
1089}
1090
1091pub fn add_model_path(path: String) -> Result<()> {
1093 let config_dir = ModelMetadataExtractor::get_config_dir()?;
1094 let file_path = config_dir.join("model_paths.toml");
1095
1096 let mut paths = if file_path.exists() {
1097 let content = fs::read_to_string(&file_path)?;
1098 toml::from_str(&content)?
1099 } else {
1100 ModelPaths::default()
1101 };
1102
1103 if !paths.paths.contains(&path) {
1104 paths.paths.push(path);
1105 let content = toml::to_string_pretty(&paths)?;
1106 fs::write(&file_path, content)?;
1107 println!("Added model path");
1108 } else {
1109 println!("Path already exists");
1110 }
1111
1112 Ok(())
1113}
1114
1115pub fn remove_model_path(path: String) -> Result<()> {
1116 let config_dir = ModelMetadataExtractor::get_config_dir()?;
1117 let file_path = config_dir.join("model_paths.toml");
1118
1119 if !file_path.exists() {
1120 anyhow::bail!("No model paths configured");
1121 }
1122
1123 let mut paths: ModelPaths = {
1124 let content = fs::read_to_string(&file_path)?;
1125 toml::from_str(&content)?
1126 };
1127
1128 if let Some(pos) = paths.paths.iter().position(|p| p == &path) {
1129 paths.paths.remove(pos);
1130 let content = toml::to_string_pretty(&paths)?;
1131 fs::write(&file_path, content)?;
1132 println!("Removed model path");
1133 } else {
1134 println!("Path not found");
1135 }
1136
1137 Ok(())
1138}
1139
1140pub fn list_model_paths() -> Result<()> {
1141 let config_dir = ModelMetadataExtractor::get_config_dir()?;
1142 let file_path = config_dir.join("model_paths.toml");
1143
1144 let paths = if file_path.exists() {
1145 let content = fs::read_to_string(&file_path)?;
1146 toml::from_str(&content)?
1147 } else {
1148 ModelPaths::default()
1149 };
1150
1151 println!("Model paths:");
1152 for path in &paths.paths {
1153 println!(" - {}", path);
1154 }
1155
1156 Ok(())
1157}
1158
1159pub fn add_tag(
1160 name: String,
1161 paths: Vec<String>,
1162 value_type: String,
1163 transform: Option<String>,
1164) -> Result<()> {
1165 let config_dir = ModelMetadataExtractor::get_config_dir()?;
1166 let file_path = config_dir.join("tags.toml");
1167
1168 let mut config = if file_path.exists() {
1169 let content = fs::read_to_string(&file_path)?;
1170 toml::from_str(&content)?
1171 } else {
1172 TagConfig::default()
1173 };
1174
1175 config.tags.insert(
1176 name.clone(),
1177 TagRule {
1178 paths,
1179 value_type,
1180 transform,
1181 },
1182 );
1183
1184 let content = toml::to_string_pretty(&config)?;
1185 fs::write(&file_path, content)?;
1186 println!("Added tag: {}", name);
1187
1188 Ok(())
1189}
1190
1191pub fn initialize_model_metadata_config() -> Result<()> {
1195 ModelMetadataExtractor::ensure_config_files_exist()
1196}
1197
1198pub fn list_tags() -> Result<()> {
1199 let config_dir = ModelMetadataExtractor::get_config_dir()?;
1200 let file_path = config_dir.join("tags.toml");
1201
1202 let config = if file_path.exists() {
1203 let content = fs::read_to_string(&file_path)?;
1204 toml::from_str(&content)?
1205 } else {
1206 TagConfig::default()
1207 };
1208
1209 println!("Tags:");
1210 for (name, rule) in &config.tags {
1211 println!(" {}:", name);
1212 println!(" Type: {}", rule.value_type);
1213 println!(" Paths:");
1214 for path in &rule.paths {
1215 println!(" - {}", path);
1216 }
1217 if let Some(transform) = &rule.transform {
1218 println!(" Transform: {}", transform);
1219 }
1220 }
1221
1222 Ok(())
1223}
1224
1225pub struct MetadataExtractor;
1227
1228impl MetadataExtractor {
1229 pub fn extract_from_provider(
1230 provider: &str,
1231 raw_json: &str,
1232 ) -> Result<Vec<ModelMetadata>, Box<dyn std::error::Error>> {
1233 let provider_obj = Provider {
1234 provider: provider.to_string(),
1235 status: "active".to_string(),
1236 supports_tools: false,
1237 supports_structured_output: false,
1238 };
1239
1240 extract_models_from_provider(&provider_obj, raw_json).map_err(|e| e.into())
1241 }
1242}