Skip to main content

sparrow/provider/
discovery.rs

1use async_trait::async_trait;
2use serde_json::Value;
3use std::time::Duration;
4
5#[async_trait]
6pub trait ModelDiscovery: Send + Sync {
7    async fn fetch_model_names(&self, base_url: &str, api_key: &str)
8    -> anyhow::Result<Vec<String>>;
9}
10
11pub struct OpenAICompatDiscovery;
12pub struct AnthropicDiscovery;
13pub struct OllamaDiscovery;
14
15#[async_trait]
16impl ModelDiscovery for OpenAICompatDiscovery {
17    async fn fetch_model_names(
18        &self,
19        base_url: &str,
20        api_key: &str,
21    ) -> anyhow::Result<Vec<String>> {
22        let client = reqwest::Client::builder()
23            .timeout(Duration::from_secs(10))
24            .build()?;
25        let url = format!("{}/models", base_url.trim_end_matches('/'));
26        let mut request = client.get(url);
27        if !api_key.trim().is_empty() {
28            request = request.bearer_auth(api_key);
29        }
30        let value: Value = request.send().await?.error_for_status()?.json().await?;
31        let models = value
32            .get("data")
33            .and_then(|data| data.as_array())
34            .map(|items| {
35                items
36                    .iter()
37                    .filter_map(|item| item.get("id").and_then(|id| id.as_str()))
38                    .filter(|id| is_chat_model_id(id))
39                    .map(str::to_string)
40                    .collect()
41            })
42            .unwrap_or_default();
43        Ok(models)
44    }
45}
46
47#[async_trait]
48impl ModelDiscovery for AnthropicDiscovery {
49    async fn fetch_model_names(
50        &self,
51        _base_url: &str,
52        api_key: &str,
53    ) -> anyhow::Result<Vec<String>> {
54        let client = reqwest::Client::builder()
55            .timeout(Duration::from_secs(10))
56            .build()?;
57        let value: Value = client
58            .get("https://api.anthropic.com/v1/models")
59            .header("x-api-key", api_key)
60            .header("anthropic-version", "2023-06-01")
61            .send()
62            .await?
63            .error_for_status()?
64            .json()
65            .await?;
66        Ok(value
67            .get("data")
68            .and_then(|data| data.as_array())
69            .map(|items| {
70                items
71                    .iter()
72                    .filter_map(|item| item.get("id").and_then(|id| id.as_str()))
73                    .map(str::to_string)
74                    .collect()
75            })
76            .unwrap_or_default())
77    }
78}
79
80#[async_trait]
81impl ModelDiscovery for OllamaDiscovery {
82    async fn fetch_model_names(
83        &self,
84        base_url: &str,
85        _api_key: &str,
86    ) -> anyhow::Result<Vec<String>> {
87        let client = reqwest::Client::builder()
88            .timeout(Duration::from_secs(10))
89            .build()?;
90        let root = base_url.trim_end_matches('/').trim_end_matches("/v1");
91        let value: Value = client
92            .get(format!("{}/api/tags", root))
93            .send()
94            .await?
95            .error_for_status()?
96            .json()
97            .await?;
98        Ok(value
99            .get("models")
100            .and_then(|models| models.as_array())
101            .map(|items| {
102                items
103                    .iter()
104                    .filter_map(|item| item.get("name").and_then(|name| name.as_str()))
105                    .filter(|name| is_chat_model_id(name))
106                    .map(str::to_string)
107                    .collect()
108            })
109            .unwrap_or_default())
110    }
111}
112
113pub async fn discover_models(
114    adapter: &str,
115    base_url: &str,
116    api_key: &str,
117) -> anyhow::Result<Vec<String>> {
118    match adapter {
119        "anthropic-messages" => {
120            AnthropicDiscovery
121                .fetch_model_names(base_url, api_key)
122                .await
123        }
124        "ollama" => OllamaDiscovery.fetch_model_names(base_url, api_key).await,
125        _ => {
126            OpenAICompatDiscovery
127                .fetch_model_names(base_url, api_key)
128                .await
129        }
130    }
131}
132
133pub fn is_chat_model_id(id: &str) -> bool {
134    let id = id.to_ascii_lowercase();
135    // Exclude non-chat model families: embeddings, image-gen, audio, moderation,
136    // legacy completions (text-davinci/curie/babbage/ada), and search/similarity helpers.
137    let exclude = [
138        "embed",
139        "embedding",
140        "bge-",
141        "e5-",
142        "rerank",
143        "retriever",
144        "retrieval",
145        "tts",
146        "dall-e",
147        "dall_e",
148        "whisper",
149        "moderation",
150        "safety",
151        "guard",
152        "detector",
153        "reward",
154        "parse",
155        "ocr",
156        "clip",
157        "vila",
158        "neva",
159        "text-davinci",
160        "text-curie",
161        "text-babbage",
162        "text-ada",
163        "babbage-",
164        "ada-",
165        "davinci-00",
166        "code-search",
167        "text-search",
168        "similarity",
169        "-edit-",
170        "cushman",
171        "text-similarity",
172        "audio",
173        "transcribe",
174        "translate",
175        "realtime",
176        // Domain-specific / non-general models that polluted the fallback chain
177        // (observed via NVIDIA /v1/models). These are not general chat models.
178        "gliner",        // NER / PII extraction
179        "pii",           // PII detection
180        "deplot",        // chart-to-table
181        "kosmos",        // vision grounding
182        "fuyu",          // vision-only
183        "calibration",   // internal calibration model
184        "cosmos-reason", // physical-AI vision reasoning
185        "palmyra-med",   // medical domain
186        "palmyra-fin",   // finance domain
187        "-med-70b",      // medical variants
188        "chatqa",        // retrieval-augmented QA, not general chat
189    ];
190    !exclude.iter().any(|needle| id.contains(needle))
191}