1use async_stream::try_stream;
2
3use futures_core::Stream;
4use futures_util::StreamExt;
5use reqwest::{
6 Client as HttpClient,
7 header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue},
8};
9use std::time::Duration;
10
11use artificial_core::provider::{TranscriptionRequest, TranscriptionResult};
12
13use crate::{
14 api_v1::{
15 AudioTranscriptionResponse, ChatCompletionChunkResponse, ChatCompletionRequest,
16 ChatCompletionResponse,
17 },
18 error::{OpenAiError, OpenAiRateLimitHeaders},
19};
20
21fn parse_retry_after_seconds(headers: &reqwest::header::HeaderMap) -> Duration {
22 use reqwest::header::RETRY_AFTER;
23 if let Some(val) = headers.get(RETRY_AFTER).and_then(|hv| hv.to_str().ok())
24 && let Ok(secs) = val.trim().parse::<u64>()
25 {
26 return Duration::from_secs(secs);
27 }
28 Duration::from_secs(0)
29}
30
31fn header_u32(headers: &reqwest::header::HeaderMap, name: &str) -> Option<u32> {
32 headers
33 .get(name)
34 .and_then(|hv| hv.to_str().ok())
35 .and_then(|s| s.parse::<u32>().ok())
36}
37
38fn header_string(headers: &reqwest::header::HeaderMap, name: &str) -> Option<String> {
39 headers
40 .get(name)
41 .and_then(|hv| hv.to_str().ok())
42 .map(|s| s.to_string())
43}
44
45fn extract_rate_limit_info(
46 headers: &reqwest::header::HeaderMap,
47) -> (Option<Duration>, Option<String>, OpenAiRateLimitHeaders) {
48 let retry_after = {
49 let d = parse_retry_after_seconds(headers);
50 if d.as_secs() > 0 { Some(d) } else { None }
51 };
52
53 let info = OpenAiRateLimitHeaders {
54 limit_requests: header_u32(headers, "x-ratelimit-limit-requests"),
55 remaining_requests: header_u32(headers, "x-ratelimit-remaining-requests"),
56 reset_requests: header_string(headers, "x-ratelimit-reset-requests"),
57 limit_tokens: header_u32(headers, "x-ratelimit-limit-tokens"),
58 remaining_tokens: header_u32(headers, "x-ratelimit-remaining-tokens"),
59 reset_tokens: header_string(headers, "x-ratelimit-reset-tokens"),
60 };
61
62 let reset_at = info
64 .reset_requests
65 .clone()
66 .or_else(|| info.reset_tokens.clone());
67
68 (retry_after, reset_at, info)
69}
70#[cfg(feature = "tracing")]
71fn log_rate_limit_tight(headers: &reqwest::header::HeaderMap, context: &str) {
72 let rem_reqs = header_u32(headers, "x-ratelimit-remaining-requests").unwrap_or(u32::MAX);
73 let rem_tokens = header_u32(headers, "x-ratelimit-remaining-tokens").unwrap_or(u32::MAX);
74 let lim_reqs = header_u32(headers, "x-ratelimit-limit-requests").unwrap_or(0);
75 let lim_tokens = header_u32(headers, "x-ratelimit-limit-tokens").unwrap_or(0);
76
77 let tight_reqs = rem_reqs <= 2 || (lim_reqs > 0 && rem_reqs as f32 / lim_reqs as f32 <= 0.05);
79 let tight_tokens =
80 rem_tokens <= 128 || (lim_tokens > 0 && rem_tokens as f32 / lim_tokens as f32 <= 0.05);
81
82 if tight_reqs || tight_tokens {
83 tracing::warn!(
84 context,
85 remaining_requests = rem_reqs,
86 limit_requests = lim_reqs,
87 remaining_tokens = rem_tokens,
88 limit_tokens = lim_tokens,
89 "rate limit headroom is tight"
90 );
91 } else {
92 tracing::debug!(
93 context,
94 remaining_requests = rem_reqs,
95 limit_requests = lim_reqs,
96 remaining_tokens = rem_tokens,
97 limit_tokens = lim_tokens,
98 "rate limit status"
99 );
100 }
101}
102
103#[derive(Clone, Debug)]
104pub struct RetryPolicy {
105 pub max_retries: u32,
106 pub base_delay: Duration,
107 pub max_delay: Duration,
108 pub respect_retry_after: bool,
109}
110
111impl Default for RetryPolicy {
112 fn default() -> Self {
113 Self {
114 max_retries: 3,
115 base_delay: Duration::from_millis(500),
116 max_delay: Duration::from_secs(30),
117 respect_retry_after: true,
118 }
119 }
120}
121
122impl RetryPolicy {
123 fn backoff_for(&self, attempt: u32) -> Duration {
124 let pow = attempt.min(10);
125 let backoff = self.base_delay.saturating_mul(1 << pow);
126 if backoff > self.max_delay {
127 self.max_delay
128 } else {
129 backoff
130 }
131 }
132}
133
134#[derive(Clone, Debug)]
135pub struct HttpTimeoutConfig {
136 pub connect_timeout: Option<Duration>,
138 pub request_timeout: Option<Duration>,
140 pub stream_timeout: Option<Duration>,
144}
145
146impl Default for HttpTimeoutConfig {
147 fn default() -> Self {
148 Self {
149 connect_timeout: Some(Duration::from_secs(10)),
150 request_timeout: Some(Duration::from_secs(30)),
151 stream_timeout: None,
152 }
153 }
154}
155
156const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
157
158#[derive(Clone)]
165pub struct OpenAiClient {
166 api_key: String,
167 http: HttpClient,
168 base: String,
169 retry: RetryPolicy,
170 timeouts: HttpTimeoutConfig,
171}
172
173impl OpenAiClient {
174 pub fn new(api_key: impl Into<String>) -> Self {
176 Self::new_with_timeouts(api_key, HttpTimeoutConfig::default())
177 }
178
179 pub fn new_with_timeouts(api_key: impl Into<String>, timeouts: HttpTimeoutConfig) -> Self {
181 let mut builder = HttpClient::builder();
182 if let Some(connect_timeout) = timeouts.connect_timeout {
183 builder = builder.connect_timeout(connect_timeout);
184 }
185 let http = builder.build().expect("building reqwest client");
186
187 Self::with_http_and_timeouts(api_key, http, None, timeouts)
188 }
189
190 #[allow(dead_code)]
193 pub fn with_http(
194 api_key: impl Into<String>,
195 http: HttpClient,
196 base_url: Option<String>,
197 ) -> Self {
198 Self::with_http_and_timeouts(api_key, http, base_url, HttpTimeoutConfig::default())
199 }
200
201 pub fn with_http_and_timeouts(
203 api_key: impl Into<String>,
204 http: HttpClient,
205 base_url: Option<String>,
206 timeouts: HttpTimeoutConfig,
207 ) -> Self {
208 Self {
209 api_key: api_key.into(),
210 http,
211 base: base_url.unwrap_or_else(|| DEFAULT_BASE_URL.to_owned()),
212 retry: RetryPolicy::default(),
213 timeouts,
214 }
215 }
216
217 pub fn with_retry_policy(mut self, retry: RetryPolicy) -> Self {
219 self.retry = retry;
220 self
221 }
222
223 async fn post_json_with_retry(
225 &self,
226 url: String,
227 headers: HeaderMap,
228 request: &ChatCompletionRequest,
229 request_timeout: Option<Duration>,
230 ) -> Result<reqwest::Response, OpenAiError> {
231 let mut attempt: u32 = 0;
232 loop {
233 let mut req = self
234 .http
235 .post(url.clone())
236 .headers(headers.clone())
237 .json(request);
238 if let Some(timeout) = request_timeout {
239 req = req.timeout(timeout);
240 }
241 let res = req.send().await;
242
243 match res {
244 Ok(resp) => {
245 let status = resp.status();
246 if status.is_success() {
247 #[cfg(feature = "tracing")]
248 {
249 log_rate_limit_tight(resp.headers(), "success");
250 }
251 return Ok(resp);
252 }
253
254 let should_retry = status == reqwest::StatusCode::TOO_MANY_REQUESTS
255 || status.is_server_error();
256
257 if should_retry && attempt < self.retry.max_retries {
258 let mut delay = self.retry.backoff_for(attempt);
259 #[allow(unused_assignments)]
260 let mut hdr_delay = Duration::from_secs(0);
261 if self.retry.respect_retry_after {
262 hdr_delay = parse_retry_after_seconds(resp.headers());
263 if hdr_delay > delay {
264 delay = hdr_delay;
265 }
266 }
267 #[cfg(feature = "tracing")]
268 {
269 tracing::info!(
270 attempt = attempt,
271 status = %status,
272 backoff_ms = delay.as_millis() as u64,
273 retry_after_ms = hdr_delay.as_millis() as u64,
274 "retrying request due to transient status"
275 );
276 log_rate_limit_tight(resp.headers(), "retrying");
277 }
278 std::thread::sleep(delay);
280 attempt += 1;
281 continue;
282 } else {
283 let status = resp.status();
284 let headers_map = resp.headers().clone();
285 let body = resp.text().await.unwrap_or_default();
286 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
287 let (retry_after, reset_at, headers) =
288 extract_rate_limit_info(&headers_map);
289 #[cfg(feature = "tracing")]
290 {
291 let ra_ms = retry_after.map(|d| d.as_millis() as u64);
292 tracing::warn!(
293 status = %status,
294 retry_after_ms = ?ra_ms,
295 reset_at = ?reset_at,
296 "rate limited; giving up after retries"
297 );
298 }
299 return Err(OpenAiError::RateLimited {
300 status,
301 body,
302 retry_after,
303 reset_at,
304 headers,
305 });
306 } else {
307 return Err(OpenAiError::Api { status, body });
308 }
309 }
310 }
311 Err(err) => {
312 if attempt < self.retry.max_retries
314 && (err.is_timeout() || err.is_connect() || !err.is_status())
315 {
316 let delay = self.retry.backoff_for(attempt);
317 #[cfg(feature = "tracing")]
318 {
319 tracing::info!(
320 attempt = attempt,
321 backoff_ms = delay.as_millis() as u64,
322 "retrying after transport error"
323 );
324 }
325 std::thread::sleep(delay);
326 attempt += 1;
327 continue;
328 } else {
329 return Err(OpenAiError::Http(err));
330 }
331 }
332 }
333 }
334 }
335
336 pub async fn chat_completion(
338 &self,
339 request: ChatCompletionRequest,
340 ) -> Result<ChatCompletionResponse, OpenAiError> {
341 let mut headers = HeaderMap::new();
343 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
344 headers.insert(
345 AUTHORIZATION,
346 HeaderValue::from_str(&format!("Bearer {}", self.api_key)).unwrap(),
347 );
348
349 let url = format!("{}/chat/completions", self.base);
350 let resp = self
351 .post_json_with_retry(url, headers, &request, self.timeouts.request_timeout)
352 .await?;
353
354 let bytes = resp.bytes().await?;
355 let parsed: ChatCompletionResponse = serde_json::from_slice(&bytes)?;
356 Ok(parsed)
357 }
358
359 pub fn chat_completion_stream(
361 &self,
362 mut request: ChatCompletionRequest,
363 ) -> impl Stream<Item = Result<ChatCompletionChunkResponse, OpenAiError>> + '_ {
364 use reqwest::header::{ACCEPT, HeaderValue};
365
366 request.stream = Some(true);
368
369 let mut headers = HeaderMap::new();
371 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
372 headers.insert(
373 AUTHORIZATION,
374 HeaderValue::from_str(&format!("Bearer {}", self.api_key)).unwrap(),
375 );
376 headers.insert(ACCEPT, HeaderValue::from_static("text/event-stream"));
377
378 let url = format!("{}/chat/completions", self.base);
379
380 try_stream! {
382 let resp = self
383 .post_json_with_retry(url, headers, &request, self.timeouts.stream_timeout)
384 .await?;
385
386 let mut bytes_stream = resp.bytes_stream();
387 let mut buf = Vec::new();
388
389 while let Some(chunk) = bytes_stream.next().await {
390 let chunk = chunk?;
391 buf.extend_from_slice(&chunk);
392
393 while let Some(pos) = buf.windows(2).position(|w| w == b"\n\n") {
394 let frame: Vec<u8> = buf.drain(..pos + 2).collect();
395 let frame_str = std::str::from_utf8(&frame)?;
396
397 if let Some(data) = frame_str.strip_prefix("data: ") {
398 let data = data.trim();
399 if data == "[DONE]" { return; }
400
401 let parsed: ChatCompletionChunkResponse = serde_json::from_str(data)?;
402 yield parsed;
403 }
404 }
405 }
406 }
407 }
408
409 pub async fn audio_transcription(
411 &self,
412 request: TranscriptionRequest,
413 ) -> Result<TranscriptionResult, OpenAiError> {
414 if request.audio.is_empty() {
415 return Err(OpenAiError::Format(
416 "audio payload must not be empty".into(),
417 ));
418 }
419 if request.mime_type.trim().is_empty() {
420 return Err(OpenAiError::Format("mime_type must not be empty".into()));
421 }
422
423 use reqwest::multipart::{Form, Part};
424 let mut headers = HeaderMap::new();
425 headers.insert(
426 AUTHORIZATION,
427 HeaderValue::from_str(&format!("Bearer {}", self.api_key)).unwrap(),
428 );
429
430 let filename = request.filename.unwrap_or_else(|| "audio.wav".to_string());
431 let file_part = Part::bytes(request.audio)
432 .file_name(filename)
433 .mime_str(&request.mime_type)
434 .map_err(|e| OpenAiError::Format(format!("invalid mime type: {e}")))?;
435
436 let mut form = Form::new().part("file", file_part).text(
437 "model",
438 request
439 .model
440 .unwrap_or_else(|| "gpt-4o-mini-transcribe".to_string()),
441 );
442
443 if let Some(language) = request.language {
444 form = form.text("language", language);
445 }
446 if let Some(prompt) = request.prompt {
447 form = form.text("prompt", prompt);
448 }
449
450 let url = format!("{}/audio/transcriptions", self.base);
451 let mut req = self.http.post(url).headers(headers).multipart(form);
452 if let Some(timeout) = self.timeouts.request_timeout {
453 req = req.timeout(timeout);
454 }
455 let resp = req.send().await?;
456
457 if !resp.status().is_success() {
458 let status = resp.status();
459 let headers_map = resp.headers().clone();
460 let body = resp.text().await.unwrap_or_default();
461 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
462 let (retry_after, reset_at, headers) = extract_rate_limit_info(&headers_map);
463 return Err(OpenAiError::RateLimited {
464 status,
465 body,
466 retry_after,
467 reset_at,
468 headers,
469 });
470 }
471 return Err(OpenAiError::Api { status, body });
472 }
473
474 let bytes = resp.bytes().await?;
475 let parsed: AudioTranscriptionResponse = serde_json::from_slice(&bytes)?;
476 Ok(parsed.into())
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483 use std::{
484 io::{Read, Write},
485 net::TcpListener,
486 thread,
487 };
488
489 use crate::api_v1::{ChatCompletionMessage, Content, MessageRole};
490
491 fn run_single_response_server(delay: Duration, body: String, content_type: &str) -> String {
492 let listener = TcpListener::bind("127.0.0.1:0").expect("bind tcp listener");
493 let addr = listener.local_addr().expect("listener addr");
494 let response = format!(
495 "HTTP/1.1 200 OK\r\ncontent-type: {content_type}\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
496 body.len()
497 );
498
499 thread::spawn(move || {
500 let (mut stream, _) = listener.accept().expect("accept connection");
501 let _ = stream.set_read_timeout(Some(Duration::from_secs(2)));
502 let mut req_buf = [0_u8; 8192];
503 let _ = stream.read(&mut req_buf);
504 thread::sleep(delay);
505 stream
506 .write_all(response.as_bytes())
507 .expect("write response");
508 let _ = stream.flush();
509 });
510
511 format!("http://{addr}")
512 }
513
514 fn sample_request() -> ChatCompletionRequest {
515 ChatCompletionRequest::new(
516 "gpt-4o-mini".to_string(),
517 vec![ChatCompletionMessage {
518 role: MessageRole::User,
519 content: Some(Content::Text("hello".to_string())),
520 name: None,
521 tool_calls: None,
522 tool_call_id: None,
523 }],
524 )
525 }
526
527 #[tokio::test]
528 async fn non_streaming_respects_request_timeout() {
529 let base_url = run_single_response_server(
530 Duration::from_millis(200),
531 r#"{"id":"x","object":"chat.completion","created":0,"model":"gpt-4o-mini","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop","finish_details":null}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2},"system_fingerprint":null}"#.to_string(),
532 "application/json",
533 );
534
535 let client = OpenAiClient::with_http_and_timeouts(
536 "test-key",
537 reqwest::Client::new(),
538 Some(base_url),
539 HttpTimeoutConfig {
540 connect_timeout: Some(Duration::from_secs(1)),
541 request_timeout: Some(Duration::from_millis(50)),
542 stream_timeout: None,
543 },
544 )
545 .with_retry_policy(RetryPolicy {
546 max_retries: 0,
547 ..RetryPolicy::default()
548 });
549
550 let err = client
551 .chat_completion(sample_request())
552 .await
553 .expect_err("non-stream request should timeout");
554 match err {
555 OpenAiError::Http(inner) => assert!(inner.is_timeout()),
556 other => panic!("unexpected error: {other:?}"),
557 }
558 }
559
560 #[tokio::test]
561 async fn streaming_uses_stream_timeout_not_request_timeout() {
562 let sse_body = format!(
563 "data: {}\n\ndata: [DONE]\n\n",
564 r#"{"id":"x","object":"chat.completion.chunk","created":0,"model":"gpt-4o-mini","choices":[{"index":0,"delta":{"content":"ok"},"finish_reason":"stop"}]}"#
565 );
566 let base_url =
567 run_single_response_server(Duration::from_millis(120), sse_body, "text/event-stream");
568
569 let client = OpenAiClient::with_http_and_timeouts(
570 "test-key",
571 reqwest::Client::new(),
572 Some(base_url),
573 HttpTimeoutConfig {
574 connect_timeout: Some(Duration::from_secs(1)),
575 request_timeout: Some(Duration::from_millis(10)),
576 stream_timeout: None,
577 },
578 )
579 .with_retry_policy(RetryPolicy {
580 max_retries: 0,
581 ..RetryPolicy::default()
582 });
583
584 let mut stream = Box::pin(client.chat_completion_stream(sample_request()));
585 let first = stream
586 .next()
587 .await
588 .expect("stream should produce first chunk")
589 .expect("first chunk should parse");
590 assert_eq!(first.choices.len(), 1);
591 }
592
593 #[tokio::test]
594 async fn audio_transcription_parses_text_response() {
595 let base_url = run_single_response_server(
596 Duration::from_millis(0),
597 r#"{"text":"hello world","language":"en","duration":1.25}"#.to_string(),
598 "application/json",
599 );
600
601 let client = OpenAiClient::with_http_and_timeouts(
602 "test-key",
603 reqwest::Client::new(),
604 Some(base_url),
605 HttpTimeoutConfig::default(),
606 )
607 .with_retry_policy(RetryPolicy {
608 max_retries: 0,
609 ..RetryPolicy::default()
610 });
611
612 let result = client
613 .audio_transcription(
614 TranscriptionRequest::new(vec![1, 2, 3], "audio/wav")
615 .with_filename("clip.wav")
616 .with_model("gpt-4o-mini-transcribe"),
617 )
618 .await
619 .expect("transcription should succeed");
620
621 assert_eq!(result.text, "hello world");
622 assert_eq!(result.language.as_deref(), Some("en"));
623 assert_eq!(result.duration_seconds, Some(1.25));
624 }
625
626 #[tokio::test]
627 async fn audio_transcription_rejects_empty_audio() {
628 let client = OpenAiClient::new("test-key");
629 let err = client
630 .audio_transcription(TranscriptionRequest::new(Vec::new(), "audio/wav"))
631 .await
632 .expect_err("empty audio should fail validation");
633
634 match err {
635 OpenAiError::Format(msg) => assert!(msg.contains("audio payload")),
636 other => panic!("unexpected error: {other:?}"),
637 }
638 }
639}