reflex/semantic/providers/
openrouter.rs1use super::LlmProvider;
8use anyhow::{Context, Result};
9use async_trait::async_trait;
10use serde_json::json;
11
12#[derive(Debug, Clone)]
14pub struct OpenRouterModel {
15 pub id: String,
16 pub name: String,
17 pub prompt_price: f64, pub completion_price: f64, pub context_length: u64,
20}
21
22pub async fn fetch_models(api_key: &str) -> Result<Vec<OpenRouterModel>> {
24 let client = reqwest::Client::new();
25
26 let response = client
27 .get("https://openrouter.ai/api/v1/models")
28 .header("Authorization", format!("Bearer {}", api_key))
29 .timeout(std::time::Duration::from_secs(10))
30 .send()
31 .await
32 .context("Failed to fetch models from OpenRouter")?;
33
34 if !response.status().is_success() {
35 let status = response.status();
36 let error_text = response
37 .text()
38 .await
39 .unwrap_or_else(|_| "Unknown error".to_string());
40 anyhow::bail!("OpenRouter API error ({}): {}", status, error_text);
41 }
42
43 let data: serde_json::Value = response
44 .json()
45 .await
46 .context("Failed to parse OpenRouter models response")?;
47
48 let models_array = data["data"]
49 .as_array()
50 .context("No 'data' array in OpenRouter models response")?;
51
52 let mut models: Vec<OpenRouterModel> = models_array
53 .iter()
54 .filter_map(|m| {
55 let id = m["id"].as_str()?;
56 let name = m["name"].as_str().unwrap_or(id);
57
58 let prompt_str = m["pricing"]["prompt"].as_str()?;
60 let completion_str = m["pricing"]["completion"].as_str()?;
61
62 let prompt_per_token: f64 = prompt_str.parse().ok()?;
63 let completion_per_token: f64 = completion_str.parse().ok()?;
64
65 if prompt_per_token < 0.0 || completion_per_token < 0.0 {
68 return None;
69 }
70
71 let context_length = m["context_length"].as_u64().unwrap_or(0);
72
73 Some(OpenRouterModel {
74 id: id.to_string(),
75 name: name.to_string(),
76 prompt_price: prompt_per_token * 1_000_000.0,
77 completion_price: completion_per_token * 1_000_000.0,
78 context_length,
79 })
80 })
81 .collect();
82
83 models.sort_by(|a, b| a.id.cmp(&b.id));
84
85 Ok(models)
86}
87
88pub struct OpenRouterProvider {
90 client: reqwest::Client,
91 api_key: String,
92 model: String,
93 sort: String,
94}
95
96impl OpenRouterProvider {
97 pub fn new(api_key: String, model: Option<String>, sort: Option<String>) -> Result<Self> {
104 let sort = sort
106 .map(|s| if s == "speed" { "latency".to_string() } else { s })
107 .unwrap_or_else(|| "price".to_string());
108 Ok(Self {
109 client: reqwest::Client::new(),
110 api_key,
111 model: model.unwrap_or_else(|| "anthropic/claude-sonnet-4".to_string()),
112 sort,
113 })
114 }
115}
116
117#[async_trait]
118impl LlmProvider for OpenRouterProvider {
119 async fn complete(&self, prompt: &str, json_mode: bool) -> Result<String> {
120 let messages = vec![json!({
121 "role": "user",
122 "content": prompt
123 })];
124
125 let mut request_body = json!({
126 "model": self.model,
127 "messages": messages,
128 "temperature": 0.1,
129 "max_tokens": 4000,
130 "provider": {
131 "sort": self.sort,
132 "allow_fallbacks": true
133 }
134 });
135
136 if json_mode {
138 request_body["response_format"] = json!({
139 "type": "json_object"
140 });
141 }
142
143 let response = self
144 .client
145 .post("https://openrouter.ai/api/v1/chat/completions")
146 .header("Authorization", format!("Bearer {}", self.api_key))
147 .header("Content-Type", "application/json")
148 .header("HTTP-Referer", "https://github.com/reflex-search/reflex")
149 .header("X-Title", "Reflex")
150 .json(&request_body)
151 .timeout(std::time::Duration::from_secs(60))
152 .send()
153 .await
154 .map_err(|e| {
155 log::error!("OpenRouter API request failed: {}", e);
156 if e.is_timeout() {
157 log::error!(" Reason: Request timeout (>60s)");
158 } else if e.is_connect() {
159 log::error!(" Reason: Connection failed");
160 } else if e.is_request() {
161 log::error!(" Reason: Invalid request");
162 }
163 anyhow::anyhow!("Failed to send request to OpenRouter API: {}", e)
164 })?;
165
166 if !response.status().is_success() {
168 let status = response.status();
169 let error_text = response
170 .text()
171 .await
172 .unwrap_or_else(|_| "Unknown error".to_string());
173
174 let error_msg = match status.as_u16() {
175 429 => {
176 log::warn!("OpenRouter rate limit exceeded: {}", error_text);
177 "Rate limit exceeded (try again in a few seconds)".to_string()
178 }
179 503 | 502 | 504 => {
180 log::warn!("OpenRouter service unavailable ({}): {}", status, error_text);
181 format!("OpenRouter service temporarily unavailable ({})", status)
182 }
183 401 => {
184 log::error!("OpenRouter authentication failed: {}", error_text);
185 "Authentication failed - check API key".to_string()
186 }
187 _ => {
188 log::error!("OpenRouter API error ({}): {}", status, error_text);
189 format!("API error ({}): {}", status, error_text)
190 }
191 };
192
193 anyhow::bail!("{}", error_msg);
194 }
195
196 let data: serde_json::Value = response
197 .json()
198 .await
199 .context("Failed to parse OpenRouter response as JSON")?;
200
201 let content = data["choices"][0]["message"]["content"]
203 .as_str()
204 .context("No content in OpenRouter response")?;
205
206 Ok(content.to_string())
207 }
208
209 fn name(&self) -> &str {
210 "openrouter"
211 }
212
213 fn default_model(&self) -> &str {
214 "anthropic/claude-sonnet-4"
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221
222 #[test]
223 fn test_new_with_defaults() {
224 let provider = OpenRouterProvider::new("test-key".to_string(), None, None).unwrap();
225 assert_eq!(provider.name(), "openrouter");
226 assert_eq!(provider.model, "anthropic/claude-sonnet-4");
227 assert_eq!(provider.sort, "price");
228 }
229
230 #[test]
231 fn test_new_with_custom_model_and_sort() {
232 let provider = OpenRouterProvider::new(
233 "test-key".to_string(),
234 Some("openai/gpt-4o-mini".to_string()),
235 Some("latency".to_string()),
236 )
237 .unwrap();
238 assert_eq!(provider.model, "openai/gpt-4o-mini");
239 assert_eq!(provider.sort, "latency");
240 }
241
242 #[test]
243 fn test_new_maps_legacy_speed_to_latency() {
244 let provider = OpenRouterProvider::new(
245 "test-key".to_string(),
246 None,
247 Some("speed".to_string()),
248 )
249 .unwrap();
250 assert_eq!(provider.sort, "latency");
251 }
252
253 #[test]
254 fn test_openrouter_model_pricing_conversion() {
255 let prompt_str = "0.000003";
257 let completion_str = "0.000015";
258
259 let prompt_per_token: f64 = prompt_str.parse().unwrap();
260 let completion_per_token: f64 = completion_str.parse().unwrap();
261
262 let prompt_per_million = prompt_per_token * 1_000_000.0;
263 let completion_per_million = completion_per_token * 1_000_000.0;
264
265 assert!((prompt_per_million - 3.0).abs() < 0.001);
266 assert!((completion_per_million - 15.0).abs() < 0.001);
267 }
268
269 #[test]
270 fn test_openrouter_model_struct() {
271 let model = OpenRouterModel {
272 id: "anthropic/claude-sonnet-4".to_string(),
273 name: "Anthropic: Claude Sonnet 4".to_string(),
274 prompt_price: 3.0,
275 completion_price: 15.0,
276 context_length: 200000,
277 };
278
279 assert_eq!(model.id, "anthropic/claude-sonnet-4");
280 assert_eq!(model.prompt_price, 3.0);
281 assert_eq!(model.completion_price, 15.0);
282 assert_eq!(model.context_length, 200000);
283 }
284}