Skip to main content

converge_provider/
common.rs

1// Copyright 2024-2025 Aprio One AB, Sweden
2// Author: Kenneth Pernyer, kenneth@aprio.one
3// SPDX-License-Identifier: MIT
4// See LICENSE file in the project root for full license information.
5
6//! Common abstractions for LLM providers.
7//!
8//! This module provides shared types and utilities to reduce code duplication
9//! across provider implementations.
10
11use converge_traits::llm::{FinishReason, LlmError, LlmRequest, LlmResponse, TokenUsage};
12use reqwest::blocking::Client;
13use serde::{Deserialize, Serialize};
14
15/// Base configuration for HTTP-based LLM providers.
16#[derive(Debug, Clone)]
17pub struct HttpProviderConfig {
18    /// API key for authentication.
19    pub api_key: String,
20    /// Model identifier.
21    pub model: String,
22    /// Base URL for the API.
23    pub base_url: String,
24    /// HTTP client.
25    pub client: Client,
26}
27
28impl HttpProviderConfig {
29    /// Creates a new HTTP provider configuration.
30    #[must_use]
31    pub fn new(
32        api_key: impl Into<String>,
33        model: impl Into<String>,
34        base_url: impl Into<String>,
35    ) -> Self {
36        Self {
37            api_key: api_key.into(),
38            model: model.into(),
39            base_url: base_url.into(),
40            client: Client::new(),
41        }
42    }
43
44    /// Uses a custom HTTP client.
45    #[must_use]
46    pub fn with_client(mut self, client: Client) -> Self {
47        self.client = client;
48        self
49    }
50}
51
52/// OpenAI-compatible message format.
53#[derive(Serialize, Deserialize, Debug, Clone)]
54pub struct ChatMessage {
55    /// Message role (system, user, assistant).
56    pub role: String,
57    /// Message content.
58    pub content: String,
59}
60
61impl ChatMessage {
62    /// Creates a system message.
63    #[must_use]
64    pub fn system(content: impl Into<String>) -> Self {
65        Self {
66            role: "system".to_string(),
67            content: content.into(),
68        }
69    }
70
71    /// Creates a user message.
72    #[must_use]
73    pub fn user(content: impl Into<String>) -> Self {
74        Self {
75            role: "user".to_string(),
76            content: content.into(),
77        }
78    }
79
80    /// Creates an assistant message.
81    #[must_use]
82    pub fn assistant(content: impl Into<String>) -> Self {
83        Self {
84            role: "assistant".to_string(),
85            content: content.into(),
86        }
87    }
88}
89
90/// OpenAI-compatible chat completion request.
91#[derive(Serialize, Debug)]
92pub struct ChatCompletionRequest {
93    /// Model identifier.
94    pub model: String,
95    /// Messages in the conversation.
96    pub messages: Vec<ChatMessage>,
97    /// Maximum tokens to generate.
98    pub max_tokens: u32,
99    /// Temperature (0.0-2.0).
100    pub temperature: f64,
101    /// Stop sequences.
102    #[serde(skip_serializing_if = "Vec::is_empty")]
103    pub stop: Vec<String>,
104}
105
106impl ChatCompletionRequest {
107    /// Creates a request from an `LlmRequest`.
108    #[must_use]
109    pub fn from_llm_request(model: impl Into<String>, request: &LlmRequest) -> Self {
110        let mut messages = Vec::new();
111
112        if let Some(ref system) = request.system {
113            messages.push(ChatMessage::system(system));
114        }
115
116        messages.push(ChatMessage::user(&request.prompt));
117
118        Self {
119            model: model.into(),
120            messages,
121            max_tokens: request.max_tokens,
122            temperature: request.temperature,
123            stop: request.stop_sequences.clone(),
124        }
125    }
126}
127
128/// OpenAI-compatible choice in response.
129#[derive(Deserialize, Debug)]
130pub struct ChatChoice {
131    /// Message from the choice.
132    pub message: ChatChoiceMessage,
133    /// Finish reason.
134    pub finish_reason: Option<String>,
135}
136
137/// OpenAI-compatible message in choice.
138#[derive(Deserialize, Debug)]
139pub struct ChatChoiceMessage {
140    /// Message content.
141    pub content: String,
142}
143
144/// OpenAI-compatible usage statistics.
145#[derive(Deserialize, Debug)]
146pub struct ChatUsage {
147    /// Prompt tokens.
148    pub prompt_tokens: u32,
149    /// Completion tokens.
150    pub completion_tokens: u32,
151    /// Total tokens.
152    pub total_tokens: u32,
153}
154
155/// OpenAI-compatible chat completion response.
156#[derive(Deserialize, Debug)]
157pub struct ChatCompletionResponse {
158    /// Model used.
159    pub model: String,
160    /// Choices in the response.
161    pub choices: Vec<ChatChoice>,
162    /// Usage statistics.
163    pub usage: ChatUsage,
164}
165
166/// Converts a finish reason string to `FinishReason` enum.
167#[must_use]
168pub fn parse_finish_reason(reason: Option<&str>) -> FinishReason {
169    match reason {
170        Some("length" | "max_tokens") => FinishReason::MaxTokens,
171        Some("content_filter") => FinishReason::ContentFilter,
172        Some("stop_sequence") => FinishReason::StopSequence,
173        _ => FinishReason::Stop, // "stop" or unknown
174    }
175}
176
177/// Converts an OpenAI-compatible response to `LlmResponse`.
178///
179/// # Errors
180///
181/// Returns error if response has no choices.
182pub fn chat_response_to_llm_response(
183    response: ChatCompletionResponse,
184) -> Result<LlmResponse, LlmError> {
185    let choice = response
186        .choices
187        .first()
188        .ok_or_else(|| LlmError::provider("No choices in response"))?;
189
190    Ok(LlmResponse {
191        content: choice.message.content.clone(),
192        model: response.model,
193        finish_reason: parse_finish_reason(choice.finish_reason.as_deref()),
194        usage: TokenUsage {
195            prompt_tokens: response.usage.prompt_tokens,
196            completion_tokens: response.usage.completion_tokens,
197            total_tokens: response.usage.total_tokens,
198        },
199    })
200}
201
202/// Makes an OpenAI-compatible chat completion request.
203///
204/// # Errors
205///
206/// Returns error if the HTTP request fails or response cannot be parsed.
207pub fn make_chat_completion_request(
208    config: &HttpProviderConfig,
209    endpoint: &str,
210    request: &ChatCompletionRequest,
211) -> Result<LlmResponse, LlmError> {
212    let url = format!("{}{}", config.base_url, endpoint);
213
214    let http_response = config
215        .client
216        .post(&url)
217        .header("Authorization", format!("Bearer {}", config.api_key))
218        .header("Content-Type", "application/json")
219        .json(&request)
220        .send()
221        .map_err(|e| LlmError::network(format!("Request failed: {e}")))?;
222
223    let status = http_response.status();
224
225    if !status.is_success() {
226        return handle_openai_style_error(http_response);
227    }
228
229    let api_response: ChatCompletionResponse = http_response
230        .json()
231        .map_err(|e| LlmError::parse(format!("Failed to parse response: {e}")))?;
232
233    chat_response_to_llm_response(api_response)
234}
235
236/// OpenAI-compatible error response format.
237#[derive(Deserialize, Debug)]
238pub struct OpenAiStyleError {
239    /// Error details.
240    pub error: OpenAiStyleErrorDetail,
241}
242
243/// Error detail in OpenAI-compatible format.
244#[derive(Deserialize, Debug)]
245pub struct OpenAiStyleErrorDetail {
246    /// Error message.
247    pub message: String,
248    /// Error type (e.g., "`authentication_error`", "`rate_limit_error`").
249    #[serde(rename = "type")]
250    pub error_type: Option<String>,
251}
252
253/// Handles HTTP error responses for OpenAI-compatible providers.
254///
255/// Parses the error response and maps error types to appropriate `LlmError` kinds.
256///
257/// # Errors
258///
259/// Returns error if:
260/// - Response cannot be parsed as JSON
261/// - Error type indicates authentication failure → `LlmError::auth()`
262/// - Error type indicates rate limit → `LlmError::rate_limit()`
263/// - Other errors → `LlmError::provider()`
264pub fn handle_openai_style_error(
265    http_response: reqwest::blocking::Response,
266) -> Result<LlmResponse, LlmError> {
267    let error_body: OpenAiStyleError = http_response
268        .json()
269        .map_err(|e| LlmError::parse(format!("Failed to parse error: {e}")))?;
270
271    let error_type = error_body.error.error_type.as_deref().unwrap_or("unknown");
272    let message = error_body.error.message;
273
274    let llm_error = match error_type {
275        "invalid_request_error" | "authentication_error" => LlmError::auth(message),
276        "rate_limit_error" => LlmError::rate_limit(message),
277        _ => LlmError::provider(message),
278    };
279
280    Err(llm_error)
281}
282
283/// Helper for providers that use OpenAI-compatible API.
284///
285/// This trait can be implemented by providers to reduce boilerplate.
286pub trait OpenAiCompatibleProvider {
287    /// Gets the provider configuration.
288    fn config(&self) -> &HttpProviderConfig;
289
290    /// Gets the API endpoint path (e.g., "/v1/chat/completions").
291    fn endpoint(&self) -> &str;
292
293    /// Makes a completion request.
294    ///
295    /// Default implementation uses `make_chat_completion_request`.
296    ///
297    /// # Errors
298    ///
299    /// Returns error if the HTTP request fails or response cannot be parsed.
300    fn complete_openai_compatible(&self, request: &LlmRequest) -> Result<LlmResponse, LlmError> {
301        let chat_request =
302            ChatCompletionRequest::from_llm_request(self.config().model.clone(), request);
303        make_chat_completion_request(self.config(), self.endpoint(), &chat_request)
304    }
305}