Skip to main content

mur_core/model/
openrouter.rs

1//! OpenRouter provider — access multiple models (Gemini Flash, etc.) via unified API.
2
3use super::provider::ModelResponse;
4use anyhow::{Context, Result};
5use serde::{Deserialize, Serialize};
6
7/// OpenRouter API provider.
8pub struct OpenRouterProvider {
9    api_key: String,
10    client: reqwest::Client,
11}
12
13#[derive(Serialize)]
14struct OpenRouterRequest {
15    model: String,
16    messages: Vec<ORMessage>,
17    temperature: f64,
18    max_tokens: u32,
19}
20
21#[derive(Serialize)]
22struct ORMessage {
23    role: String,
24    content: String,
25}
26
27#[derive(Deserialize)]
28struct OpenRouterResponse {
29    choices: Vec<ORChoice>,
30    model: String,
31    usage: Option<ORUsage>,
32}
33
34#[derive(Deserialize)]
35struct ORChoice {
36    message: ORChoiceMessage,
37}
38
39#[derive(Deserialize)]
40struct ORChoiceMessage {
41    content: Option<String>,
42}
43
44#[derive(Deserialize)]
45struct ORUsage {
46    prompt_tokens: u32,
47    completion_tokens: u32,
48}
49
50impl OpenRouterProvider {
51    /// Create a new provider. Reads API key from env if not provided.
52    pub fn new(api_key: Option<String>) -> Result<Self> {
53        let key = api_key
54            .or_else(|| std::env::var("OPENROUTER_API_KEY").ok())
55            .context("OPENROUTER_API_KEY not set")?;
56
57        Ok(Self {
58            api_key: key,
59            client: reqwest::Client::new(),
60        })
61    }
62
63    /// Send a completion request.
64    pub async fn complete(
65        &self,
66        prompt: &str,
67        model: &str,
68        temperature: f64,
69        max_tokens: u32,
70    ) -> Result<ModelResponse> {
71        let request = OpenRouterRequest {
72            model: model.to_string(),
73            messages: vec![ORMessage {
74                role: "user".into(),
75                content: prompt.to_string(),
76            }],
77            temperature,
78            max_tokens,
79        };
80
81        let response = self
82            .client
83            .post("https://openrouter.ai/api/v1/chat/completions")
84            .header("Authorization", format!("Bearer {}", self.api_key))
85            .header("HTTP-Referer", "https://mur.run")
86            .header("X-Title", "MUR Commander")
87            .json(&request)
88            .send()
89            .await
90            .context("Sending request to OpenRouter")?;
91
92        if !response.status().is_success() {
93            let status = response.status();
94            let body = response.text().await.unwrap_or_default();
95            anyhow::bail!("OpenRouter API error {}: {}", status, body);
96        }
97
98        let body: OpenRouterResponse = response
99            .json()
100            .await
101            .context("Parsing OpenRouter response")?;
102
103        let content = body
104            .choices
105            .into_iter()
106            .filter_map(|c| c.message.content)
107            .collect::<Vec<_>>()
108            .join("");
109
110        let (input_tokens, output_tokens) = body
111            .usage
112            .map(|u| (u.prompt_tokens, u.completion_tokens))
113            .unwrap_or((0, 0));
114
115        // OpenRouter pricing varies; estimate conservatively
116        let cost = estimate_openrouter_cost(&body.model, input_tokens, output_tokens);
117
118        Ok(ModelResponse {
119            content,
120            model: body.model,
121            input_tokens,
122            output_tokens,
123            cost,
124        })
125    }
126}
127
128/// Estimate cost for OpenRouter models.
129fn estimate_openrouter_cost(model: &str, input_tokens: u32, output_tokens: u32) -> f64 {
130    let (input_price, output_price) = match model {
131        m if m.contains("gemini-2.0-flash") => (0.1, 0.4),
132        m if m.contains("gemini-2.5-pro") => (1.25, 10.0),
133        m if m.contains("gpt-4o") => (2.5, 10.0),
134        m if m.contains("gpt-4o-mini") => (0.15, 0.6),
135        m if m.contains("llama") => (0.05, 0.05),
136        _ => (1.0, 2.0), // conservative default
137    };
138
139    (input_tokens as f64 / 1_000_000.0 * input_price)
140        + (output_tokens as f64 / 1_000_000.0 * output_price)
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146
147    #[test]
148    fn test_cost_estimation() {
149        let cost = estimate_openrouter_cost("google/gemini-2.0-flash-001", 1000, 500);
150        assert!(cost > 0.0);
151        assert!(cost < 0.001); // Flash is very cheap
152
153        let pro_cost = estimate_openrouter_cost("google/gemini-2.5-pro", 1000, 500);
154        assert!(pro_cost > cost);
155    }
156}