Skip to main content

reflex/semantic/providers/
openrouter.rs

1//! OpenRouter API provider implementation
2//!
3//! OpenRouter is an OpenAI-compatible API aggregator that routes requests
4//! to 200+ models across providers (Claude, GPT, Gemini, Llama, etc.).
5//! It adds a "sort" strategy for provider routing (by price, speed, or throughput).
6
7use super::LlmProvider;
8use anyhow::{Context, Result};
9use async_trait::async_trait;
10use serde_json::json;
11use std::time::Duration;
12
13/// Model info fetched from OpenRouter API
14#[derive(Debug, Clone)]
15pub struct OpenRouterModel {
16    pub id: String,
17    pub name: String,
18    pub prompt_price: f64,     // USD per million tokens
19    pub completion_price: f64, // USD per million tokens
20    pub context_length: u64,
21}
22
23/// Fetch available models from OpenRouter API
24pub async fn fetch_models(api_key: &str) -> Result<Vec<OpenRouterModel>> {
25    let client = reqwest::Client::new();
26
27    let response = client
28        .get("https://openrouter.ai/api/v1/models")
29        .header("Authorization", format!("Bearer {}", api_key))
30        .timeout(std::time::Duration::from_secs(10))
31        .send()
32        .await
33        .context("Failed to fetch models from OpenRouter")?;
34
35    if !response.status().is_success() {
36        let status = response.status();
37        let error_text = response
38            .text()
39            .await
40            .unwrap_or_else(|_| "Unknown error".to_string());
41        anyhow::bail!("OpenRouter API error ({}): {}", status, error_text);
42    }
43
44    let data: serde_json::Value = response
45        .json()
46        .await
47        .context("Failed to parse OpenRouter models response")?;
48
49    let models_array = data["data"]
50        .as_array()
51        .context("No 'data' array in OpenRouter models response")?;
52
53    let mut models: Vec<OpenRouterModel> = models_array
54        .iter()
55        .filter_map(|m| {
56            let id = m["id"].as_str()?;
57            let name = m["name"].as_str().unwrap_or(id);
58
59            // Skip models without prompt/completion pricing (image, audio, embedding models)
60            let prompt_str = m["pricing"]["prompt"].as_str()?;
61            let completion_str = m["pricing"]["completion"].as_str()?;
62
63            let prompt_per_token: f64 = prompt_str.parse().ok()?;
64            let completion_per_token: f64 = completion_str.parse().ok()?;
65
66            // Skip free/zero-cost models that are likely non-text or test endpoints
67            // Also skip if both are zero (often indicates non-functional endpoints)
68            if prompt_per_token < 0.0 || completion_per_token < 0.0 {
69                return None;
70            }
71
72            let context_length = m["context_length"].as_u64().unwrap_or(0);
73
74            Some(OpenRouterModel {
75                id: id.to_string(),
76                name: name.to_string(),
77                prompt_price: prompt_per_token * 1_000_000.0,
78                completion_price: completion_per_token * 1_000_000.0,
79                context_length,
80            })
81        })
82        .collect();
83
84    models.sort_by(|a, b| a.id.cmp(&b.id));
85
86    Ok(models)
87}
88
89/// OpenRouter provider (OpenAI-compatible API with provider routing)
90pub struct OpenRouterProvider {
91    client: reqwest::Client,
92    api_key: String,
93    model: String,
94    sort: String,
95}
96
97impl OpenRouterProvider {
98    /// Create a new OpenRouter provider
99    ///
100    /// # Arguments
101    /// * `api_key` - OpenRouter API key
102    /// * `model` - Optional model override (default: anthropic/claude-sonnet-4)
103    /// * `sort` - Optional sort strategy: "price", "speed", or "throughput" (default: "price")
104    pub fn new(
105        api_key: String,
106        model: Option<String>,
107        sort: Option<String>,
108        timeout_secs: u64,
109    ) -> Result<Self> {
110        // Normalize sort value: map legacy "speed" to the correct API value "latency"
111        let sort = sort
112            .map(|s| {
113                if s == "speed" {
114                    "latency".to_string()
115                } else {
116                    s
117                }
118            })
119            .unwrap_or_else(|| "price".to_string());
120        let client = reqwest::Client::builder()
121            .timeout(Duration::from_secs(timeout_secs))
122            .build()
123            .context("Failed to build reqwest client")?;
124        Ok(Self {
125            client,
126            api_key,
127            model: model.unwrap_or_else(|| "anthropic/claude-sonnet-4".to_string()),
128            sort,
129        })
130    }
131}
132
133#[async_trait]
134impl LlmProvider for OpenRouterProvider {
135    async fn complete(&self, prompt: &str, json_mode: bool) -> Result<String> {
136        let messages = vec![json!({
137            "role": "user",
138            "content": prompt
139        })];
140
141        let mut request_body = json!({
142            "model": self.model,
143            "messages": messages,
144            "temperature": 0.1,
145            "max_tokens": 4000,
146            "provider": {
147                "sort": self.sort,
148                "allow_fallbacks": true
149            }
150        });
151
152        // Add JSON response format if requested
153        if json_mode {
154            request_body["response_format"] = json!({
155                "type": "json_object"
156            });
157        }
158
159        let response = self
160            .client
161            .post("https://openrouter.ai/api/v1/chat/completions")
162            .header("Authorization", format!("Bearer {}", self.api_key))
163            .header("Content-Type", "application/json")
164            .header("HTTP-Referer", "https://github.com/reflex-search/reflex")
165            .header("X-Title", "Reflex")
166            .json(&request_body)
167            .send()
168            .await
169            .map_err(|e| {
170                log::error!("OpenRouter API request failed: {}", e);
171                if e.is_timeout() {
172                    log::error!("  Reason: Request timeout (>60s)");
173                } else if e.is_connect() {
174                    log::error!("  Reason: Connection failed");
175                } else if e.is_request() {
176                    log::error!("  Reason: Invalid request");
177                }
178                anyhow::anyhow!("Failed to send request to OpenRouter API: {}", e)
179            })?;
180
181        // Check for HTTP errors
182        if !response.status().is_success() {
183            let status = response.status();
184            let error_text = response
185                .text()
186                .await
187                .unwrap_or_else(|_| "Unknown error".to_string());
188
189            let error_msg = match status.as_u16() {
190                429 => {
191                    log::warn!("OpenRouter rate limit exceeded: {}", error_text);
192                    "Rate limit exceeded (try again in a few seconds)".to_string()
193                }
194                503 | 502 | 504 => {
195                    log::warn!(
196                        "OpenRouter service unavailable ({}): {}",
197                        status,
198                        error_text
199                    );
200                    format!("OpenRouter service temporarily unavailable ({})", status)
201                }
202                401 => {
203                    log::error!("OpenRouter authentication failed: {}", error_text);
204                    "Authentication failed - check API key".to_string()
205                }
206                _ => {
207                    log::error!("OpenRouter API error ({}): {}", status, error_text);
208                    format!("API error ({}): {}", status, error_text)
209                }
210            };
211
212            anyhow::bail!("{}", error_msg);
213        }
214
215        let data: serde_json::Value = response
216            .json()
217            .await
218            .context("Failed to parse OpenRouter response as JSON")?;
219
220        // Extract content from response (OpenAI-compatible format)
221        let content = data["choices"][0]["message"]["content"]
222            .as_str()
223            .context("No content in OpenRouter response")?;
224
225        Ok(content.to_string())
226    }
227
228    fn name(&self) -> &str {
229        "openrouter"
230    }
231
232    fn default_model(&self) -> &str {
233        "anthropic/claude-sonnet-4"
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    #[test]
242    fn test_new_with_defaults() {
243        let provider = OpenRouterProvider::new("test-key".to_string(), None, None, 30).unwrap();
244        assert_eq!(provider.name(), "openrouter");
245        assert_eq!(provider.model, "anthropic/claude-sonnet-4");
246        assert_eq!(provider.sort, "price");
247    }
248
249    #[test]
250    fn test_new_with_custom_model_and_sort() {
251        let provider = OpenRouterProvider::new(
252            "test-key".to_string(),
253            Some("openai/gpt-4o-mini".to_string()),
254            Some("latency".to_string()),
255            300,
256        )
257        .unwrap();
258        assert_eq!(provider.model, "openai/gpt-4o-mini");
259        assert_eq!(provider.sort, "latency");
260    }
261
262    #[test]
263    fn test_new_maps_legacy_speed_to_latency() {
264        let provider =
265            OpenRouterProvider::new("test-key".to_string(), None, Some("speed".to_string()), 300)
266                .unwrap();
267        assert_eq!(provider.sort, "latency");
268    }
269
270    #[test]
271    fn test_openrouter_model_pricing_conversion() {
272        // Simulate what fetch_models does with per-token pricing strings
273        let prompt_str = "0.000003";
274        let completion_str = "0.000015";
275
276        let prompt_per_token: f64 = prompt_str.parse().unwrap();
277        let completion_per_token: f64 = completion_str.parse().unwrap();
278
279        let prompt_per_million = prompt_per_token * 1_000_000.0;
280        let completion_per_million = completion_per_token * 1_000_000.0;
281
282        assert!((prompt_per_million - 3.0).abs() < 0.001);
283        assert!((completion_per_million - 15.0).abs() < 0.001);
284    }
285
286    #[test]
287    fn test_openrouter_model_struct() {
288        let model = OpenRouterModel {
289            id: "anthropic/claude-sonnet-4".to_string(),
290            name: "Anthropic: Claude Sonnet 4".to_string(),
291            prompt_price: 3.0,
292            completion_price: 15.0,
293            context_length: 200000,
294        };
295
296        assert_eq!(model.id, "anthropic/claude-sonnet-4");
297        assert_eq!(model.prompt_price, 3.0);
298        assert_eq!(model.completion_price, 15.0);
299        assert_eq!(model.context_length, 200000);
300    }
301}