1use converge_traits::llm::{FinishReason, LlmError, LlmRequest, LlmResponse, TokenUsage};
12use reqwest::blocking::Client;
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Clone)]
17pub struct HttpProviderConfig {
18 pub api_key: String,
20 pub model: String,
22 pub base_url: String,
24 pub client: Client,
26}
27
28impl HttpProviderConfig {
29 #[must_use]
31 pub fn new(
32 api_key: impl Into<String>,
33 model: impl Into<String>,
34 base_url: impl Into<String>,
35 ) -> Self {
36 Self {
37 api_key: api_key.into(),
38 model: model.into(),
39 base_url: base_url.into(),
40 client: Client::new(),
41 }
42 }
43
44 #[must_use]
46 pub fn with_client(mut self, client: Client) -> Self {
47 self.client = client;
48 self
49 }
50}
51
52#[derive(Serialize, Deserialize, Debug, Clone)]
54pub struct ChatMessage {
55 pub role: String,
57 pub content: String,
59}
60
61impl ChatMessage {
62 #[must_use]
64 pub fn system(content: impl Into<String>) -> Self {
65 Self {
66 role: "system".to_string(),
67 content: content.into(),
68 }
69 }
70
71 #[must_use]
73 pub fn user(content: impl Into<String>) -> Self {
74 Self {
75 role: "user".to_string(),
76 content: content.into(),
77 }
78 }
79
80 #[must_use]
82 pub fn assistant(content: impl Into<String>) -> Self {
83 Self {
84 role: "assistant".to_string(),
85 content: content.into(),
86 }
87 }
88}
89
90#[derive(Serialize, Debug)]
92pub struct ChatCompletionRequest {
93 pub model: String,
95 pub messages: Vec<ChatMessage>,
97 pub max_tokens: u32,
99 pub temperature: f64,
101 #[serde(skip_serializing_if = "Vec::is_empty")]
103 pub stop: Vec<String>,
104}
105
106impl ChatCompletionRequest {
107 #[must_use]
109 pub fn from_llm_request(model: impl Into<String>, request: &LlmRequest) -> Self {
110 let mut messages = Vec::new();
111
112 if let Some(ref system) = request.system {
113 messages.push(ChatMessage::system(system));
114 }
115
116 messages.push(ChatMessage::user(&request.prompt));
117
118 Self {
119 model: model.into(),
120 messages,
121 max_tokens: request.max_tokens,
122 temperature: request.temperature,
123 stop: request.stop_sequences.clone(),
124 }
125 }
126}
127
128#[derive(Deserialize, Debug)]
130pub struct ChatChoice {
131 pub message: ChatChoiceMessage,
133 pub finish_reason: Option<String>,
135}
136
137#[derive(Deserialize, Debug)]
139pub struct ChatChoiceMessage {
140 pub content: String,
142}
143
144#[derive(Deserialize, Debug)]
146pub struct ChatUsage {
147 pub prompt_tokens: u32,
149 pub completion_tokens: u32,
151 pub total_tokens: u32,
153}
154
155#[derive(Deserialize, Debug)]
157pub struct ChatCompletionResponse {
158 pub model: String,
160 pub choices: Vec<ChatChoice>,
162 pub usage: ChatUsage,
164}
165
166#[must_use]
168pub fn parse_finish_reason(reason: Option<&str>) -> FinishReason {
169 match reason {
170 Some("length" | "max_tokens") => FinishReason::MaxTokens,
171 Some("content_filter") => FinishReason::ContentFilter,
172 Some("stop_sequence") => FinishReason::StopSequence,
173 _ => FinishReason::Stop, }
175}
176
177pub fn chat_response_to_llm_response(
183 response: ChatCompletionResponse,
184) -> Result<LlmResponse, LlmError> {
185 let choice = response
186 .choices
187 .first()
188 .ok_or_else(|| LlmError::provider("No choices in response"))?;
189
190 Ok(LlmResponse {
191 content: choice.message.content.clone(),
192 model: response.model,
193 finish_reason: parse_finish_reason(choice.finish_reason.as_deref()),
194 usage: TokenUsage {
195 prompt_tokens: response.usage.prompt_tokens,
196 completion_tokens: response.usage.completion_tokens,
197 total_tokens: response.usage.total_tokens,
198 },
199 })
200}
201
202pub fn make_chat_completion_request(
208 config: &HttpProviderConfig,
209 endpoint: &str,
210 request: &ChatCompletionRequest,
211) -> Result<LlmResponse, LlmError> {
212 let url = format!("{}{}", config.base_url, endpoint);
213
214 let http_response = config
215 .client
216 .post(&url)
217 .header("Authorization", format!("Bearer {}", config.api_key))
218 .header("Content-Type", "application/json")
219 .json(&request)
220 .send()
221 .map_err(|e| LlmError::network(format!("Request failed: {e}")))?;
222
223 let status = http_response.status();
224
225 if !status.is_success() {
226 return handle_openai_style_error(http_response);
227 }
228
229 let api_response: ChatCompletionResponse = http_response
230 .json()
231 .map_err(|e| LlmError::parse(format!("Failed to parse response: {e}")))?;
232
233 chat_response_to_llm_response(api_response)
234}
235
236#[derive(Deserialize, Debug)]
238pub struct OpenAiStyleError {
239 pub error: OpenAiStyleErrorDetail,
241}
242
243#[derive(Deserialize, Debug)]
245pub struct OpenAiStyleErrorDetail {
246 pub message: String,
248 #[serde(rename = "type")]
250 pub error_type: Option<String>,
251}
252
253pub fn handle_openai_style_error(
265 http_response: reqwest::blocking::Response,
266) -> Result<LlmResponse, LlmError> {
267 let error_body: OpenAiStyleError = http_response
268 .json()
269 .map_err(|e| LlmError::parse(format!("Failed to parse error: {e}")))?;
270
271 let error_type = error_body.error.error_type.as_deref().unwrap_or("unknown");
272 let message = error_body.error.message;
273
274 let llm_error = match error_type {
275 "invalid_request_error" | "authentication_error" => LlmError::auth(message),
276 "rate_limit_error" => LlmError::rate_limit(message),
277 _ => LlmError::provider(message),
278 };
279
280 Err(llm_error)
281}
282
283pub trait OpenAiCompatibleProvider {
287 fn config(&self) -> &HttpProviderConfig;
289
290 fn endpoint(&self) -> &str;
292
293 fn complete_openai_compatible(&self, request: &LlmRequest) -> Result<LlmResponse, LlmError> {
301 let chat_request =
302 ChatCompletionRequest::from_llm_request(self.config().model.clone(), request);
303 make_chat_completion_request(self.config(), self.endpoint(), &chat_request)
304 }
305}