Skip to main content

aster/providers/
openrouter.rs

1use anyhow::{Error, Result};
2use async_trait::async_trait;
3use serde_json::{json, Value};
4
5use super::api_client::{ApiClient, AuthMethod};
6use super::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage, Usage};
7use super::errors::ProviderError;
8use super::retry::ProviderRetry;
9use super::utils::{
10    get_model, handle_response_google_compat, handle_response_openai_compat,
11    handle_status_openai_compat, is_google_model, stream_openai_compat, RequestLog,
12};
13use crate::conversation::message::Message;
14
15use crate::model::ModelConfig;
16use crate::providers::formats::openai::{create_request, get_usage, response_to_message};
17use rmcp::model::Tool;
18
19pub const OPENROUTER_DEFAULT_MODEL: &str = "anthropic/claude-sonnet-4";
20pub const OPENROUTER_DEFAULT_FAST_MODEL: &str = "google/gemini-2.5-flash";
21pub const OPENROUTER_MODEL_PREFIX_ANTHROPIC: &str = "anthropic";
22
23// OpenRouter can run many models, we suggest the default
24pub const OPENROUTER_KNOWN_MODELS: &[&str] = &[
25    "x-ai/grok-code-fast-1",
26    "anthropic/claude-sonnet-4.5",
27    "anthropic/claude-sonnet-4",
28    "anthropic/claude-opus-4.1",
29    "anthropic/claude-opus-4",
30    "google/gemini-2.5-pro",
31    "google/gemini-2.5-flash",
32    "deepseek/deepseek-r1-0528",
33    "qwen/qwen3-coder",
34    "moonshotai/kimi-k2",
35];
36pub const OPENROUTER_DOC_URL: &str = "https://openrouter.ai/models";
37
38#[derive(serde::Serialize)]
39pub struct OpenRouterProvider {
40    #[serde(skip)]
41    api_client: ApiClient,
42    model: ModelConfig,
43    supports_streaming: bool,
44    #[serde(skip)]
45    name: String,
46}
47
48impl OpenRouterProvider {
49    pub async fn from_env(model: ModelConfig) -> Result<Self> {
50        let model = model.with_fast(OPENROUTER_DEFAULT_FAST_MODEL.to_string());
51
52        let config = crate::config::Config::global();
53        let api_key: String = config.get_secret("OPENROUTER_API_KEY")?;
54        let host: String = config
55            .get_param("OPENROUTER_HOST")
56            .unwrap_or_else(|_| "https://openrouter.ai".to_string());
57
58        let auth = AuthMethod::BearerToken(api_key);
59        let api_client = ApiClient::new(host, auth)?
60            .with_header("HTTP-Referer", "https://astercloud.github.io/aster-rust")?
61            .with_header("X-Title", "aster")?;
62
63        Ok(Self {
64            api_client,
65            model,
66            supports_streaming: true,
67            name: Self::metadata().name,
68        })
69    }
70
71    async fn post(&self, payload: &Value) -> Result<Value, ProviderError> {
72        let response = self
73            .api_client
74            .response_post("api/v1/chat/completions", payload)
75            .await?;
76
77        // Handle Google-compatible model responses differently
78        if is_google_model(payload) {
79            return handle_response_google_compat(response).await;
80        }
81
82        // For OpenAI-compatible models, parse the response body to JSON
83        let response_body = handle_response_openai_compat(response)
84            .await
85            .map_err(|e| ProviderError::RequestFailed(format!("Failed to parse response: {e}")))?;
86
87        let _debug = format!(
88            "OpenRouter request with payload: {} and response: {}",
89            serde_json::to_string_pretty(payload).unwrap_or_else(|_| "Invalid JSON".to_string()),
90            serde_json::to_string_pretty(&response_body)
91                .unwrap_or_else(|_| "Invalid JSON".to_string())
92        );
93
94        // OpenRouter can return errors in 200 OK responses, so we have to check for errors explicitly
95        // https://openrouter.ai/docs/api-reference/errors
96        if let Some(error_obj) = response_body.get("error") {
97            // If there's an error object, extract the error message and code
98            let error_message = error_obj
99                .get("message")
100                .and_then(|m| m.as_str())
101                .unwrap_or("Unknown OpenRouter error");
102
103            let error_code = error_obj.get("code").and_then(|c| c.as_u64()).unwrap_or(0);
104
105            // Check for context length errors in the error message
106            if error_code == 400 && error_message.contains("maximum context length") {
107                return Err(ProviderError::ContextLengthExceeded(
108                    error_message.to_string(),
109                ));
110            }
111
112            // Return appropriate error based on the OpenRouter error code
113            match error_code {
114                401 | 403 => return Err(ProviderError::Authentication(error_message.to_string())),
115                429 => {
116                    return Err(ProviderError::RateLimitExceeded {
117                        details: error_message.to_string(),
118                        retry_delay: None,
119                    })
120                }
121                500 | 503 => return Err(ProviderError::ServerError(error_message.to_string())),
122                _ => return Err(ProviderError::RequestFailed(error_message.to_string())),
123            }
124        }
125
126        // No error detected, return the response body
127        Ok(response_body)
128    }
129}
130
131/// Update the request when using anthropic model.
132/// For anthropic model, we can enable prompt caching to save cost. Since openrouter is the OpenAI compatible
133/// endpoint, we need to modify the open ai request to have anthropic cache control field.
134fn update_request_for_anthropic(original_payload: &Value) -> Value {
135    let mut payload = original_payload.clone();
136
137    if let Some(messages_spec) = payload
138        .as_object_mut()
139        .and_then(|obj| obj.get_mut("messages"))
140        .and_then(|messages| messages.as_array_mut())
141    {
142        // Add "cache_control" to the last and second-to-last "user" messages.
143        // During each turn, we mark the final message with cache_control so the conversation can be
144        // incrementally cached. The second-to-last user message is also marked for caching with the
145        // cache_control parameter, so that this checkpoint can read from the previous cache.
146        let mut user_count = 0;
147        for message in messages_spec.iter_mut().rev() {
148            if message.get("role") == Some(&json!("user")) {
149                if let Some(content) = message.get_mut("content") {
150                    if let Some(content_str) = content.as_str() {
151                        *content = json!([{
152                            "type": "text",
153                            "text": content_str,
154                            "cache_control": { "type": "ephemeral" }
155                        }]);
156                    }
157                }
158                user_count += 1;
159                if user_count >= 2 {
160                    break;
161                }
162            }
163        }
164
165        // Update the system message to have cache_control field.
166        if let Some(system_message) = messages_spec
167            .iter_mut()
168            .find(|msg| msg.get("role") == Some(&json!("system")))
169        {
170            if let Some(content) = system_message.get_mut("content") {
171                if let Some(content_str) = content.as_str() {
172                    *system_message = json!({
173                        "role": "system",
174                        "content": [{
175                            "type": "text",
176                            "text": content_str,
177                            "cache_control": { "type": "ephemeral" }
178                        }]
179                    });
180                }
181            }
182        }
183    }
184
185    if let Some(tools_spec) = payload
186        .as_object_mut()
187        .and_then(|obj| obj.get_mut("tools"))
188        .and_then(|tools| tools.as_array_mut())
189    {
190        // Add "cache_control" to the last tool spec, if any. This means that all tool definitions,
191        // will be cached as a single prefix.
192        if let Some(last_tool) = tools_spec.last_mut() {
193            if let Some(function) = last_tool.get_mut("function") {
194                function
195                    .as_object_mut()
196                    .unwrap()
197                    .insert("cache_control".to_string(), json!({ "type": "ephemeral" }));
198            }
199        }
200    }
201    payload
202}
203
204async fn create_request_based_on_model(
205    provider: &OpenRouterProvider,
206    system: &str,
207    messages: &[Message],
208    tools: &[Tool],
209) -> anyhow::Result<Value, Error> {
210    let mut payload = create_request(
211        &provider.model,
212        system,
213        messages,
214        tools,
215        &super::utils::ImageFormat::OpenAi,
216        false,
217    )?;
218
219    if provider.supports_cache_control().await {
220        payload = update_request_for_anthropic(&payload);
221    }
222
223    payload
224        .as_object_mut()
225        .unwrap()
226        .insert("transforms".to_string(), json!(["middle-out"]));
227
228    Ok(payload)
229}
230
231#[async_trait]
232impl Provider for OpenRouterProvider {
233    fn metadata() -> ProviderMetadata {
234        ProviderMetadata::new(
235            "openrouter",
236            "OpenRouter",
237            "Router for many model providers",
238            OPENROUTER_DEFAULT_MODEL,
239            OPENROUTER_KNOWN_MODELS.to_vec(),
240            OPENROUTER_DOC_URL,
241            vec![
242                ConfigKey::new("OPENROUTER_API_KEY", true, true, None),
243                ConfigKey::new(
244                    "OPENROUTER_HOST",
245                    false,
246                    false,
247                    Some("https://openrouter.ai"),
248                ),
249            ],
250        )
251    }
252
253    fn get_name(&self) -> &str {
254        &self.name
255    }
256
257    fn get_model_config(&self) -> ModelConfig {
258        self.model.clone()
259    }
260
261    #[tracing::instrument(
262        skip(self, model_config, system, messages, tools),
263        fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
264    )]
265    async fn complete_with_model(
266        &self,
267        model_config: &ModelConfig,
268        system: &str,
269        messages: &[Message],
270        tools: &[Tool],
271    ) -> Result<(Message, ProviderUsage), ProviderError> {
272        let payload = create_request_based_on_model(self, system, messages, tools).await?;
273        let mut log = RequestLog::start(model_config, &payload)?;
274
275        // Make request
276        let response = self
277            .with_retry(|| async {
278                let payload_clone = payload.clone();
279                self.post(&payload_clone).await
280            })
281            .await?;
282
283        // Parse response
284        let message = response_to_message(&response)?;
285        let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
286            tracing::debug!("Failed to get usage data");
287            Usage::default()
288        });
289        let response_model = get_model(&response);
290        log.write(&response, Some(&usage))?;
291        Ok((message, ProviderUsage::new(response_model, usage)))
292    }
293
294    /// Fetch supported models from OpenRouter API (only models with tool support)
295    async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
296        // Handle request failures gracefully
297        // If the request fails, fall back to manual entry
298        let response = match self.api_client.response_get("api/v1/models").await {
299            Ok(response) => response,
300            Err(e) => {
301                tracing::warn!("Failed to fetch models from OpenRouter API: {}, falling back to manual model entry", e);
302                return Ok(None);
303            }
304        };
305
306        // Handle JSON parsing failures gracefully
307        let json: serde_json::Value = match response.json().await {
308            Ok(json) => json,
309            Err(e) => {
310                tracing::warn!("Failed to parse OpenRouter API response as JSON: {}, falling back to manual model entry", e);
311                return Ok(None);
312            }
313        };
314
315        // Check for error in response
316        if let Some(err_obj) = json.get("error") {
317            let msg = err_obj
318                .get("message")
319                .and_then(|v| v.as_str())
320                .unwrap_or("unknown error");
321            tracing::warn!("OpenRouter API returned an error: {}", msg);
322            return Ok(None);
323        }
324
325        let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
326            ProviderError::UsageError("Missing data field in JSON response".into())
327        })?;
328
329        let mut models: Vec<String> = data
330            .iter()
331            .filter_map(|model| {
332                // Get the model ID
333                let id = model.get("id").and_then(|v| v.as_str())?;
334
335                // Check if the model supports tools
336                let supported_params =
337                    match model.get("supported_parameters").and_then(|v| v.as_array()) {
338                        Some(params) => params,
339                        None => {
340                            // If supported_parameters is missing, skip this model (assume no tool support)
341                            tracing::debug!(
342                                "Model '{}' missing supported_parameters field, skipping",
343                                id
344                            );
345                            return None;
346                        }
347                    };
348
349                let has_tool_support = supported_params
350                    .iter()
351                    .any(|param| param.as_str() == Some("tools"));
352
353                if has_tool_support {
354                    Some(id.to_string())
355                } else {
356                    None
357                }
358            })
359            .collect();
360
361        // If no models with tool support were found, fall back to manual entry
362        if models.is_empty() {
363            tracing::warn!("No models with tool support found in OpenRouter API response, falling back to manual model entry");
364            return Ok(None);
365        }
366
367        models.sort();
368        Ok(Some(models))
369    }
370
371    async fn supports_cache_control(&self) -> bool {
372        self.model
373            .model_name
374            .starts_with(OPENROUTER_MODEL_PREFIX_ANTHROPIC)
375    }
376
377    fn supports_streaming(&self) -> bool {
378        self.supports_streaming
379    }
380
381    async fn stream(
382        &self,
383        system: &str,
384        messages: &[Message],
385        tools: &[Tool],
386    ) -> Result<MessageStream, ProviderError> {
387        let mut payload = create_request(
388            &self.model,
389            system,
390            messages,
391            tools,
392            &super::utils::ImageFormat::OpenAi,
393            true,
394        )?;
395
396        if self.supports_cache_control().await {
397            payload = update_request_for_anthropic(&payload);
398        }
399
400        payload
401            .as_object_mut()
402            .unwrap()
403            .insert("transforms".to_string(), json!(["middle-out"]));
404
405        let mut log = RequestLog::start(&self.model, &payload)?;
406
407        let response = self
408            .with_retry(|| async {
409                let resp = self
410                    .api_client
411                    .response_post("api/v1/chat/completions", &payload)
412                    .await?;
413                handle_status_openai_compat(resp).await
414            })
415            .await
416            .inspect_err(|e| {
417                let _ = log.error(e);
418            })?;
419
420        stream_openai_compat(response, log)
421    }
422}