Skip to main content

openrouter_rust/
chat.rs

1use crate::{
2    client::OpenRouterClient,
3    error::{OpenRouterError, Result},
4    types::{Message, Plugin, ProviderPreferences, ResponseFormat, Tool, ToolChoice, Usage},
5};
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct ChatCompletionRequest {
10    pub model: String,
11    pub messages: Vec<Message>,
12    #[serde(skip_serializing_if = "Option::is_none")]
13    pub temperature: Option<f32>,
14    #[serde(skip_serializing_if = "Option::is_none")]
15    pub top_p: Option<f32>,
16    #[serde(skip_serializing_if = "Option::is_none")]
17    pub max_tokens: Option<u32>,
18    #[serde(skip_serializing_if = "Option::is_none")]
19    pub stop: Option<Vec<String>>,
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub stream: Option<bool>,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub tools: Option<Vec<Tool>>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub tool_choice: Option<ToolChoice>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub response_format: Option<ResponseFormat>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub seed: Option<i64>,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub top_k: Option<u32>,
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub frequency_penalty: Option<f32>,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub presence_penalty: Option<f32>,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub repetition_penalty: Option<f32>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub provider: Option<ProviderPreferences>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub plugins: Option<Vec<Plugin>>,
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub transforms: Option<Vec<String>>,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub models: Option<Vec<String>>,
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub route: Option<String>,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct ChatCompletionResponse {
52    pub id: String,
53    pub object: String,
54    pub created: i64,
55    pub model: String,
56    pub choices: Vec<Choice>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub usage: Option<Usage>,
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub system_fingerprint: Option<String>,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct Choice {
65    pub index: u32,
66    pub message: Message,
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub finish_reason: Option<String>,
69    #[serde(skip_serializing_if = "Option::is_none")]
70    pub native_finish_reason: Option<String>,
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub error: Option<ChoiceError>,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct ChoiceError {
77    pub code: u16,
78    pub message: String,
79}
80
81pub struct ChatCompletionBuilder {
82    request: ChatCompletionRequest,
83}
84
85impl ChatCompletionBuilder {
86    pub fn new(model: impl Into<String>) -> Self {
87        Self {
88            request: ChatCompletionRequest {
89                model: model.into(),
90                messages: Vec::new(),
91                temperature: None,
92                top_p: None,
93                max_tokens: None,
94                stop: None,
95                stream: None,
96                tools: None,
97                tool_choice: None,
98                response_format: None,
99                seed: None,
100                top_k: None,
101                frequency_penalty: None,
102                presence_penalty: None,
103                repetition_penalty: None,
104                provider: None,
105                plugins: None,
106                transforms: None,
107                models: None,
108                route: None,
109            },
110        }
111    }
112
113    pub fn message(mut self, role: crate::types::Role, content: impl Into<String>) -> Self {
114        self.request.messages.push(Message {
115            role,
116            content: Some(content.into()),
117            name: None,
118            tool_calls: None,
119        });
120        self
121    }
122
123    pub fn system_message(self, content: impl Into<String>) -> Self {
124        self.message(crate::types::Role::System, content)
125    }
126
127    pub fn user_message(self, content: impl Into<String>) -> Self {
128        self.message(crate::types::Role::User, content)
129    }
130
131    pub fn assistant_message(self, content: impl Into<String>) -> Self {
132        self.message(crate::types::Role::Assistant, content)
133    }
134
135    pub fn temperature(mut self, temp: f32) -> Self {
136        self.request.temperature = Some(temp);
137        self
138    }
139
140    pub fn top_p(mut self, top_p: f32) -> Self {
141        self.request.top_p = Some(top_p);
142        self
143    }
144
145    pub fn max_tokens(mut self, max: u32) -> Self {
146        self.request.max_tokens = Some(max);
147        self
148    }
149
150    pub fn stop(mut self, stop: Vec<String>) -> Self {
151        self.request.stop = Some(stop);
152        self
153    }
154
155    pub fn stream(mut self, stream: bool) -> Self {
156        self.request.stream = Some(stream);
157        self
158    }
159
160    pub fn tools(mut self, tools: Vec<Tool>) -> Self {
161        self.request.tools = Some(tools);
162        self
163    }
164
165    pub fn response_format_json(mut self) -> Self {
166        self.request.response_format = Some(ResponseFormat {
167            response_type: "json_object".to_string(),
168            json_schema: None,
169        });
170        self
171    }
172
173    pub fn build(self) -> ChatCompletionRequest {
174        self.request
175    }
176}
177
178impl OpenRouterClient {
179    pub async fn chat_completion(
180        &self,
181        request: ChatCompletionRequest,
182    ) -> Result<ChatCompletionResponse> {
183        let url = format!("{}/chat/completions", self.base_url);
184        let headers = self.build_headers()?;
185
186        let response = self
187            .client
188            .post(&url)
189            .headers(headers)
190            .json(&request)
191            .send()
192            .await
193            .map_err(OpenRouterError::HttpError)?;
194
195        let status = response.status();
196        
197        if !status.is_success() {
198            let error_text = response.text().await.unwrap_or_default();
199            return Err(OpenRouterError::ApiError {
200                code: status.as_u16(),
201                message: error_text,
202            });
203        }
204
205        let completion = response
206            .json::<ChatCompletionResponse>()
207            .await
208            .map_err(OpenRouterError::HttpError)?;
209
210        Ok(completion)
211    }
212}