mur_core/model/
openrouter.rs1use super::provider::ModelResponse;
4use anyhow::{Context, Result};
5use serde::{Deserialize, Serialize};
6
7pub struct OpenRouterProvider {
9 api_key: String,
10 client: reqwest::Client,
11}
12
13#[derive(Serialize)]
14struct OpenRouterRequest {
15 model: String,
16 messages: Vec<ORMessage>,
17 temperature: f64,
18 max_tokens: u32,
19}
20
21#[derive(Serialize)]
22struct ORMessage {
23 role: String,
24 content: String,
25}
26
27#[derive(Deserialize)]
28struct OpenRouterResponse {
29 choices: Vec<ORChoice>,
30 model: String,
31 usage: Option<ORUsage>,
32}
33
34#[derive(Deserialize)]
35struct ORChoice {
36 message: ORChoiceMessage,
37}
38
39#[derive(Deserialize)]
40struct ORChoiceMessage {
41 content: Option<String>,
42}
43
44#[derive(Deserialize)]
45struct ORUsage {
46 prompt_tokens: u32,
47 completion_tokens: u32,
48}
49
50impl OpenRouterProvider {
51 pub fn new(api_key: Option<String>) -> Result<Self> {
53 let key = api_key
54 .or_else(|| std::env::var("OPENROUTER_API_KEY").ok())
55 .context("OPENROUTER_API_KEY not set")?;
56
57 Ok(Self {
58 api_key: key,
59 client: reqwest::Client::new(),
60 })
61 }
62
63 pub async fn complete(
65 &self,
66 prompt: &str,
67 model: &str,
68 temperature: f64,
69 max_tokens: u32,
70 ) -> Result<ModelResponse> {
71 let request = OpenRouterRequest {
72 model: model.to_string(),
73 messages: vec![ORMessage {
74 role: "user".into(),
75 content: prompt.to_string(),
76 }],
77 temperature,
78 max_tokens,
79 };
80
81 let response = self
82 .client
83 .post("https://openrouter.ai/api/v1/chat/completions")
84 .header("Authorization", format!("Bearer {}", self.api_key))
85 .header("HTTP-Referer", "https://mur.run")
86 .header("X-Title", "MUR Commander")
87 .json(&request)
88 .send()
89 .await
90 .context("Sending request to OpenRouter")?;
91
92 if !response.status().is_success() {
93 let status = response.status();
94 let body = response.text().await.unwrap_or_default();
95 anyhow::bail!("OpenRouter API error {}: {}", status, body);
96 }
97
98 let body: OpenRouterResponse = response
99 .json()
100 .await
101 .context("Parsing OpenRouter response")?;
102
103 let content = body
104 .choices
105 .into_iter()
106 .filter_map(|c| c.message.content)
107 .collect::<Vec<_>>()
108 .join("");
109
110 let (input_tokens, output_tokens) = body
111 .usage
112 .map(|u| (u.prompt_tokens, u.completion_tokens))
113 .unwrap_or((0, 0));
114
115 let cost = estimate_openrouter_cost(&body.model, input_tokens, output_tokens);
117
118 Ok(ModelResponse {
119 content,
120 model: body.model,
121 input_tokens,
122 output_tokens,
123 cost,
124 })
125 }
126}
127
128fn estimate_openrouter_cost(model: &str, input_tokens: u32, output_tokens: u32) -> f64 {
130 let (input_price, output_price) = match model {
131 m if m.contains("gemini-2.0-flash") => (0.1, 0.4),
132 m if m.contains("gemini-2.5-pro") => (1.25, 10.0),
133 m if m.contains("gpt-4o") => (2.5, 10.0),
134 m if m.contains("gpt-4o-mini") => (0.15, 0.6),
135 m if m.contains("llama") => (0.05, 0.05),
136 _ => (1.0, 2.0), };
138
139 (input_tokens as f64 / 1_000_000.0 * input_price)
140 + (output_tokens as f64 / 1_000_000.0 * output_price)
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146
147 #[test]
148 fn test_cost_estimation() {
149 let cost = estimate_openrouter_cost("google/gemini-2.0-flash-001", 1000, 500);
150 assert!(cost > 0.0);
151 assert!(cost < 0.001); let pro_cost = estimate_openrouter_cost("google/gemini-2.5-pro", 1000, 500);
154 assert!(pro_cost > cost);
155 }
156}