1use anyhow::{Error, Result};
2use async_trait::async_trait;
3use serde_json::{json, Value};
4
5use super::api_client::{ApiClient, AuthMethod};
6use super::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage, Usage};
7use super::errors::ProviderError;
8use super::retry::ProviderRetry;
9use super::utils::{
10 get_model, handle_response_google_compat, handle_response_openai_compat,
11 handle_status_openai_compat, is_google_model, stream_openai_compat, RequestLog,
12};
13use crate::conversation::message::Message;
14
15use crate::model::ModelConfig;
16use crate::providers::formats::openai::{create_request, get_usage, response_to_message};
17use rmcp::model::Tool;
18
19pub const OPENROUTER_DEFAULT_MODEL: &str = "anthropic/claude-sonnet-4";
20pub const OPENROUTER_DEFAULT_FAST_MODEL: &str = "google/gemini-2.5-flash";
21pub const OPENROUTER_MODEL_PREFIX_ANTHROPIC: &str = "anthropic";
22
23pub const OPENROUTER_KNOWN_MODELS: &[&str] = &[
25 "x-ai/grok-code-fast-1",
26 "anthropic/claude-sonnet-4.5",
27 "anthropic/claude-sonnet-4",
28 "anthropic/claude-opus-4.1",
29 "anthropic/claude-opus-4",
30 "google/gemini-2.5-pro",
31 "google/gemini-2.5-flash",
32 "deepseek/deepseek-r1-0528",
33 "qwen/qwen3-coder",
34 "moonshotai/kimi-k2",
35];
36pub const OPENROUTER_DOC_URL: &str = "https://openrouter.ai/models";
37
38#[derive(serde::Serialize)]
39pub struct OpenRouterProvider {
40 #[serde(skip)]
41 api_client: ApiClient,
42 model: ModelConfig,
43 supports_streaming: bool,
44 #[serde(skip)]
45 name: String,
46}
47
48impl OpenRouterProvider {
49 pub async fn from_env(model: ModelConfig) -> Result<Self> {
50 let model = model.with_fast(OPENROUTER_DEFAULT_FAST_MODEL.to_string());
51
52 let config = crate::config::Config::global();
53 let api_key: String = config.get_secret("OPENROUTER_API_KEY")?;
54 let host: String = config
55 .get_param("OPENROUTER_HOST")
56 .unwrap_or_else(|_| "https://openrouter.ai".to_string());
57
58 let auth = AuthMethod::BearerToken(api_key);
59 let api_client = ApiClient::new(host, auth)?
60 .with_header("HTTP-Referer", "https://astercloud.github.io/aster-rust")?
61 .with_header("X-Title", "aster")?;
62
63 Ok(Self {
64 api_client,
65 model,
66 supports_streaming: true,
67 name: Self::metadata().name,
68 })
69 }
70
71 async fn post(&self, payload: &Value) -> Result<Value, ProviderError> {
72 let response = self
73 .api_client
74 .response_post("api/v1/chat/completions", payload)
75 .await?;
76
77 if is_google_model(payload) {
79 return handle_response_google_compat(response).await;
80 }
81
82 let response_body = handle_response_openai_compat(response)
84 .await
85 .map_err(|e| ProviderError::RequestFailed(format!("Failed to parse response: {e}")))?;
86
87 let _debug = format!(
88 "OpenRouter request with payload: {} and response: {}",
89 serde_json::to_string_pretty(payload).unwrap_or_else(|_| "Invalid JSON".to_string()),
90 serde_json::to_string_pretty(&response_body)
91 .unwrap_or_else(|_| "Invalid JSON".to_string())
92 );
93
94 if let Some(error_obj) = response_body.get("error") {
97 let error_message = error_obj
99 .get("message")
100 .and_then(|m| m.as_str())
101 .unwrap_or("Unknown OpenRouter error");
102
103 let error_code = error_obj.get("code").and_then(|c| c.as_u64()).unwrap_or(0);
104
105 if error_code == 400 && error_message.contains("maximum context length") {
107 return Err(ProviderError::ContextLengthExceeded(
108 error_message.to_string(),
109 ));
110 }
111
112 match error_code {
114 401 | 403 => return Err(ProviderError::Authentication(error_message.to_string())),
115 429 => {
116 return Err(ProviderError::RateLimitExceeded {
117 details: error_message.to_string(),
118 retry_delay: None,
119 })
120 }
121 500 | 503 => return Err(ProviderError::ServerError(error_message.to_string())),
122 _ => return Err(ProviderError::RequestFailed(error_message.to_string())),
123 }
124 }
125
126 Ok(response_body)
128 }
129}
130
131fn update_request_for_anthropic(original_payload: &Value) -> Value {
135 let mut payload = original_payload.clone();
136
137 if let Some(messages_spec) = payload
138 .as_object_mut()
139 .and_then(|obj| obj.get_mut("messages"))
140 .and_then(|messages| messages.as_array_mut())
141 {
142 let mut user_count = 0;
147 for message in messages_spec.iter_mut().rev() {
148 if message.get("role") == Some(&json!("user")) {
149 if let Some(content) = message.get_mut("content") {
150 if let Some(content_str) = content.as_str() {
151 *content = json!([{
152 "type": "text",
153 "text": content_str,
154 "cache_control": { "type": "ephemeral" }
155 }]);
156 }
157 }
158 user_count += 1;
159 if user_count >= 2 {
160 break;
161 }
162 }
163 }
164
165 if let Some(system_message) = messages_spec
167 .iter_mut()
168 .find(|msg| msg.get("role") == Some(&json!("system")))
169 {
170 if let Some(content) = system_message.get_mut("content") {
171 if let Some(content_str) = content.as_str() {
172 *system_message = json!({
173 "role": "system",
174 "content": [{
175 "type": "text",
176 "text": content_str,
177 "cache_control": { "type": "ephemeral" }
178 }]
179 });
180 }
181 }
182 }
183 }
184
185 if let Some(tools_spec) = payload
186 .as_object_mut()
187 .and_then(|obj| obj.get_mut("tools"))
188 .and_then(|tools| tools.as_array_mut())
189 {
190 if let Some(last_tool) = tools_spec.last_mut() {
193 if let Some(function) = last_tool.get_mut("function") {
194 function
195 .as_object_mut()
196 .unwrap()
197 .insert("cache_control".to_string(), json!({ "type": "ephemeral" }));
198 }
199 }
200 }
201 payload
202}
203
204async fn create_request_based_on_model(
205 provider: &OpenRouterProvider,
206 system: &str,
207 messages: &[Message],
208 tools: &[Tool],
209) -> anyhow::Result<Value, Error> {
210 let mut payload = create_request(
211 &provider.model,
212 system,
213 messages,
214 tools,
215 &super::utils::ImageFormat::OpenAi,
216 false,
217 )?;
218
219 if provider.supports_cache_control().await {
220 payload = update_request_for_anthropic(&payload);
221 }
222
223 payload
224 .as_object_mut()
225 .unwrap()
226 .insert("transforms".to_string(), json!(["middle-out"]));
227
228 Ok(payload)
229}
230
231#[async_trait]
232impl Provider for OpenRouterProvider {
233 fn metadata() -> ProviderMetadata {
234 ProviderMetadata::new(
235 "openrouter",
236 "OpenRouter",
237 "Router for many model providers",
238 OPENROUTER_DEFAULT_MODEL,
239 OPENROUTER_KNOWN_MODELS.to_vec(),
240 OPENROUTER_DOC_URL,
241 vec![
242 ConfigKey::new("OPENROUTER_API_KEY", true, true, None),
243 ConfigKey::new(
244 "OPENROUTER_HOST",
245 false,
246 false,
247 Some("https://openrouter.ai"),
248 ),
249 ],
250 )
251 }
252
253 fn get_name(&self) -> &str {
254 &self.name
255 }
256
257 fn get_model_config(&self) -> ModelConfig {
258 self.model.clone()
259 }
260
261 #[tracing::instrument(
262 skip(self, model_config, system, messages, tools),
263 fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
264 )]
265 async fn complete_with_model(
266 &self,
267 model_config: &ModelConfig,
268 system: &str,
269 messages: &[Message],
270 tools: &[Tool],
271 ) -> Result<(Message, ProviderUsage), ProviderError> {
272 let payload = create_request_based_on_model(self, system, messages, tools).await?;
273 let mut log = RequestLog::start(model_config, &payload)?;
274
275 let response = self
277 .with_retry(|| async {
278 let payload_clone = payload.clone();
279 self.post(&payload_clone).await
280 })
281 .await?;
282
283 let message = response_to_message(&response)?;
285 let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
286 tracing::debug!("Failed to get usage data");
287 Usage::default()
288 });
289 let response_model = get_model(&response);
290 log.write(&response, Some(&usage))?;
291 Ok((message, ProviderUsage::new(response_model, usage)))
292 }
293
294 async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
296 let response = match self.api_client.response_get("api/v1/models").await {
299 Ok(response) => response,
300 Err(e) => {
301 tracing::warn!("Failed to fetch models from OpenRouter API: {}, falling back to manual model entry", e);
302 return Ok(None);
303 }
304 };
305
306 let json: serde_json::Value = match response.json().await {
308 Ok(json) => json,
309 Err(e) => {
310 tracing::warn!("Failed to parse OpenRouter API response as JSON: {}, falling back to manual model entry", e);
311 return Ok(None);
312 }
313 };
314
315 if let Some(err_obj) = json.get("error") {
317 let msg = err_obj
318 .get("message")
319 .and_then(|v| v.as_str())
320 .unwrap_or("unknown error");
321 tracing::warn!("OpenRouter API returned an error: {}", msg);
322 return Ok(None);
323 }
324
325 let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
326 ProviderError::UsageError("Missing data field in JSON response".into())
327 })?;
328
329 let mut models: Vec<String> = data
330 .iter()
331 .filter_map(|model| {
332 let id = model.get("id").and_then(|v| v.as_str())?;
334
335 let supported_params =
337 match model.get("supported_parameters").and_then(|v| v.as_array()) {
338 Some(params) => params,
339 None => {
340 tracing::debug!(
342 "Model '{}' missing supported_parameters field, skipping",
343 id
344 );
345 return None;
346 }
347 };
348
349 let has_tool_support = supported_params
350 .iter()
351 .any(|param| param.as_str() == Some("tools"));
352
353 if has_tool_support {
354 Some(id.to_string())
355 } else {
356 None
357 }
358 })
359 .collect();
360
361 if models.is_empty() {
363 tracing::warn!("No models with tool support found in OpenRouter API response, falling back to manual model entry");
364 return Ok(None);
365 }
366
367 models.sort();
368 Ok(Some(models))
369 }
370
371 async fn supports_cache_control(&self) -> bool {
372 self.model
373 .model_name
374 .starts_with(OPENROUTER_MODEL_PREFIX_ANTHROPIC)
375 }
376
377 fn supports_streaming(&self) -> bool {
378 self.supports_streaming
379 }
380
381 async fn stream(
382 &self,
383 system: &str,
384 messages: &[Message],
385 tools: &[Tool],
386 ) -> Result<MessageStream, ProviderError> {
387 let mut payload = create_request(
388 &self.model,
389 system,
390 messages,
391 tools,
392 &super::utils::ImageFormat::OpenAi,
393 true,
394 )?;
395
396 if self.supports_cache_control().await {
397 payload = update_request_for_anthropic(&payload);
398 }
399
400 payload
401 .as_object_mut()
402 .unwrap()
403 .insert("transforms".to_string(), json!(["middle-out"]));
404
405 let mut log = RequestLog::start(&self.model, &payload)?;
406
407 let response = self
408 .with_retry(|| async {
409 let resp = self
410 .api_client
411 .response_post("api/v1/chat/completions", &payload)
412 .await?;
413 handle_status_openai_compat(resp).await
414 })
415 .await
416 .inspect_err(|e| {
417 let _ = log.error(e);
418 })?;
419
420 stream_openai_compat(response, log)
421 }
422}