llmkit_openai/
provider.rs1use std::time::{Duration, Instant};
4
5use async_trait::async_trait;
6use llmkit_core::{
7 pricing, ChatRequest, ChatResponse, ChatStream, CostEstimate, EmbedRequest, EmbedResponse,
8 LlmError, LlmProvider, LlmResult,
9};
10
11use crate::types::{ApiError, ChatCompletionResponse, EmbeddingsResponse};
12use crate::{chat, embed, stream};
13
14const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
15const DEFAULT_MODEL: &str = "gpt-4o-mini";
16const DEFAULT_EMBED_MODEL: &str = "text-embedding-3-small";
17
18#[derive(Clone)]
20pub struct OpenAiProvider {
21 http: reqwest::Client,
22 api_key: String,
23 base_url: String,
24 model: String,
25 embed_model: String,
26}
27
28impl OpenAiProvider {
29 pub fn new(api_key: impl Into<String>) -> Self {
31 Self {
32 http: reqwest::Client::new(),
33 api_key: api_key.into(),
34 base_url: DEFAULT_BASE_URL.to_string(),
35 model: DEFAULT_MODEL.to_string(),
36 embed_model: DEFAULT_EMBED_MODEL.to_string(),
37 }
38 }
39
40 pub fn from_env() -> LlmResult<Self> {
42 let key = std::env::var("OPENAI_API_KEY")
43 .map_err(|_| LlmError::Auth("OPENAI_API_KEY not set".into()))?;
44 Ok(Self::new(key))
45 }
46
47 pub fn model(mut self, model: impl Into<String>) -> Self {
49 self.model = model.into();
50 self
51 }
52
53 pub fn embed_model(mut self, model: impl Into<String>) -> Self {
55 self.embed_model = model.into();
56 self
57 }
58
59 pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
61 self.base_url = base_url.into();
62 self
63 }
64
65 pub fn with_client(mut self, client: reqwest::Client) -> Self {
67 self.http = client;
68 self
69 }
70
71 fn resolved_model(&self, req: &ChatRequest) -> String {
72 req.model.clone().unwrap_or_else(|| self.model.clone())
73 }
74}
75
76#[async_trait]
77impl LlmProvider for OpenAiProvider {
78 async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
79 let model = self.resolved_model(&req);
80 let body = chat::build_request(&req, model, false);
81
82 let start = Instant::now();
83 let resp = self
84 .http
85 .post(format!("{}/chat/completions", self.base_url))
86 .bearer_auth(&self.api_key)
87 .json(&body)
88 .send()
89 .await
90 .map_err(map_reqwest_err)?;
91
92 let resp = check_status(resp).await?;
93 let parsed: ChatCompletionResponse = resp.json().await.map_err(map_reqwest_err)?;
94 let mut out = chat::map_response(parsed, start.elapsed().as_millis() as u64)?;
95 out.cost = pricing::pricing_for(&out.model).map(|p| p.cost_for(out.usage));
96 Ok(out)
97 }
98
99 async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
100 let model = self.resolved_model(&req);
101 let body = chat::build_request(&req, model, true);
102
103 let resp = self
104 .http
105 .post(format!("{}/chat/completions", self.base_url))
106 .bearer_auth(&self.api_key)
107 .json(&body)
108 .send()
109 .await
110 .map_err(map_reqwest_err)?;
111
112 let resp = check_status(resp).await?;
113 Ok(stream::parse(resp))
114 }
115
116 async fn embed(&self, req: EmbedRequest) -> LlmResult<EmbedResponse> {
117 let model = req.model.clone().unwrap_or_else(|| self.embed_model.clone());
118 let body = embed::build_request(req.input, model);
119
120 let resp = self
121 .http
122 .post(format!("{}/embeddings", self.base_url))
123 .bearer_auth(&self.api_key)
124 .json(&body)
125 .send()
126 .await
127 .map_err(map_reqwest_err)?;
128
129 let resp = check_status(resp).await?;
130 let parsed: EmbeddingsResponse = resp.json().await.map_err(map_reqwest_err)?;
131 Ok(embed::map_response(parsed))
132 }
133
134 fn name(&self) -> &'static str {
135 "openai"
136 }
137
138 fn model(&self) -> &str {
139 &self.model
140 }
141
142 fn estimate_cost(&self, req: &ChatRequest) -> Option<CostEstimate> {
143 let model = self.resolved_model(req);
144 let pricing = pricing::pricing_for(&model)?;
145 let prompt_chars: usize = req.messages.iter().filter_map(|m| m.content.as_text()).map(|t| t.len()).sum();
148 let prompt_tokens = (prompt_chars / 4) as u32;
149 let completion_tokens = req.max_tokens.unwrap_or(256);
150 Some(pricing.cost_for(llmkit_core::TokenUsage::new(prompt_tokens, completion_tokens)))
151 }
152}
153
154pub(crate) fn map_reqwest_err(e: reqwest::Error) -> LlmError {
156 if e.is_timeout() {
157 LlmError::Timeout
158 } else if e.is_decode() {
159 LlmError::Serialization(e.to_string())
160 } else {
161 LlmError::Transport(e.to_string())
162 }
163}
164
165pub(crate) async fn check_status(resp: reqwest::Response) -> LlmResult<reqwest::Response> {
167 let status = resp.status();
168 if status.is_success() {
169 return Ok(resp);
170 }
171
172 let retry_after = resp
173 .headers()
174 .get(reqwest::header::RETRY_AFTER)
175 .and_then(|v| v.to_str().ok())
176 .and_then(|s| s.parse::<u64>().ok())
177 .map(Duration::from_secs);
178
179 let body = resp.text().await.unwrap_or_default();
180 let message = serde_json::from_str::<ApiError>(&body)
181 .map(|e| e.error.message)
182 .unwrap_or(body);
183
184 Err(match status.as_u16() {
185 401 | 403 => LlmError::Auth(message),
186 429 => LlmError::RateLimited { retry_after, message },
187 400 | 404 | 422 => LlmError::InvalidRequest(message),
188 code => LlmError::Provider { status: code, message },
189 })
190}