Skip to main content

codineer_api/providers/
openai_compat.rs

1use std::collections::{BTreeMap, VecDeque};
2
3use serde::Deserialize;
4use serde_json::{json, Value};
5
6use crate::error::ApiError;
7use crate::providers::RetryPolicy;
8use crate::types::{
9    ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
10    InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest,
11    MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent,
12    ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
13};
14
15pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1";
16pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
17const REQUEST_ID_HEADER: &str = "request-id";
18const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub struct OpenAiCompatConfig {
22    pub provider_name: &'static str,
23    pub api_key_env: &'static str,
24    pub base_url_env: &'static str,
25    pub default_base_url: &'static str,
26}
27
28const XAI_ENV_VARS: &[&str] = &["XAI_API_KEY"];
29const OPENAI_ENV_VARS: &[&str] = &["OPENAI_API_KEY"];
30
31impl OpenAiCompatConfig {
32    #[must_use]
33    pub const fn xai() -> Self {
34        Self {
35            provider_name: "xAI",
36            api_key_env: "XAI_API_KEY",
37            base_url_env: "XAI_BASE_URL",
38            default_base_url: DEFAULT_XAI_BASE_URL,
39        }
40    }
41
42    #[must_use]
43    pub const fn openai() -> Self {
44        Self {
45            provider_name: "OpenAI",
46            api_key_env: "OPENAI_API_KEY",
47            base_url_env: "OPENAI_BASE_URL",
48            default_base_url: DEFAULT_OPENAI_BASE_URL,
49        }
50    }
51    #[must_use]
52    pub fn credential_env_vars(self) -> &'static [&'static str] {
53        match self.api_key_env {
54            "XAI_API_KEY" => XAI_ENV_VARS,
55            "OPENAI_API_KEY" => OPENAI_ENV_VARS,
56            _ => &[],
57        }
58    }
59}
60
61#[derive(Debug, Clone)]
62pub struct OpenAiCompatClient {
63    http: reqwest::Client,
64    api_key: String,
65    base_url: String,
66    retry: RetryPolicy,
67}
68
69impl OpenAiCompatClient {
70    #[must_use]
71    pub fn new(api_key: impl Into<String>, config: OpenAiCompatConfig) -> Self {
72        Self {
73            http: reqwest::Client::new(),
74            api_key: api_key.into(),
75            base_url: read_base_url(config),
76            retry: RetryPolicy::default(),
77        }
78    }
79
80    pub fn from_env(config: OpenAiCompatConfig) -> Result<Self, ApiError> {
81        let Some(api_key) = read_env_non_empty(config.api_key_env)? else {
82            return Err(ApiError::missing_credentials(
83                config.provider_name,
84                config.credential_env_vars(),
85            ));
86        };
87        Ok(Self::new(api_key, config))
88    }
89
90    #[must_use]
91    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
92        self.base_url = base_url.into();
93        self
94    }
95
96    #[must_use]
97    pub fn with_retry_policy(mut self, retry: RetryPolicy) -> Self {
98        self.retry = retry;
99        self
100    }
101
102    pub async fn send_message(
103        &self,
104        request: &MessageRequest,
105    ) -> Result<MessageResponse, ApiError> {
106        let request = MessageRequest {
107            stream: false,
108            ..request.clone()
109        };
110        let response = self.send_with_retry(&request).await?;
111        let request_id = request_id_from_headers(response.headers());
112        let payload = response.json::<ChatCompletionResponse>().await?;
113        let mut normalized = normalize_response(&request.model, payload)?;
114        if normalized.request_id.is_none() {
115            normalized.request_id = request_id;
116        }
117        Ok(normalized)
118    }
119
120    pub async fn stream_message(
121        &self,
122        request: &MessageRequest,
123    ) -> Result<MessageStream, ApiError> {
124        let response = self
125            .send_with_retry(&request.clone().with_streaming())
126            .await?;
127        Ok(MessageStream {
128            request_id: request_id_from_headers(response.headers()),
129            response,
130            parser: OpenAiSseParser::new(),
131            pending: VecDeque::new(),
132            done: false,
133            state: StreamState::new(request.model.clone()),
134        })
135    }
136
137    async fn send_with_retry(
138        &self,
139        request: &MessageRequest,
140    ) -> Result<reqwest::Response, ApiError> {
141        let mut attempts = 0;
142
143        let last_error = loop {
144            attempts += 1;
145            let retryable_error = match self.send_raw_request(request).await {
146                Ok(response) => match expect_success(response).await {
147                    Ok(response) => return Ok(response),
148                    Err(error)
149                        if error.is_retryable() && attempts <= self.retry.max_retries + 1 =>
150                    {
151                        error
152                    }
153                    Err(error) => return Err(error),
154                },
155                Err(error) if error.is_retryable() && attempts <= self.retry.max_retries + 1 => {
156                    error
157                }
158                Err(error) => return Err(error),
159            };
160
161            if attempts > self.retry.max_retries {
162                break retryable_error;
163            }
164
165            tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
166        };
167
168        Err(ApiError::RetriesExhausted {
169            attempts,
170            last_error: Box::new(last_error),
171        })
172    }
173
174    async fn send_raw_request(
175        &self,
176        request: &MessageRequest,
177    ) -> Result<reqwest::Response, ApiError> {
178        let request_url = chat_completions_endpoint(&self.base_url);
179        self.http
180            .post(&request_url)
181            .header("content-type", "application/json")
182            .bearer_auth(&self.api_key)
183            .json(&build_chat_completion_request(request))
184            .send()
185            .await
186            .map_err(ApiError::from)
187    }
188
189    fn backoff_for_attempt(&self, attempt: u32) -> Result<std::time::Duration, ApiError> {
190        let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
191            return Err(ApiError::BackoffOverflow {
192                attempt,
193                base_delay: self.retry.initial_backoff,
194            });
195        };
196        Ok(self
197            .retry
198            .initial_backoff
199            .checked_mul(multiplier)
200            .map_or(self.retry.max_backoff, |delay| {
201                delay.min(self.retry.max_backoff)
202            }))
203    }
204}
205
206#[derive(Debug)]
207pub struct MessageStream {
208    request_id: Option<String>,
209    response: reqwest::Response,
210    parser: OpenAiSseParser,
211    pending: VecDeque<StreamEvent>,
212    done: bool,
213    state: StreamState,
214}
215
216impl MessageStream {
217    #[must_use]
218    pub fn request_id(&self) -> Option<&str> {
219        self.request_id.as_deref()
220    }
221
222    pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
223        loop {
224            if let Some(event) = self.pending.pop_front() {
225                return Ok(Some(event));
226            }
227
228            if self.done {
229                self.pending.extend(self.state.finish());
230                if let Some(event) = self.pending.pop_front() {
231                    return Ok(Some(event));
232                }
233                return Ok(None);
234            }
235
236            match self.response.chunk().await? {
237                Some(chunk) => {
238                    for parsed in self.parser.push(&chunk)? {
239                        self.pending.extend(self.state.ingest_chunk(parsed));
240                    }
241                }
242                None => {
243                    self.done = true;
244                }
245            }
246        }
247    }
248}
249
250#[derive(Debug, Default)]
251struct OpenAiSseParser {
252    buffer: Vec<u8>,
253}
254
255impl OpenAiSseParser {
256    fn new() -> Self {
257        Self::default()
258    }
259
260    fn push(&mut self, chunk: &[u8]) -> Result<Vec<ChatCompletionChunk>, ApiError> {
261        self.buffer.extend_from_slice(chunk);
262        if self.buffer.len() > 16 * 1024 * 1024 {
263            return Err(ApiError::ResponsePayloadTooLarge {
264                limit: 16 * 1024 * 1024,
265            });
266        }
267        let mut events = Vec::new();
268
269        while let Some(frame) = next_sse_frame(&mut self.buffer) {
270            if let Some(event) = parse_sse_frame(&frame)? {
271                events.push(event);
272            }
273        }
274
275        Ok(events)
276    }
277}
278
279#[derive(Debug, Clone, Copy, PartialEq, Eq)]
280enum TextPhase {
281    Pending,
282    Active,
283    Done,
284}
285
286#[derive(Debug)]
287struct StreamState {
288    model: String,
289    message_started: bool,
290    text_phase: TextPhase,
291    finished: bool,
292    stop_reason: Option<String>,
293    usage: Option<Usage>,
294    tool_calls: BTreeMap<u32, ToolCallState>,
295}
296
297impl StreamState {
298    fn new(model: String) -> Self {
299        Self {
300            model,
301            message_started: false,
302            text_phase: TextPhase::Pending,
303            finished: false,
304            stop_reason: None,
305            usage: None,
306            tool_calls: BTreeMap::new(),
307        }
308    }
309
310    fn ingest_chunk(&mut self, chunk: ChatCompletionChunk) -> Vec<StreamEvent> {
311        let mut events = Vec::new();
312        if !self.message_started {
313            self.message_started = true;
314            events.push(StreamEvent::MessageStart(MessageStartEvent {
315                message: MessageResponse {
316                    id: chunk.id.clone(),
317                    kind: "message".to_string(),
318                    role: "assistant".to_string(),
319                    content: Vec::new(),
320                    model: chunk.model.clone().unwrap_or_else(|| self.model.clone()),
321                    stop_reason: None,
322                    stop_sequence: None,
323                    usage: Usage {
324                        input_tokens: 0,
325                        cache_creation_input_tokens: 0,
326                        cache_read_input_tokens: 0,
327                        output_tokens: 0,
328                    },
329                    request_id: None,
330                },
331            }));
332        }
333
334        if let Some(usage) = chunk.usage {
335            self.usage = Some(Usage {
336                input_tokens: usage.prompt_tokens,
337                cache_creation_input_tokens: 0,
338                cache_read_input_tokens: 0,
339                output_tokens: usage.completion_tokens,
340            });
341        }
342
343        for choice in chunk.choices {
344            if let Some(content) = choice.delta.content.filter(|value| !value.is_empty()) {
345                if self.text_phase == TextPhase::Pending {
346                    self.text_phase = TextPhase::Active;
347                    events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent {
348                        index: 0,
349                        content_block: OutputContentBlock::Text {
350                            text: String::new(),
351                        },
352                    }));
353                }
354                events.push(StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
355                    index: 0,
356                    delta: ContentBlockDelta::TextDelta { text: content },
357                }));
358            }
359
360            for tool_call in choice.delta.tool_calls {
361                let state = self.tool_calls.entry(tool_call.index).or_default();
362                state.apply(tool_call);
363                let block_index = state.block_index();
364                if !state.started {
365                    if let Some(start_event) = state.start_event() {
366                        state.started = true;
367                        events.push(StreamEvent::ContentBlockStart(start_event));
368                    } else {
369                        continue;
370                    }
371                }
372                if let Some(delta_event) = state.delta_event() {
373                    events.push(StreamEvent::ContentBlockDelta(delta_event));
374                }
375                if choice.finish_reason.as_deref() == Some("tool_calls") && !state.stopped {
376                    state.stopped = true;
377                    events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
378                        index: block_index,
379                    }));
380                }
381            }
382
383            if let Some(finish_reason) = choice.finish_reason {
384                self.stop_reason = Some(normalize_finish_reason(&finish_reason));
385                if finish_reason == "tool_calls" {
386                    for state in self.tool_calls.values_mut() {
387                        if state.started && !state.stopped {
388                            state.stopped = true;
389                            events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
390                                index: state.block_index(),
391                            }));
392                        }
393                    }
394                }
395            }
396        }
397
398        events
399    }
400
401    fn finish(&mut self) -> Vec<StreamEvent> {
402        if self.finished {
403            return Vec::new();
404        }
405        self.finished = true;
406
407        let mut events = Vec::new();
408        if self.text_phase == TextPhase::Active {
409            self.text_phase = TextPhase::Done;
410            events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
411                index: 0,
412            }));
413        }
414
415        for state in self.tool_calls.values_mut() {
416            if !state.started {
417                if let Some(start_event) = state.start_event() {
418                    state.started = true;
419                    events.push(StreamEvent::ContentBlockStart(start_event));
420                    if let Some(delta_event) = state.delta_event() {
421                        events.push(StreamEvent::ContentBlockDelta(delta_event));
422                    }
423                }
424            }
425            if state.started && !state.stopped {
426                state.stopped = true;
427                events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
428                    index: state.block_index(),
429                }));
430            }
431        }
432
433        if self.message_started {
434            events.push(StreamEvent::MessageDelta(MessageDeltaEvent {
435                delta: MessageDelta {
436                    stop_reason: Some(
437                        self.stop_reason
438                            .clone()
439                            .unwrap_or_else(|| "end_turn".to_string()),
440                    ),
441                    stop_sequence: None,
442                },
443                usage: self.usage.clone().unwrap_or(Usage {
444                    input_tokens: 0,
445                    cache_creation_input_tokens: 0,
446                    cache_read_input_tokens: 0,
447                    output_tokens: 0,
448                }),
449            }));
450            events.push(StreamEvent::MessageStop(MessageStopEvent {}));
451        }
452        events
453    }
454}
455
456#[derive(Debug, Default)]
457struct ToolCallState {
458    openai_index: u32,
459    id: Option<String>,
460    name: Option<String>,
461    arguments: String,
462    emitted_len: usize,
463    started: bool,
464    stopped: bool,
465}
466
467impl ToolCallState {
468    fn apply(&mut self, tool_call: DeltaToolCall) {
469        self.openai_index = tool_call.index;
470        if let Some(id) = tool_call.id {
471            self.id = Some(id);
472        }
473        if let Some(name) = tool_call.function.name {
474            self.name = Some(name);
475        }
476        if let Some(arguments) = tool_call.function.arguments {
477            self.arguments.push_str(&arguments);
478        }
479    }
480
481    const fn block_index(&self) -> u32 {
482        self.openai_index + 1
483    }
484
485    fn start_event(&self) -> Option<ContentBlockStartEvent> {
486        let name = self.name.clone()?;
487        let id = self
488            .id
489            .clone()
490            .unwrap_or_else(|| format!("tool_call_{}", self.openai_index));
491        Some(ContentBlockStartEvent {
492            index: self.block_index(),
493            content_block: OutputContentBlock::ToolUse {
494                id,
495                name,
496                input: json!({}),
497            },
498        })
499    }
500
501    fn delta_event(&mut self) -> Option<ContentBlockDeltaEvent> {
502        if self.emitted_len >= self.arguments.len() {
503            return None;
504        }
505        let delta = self.arguments[self.emitted_len..].to_string();
506        self.emitted_len = self.arguments.len();
507        Some(ContentBlockDeltaEvent {
508            index: self.block_index(),
509            delta: ContentBlockDelta::InputJsonDelta {
510                partial_json: delta,
511            },
512        })
513    }
514}
515
516#[derive(Debug, Deserialize)]
517struct ChatCompletionResponse {
518    id: String,
519    model: String,
520    choices: Vec<ChatChoice>,
521    #[serde(default)]
522    usage: Option<OpenAiUsage>,
523}
524
525#[derive(Debug, Deserialize)]
526struct ChatChoice {
527    message: ChatMessage,
528    #[serde(default)]
529    finish_reason: Option<String>,
530}
531
532#[derive(Debug, Deserialize)]
533struct ChatMessage {
534    role: String,
535    #[serde(default)]
536    content: Option<String>,
537    #[serde(default)]
538    tool_calls: Vec<ResponseToolCall>,
539}
540
541#[derive(Debug, Deserialize)]
542struct ResponseToolCall {
543    id: String,
544    function: ResponseToolFunction,
545}
546
547#[derive(Debug, Deserialize)]
548struct ResponseToolFunction {
549    name: String,
550    arguments: String,
551}
552
553#[derive(Debug, Deserialize)]
554struct OpenAiUsage {
555    #[serde(default)]
556    prompt_tokens: u32,
557    #[serde(default)]
558    completion_tokens: u32,
559}
560
561#[derive(Debug, Deserialize)]
562struct ChatCompletionChunk {
563    id: String,
564    #[serde(default)]
565    model: Option<String>,
566    #[serde(default)]
567    choices: Vec<ChunkChoice>,
568    #[serde(default)]
569    usage: Option<OpenAiUsage>,
570}
571
572#[derive(Debug, Deserialize)]
573struct ChunkChoice {
574    delta: ChunkDelta,
575    #[serde(default)]
576    finish_reason: Option<String>,
577}
578
579#[derive(Debug, Default, Deserialize)]
580struct ChunkDelta {
581    #[serde(default)]
582    content: Option<String>,
583    #[serde(default)]
584    tool_calls: Vec<DeltaToolCall>,
585}
586
587#[derive(Debug, Deserialize)]
588struct DeltaToolCall {
589    #[serde(default)]
590    index: u32,
591    #[serde(default)]
592    id: Option<String>,
593    #[serde(default)]
594    function: DeltaFunction,
595}
596
597#[derive(Debug, Default, Deserialize)]
598struct DeltaFunction {
599    #[serde(default)]
600    name: Option<String>,
601    #[serde(default)]
602    arguments: Option<String>,
603}
604
605#[derive(Debug, Deserialize)]
606struct ErrorEnvelope {
607    error: ErrorBody,
608}
609
610#[derive(Debug, Deserialize)]
611struct ErrorBody {
612    #[serde(rename = "type")]
613    error_type: Option<String>,
614    message: Option<String>,
615}
616
617fn build_chat_completion_request(request: &MessageRequest) -> Value {
618    let mut messages = Vec::new();
619    if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) {
620        messages.push(json!({
621            "role": "system",
622            "content": system,
623        }));
624    }
625    for message in &request.messages {
626        messages.extend(translate_message(message));
627    }
628
629    let mut payload = json!({
630        "model": request.model,
631        "max_tokens": request.max_tokens,
632        "messages": messages,
633        "stream": request.stream,
634    });
635
636    if let Some(tools) = &request.tools {
637        payload["tools"] =
638            Value::Array(tools.iter().map(openai_tool_definition).collect::<Vec<_>>());
639    }
640    if let Some(tool_choice) = &request.tool_choice {
641        payload["tool_choice"] = openai_tool_choice(tool_choice);
642    }
643
644    payload
645}
646
647fn translate_message(message: &InputMessage) -> Vec<Value> {
648    match message.role.as_str() {
649        "assistant" => {
650            let mut text = String::new();
651            let mut tool_calls = Vec::new();
652            for block in &message.content {
653                match block {
654                    InputContentBlock::Text { text: value } => text.push_str(value),
655                    InputContentBlock::ToolUse { id, name, input } => tool_calls.push(json!({
656                        "id": id,
657                        "type": "function",
658                        "function": {
659                            "name": name,
660                            "arguments": serde_json::to_string(input).unwrap_or_default(),
661                        }
662                    })),
663                    InputContentBlock::ToolResult { .. } => {}
664                }
665            }
666            if text.is_empty() && tool_calls.is_empty() {
667                Vec::new()
668            } else {
669                vec![json!({
670                    "role": "assistant",
671                    "content": (!text.is_empty()).then_some(text),
672                    "tool_calls": tool_calls,
673                })]
674            }
675        }
676        _ => message
677            .content
678            .iter()
679            .filter_map(|block| match block {
680                InputContentBlock::Text { text } => Some(json!({
681                    "role": "user",
682                    "content": text,
683                })),
684                InputContentBlock::ToolResult {
685                    tool_use_id,
686                    content,
687                    is_error,
688                } => Some(json!({
689                    "role": "tool",
690                    "tool_call_id": tool_use_id,
691                    "content": flatten_tool_result_content(content),
692                    "is_error": is_error,
693                })),
694                InputContentBlock::ToolUse { .. } => None,
695            })
696            .collect(),
697    }
698}
699
700fn flatten_tool_result_content(content: &[ToolResultContentBlock]) -> String {
701    content
702        .iter()
703        .map(|block| match block {
704            ToolResultContentBlock::Text { text } => text.clone(),
705            ToolResultContentBlock::Json { value } => value.to_string(),
706        })
707        .collect::<Vec<_>>()
708        .join("\n")
709}
710
711fn openai_tool_definition(tool: &ToolDefinition) -> Value {
712    json!({
713        "type": "function",
714        "function": {
715            "name": tool.name,
716            "description": tool.description,
717            "parameters": tool.input_schema,
718        }
719    })
720}
721
722fn openai_tool_choice(tool_choice: &ToolChoice) -> Value {
723    match tool_choice {
724        ToolChoice::Auto => Value::String("auto".to_string()),
725        ToolChoice::Any => Value::String("required".to_string()),
726        ToolChoice::Tool { name } => json!({
727            "type": "function",
728            "function": { "name": name },
729        }),
730    }
731}
732
733fn normalize_response(
734    model: &str,
735    response: ChatCompletionResponse,
736) -> Result<MessageResponse, ApiError> {
737    let choice = response
738        .choices
739        .into_iter()
740        .next()
741        .ok_or(ApiError::InvalidSseFrame(
742            "chat completion response missing choices",
743        ))?;
744    let mut content = Vec::new();
745    if let Some(text) = choice.message.content.filter(|value| !value.is_empty()) {
746        content.push(OutputContentBlock::Text { text });
747    }
748    for tool_call in choice.message.tool_calls {
749        content.push(OutputContentBlock::ToolUse {
750            id: tool_call.id,
751            name: tool_call.function.name,
752            input: parse_tool_arguments(&tool_call.function.arguments),
753        });
754    }
755
756    Ok(MessageResponse {
757        id: response.id,
758        kind: "message".to_string(),
759        role: choice.message.role,
760        content,
761        model: response.model.if_empty_then(model.to_string()),
762        stop_reason: choice
763            .finish_reason
764            .map(|value| normalize_finish_reason(&value)),
765        stop_sequence: None,
766        usage: Usage {
767            input_tokens: response
768                .usage
769                .as_ref()
770                .map_or(0, |usage| usage.prompt_tokens),
771            cache_creation_input_tokens: 0,
772            cache_read_input_tokens: 0,
773            output_tokens: response
774                .usage
775                .as_ref()
776                .map_or(0, |usage| usage.completion_tokens),
777        },
778        request_id: None,
779    })
780}
781
782fn parse_tool_arguments(arguments: &str) -> Value {
783    serde_json::from_str(arguments).unwrap_or_else(|_| json!({ "raw": arguments }))
784}
785
786fn next_sse_frame(buffer: &mut Vec<u8>) -> Option<String> {
787    let separator = buffer
788        .windows(2)
789        .position(|window| window == b"\n\n")
790        .map(|position| (position, 2))
791        .or_else(|| {
792            buffer
793                .windows(4)
794                .position(|window| window == b"\r\n\r\n")
795                .map(|position| (position, 4))
796        })?;
797
798    let (position, separator_len) = separator;
799    let frame = buffer.drain(..position + separator_len).collect::<Vec<_>>();
800    let frame_len = frame.len().saturating_sub(separator_len);
801    Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned())
802}
803
804fn parse_sse_frame(frame: &str) -> Result<Option<ChatCompletionChunk>, ApiError> {
805    let trimmed = frame.trim();
806    if trimmed.is_empty() {
807        return Ok(None);
808    }
809
810    let mut data_lines = Vec::new();
811    for line in trimmed.lines() {
812        if line.starts_with(':') {
813            continue;
814        }
815        if let Some(data) = line.strip_prefix("data:") {
816            data_lines.push(data.trim_start());
817        }
818    }
819    if data_lines.is_empty() {
820        return Ok(None);
821    }
822    let payload = data_lines.join("\n");
823    if payload == "[DONE]" {
824        return Ok(None);
825    }
826    serde_json::from_str(&payload)
827        .map(Some)
828        .map_err(ApiError::from)
829}
830
831fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
832    match std::env::var(key) {
833        Ok(value) if !value.is_empty() => Ok(Some(value)),
834        Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
835        Err(error) => Err(ApiError::from(error)),
836    }
837}
838
839#[must_use]
840pub fn has_api_key(key: &str) -> bool {
841    read_env_non_empty(key)
842        .ok()
843        .and_then(std::convert::identity)
844        .is_some()
845}
846
847#[must_use]
848pub fn read_base_url(config: OpenAiCompatConfig) -> String {
849    std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string())
850}
851
852fn chat_completions_endpoint(base_url: &str) -> String {
853    let trimmed = base_url.trim_end_matches('/');
854    if trimmed.ends_with("/chat/completions") {
855        trimmed.to_string()
856    } else {
857        format!("{trimmed}/chat/completions")
858    }
859}
860
861fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
862    headers
863        .get(REQUEST_ID_HEADER)
864        .or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
865        .and_then(|value| value.to_str().ok())
866        .map(ToOwned::to_owned)
867}
868
869async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
870    let status = response.status();
871    if status.is_success() {
872        return Ok(response);
873    }
874
875    let body = response.text().await.unwrap_or_default();
876    let parsed_error = serde_json::from_str::<ErrorEnvelope>(&body).ok();
877    let retryable = is_retryable_status(status);
878
879    Err(ApiError::Api {
880        status,
881        error_type: parsed_error
882            .as_ref()
883            .and_then(|error| error.error.error_type.clone()),
884        message: parsed_error
885            .as_ref()
886            .and_then(|error| error.error.message.clone()),
887        body,
888        retryable,
889    })
890}
891
892const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
893    matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
894}
895
896fn normalize_finish_reason(value: &str) -> String {
897    match value {
898        "stop" => "end_turn",
899        "tool_calls" => "tool_use",
900        other => other,
901    }
902    .to_string()
903}
904
905trait StringExt {
906    fn if_empty_then(self, fallback: String) -> String;
907}
908
909impl StringExt for String {
910    fn if_empty_then(self, fallback: String) -> String {
911        if self.is_empty() {
912            fallback
913        } else {
914            self
915        }
916    }
917}
918
919#[cfg(test)]
920#[path = "openai_compat_tests.rs"]
921mod tests;