Skip to main content

reflex/semantic/providers/
openrouter.rs

1//! OpenRouter API provider implementation
2//!
3//! OpenRouter is an OpenAI-compatible API aggregator that routes requests
4//! to 200+ models across providers (Claude, GPT, Gemini, Llama, etc.).
5//! It adds a "sort" strategy for provider routing (by price, speed, or throughput).
6
7use super::LlmProvider;
8use anyhow::{Context, Result};
9use async_trait::async_trait;
10use serde_json::json;
11use std::time::Duration;
12
13/// Model info fetched from OpenRouter API
14#[derive(Debug, Clone)]
15pub struct OpenRouterModel {
16    pub id: String,
17    pub name: String,
18    pub prompt_price: f64,      // USD per million tokens
19    pub completion_price: f64,  // USD per million tokens
20    pub context_length: u64,
21}
22
23/// Fetch available models from OpenRouter API
24pub async fn fetch_models(api_key: &str) -> Result<Vec<OpenRouterModel>> {
25    let client = reqwest::Client::new();
26
27    let response = client
28        .get("https://openrouter.ai/api/v1/models")
29        .header("Authorization", format!("Bearer {}", api_key))
30        .timeout(std::time::Duration::from_secs(10))
31        .send()
32        .await
33        .context("Failed to fetch models from OpenRouter")?;
34
35    if !response.status().is_success() {
36        let status = response.status();
37        let error_text = response
38            .text()
39            .await
40            .unwrap_or_else(|_| "Unknown error".to_string());
41        anyhow::bail!("OpenRouter API error ({}): {}", status, error_text);
42    }
43
44    let data: serde_json::Value = response
45        .json()
46        .await
47        .context("Failed to parse OpenRouter models response")?;
48
49    let models_array = data["data"]
50        .as_array()
51        .context("No 'data' array in OpenRouter models response")?;
52
53    let mut models: Vec<OpenRouterModel> = models_array
54        .iter()
55        .filter_map(|m| {
56            let id = m["id"].as_str()?;
57            let name = m["name"].as_str().unwrap_or(id);
58
59            // Skip models without prompt/completion pricing (image, audio, embedding models)
60            let prompt_str = m["pricing"]["prompt"].as_str()?;
61            let completion_str = m["pricing"]["completion"].as_str()?;
62
63            let prompt_per_token: f64 = prompt_str.parse().ok()?;
64            let completion_per_token: f64 = completion_str.parse().ok()?;
65
66            // Skip free/zero-cost models that are likely non-text or test endpoints
67            // Also skip if both are zero (often indicates non-functional endpoints)
68            if prompt_per_token < 0.0 || completion_per_token < 0.0 {
69                return None;
70            }
71
72            let context_length = m["context_length"].as_u64().unwrap_or(0);
73
74            Some(OpenRouterModel {
75                id: id.to_string(),
76                name: name.to_string(),
77                prompt_price: prompt_per_token * 1_000_000.0,
78                completion_price: completion_per_token * 1_000_000.0,
79                context_length,
80            })
81        })
82        .collect();
83
84    models.sort_by(|a, b| a.id.cmp(&b.id));
85
86    Ok(models)
87}
88
89/// OpenRouter provider (OpenAI-compatible API with provider routing)
90pub struct OpenRouterProvider {
91    client: reqwest::Client,
92    api_key: String,
93    model: String,
94    sort: String,
95}
96
97impl OpenRouterProvider {
98    /// Create a new OpenRouter provider
99    ///
100    /// # Arguments
101    /// * `api_key` - OpenRouter API key
102    /// * `model` - Optional model override (default: anthropic/claude-sonnet-4)
103    /// * `sort` - Optional sort strategy: "price", "speed", or "throughput" (default: "price")
104    pub fn new(api_key: String, model: Option<String>, sort: Option<String>, timeout_secs: u64) -> Result<Self> {
105        // Normalize sort value: map legacy "speed" to the correct API value "latency"
106        let sort = sort
107            .map(|s| if s == "speed" { "latency".to_string() } else { s })
108            .unwrap_or_else(|| "price".to_string());
109        let client = reqwest::Client::builder()
110            .timeout(Duration::from_secs(timeout_secs))
111            .build()
112            .context("Failed to build reqwest client")?;
113        Ok(Self {
114            client,
115            api_key,
116            model: model.unwrap_or_else(|| "anthropic/claude-sonnet-4".to_string()),
117            sort,
118        })
119    }
120}
121
122#[async_trait]
123impl LlmProvider for OpenRouterProvider {
124    async fn complete(&self, prompt: &str, json_mode: bool) -> Result<String> {
125        let messages = vec![json!({
126            "role": "user",
127            "content": prompt
128        })];
129
130        let mut request_body = json!({
131            "model": self.model,
132            "messages": messages,
133            "temperature": 0.1,
134            "max_tokens": 4000,
135            "provider": {
136                "sort": self.sort,
137                "allow_fallbacks": true
138            }
139        });
140
141        // Add JSON response format if requested
142        if json_mode {
143            request_body["response_format"] = json!({
144                "type": "json_object"
145            });
146        }
147
148        let response = self
149            .client
150            .post("https://openrouter.ai/api/v1/chat/completions")
151            .header("Authorization", format!("Bearer {}", self.api_key))
152            .header("Content-Type", "application/json")
153            .header("HTTP-Referer", "https://github.com/reflex-search/reflex")
154            .header("X-Title", "Reflex")
155            .json(&request_body)
156            .send()
157            .await
158            .map_err(|e| {
159                log::error!("OpenRouter API request failed: {}", e);
160                if e.is_timeout() {
161                    log::error!("  Reason: Request timeout (>60s)");
162                } else if e.is_connect() {
163                    log::error!("  Reason: Connection failed");
164                } else if e.is_request() {
165                    log::error!("  Reason: Invalid request");
166                }
167                anyhow::anyhow!("Failed to send request to OpenRouter API: {}", e)
168            })?;
169
170        // Check for HTTP errors
171        if !response.status().is_success() {
172            let status = response.status();
173            let error_text = response
174                .text()
175                .await
176                .unwrap_or_else(|_| "Unknown error".to_string());
177
178            let error_msg = match status.as_u16() {
179                429 => {
180                    log::warn!("OpenRouter rate limit exceeded: {}", error_text);
181                    "Rate limit exceeded (try again in a few seconds)".to_string()
182                }
183                503 | 502 | 504 => {
184                    log::warn!("OpenRouter service unavailable ({}): {}", status, error_text);
185                    format!("OpenRouter service temporarily unavailable ({})", status)
186                }
187                401 => {
188                    log::error!("OpenRouter authentication failed: {}", error_text);
189                    "Authentication failed - check API key".to_string()
190                }
191                _ => {
192                    log::error!("OpenRouter API error ({}): {}", status, error_text);
193                    format!("API error ({}): {}", status, error_text)
194                }
195            };
196
197            anyhow::bail!("{}", error_msg);
198        }
199
200        let data: serde_json::Value = response
201            .json()
202            .await
203            .context("Failed to parse OpenRouter response as JSON")?;
204
205        // Extract content from response (OpenAI-compatible format)
206        let content = data["choices"][0]["message"]["content"]
207            .as_str()
208            .context("No content in OpenRouter response")?;
209
210        Ok(content.to_string())
211    }
212
213    fn name(&self) -> &str {
214        "openrouter"
215    }
216
217    fn default_model(&self) -> &str {
218        "anthropic/claude-sonnet-4"
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[test]
227    fn test_new_with_defaults() {
228        let provider = OpenRouterProvider::new("test-key".to_string(), None, None, 30).unwrap();
229        assert_eq!(provider.name(), "openrouter");
230        assert_eq!(provider.model, "anthropic/claude-sonnet-4");
231        assert_eq!(provider.sort, "price");
232    }
233
234    #[test]
235    fn test_new_with_custom_model_and_sort() {
236        let provider = OpenRouterProvider::new(
237            "test-key".to_string(),
238            Some("openai/gpt-4o-mini".to_string()),
239            Some("latency".to_string()),
240            300,
241        )
242        .unwrap();
243        assert_eq!(provider.model, "openai/gpt-4o-mini");
244        assert_eq!(provider.sort, "latency");
245    }
246
247    #[test]
248    fn test_new_maps_legacy_speed_to_latency() {
249        let provider = OpenRouterProvider::new(
250            "test-key".to_string(),
251            None,
252            Some("speed".to_string()),
253            300,
254        )
255        .unwrap();
256        assert_eq!(provider.sort, "latency");
257    }
258
259    #[test]
260    fn test_openrouter_model_pricing_conversion() {
261        // Simulate what fetch_models does with per-token pricing strings
262        let prompt_str = "0.000003";
263        let completion_str = "0.000015";
264
265        let prompt_per_token: f64 = prompt_str.parse().unwrap();
266        let completion_per_token: f64 = completion_str.parse().unwrap();
267
268        let prompt_per_million = prompt_per_token * 1_000_000.0;
269        let completion_per_million = completion_per_token * 1_000_000.0;
270
271        assert!((prompt_per_million - 3.0).abs() < 0.001);
272        assert!((completion_per_million - 15.0).abs() < 0.001);
273    }
274
275    #[test]
276    fn test_openrouter_model_struct() {
277        let model = OpenRouterModel {
278            id: "anthropic/claude-sonnet-4".to_string(),
279            name: "Anthropic: Claude Sonnet 4".to_string(),
280            prompt_price: 3.0,
281            completion_price: 15.0,
282            context_length: 200000,
283        };
284
285        assert_eq!(model.id, "anthropic/claude-sonnet-4");
286        assert_eq!(model.prompt_price, 3.0);
287        assert_eq!(model.completion_price, 15.0);
288        assert_eq!(model.context_length, 200000);
289    }
290}