Skip to main content

artificial_openai/
client.rs

1use async_stream::try_stream;
2
3use futures_core::Stream;
4use futures_util::StreamExt;
5use reqwest::{
6    Client as HttpClient,
7    header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue},
8};
9use std::time::Duration;
10
11use artificial_core::provider::{TranscriptionRequest, TranscriptionResult};
12
13use crate::{
14    api_v1::{
15        AudioTranscriptionResponse, ChatCompletionChunkResponse, ChatCompletionRequest,
16        ChatCompletionResponse,
17    },
18    error::{OpenAiError, OpenAiRateLimitHeaders},
19};
20
21fn parse_retry_after_seconds(headers: &reqwest::header::HeaderMap) -> Duration {
22    use reqwest::header::RETRY_AFTER;
23    if let Some(val) = headers.get(RETRY_AFTER).and_then(|hv| hv.to_str().ok())
24        && let Ok(secs) = val.trim().parse::<u64>()
25    {
26        return Duration::from_secs(secs);
27    }
28    Duration::from_secs(0)
29}
30
31fn header_u32(headers: &reqwest::header::HeaderMap, name: &str) -> Option<u32> {
32    headers
33        .get(name)
34        .and_then(|hv| hv.to_str().ok())
35        .and_then(|s| s.parse::<u32>().ok())
36}
37
38fn header_string(headers: &reqwest::header::HeaderMap, name: &str) -> Option<String> {
39    headers
40        .get(name)
41        .and_then(|hv| hv.to_str().ok())
42        .map(|s| s.to_string())
43}
44
45fn extract_rate_limit_info(
46    headers: &reqwest::header::HeaderMap,
47) -> (Option<Duration>, Option<String>, OpenAiRateLimitHeaders) {
48    let retry_after = {
49        let d = parse_retry_after_seconds(headers);
50        if d.as_secs() > 0 { Some(d) } else { None }
51    };
52
53    let info = OpenAiRateLimitHeaders {
54        limit_requests: header_u32(headers, "x-ratelimit-limit-requests"),
55        remaining_requests: header_u32(headers, "x-ratelimit-remaining-requests"),
56        reset_requests: header_string(headers, "x-ratelimit-reset-requests"),
57        limit_tokens: header_u32(headers, "x-ratelimit-limit-tokens"),
58        remaining_tokens: header_u32(headers, "x-ratelimit-remaining-tokens"),
59        reset_tokens: header_string(headers, "x-ratelimit-reset-tokens"),
60    };
61
62    // Prefer request reset, fall back to token reset.
63    let reset_at = info
64        .reset_requests
65        .clone()
66        .or_else(|| info.reset_tokens.clone());
67
68    (retry_after, reset_at, info)
69}
70#[cfg(feature = "tracing")]
71fn log_rate_limit_tight(headers: &reqwest::header::HeaderMap, context: &str) {
72    let rem_reqs = header_u32(headers, "x-ratelimit-remaining-requests").unwrap_or(u32::MAX);
73    let rem_tokens = header_u32(headers, "x-ratelimit-remaining-tokens").unwrap_or(u32::MAX);
74    let lim_reqs = header_u32(headers, "x-ratelimit-limit-requests").unwrap_or(0);
75    let lim_tokens = header_u32(headers, "x-ratelimit-limit-tokens").unwrap_or(0);
76
77    // Heuristics: warn when headroom is tight
78    let tight_reqs = rem_reqs <= 2 || (lim_reqs > 0 && rem_reqs as f32 / lim_reqs as f32 <= 0.05);
79    let tight_tokens =
80        rem_tokens <= 128 || (lim_tokens > 0 && rem_tokens as f32 / lim_tokens as f32 <= 0.05);
81
82    if tight_reqs || tight_tokens {
83        tracing::warn!(
84            context,
85            remaining_requests = rem_reqs,
86            limit_requests = lim_reqs,
87            remaining_tokens = rem_tokens,
88            limit_tokens = lim_tokens,
89            "rate limit headroom is tight"
90        );
91    } else {
92        tracing::debug!(
93            context,
94            remaining_requests = rem_reqs,
95            limit_requests = lim_reqs,
96            remaining_tokens = rem_tokens,
97            limit_tokens = lim_tokens,
98            "rate limit status"
99        );
100    }
101}
102
103#[derive(Clone, Debug)]
104pub struct RetryPolicy {
105    pub max_retries: u32,
106    pub base_delay: Duration,
107    pub max_delay: Duration,
108    pub respect_retry_after: bool,
109}
110
111impl Default for RetryPolicy {
112    fn default() -> Self {
113        Self {
114            max_retries: 3,
115            base_delay: Duration::from_millis(500),
116            max_delay: Duration::from_secs(30),
117            respect_retry_after: true,
118        }
119    }
120}
121
122impl RetryPolicy {
123    fn backoff_for(&self, attempt: u32) -> Duration {
124        let pow = attempt.min(10);
125        let backoff = self.base_delay.saturating_mul(1 << pow);
126        if backoff > self.max_delay {
127            self.max_delay
128        } else {
129            backoff
130        }
131    }
132}
133
134#[derive(Clone, Debug)]
135pub struct HttpTimeoutConfig {
136    /// TCP/TLS connection timeout.
137    pub connect_timeout: Option<Duration>,
138    /// Total timeout for non-streaming requests.
139    pub request_timeout: Option<Duration>,
140    /// Total timeout for streaming requests.
141    ///
142    /// `None` keeps streams open indefinitely (default).
143    pub stream_timeout: Option<Duration>,
144}
145
146impl Default for HttpTimeoutConfig {
147    fn default() -> Self {
148        Self {
149            connect_timeout: Some(Duration::from_secs(10)),
150            request_timeout: Some(Duration::from_secs(30)),
151            stream_timeout: None,
152        }
153    }
154}
155
156const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
157
158/// Minimal HTTP client for OpenAI’s *chat/completions* endpoint.
159///
160/// * Non-streaming only (one request ▶ one response).
161/// * Accepts and returns the `api_v1` request / response structs defined
162///   in this crate.
163/// * Shares a single `reqwest::Client`, so cloning `OpenAiClient` is cheap.
164#[derive(Clone)]
165pub struct OpenAiClient {
166    api_key: String,
167    http: HttpClient,
168    base: String,
169    retry: RetryPolicy,
170    timeouts: HttpTimeoutConfig,
171}
172
173impl OpenAiClient {
174    /// Convenience constructor building a default `reqwest` client.
175    pub fn new(api_key: impl Into<String>) -> Self {
176        Self::new_with_timeouts(api_key, HttpTimeoutConfig::default())
177    }
178
179    /// Convenience constructor with explicit timeout configuration.
180    pub fn new_with_timeouts(api_key: impl Into<String>, timeouts: HttpTimeoutConfig) -> Self {
181        let mut builder = HttpClient::builder();
182        if let Some(connect_timeout) = timeouts.connect_timeout {
183            builder = builder.connect_timeout(connect_timeout);
184        }
185        let http = builder.build().expect("building reqwest client");
186
187        Self::with_http_and_timeouts(api_key, http, None, timeouts)
188    }
189
190    /// Build with a custom `reqwest::Client` in case the caller needs proxy
191    /// settings, custom TLS, etc.
192    #[allow(dead_code)]
193    pub fn with_http(
194        api_key: impl Into<String>,
195        http: HttpClient,
196        base_url: Option<String>,
197    ) -> Self {
198        Self::with_http_and_timeouts(api_key, http, base_url, HttpTimeoutConfig::default())
199    }
200
201    /// Build with a custom `reqwest::Client` and timeout configuration.
202    pub fn with_http_and_timeouts(
203        api_key: impl Into<String>,
204        http: HttpClient,
205        base_url: Option<String>,
206        timeouts: HttpTimeoutConfig,
207    ) -> Self {
208        Self {
209            api_key: api_key.into(),
210            http,
211            base: base_url.unwrap_or_else(|| DEFAULT_BASE_URL.to_owned()),
212            retry: RetryPolicy::default(),
213            timeouts,
214        }
215    }
216
217    /// Allow callers to override the default retry policy.
218    pub fn with_retry_policy(mut self, retry: RetryPolicy) -> Self {
219        self.retry = retry;
220        self
221    }
222
223    // Internal: send POST with retry/backoff handling.
224    async fn post_json_with_retry(
225        &self,
226        url: String,
227        headers: HeaderMap,
228        request: &ChatCompletionRequest,
229        request_timeout: Option<Duration>,
230    ) -> Result<reqwest::Response, OpenAiError> {
231        let mut attempt: u32 = 0;
232        loop {
233            let mut req = self
234                .http
235                .post(url.clone())
236                .headers(headers.clone())
237                .json(request);
238            if let Some(timeout) = request_timeout {
239                req = req.timeout(timeout);
240            }
241            let res = req.send().await;
242
243            match res {
244                Ok(resp) => {
245                    let status = resp.status();
246                    if status.is_success() {
247                        #[cfg(feature = "tracing")]
248                        {
249                            log_rate_limit_tight(resp.headers(), "success");
250                        }
251                        return Ok(resp);
252                    }
253
254                    let should_retry = status == reqwest::StatusCode::TOO_MANY_REQUESTS
255                        || status.is_server_error();
256
257                    if should_retry && attempt < self.retry.max_retries {
258                        let mut delay = self.retry.backoff_for(attempt);
259                        #[allow(unused_assignments)]
260                        let mut hdr_delay = Duration::from_secs(0);
261                        if self.retry.respect_retry_after {
262                            hdr_delay = parse_retry_after_seconds(resp.headers());
263                            if hdr_delay > delay {
264                                delay = hdr_delay;
265                            }
266                        }
267                        #[cfg(feature = "tracing")]
268                        {
269                            tracing::info!(
270                                attempt = attempt,
271                                status = %status,
272                                backoff_ms = delay.as_millis() as u64,
273                                retry_after_ms = hdr_delay.as_millis() as u64,
274                                "retrying request due to transient status"
275                            );
276                            log_rate_limit_tight(resp.headers(), "retrying");
277                        }
278                        // Blocking sleep to avoid introducing a new async runtime dependency.
279                        std::thread::sleep(delay);
280                        attempt += 1;
281                        continue;
282                    } else {
283                        let status = resp.status();
284                        let headers_map = resp.headers().clone();
285                        let body = resp.text().await.unwrap_or_default();
286                        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
287                            let (retry_after, reset_at, headers) =
288                                extract_rate_limit_info(&headers_map);
289                            #[cfg(feature = "tracing")]
290                            {
291                                let ra_ms = retry_after.map(|d| d.as_millis() as u64);
292                                tracing::warn!(
293                                    status = %status,
294                                    retry_after_ms = ?ra_ms,
295                                    reset_at = ?reset_at,
296                                    "rate limited; giving up after retries"
297                                );
298                            }
299                            return Err(OpenAiError::RateLimited {
300                                status,
301                                body,
302                                retry_after,
303                                reset_at,
304                                headers,
305                            });
306                        } else {
307                            return Err(OpenAiError::Api { status, body });
308                        }
309                    }
310                }
311                Err(err) => {
312                    // Retry on transport errors up to max_retries.
313                    if attempt < self.retry.max_retries
314                        && (err.is_timeout() || err.is_connect() || !err.is_status())
315                    {
316                        let delay = self.retry.backoff_for(attempt);
317                        #[cfg(feature = "tracing")]
318                        {
319                            tracing::info!(
320                                attempt = attempt,
321                                backoff_ms = delay.as_millis() as u64,
322                                "retrying after transport error"
323                            );
324                        }
325                        std::thread::sleep(delay);
326                        attempt += 1;
327                        continue;
328                    } else {
329                        return Err(OpenAiError::Http(err));
330                    }
331                }
332            }
333        }
334    }
335
336    /// Perform a **non-streaming** chat completion.
337    pub async fn chat_completion(
338        &self,
339        request: ChatCompletionRequest,
340    ) -> Result<ChatCompletionResponse, OpenAiError> {
341        // Build headers once.
342        let mut headers = HeaderMap::new();
343        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
344        headers.insert(
345            AUTHORIZATION,
346            HeaderValue::from_str(&format!("Bearer {}", self.api_key)).unwrap(),
347        );
348
349        let url = format!("{}/chat/completions", self.base);
350        let resp = self
351            .post_json_with_retry(url, headers, &request, self.timeouts.request_timeout)
352            .await?;
353
354        let bytes = resp.bytes().await?;
355        let parsed: ChatCompletionResponse = serde_json::from_slice(&bytes)?;
356        Ok(parsed)
357    }
358
359    /// Perform a **streaming** chat completion.
360    pub fn chat_completion_stream(
361        &self,
362        mut request: ChatCompletionRequest,
363    ) -> impl Stream<Item = Result<ChatCompletionChunkResponse, OpenAiError>> + '_ {
364        use reqwest::header::{ACCEPT, HeaderValue};
365
366        // 1) enforce streaming flag
367        request.stream = Some(true);
368
369        // 2) headers (incl. SSE accept)
370        let mut headers = HeaderMap::new();
371        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
372        headers.insert(
373            AUTHORIZATION,
374            HeaderValue::from_str(&format!("Bearer {}", self.api_key)).unwrap(),
375        );
376        headers.insert(ACCEPT, HeaderValue::from_static("text/event-stream"));
377
378        let url = format!("{}/chat/completions", self.base);
379
380        // 3) async stream wrapper
381        try_stream! {
382            let resp = self
383                .post_json_with_retry(url, headers, &request, self.timeouts.stream_timeout)
384                .await?;
385
386            let mut bytes_stream = resp.bytes_stream();
387            let mut buf = Vec::new();
388
389            while let Some(chunk) = bytes_stream.next().await {
390                let chunk = chunk?;
391                buf.extend_from_slice(&chunk);
392
393                while let Some(pos) = buf.windows(2).position(|w| w == b"\n\n") {
394                    let frame: Vec<u8> = buf.drain(..pos + 2).collect();
395                    let frame_str = std::str::from_utf8(&frame)?;
396
397                    if let Some(data) = frame_str.strip_prefix("data: ") {
398                        let data = data.trim();
399                        if data == "[DONE]" { return; }
400
401                        let parsed: ChatCompletionChunkResponse = serde_json::from_str(data)?;
402                        yield parsed;
403                    }
404                }
405            }
406        }
407    }
408
409    /// Perform an audio transcription via OpenAI `/audio/transcriptions`.
410    pub async fn audio_transcription(
411        &self,
412        request: TranscriptionRequest,
413    ) -> Result<TranscriptionResult, OpenAiError> {
414        if request.audio.is_empty() {
415            return Err(OpenAiError::Format(
416                "audio payload must not be empty".into(),
417            ));
418        }
419        if request.mime_type.trim().is_empty() {
420            return Err(OpenAiError::Format("mime_type must not be empty".into()));
421        }
422
423        use reqwest::multipart::{Form, Part};
424        let mut headers = HeaderMap::new();
425        headers.insert(
426            AUTHORIZATION,
427            HeaderValue::from_str(&format!("Bearer {}", self.api_key)).unwrap(),
428        );
429
430        let filename = request.filename.unwrap_or_else(|| "audio.wav".to_string());
431        let file_part = Part::bytes(request.audio)
432            .file_name(filename)
433            .mime_str(&request.mime_type)
434            .map_err(|e| OpenAiError::Format(format!("invalid mime type: {e}")))?;
435
436        let mut form = Form::new().part("file", file_part).text(
437            "model",
438            request
439                .model
440                .unwrap_or_else(|| "gpt-4o-mini-transcribe".to_string()),
441        );
442
443        if let Some(language) = request.language {
444            form = form.text("language", language);
445        }
446        if let Some(prompt) = request.prompt {
447            form = form.text("prompt", prompt);
448        }
449
450        let url = format!("{}/audio/transcriptions", self.base);
451        let mut req = self.http.post(url).headers(headers).multipart(form);
452        if let Some(timeout) = self.timeouts.request_timeout {
453            req = req.timeout(timeout);
454        }
455        let resp = req.send().await?;
456
457        if !resp.status().is_success() {
458            let status = resp.status();
459            let headers_map = resp.headers().clone();
460            let body = resp.text().await.unwrap_or_default();
461            if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
462                let (retry_after, reset_at, headers) = extract_rate_limit_info(&headers_map);
463                return Err(OpenAiError::RateLimited {
464                    status,
465                    body,
466                    retry_after,
467                    reset_at,
468                    headers,
469                });
470            }
471            return Err(OpenAiError::Api { status, body });
472        }
473
474        let bytes = resp.bytes().await?;
475        let parsed: AudioTranscriptionResponse = serde_json::from_slice(&bytes)?;
476        Ok(parsed.into())
477    }
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483    use std::{
484        io::{Read, Write},
485        net::TcpListener,
486        thread,
487    };
488
489    use crate::api_v1::{ChatCompletionMessage, Content, MessageRole};
490
491    fn run_single_response_server(delay: Duration, body: String, content_type: &str) -> String {
492        let listener = TcpListener::bind("127.0.0.1:0").expect("bind tcp listener");
493        let addr = listener.local_addr().expect("listener addr");
494        let response = format!(
495            "HTTP/1.1 200 OK\r\ncontent-type: {content_type}\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
496            body.len()
497        );
498
499        thread::spawn(move || {
500            let (mut stream, _) = listener.accept().expect("accept connection");
501            let _ = stream.set_read_timeout(Some(Duration::from_secs(2)));
502            let mut req_buf = [0_u8; 8192];
503            let _ = stream.read(&mut req_buf);
504            thread::sleep(delay);
505            stream
506                .write_all(response.as_bytes())
507                .expect("write response");
508            let _ = stream.flush();
509        });
510
511        format!("http://{addr}")
512    }
513
514    fn sample_request() -> ChatCompletionRequest {
515        ChatCompletionRequest::new(
516            "gpt-4o-mini".to_string(),
517            vec![ChatCompletionMessage {
518                role: MessageRole::User,
519                content: Some(Content::Text("hello".to_string())),
520                name: None,
521                tool_calls: None,
522                tool_call_id: None,
523            }],
524        )
525    }
526
527    #[tokio::test]
528    async fn non_streaming_respects_request_timeout() {
529        let base_url = run_single_response_server(
530            Duration::from_millis(200),
531            r#"{"id":"x","object":"chat.completion","created":0,"model":"gpt-4o-mini","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop","finish_details":null}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2},"system_fingerprint":null}"#.to_string(),
532            "application/json",
533        );
534
535        let client = OpenAiClient::with_http_and_timeouts(
536            "test-key",
537            reqwest::Client::new(),
538            Some(base_url),
539            HttpTimeoutConfig {
540                connect_timeout: Some(Duration::from_secs(1)),
541                request_timeout: Some(Duration::from_millis(50)),
542                stream_timeout: None,
543            },
544        )
545        .with_retry_policy(RetryPolicy {
546            max_retries: 0,
547            ..RetryPolicy::default()
548        });
549
550        let err = client
551            .chat_completion(sample_request())
552            .await
553            .expect_err("non-stream request should timeout");
554        match err {
555            OpenAiError::Http(inner) => assert!(inner.is_timeout()),
556            other => panic!("unexpected error: {other:?}"),
557        }
558    }
559
560    #[tokio::test]
561    async fn streaming_uses_stream_timeout_not_request_timeout() {
562        let sse_body = format!(
563            "data: {}\n\ndata: [DONE]\n\n",
564            r#"{"id":"x","object":"chat.completion.chunk","created":0,"model":"gpt-4o-mini","choices":[{"index":0,"delta":{"content":"ok"},"finish_reason":"stop"}]}"#
565        );
566        let base_url =
567            run_single_response_server(Duration::from_millis(120), sse_body, "text/event-stream");
568
569        let client = OpenAiClient::with_http_and_timeouts(
570            "test-key",
571            reqwest::Client::new(),
572            Some(base_url),
573            HttpTimeoutConfig {
574                connect_timeout: Some(Duration::from_secs(1)),
575                request_timeout: Some(Duration::from_millis(10)),
576                stream_timeout: None,
577            },
578        )
579        .with_retry_policy(RetryPolicy {
580            max_retries: 0,
581            ..RetryPolicy::default()
582        });
583
584        let mut stream = Box::pin(client.chat_completion_stream(sample_request()));
585        let first = stream
586            .next()
587            .await
588            .expect("stream should produce first chunk")
589            .expect("first chunk should parse");
590        assert_eq!(first.choices.len(), 1);
591    }
592
593    #[tokio::test]
594    async fn audio_transcription_parses_text_response() {
595        let base_url = run_single_response_server(
596            Duration::from_millis(0),
597            r#"{"text":"hello world","language":"en","duration":1.25}"#.to_string(),
598            "application/json",
599        );
600
601        let client = OpenAiClient::with_http_and_timeouts(
602            "test-key",
603            reqwest::Client::new(),
604            Some(base_url),
605            HttpTimeoutConfig::default(),
606        )
607        .with_retry_policy(RetryPolicy {
608            max_retries: 0,
609            ..RetryPolicy::default()
610        });
611
612        let result = client
613            .audio_transcription(
614                TranscriptionRequest::new(vec![1, 2, 3], "audio/wav")
615                    .with_filename("clip.wav")
616                    .with_model("gpt-4o-mini-transcribe"),
617            )
618            .await
619            .expect("transcription should succeed");
620
621        assert_eq!(result.text, "hello world");
622        assert_eq!(result.language.as_deref(), Some("en"));
623        assert_eq!(result.duration_seconds, Some(1.25));
624    }
625
626    #[tokio::test]
627    async fn audio_transcription_rejects_empty_audio() {
628        let client = OpenAiClient::new("test-key");
629        let err = client
630            .audio_transcription(TranscriptionRequest::new(Vec::new(), "audio/wav"))
631            .await
632            .expect_err("empty audio should fail validation");
633
634        match err {
635            OpenAiError::Format(msg) => assert!(msg.contains("audio payload")),
636            other => panic!("unexpected error: {other:?}"),
637        }
638    }
639}