Skip to main content

llmkit_openai/
provider.rs

1//! [`OpenAiProvider`] — implements [`LlmProvider`] against the OpenAI REST API.
2
3use std::time::{Duration, Instant};
4
5use async_trait::async_trait;
6use llmkit_core::{
7    pricing, ChatRequest, ChatResponse, ChatStream, CostEstimate, EmbedRequest, EmbedResponse,
8    LlmError, LlmProvider, LlmResult,
9};
10
11use crate::types::{ApiError, ChatCompletionResponse, EmbeddingsResponse};
12use crate::{chat, embed, stream};
13
14const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
15const DEFAULT_MODEL: &str = "gpt-4o-mini";
16const DEFAULT_EMBED_MODEL: &str = "text-embedding-3-small";
17
18/// OpenAI provider (GPT-4o, o1, embeddings).
19#[derive(Clone)]
20pub struct OpenAiProvider {
21    http: reqwest::Client,
22    api_key: String,
23    base_url: String,
24    model: String,
25    embed_model: String,
26}
27
28impl OpenAiProvider {
29    /// Construct with an explicit API key.
30    pub fn new(api_key: impl Into<String>) -> Self {
31        Self {
32            http: reqwest::Client::new(),
33            api_key: api_key.into(),
34            base_url: DEFAULT_BASE_URL.to_string(),
35            model: DEFAULT_MODEL.to_string(),
36            embed_model: DEFAULT_EMBED_MODEL.to_string(),
37        }
38    }
39
40    /// Construct from the `OPENAI_API_KEY` environment variable.
41    pub fn from_env() -> LlmResult<Self> {
42        let key = std::env::var("OPENAI_API_KEY")
43            .map_err(|_| LlmError::Auth("OPENAI_API_KEY not set".into()))?;
44        Ok(Self::new(key))
45    }
46
47    /// Set the default chat model.
48    pub fn model(mut self, model: impl Into<String>) -> Self {
49        self.model = model.into();
50        self
51    }
52
53    /// Set the default embedding model.
54    pub fn embed_model(mut self, model: impl Into<String>) -> Self {
55        self.embed_model = model.into();
56        self
57    }
58
59    /// Override the base URL (e.g. for an OpenAI-compatible gateway).
60    pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
61        self.base_url = base_url.into();
62        self
63    }
64
65    /// Provide a custom [`reqwest::Client`] (connection pooling, timeouts, …).
66    pub fn with_client(mut self, client: reqwest::Client) -> Self {
67        self.http = client;
68        self
69    }
70
71    fn resolved_model(&self, req: &ChatRequest) -> String {
72        req.model.clone().unwrap_or_else(|| self.model.clone())
73    }
74}
75
76#[async_trait]
77impl LlmProvider for OpenAiProvider {
78    async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
79        let model = self.resolved_model(&req);
80        let body = chat::build_request(&req, model, false);
81
82        let start = Instant::now();
83        let resp = self
84            .http
85            .post(format!("{}/chat/completions", self.base_url))
86            .bearer_auth(&self.api_key)
87            .json(&body)
88            .send()
89            .await
90            .map_err(map_reqwest_err)?;
91
92        let resp = check_status(resp).await?;
93        let parsed: ChatCompletionResponse = resp.json().await.map_err(map_reqwest_err)?;
94        let mut out = chat::map_response(parsed, start.elapsed().as_millis() as u64)?;
95        out.cost = pricing::pricing_for(&out.model).map(|p| p.cost_for(out.usage));
96        Ok(out)
97    }
98
99    async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
100        let model = self.resolved_model(&req);
101        let body = chat::build_request(&req, model, true);
102
103        let resp = self
104            .http
105            .post(format!("{}/chat/completions", self.base_url))
106            .bearer_auth(&self.api_key)
107            .json(&body)
108            .send()
109            .await
110            .map_err(map_reqwest_err)?;
111
112        let resp = check_status(resp).await?;
113        Ok(stream::parse(resp))
114    }
115
116    async fn embed(&self, req: EmbedRequest) -> LlmResult<EmbedResponse> {
117        let model = req.model.clone().unwrap_or_else(|| self.embed_model.clone());
118        let body = embed::build_request(req.input, model);
119
120        let resp = self
121            .http
122            .post(format!("{}/embeddings", self.base_url))
123            .bearer_auth(&self.api_key)
124            .json(&body)
125            .send()
126            .await
127            .map_err(map_reqwest_err)?;
128
129        let resp = check_status(resp).await?;
130        let parsed: EmbeddingsResponse = resp.json().await.map_err(map_reqwest_err)?;
131        Ok(embed::map_response(parsed))
132    }
133
134    fn name(&self) -> &'static str {
135        "openai"
136    }
137
138    fn model(&self) -> &str {
139        &self.model
140    }
141
142    fn estimate_cost(&self, req: &ChatRequest) -> Option<CostEstimate> {
143        let model = self.resolved_model(req);
144        let pricing = pricing::pricing_for(&model)?;
145        // Rough pre-flight estimate: ~4 chars/token for the prompt, assume the
146        // response fills max_tokens (or a small default).
147        let prompt_chars: usize = req.messages.iter().filter_map(|m| m.content.as_text()).map(|t| t.len()).sum();
148        let prompt_tokens = (prompt_chars / 4) as u32;
149        let completion_tokens = req.max_tokens.unwrap_or(256);
150        Some(pricing.cost_for(llmkit_core::TokenUsage::new(prompt_tokens, completion_tokens)))
151    }
152}
153
154/// Map a `reqwest` transport error into an [`LlmError`].
155pub(crate) fn map_reqwest_err(e: reqwest::Error) -> LlmError {
156    if e.is_timeout() {
157        LlmError::Timeout
158    } else if e.is_decode() {
159        LlmError::Serialization(e.to_string())
160    } else {
161        LlmError::Transport(e.to_string())
162    }
163}
164
165/// Inspect the HTTP status and convert non-2xx responses into [`LlmError`].
166pub(crate) async fn check_status(resp: reqwest::Response) -> LlmResult<reqwest::Response> {
167    let status = resp.status();
168    if status.is_success() {
169        return Ok(resp);
170    }
171
172    let retry_after = resp
173        .headers()
174        .get(reqwest::header::RETRY_AFTER)
175        .and_then(|v| v.to_str().ok())
176        .and_then(|s| s.parse::<u64>().ok())
177        .map(Duration::from_secs);
178
179    let body = resp.text().await.unwrap_or_default();
180    let message = serde_json::from_str::<ApiError>(&body)
181        .map(|e| e.error.message)
182        .unwrap_or(body);
183
184    Err(match status.as_u16() {
185        401 | 403 => LlmError::Auth(message),
186        429 => LlmError::RateLimited { retry_after, message },
187        400 | 404 | 422 => LlmError::InvalidRequest(message),
188        code => LlmError::Provider { status: code, message },
189    })
190}