sparrow/provider/
discovery.rs1use async_trait::async_trait;
2use serde_json::Value;
3use std::time::Duration;
4
5#[async_trait]
6pub trait ModelDiscovery: Send + Sync {
7 async fn fetch_model_names(&self, base_url: &str, api_key: &str)
8 -> anyhow::Result<Vec<String>>;
9}
10
11pub struct OpenAICompatDiscovery;
12pub struct AnthropicDiscovery;
13pub struct OllamaDiscovery;
14
15#[async_trait]
16impl ModelDiscovery for OpenAICompatDiscovery {
17 async fn fetch_model_names(
18 &self,
19 base_url: &str,
20 api_key: &str,
21 ) -> anyhow::Result<Vec<String>> {
22 let client = reqwest::Client::builder()
23 .timeout(Duration::from_secs(10))
24 .build()?;
25 let url = format!("{}/models", base_url.trim_end_matches('/'));
26 let mut request = client.get(url);
27 if !api_key.trim().is_empty() {
28 request = request.bearer_auth(api_key);
29 }
30 let value: Value = request.send().await?.error_for_status()?.json().await?;
31 let models = value
32 .get("data")
33 .and_then(|data| data.as_array())
34 .map(|items| {
35 items
36 .iter()
37 .filter_map(|item| item.get("id").and_then(|id| id.as_str()))
38 .filter(|id| is_chat_model_id(id))
39 .map(str::to_string)
40 .collect()
41 })
42 .unwrap_or_default();
43 Ok(models)
44 }
45}
46
47#[async_trait]
48impl ModelDiscovery for AnthropicDiscovery {
49 async fn fetch_model_names(
50 &self,
51 _base_url: &str,
52 api_key: &str,
53 ) -> anyhow::Result<Vec<String>> {
54 let client = reqwest::Client::builder()
55 .timeout(Duration::from_secs(10))
56 .build()?;
57 let value: Value = client
58 .get("https://api.anthropic.com/v1/models")
59 .header("x-api-key", api_key)
60 .header("anthropic-version", "2023-06-01")
61 .send()
62 .await?
63 .error_for_status()?
64 .json()
65 .await?;
66 Ok(value
67 .get("data")
68 .and_then(|data| data.as_array())
69 .map(|items| {
70 items
71 .iter()
72 .filter_map(|item| item.get("id").and_then(|id| id.as_str()))
73 .map(str::to_string)
74 .collect()
75 })
76 .unwrap_or_default())
77 }
78}
79
80#[async_trait]
81impl ModelDiscovery for OllamaDiscovery {
82 async fn fetch_model_names(
83 &self,
84 base_url: &str,
85 _api_key: &str,
86 ) -> anyhow::Result<Vec<String>> {
87 let client = reqwest::Client::builder()
88 .timeout(Duration::from_secs(10))
89 .build()?;
90 let root = base_url.trim_end_matches('/').trim_end_matches("/v1");
91 let value: Value = client
92 .get(format!("{}/api/tags", root))
93 .send()
94 .await?
95 .error_for_status()?
96 .json()
97 .await?;
98 Ok(value
99 .get("models")
100 .and_then(|models| models.as_array())
101 .map(|items| {
102 items
103 .iter()
104 .filter_map(|item| item.get("name").and_then(|name| name.as_str()))
105 .filter(|name| is_chat_model_id(name))
106 .map(str::to_string)
107 .collect()
108 })
109 .unwrap_or_default())
110 }
111}
112
113pub async fn discover_models(
114 adapter: &str,
115 base_url: &str,
116 api_key: &str,
117) -> anyhow::Result<Vec<String>> {
118 match adapter {
119 "anthropic-messages" => {
120 AnthropicDiscovery
121 .fetch_model_names(base_url, api_key)
122 .await
123 }
124 "ollama" => OllamaDiscovery.fetch_model_names(base_url, api_key).await,
125 _ => {
126 OpenAICompatDiscovery
127 .fetch_model_names(base_url, api_key)
128 .await
129 }
130 }
131}
132
133pub fn is_chat_model_id(id: &str) -> bool {
134 let id = id.to_ascii_lowercase();
135 let exclude = [
138 "embed",
139 "embedding",
140 "bge-",
141 "e5-",
142 "rerank",
143 "retriever",
144 "retrieval",
145 "tts",
146 "dall-e",
147 "dall_e",
148 "whisper",
149 "moderation",
150 "safety",
151 "guard",
152 "detector",
153 "reward",
154 "parse",
155 "ocr",
156 "clip",
157 "vila",
158 "neva",
159 "text-davinci",
160 "text-curie",
161 "text-babbage",
162 "text-ada",
163 "babbage-",
164 "ada-",
165 "davinci-00",
166 "code-search",
167 "text-search",
168 "similarity",
169 "-edit-",
170 "cushman",
171 "text-similarity",
172 "audio",
173 "transcribe",
174 "translate",
175 "realtime",
176 "gliner", "pii", "deplot", "kosmos", "fuyu", "calibration", "cosmos-reason", "palmyra-med", "palmyra-fin", "-med-70b", "chatqa", ];
190 !exclude.iter().any(|needle| id.contains(needle))
191}