1use std::collections::{BTreeMap, VecDeque};
2
3use serde::Deserialize;
4use serde_json::{json, Value};
5
6use crate::error::ApiError;
7use crate::providers::RetryPolicy;
8use crate::types::{
9 ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
10 InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest,
11 MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent,
12 ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
13};
14
15pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1";
16pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
17const REQUEST_ID_HEADER: &str = "request-id";
18const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub struct OpenAiCompatConfig {
22 pub provider_name: &'static str,
23 pub api_key_env: &'static str,
24 pub base_url_env: &'static str,
25 pub default_base_url: &'static str,
26}
27
28const XAI_ENV_VARS: &[&str] = &["XAI_API_KEY"];
29const OPENAI_ENV_VARS: &[&str] = &["OPENAI_API_KEY"];
30
31impl OpenAiCompatConfig {
32 #[must_use]
33 pub const fn xai() -> Self {
34 Self {
35 provider_name: "xAI",
36 api_key_env: "XAI_API_KEY",
37 base_url_env: "XAI_BASE_URL",
38 default_base_url: DEFAULT_XAI_BASE_URL,
39 }
40 }
41
42 #[must_use]
43 pub const fn openai() -> Self {
44 Self {
45 provider_name: "OpenAI",
46 api_key_env: "OPENAI_API_KEY",
47 base_url_env: "OPENAI_BASE_URL",
48 default_base_url: DEFAULT_OPENAI_BASE_URL,
49 }
50 }
51 #[must_use]
52 pub fn credential_env_vars(self) -> &'static [&'static str] {
53 match self.api_key_env {
54 "XAI_API_KEY" => XAI_ENV_VARS,
55 "OPENAI_API_KEY" => OPENAI_ENV_VARS,
56 _ => &[],
57 }
58 }
59}
60
61#[derive(Debug, Clone)]
62pub struct OpenAiCompatClient {
63 http: reqwest::Client,
64 api_key: String,
65 base_url: String,
66 retry: RetryPolicy,
67}
68
69impl OpenAiCompatClient {
70 #[must_use]
71 pub fn new(api_key: impl Into<String>, config: OpenAiCompatConfig) -> Self {
72 Self {
73 http: reqwest::Client::new(),
74 api_key: api_key.into(),
75 base_url: read_base_url(config),
76 retry: RetryPolicy::default(),
77 }
78 }
79
80 pub fn from_env(config: OpenAiCompatConfig) -> Result<Self, ApiError> {
81 let Some(api_key) = read_env_non_empty(config.api_key_env)? else {
82 return Err(ApiError::missing_credentials(
83 config.provider_name,
84 config.credential_env_vars(),
85 ));
86 };
87 Ok(Self::new(api_key, config))
88 }
89
90 #[must_use]
91 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
92 self.base_url = base_url.into();
93 self
94 }
95
96 #[must_use]
97 pub fn with_retry_policy(mut self, retry: RetryPolicy) -> Self {
98 self.retry = retry;
99 self
100 }
101
102 pub async fn send_message(
103 &self,
104 request: &MessageRequest,
105 ) -> Result<MessageResponse, ApiError> {
106 let request = MessageRequest {
107 stream: false,
108 ..request.clone()
109 };
110 let response = self.send_with_retry(&request).await?;
111 let request_id = request_id_from_headers(response.headers());
112 let payload = response.json::<ChatCompletionResponse>().await?;
113 let mut normalized = normalize_response(&request.model, payload)?;
114 if normalized.request_id.is_none() {
115 normalized.request_id = request_id;
116 }
117 Ok(normalized)
118 }
119
120 pub async fn stream_message(
121 &self,
122 request: &MessageRequest,
123 ) -> Result<MessageStream, ApiError> {
124 let response = self
125 .send_with_retry(&request.clone().with_streaming())
126 .await?;
127 Ok(MessageStream {
128 request_id: request_id_from_headers(response.headers()),
129 response,
130 parser: OpenAiSseParser::new(),
131 pending: VecDeque::new(),
132 done: false,
133 state: StreamState::new(request.model.clone()),
134 })
135 }
136
137 async fn send_with_retry(
138 &self,
139 request: &MessageRequest,
140 ) -> Result<reqwest::Response, ApiError> {
141 let mut attempts = 0;
142
143 let last_error = loop {
144 attempts += 1;
145 let retryable_error = match self.send_raw_request(request).await {
146 Ok(response) => match expect_success(response).await {
147 Ok(response) => return Ok(response),
148 Err(error)
149 if error.is_retryable() && attempts <= self.retry.max_retries + 1 =>
150 {
151 error
152 }
153 Err(error) => return Err(error),
154 },
155 Err(error) if error.is_retryable() && attempts <= self.retry.max_retries + 1 => {
156 error
157 }
158 Err(error) => return Err(error),
159 };
160
161 if attempts > self.retry.max_retries {
162 break retryable_error;
163 }
164
165 tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
166 };
167
168 Err(ApiError::RetriesExhausted {
169 attempts,
170 last_error: Box::new(last_error),
171 })
172 }
173
174 async fn send_raw_request(
175 &self,
176 request: &MessageRequest,
177 ) -> Result<reqwest::Response, ApiError> {
178 let request_url = chat_completions_endpoint(&self.base_url);
179 self.http
180 .post(&request_url)
181 .header("content-type", "application/json")
182 .bearer_auth(&self.api_key)
183 .json(&build_chat_completion_request(request))
184 .send()
185 .await
186 .map_err(ApiError::from)
187 }
188
189 fn backoff_for_attempt(&self, attempt: u32) -> Result<std::time::Duration, ApiError> {
190 let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
191 return Err(ApiError::BackoffOverflow {
192 attempt,
193 base_delay: self.retry.initial_backoff,
194 });
195 };
196 Ok(self
197 .retry
198 .initial_backoff
199 .checked_mul(multiplier)
200 .map_or(self.retry.max_backoff, |delay| {
201 delay.min(self.retry.max_backoff)
202 }))
203 }
204}
205
206#[derive(Debug)]
207pub struct MessageStream {
208 request_id: Option<String>,
209 response: reqwest::Response,
210 parser: OpenAiSseParser,
211 pending: VecDeque<StreamEvent>,
212 done: bool,
213 state: StreamState,
214}
215
216impl MessageStream {
217 #[must_use]
218 pub fn request_id(&self) -> Option<&str> {
219 self.request_id.as_deref()
220 }
221
222 pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
223 loop {
224 if let Some(event) = self.pending.pop_front() {
225 return Ok(Some(event));
226 }
227
228 if self.done {
229 self.pending.extend(self.state.finish());
230 if let Some(event) = self.pending.pop_front() {
231 return Ok(Some(event));
232 }
233 return Ok(None);
234 }
235
236 match self.response.chunk().await? {
237 Some(chunk) => {
238 for parsed in self.parser.push(&chunk)? {
239 self.pending.extend(self.state.ingest_chunk(parsed));
240 }
241 }
242 None => {
243 self.done = true;
244 }
245 }
246 }
247 }
248}
249
250#[derive(Debug, Default)]
251struct OpenAiSseParser {
252 buffer: Vec<u8>,
253}
254
255impl OpenAiSseParser {
256 fn new() -> Self {
257 Self::default()
258 }
259
260 fn push(&mut self, chunk: &[u8]) -> Result<Vec<ChatCompletionChunk>, ApiError> {
261 self.buffer.extend_from_slice(chunk);
262 if self.buffer.len() > 16 * 1024 * 1024 {
263 return Err(ApiError::ResponsePayloadTooLarge {
264 limit: 16 * 1024 * 1024,
265 });
266 }
267 let mut events = Vec::new();
268
269 while let Some(frame) = next_sse_frame(&mut self.buffer) {
270 if let Some(event) = parse_sse_frame(&frame)? {
271 events.push(event);
272 }
273 }
274
275 Ok(events)
276 }
277}
278
279#[derive(Debug, Clone, Copy, PartialEq, Eq)]
280enum TextPhase {
281 Pending,
282 Active,
283 Done,
284}
285
286#[derive(Debug)]
287struct StreamState {
288 model: String,
289 message_started: bool,
290 text_phase: TextPhase,
291 finished: bool,
292 stop_reason: Option<String>,
293 usage: Option<Usage>,
294 tool_calls: BTreeMap<u32, ToolCallState>,
295}
296
297impl StreamState {
298 fn new(model: String) -> Self {
299 Self {
300 model,
301 message_started: false,
302 text_phase: TextPhase::Pending,
303 finished: false,
304 stop_reason: None,
305 usage: None,
306 tool_calls: BTreeMap::new(),
307 }
308 }
309
310 fn ingest_chunk(&mut self, chunk: ChatCompletionChunk) -> Vec<StreamEvent> {
311 let mut events = Vec::new();
312 if !self.message_started {
313 self.message_started = true;
314 events.push(StreamEvent::MessageStart(MessageStartEvent {
315 message: MessageResponse {
316 id: chunk.id.clone(),
317 kind: "message".to_string(),
318 role: "assistant".to_string(),
319 content: Vec::new(),
320 model: chunk.model.clone().unwrap_or_else(|| self.model.clone()),
321 stop_reason: None,
322 stop_sequence: None,
323 usage: Usage {
324 input_tokens: 0,
325 cache_creation_input_tokens: 0,
326 cache_read_input_tokens: 0,
327 output_tokens: 0,
328 },
329 request_id: None,
330 },
331 }));
332 }
333
334 if let Some(usage) = chunk.usage {
335 self.usage = Some(Usage {
336 input_tokens: usage.prompt_tokens,
337 cache_creation_input_tokens: 0,
338 cache_read_input_tokens: 0,
339 output_tokens: usage.completion_tokens,
340 });
341 }
342
343 for choice in chunk.choices {
344 if let Some(content) = choice.delta.content.filter(|value| !value.is_empty()) {
345 if self.text_phase == TextPhase::Pending {
346 self.text_phase = TextPhase::Active;
347 events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent {
348 index: 0,
349 content_block: OutputContentBlock::Text {
350 text: String::new(),
351 },
352 }));
353 }
354 events.push(StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
355 index: 0,
356 delta: ContentBlockDelta::TextDelta { text: content },
357 }));
358 }
359
360 for tool_call in choice.delta.tool_calls {
361 let state = self.tool_calls.entry(tool_call.index).or_default();
362 state.apply(tool_call);
363 let block_index = state.block_index();
364 if !state.started {
365 if let Some(start_event) = state.start_event() {
366 state.started = true;
367 events.push(StreamEvent::ContentBlockStart(start_event));
368 } else {
369 continue;
370 }
371 }
372 if let Some(delta_event) = state.delta_event() {
373 events.push(StreamEvent::ContentBlockDelta(delta_event));
374 }
375 if choice.finish_reason.as_deref() == Some("tool_calls") && !state.stopped {
376 state.stopped = true;
377 events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
378 index: block_index,
379 }));
380 }
381 }
382
383 if let Some(finish_reason) = choice.finish_reason {
384 self.stop_reason = Some(normalize_finish_reason(&finish_reason));
385 if finish_reason == "tool_calls" {
386 for state in self.tool_calls.values_mut() {
387 if state.started && !state.stopped {
388 state.stopped = true;
389 events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
390 index: state.block_index(),
391 }));
392 }
393 }
394 }
395 }
396 }
397
398 events
399 }
400
401 fn finish(&mut self) -> Vec<StreamEvent> {
402 if self.finished {
403 return Vec::new();
404 }
405 self.finished = true;
406
407 let mut events = Vec::new();
408 if self.text_phase == TextPhase::Active {
409 self.text_phase = TextPhase::Done;
410 events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
411 index: 0,
412 }));
413 }
414
415 for state in self.tool_calls.values_mut() {
416 if !state.started {
417 if let Some(start_event) = state.start_event() {
418 state.started = true;
419 events.push(StreamEvent::ContentBlockStart(start_event));
420 if let Some(delta_event) = state.delta_event() {
421 events.push(StreamEvent::ContentBlockDelta(delta_event));
422 }
423 }
424 }
425 if state.started && !state.stopped {
426 state.stopped = true;
427 events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
428 index: state.block_index(),
429 }));
430 }
431 }
432
433 if self.message_started {
434 events.push(StreamEvent::MessageDelta(MessageDeltaEvent {
435 delta: MessageDelta {
436 stop_reason: Some(
437 self.stop_reason
438 .clone()
439 .unwrap_or_else(|| "end_turn".to_string()),
440 ),
441 stop_sequence: None,
442 },
443 usage: self.usage.clone().unwrap_or(Usage {
444 input_tokens: 0,
445 cache_creation_input_tokens: 0,
446 cache_read_input_tokens: 0,
447 output_tokens: 0,
448 }),
449 }));
450 events.push(StreamEvent::MessageStop(MessageStopEvent {}));
451 }
452 events
453 }
454}
455
456#[derive(Debug, Default)]
457struct ToolCallState {
458 openai_index: u32,
459 id: Option<String>,
460 name: Option<String>,
461 arguments: String,
462 emitted_len: usize,
463 started: bool,
464 stopped: bool,
465}
466
467impl ToolCallState {
468 fn apply(&mut self, tool_call: DeltaToolCall) {
469 self.openai_index = tool_call.index;
470 if let Some(id) = tool_call.id {
471 self.id = Some(id);
472 }
473 if let Some(name) = tool_call.function.name {
474 self.name = Some(name);
475 }
476 if let Some(arguments) = tool_call.function.arguments {
477 self.arguments.push_str(&arguments);
478 }
479 }
480
481 const fn block_index(&self) -> u32 {
482 self.openai_index + 1
483 }
484
485 fn start_event(&self) -> Option<ContentBlockStartEvent> {
486 let name = self.name.clone()?;
487 let id = self
488 .id
489 .clone()
490 .unwrap_or_else(|| format!("tool_call_{}", self.openai_index));
491 Some(ContentBlockStartEvent {
492 index: self.block_index(),
493 content_block: OutputContentBlock::ToolUse {
494 id,
495 name,
496 input: json!({}),
497 },
498 })
499 }
500
501 fn delta_event(&mut self) -> Option<ContentBlockDeltaEvent> {
502 if self.emitted_len >= self.arguments.len() {
503 return None;
504 }
505 let delta = self.arguments[self.emitted_len..].to_string();
506 self.emitted_len = self.arguments.len();
507 Some(ContentBlockDeltaEvent {
508 index: self.block_index(),
509 delta: ContentBlockDelta::InputJsonDelta {
510 partial_json: delta,
511 },
512 })
513 }
514}
515
516#[derive(Debug, Deserialize)]
517struct ChatCompletionResponse {
518 id: String,
519 model: String,
520 choices: Vec<ChatChoice>,
521 #[serde(default)]
522 usage: Option<OpenAiUsage>,
523}
524
525#[derive(Debug, Deserialize)]
526struct ChatChoice {
527 message: ChatMessage,
528 #[serde(default)]
529 finish_reason: Option<String>,
530}
531
532#[derive(Debug, Deserialize)]
533struct ChatMessage {
534 role: String,
535 #[serde(default)]
536 content: Option<String>,
537 #[serde(default)]
538 tool_calls: Vec<ResponseToolCall>,
539}
540
541#[derive(Debug, Deserialize)]
542struct ResponseToolCall {
543 id: String,
544 function: ResponseToolFunction,
545}
546
547#[derive(Debug, Deserialize)]
548struct ResponseToolFunction {
549 name: String,
550 arguments: String,
551}
552
553#[derive(Debug, Deserialize)]
554struct OpenAiUsage {
555 #[serde(default)]
556 prompt_tokens: u32,
557 #[serde(default)]
558 completion_tokens: u32,
559}
560
561#[derive(Debug, Deserialize)]
562struct ChatCompletionChunk {
563 id: String,
564 #[serde(default)]
565 model: Option<String>,
566 #[serde(default)]
567 choices: Vec<ChunkChoice>,
568 #[serde(default)]
569 usage: Option<OpenAiUsage>,
570}
571
572#[derive(Debug, Deserialize)]
573struct ChunkChoice {
574 delta: ChunkDelta,
575 #[serde(default)]
576 finish_reason: Option<String>,
577}
578
579#[derive(Debug, Default, Deserialize)]
580struct ChunkDelta {
581 #[serde(default)]
582 content: Option<String>,
583 #[serde(default)]
584 tool_calls: Vec<DeltaToolCall>,
585}
586
587#[derive(Debug, Deserialize)]
588struct DeltaToolCall {
589 #[serde(default)]
590 index: u32,
591 #[serde(default)]
592 id: Option<String>,
593 #[serde(default)]
594 function: DeltaFunction,
595}
596
597#[derive(Debug, Default, Deserialize)]
598struct DeltaFunction {
599 #[serde(default)]
600 name: Option<String>,
601 #[serde(default)]
602 arguments: Option<String>,
603}
604
605#[derive(Debug, Deserialize)]
606struct ErrorEnvelope {
607 error: ErrorBody,
608}
609
610#[derive(Debug, Deserialize)]
611struct ErrorBody {
612 #[serde(rename = "type")]
613 error_type: Option<String>,
614 message: Option<String>,
615}
616
617fn build_chat_completion_request(request: &MessageRequest) -> Value {
618 let mut messages = Vec::new();
619 if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) {
620 messages.push(json!({
621 "role": "system",
622 "content": system,
623 }));
624 }
625 for message in &request.messages {
626 messages.extend(translate_message(message));
627 }
628
629 let mut payload = json!({
630 "model": request.model,
631 "max_tokens": request.max_tokens,
632 "messages": messages,
633 "stream": request.stream,
634 });
635
636 if let Some(tools) = &request.tools {
637 payload["tools"] =
638 Value::Array(tools.iter().map(openai_tool_definition).collect::<Vec<_>>());
639 }
640 if let Some(tool_choice) = &request.tool_choice {
641 payload["tool_choice"] = openai_tool_choice(tool_choice);
642 }
643
644 payload
645}
646
647fn translate_message(message: &InputMessage) -> Vec<Value> {
648 match message.role.as_str() {
649 "assistant" => {
650 let mut text = String::new();
651 let mut tool_calls = Vec::new();
652 for block in &message.content {
653 match block {
654 InputContentBlock::Text { text: value } => text.push_str(value),
655 InputContentBlock::ToolUse { id, name, input } => tool_calls.push(json!({
656 "id": id,
657 "type": "function",
658 "function": {
659 "name": name,
660 "arguments": serde_json::to_string(input).unwrap_or_default(),
661 }
662 })),
663 InputContentBlock::ToolResult { .. } => {}
664 }
665 }
666 if text.is_empty() && tool_calls.is_empty() {
667 Vec::new()
668 } else {
669 vec![json!({
670 "role": "assistant",
671 "content": (!text.is_empty()).then_some(text),
672 "tool_calls": tool_calls,
673 })]
674 }
675 }
676 _ => message
677 .content
678 .iter()
679 .filter_map(|block| match block {
680 InputContentBlock::Text { text } => Some(json!({
681 "role": "user",
682 "content": text,
683 })),
684 InputContentBlock::ToolResult {
685 tool_use_id,
686 content,
687 is_error,
688 } => Some(json!({
689 "role": "tool",
690 "tool_call_id": tool_use_id,
691 "content": flatten_tool_result_content(content),
692 "is_error": is_error,
693 })),
694 InputContentBlock::ToolUse { .. } => None,
695 })
696 .collect(),
697 }
698}
699
700fn flatten_tool_result_content(content: &[ToolResultContentBlock]) -> String {
701 content
702 .iter()
703 .map(|block| match block {
704 ToolResultContentBlock::Text { text } => text.clone(),
705 ToolResultContentBlock::Json { value } => value.to_string(),
706 })
707 .collect::<Vec<_>>()
708 .join("\n")
709}
710
711fn openai_tool_definition(tool: &ToolDefinition) -> Value {
712 json!({
713 "type": "function",
714 "function": {
715 "name": tool.name,
716 "description": tool.description,
717 "parameters": tool.input_schema,
718 }
719 })
720}
721
722fn openai_tool_choice(tool_choice: &ToolChoice) -> Value {
723 match tool_choice {
724 ToolChoice::Auto => Value::String("auto".to_string()),
725 ToolChoice::Any => Value::String("required".to_string()),
726 ToolChoice::Tool { name } => json!({
727 "type": "function",
728 "function": { "name": name },
729 }),
730 }
731}
732
733fn normalize_response(
734 model: &str,
735 response: ChatCompletionResponse,
736) -> Result<MessageResponse, ApiError> {
737 let choice = response
738 .choices
739 .into_iter()
740 .next()
741 .ok_or(ApiError::InvalidSseFrame(
742 "chat completion response missing choices",
743 ))?;
744 let mut content = Vec::new();
745 if let Some(text) = choice.message.content.filter(|value| !value.is_empty()) {
746 content.push(OutputContentBlock::Text { text });
747 }
748 for tool_call in choice.message.tool_calls {
749 content.push(OutputContentBlock::ToolUse {
750 id: tool_call.id,
751 name: tool_call.function.name,
752 input: parse_tool_arguments(&tool_call.function.arguments),
753 });
754 }
755
756 Ok(MessageResponse {
757 id: response.id,
758 kind: "message".to_string(),
759 role: choice.message.role,
760 content,
761 model: response.model.if_empty_then(model.to_string()),
762 stop_reason: choice
763 .finish_reason
764 .map(|value| normalize_finish_reason(&value)),
765 stop_sequence: None,
766 usage: Usage {
767 input_tokens: response
768 .usage
769 .as_ref()
770 .map_or(0, |usage| usage.prompt_tokens),
771 cache_creation_input_tokens: 0,
772 cache_read_input_tokens: 0,
773 output_tokens: response
774 .usage
775 .as_ref()
776 .map_or(0, |usage| usage.completion_tokens),
777 },
778 request_id: None,
779 })
780}
781
782fn parse_tool_arguments(arguments: &str) -> Value {
783 serde_json::from_str(arguments).unwrap_or_else(|_| json!({ "raw": arguments }))
784}
785
786fn next_sse_frame(buffer: &mut Vec<u8>) -> Option<String> {
787 let separator = buffer
788 .windows(2)
789 .position(|window| window == b"\n\n")
790 .map(|position| (position, 2))
791 .or_else(|| {
792 buffer
793 .windows(4)
794 .position(|window| window == b"\r\n\r\n")
795 .map(|position| (position, 4))
796 })?;
797
798 let (position, separator_len) = separator;
799 let frame = buffer.drain(..position + separator_len).collect::<Vec<_>>();
800 let frame_len = frame.len().saturating_sub(separator_len);
801 Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned())
802}
803
804fn parse_sse_frame(frame: &str) -> Result<Option<ChatCompletionChunk>, ApiError> {
805 let trimmed = frame.trim();
806 if trimmed.is_empty() {
807 return Ok(None);
808 }
809
810 let mut data_lines = Vec::new();
811 for line in trimmed.lines() {
812 if line.starts_with(':') {
813 continue;
814 }
815 if let Some(data) = line.strip_prefix("data:") {
816 data_lines.push(data.trim_start());
817 }
818 }
819 if data_lines.is_empty() {
820 return Ok(None);
821 }
822 let payload = data_lines.join("\n");
823 if payload == "[DONE]" {
824 return Ok(None);
825 }
826 serde_json::from_str(&payload)
827 .map(Some)
828 .map_err(ApiError::from)
829}
830
831fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
832 match std::env::var(key) {
833 Ok(value) if !value.is_empty() => Ok(Some(value)),
834 Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
835 Err(error) => Err(ApiError::from(error)),
836 }
837}
838
839#[must_use]
840pub fn has_api_key(key: &str) -> bool {
841 read_env_non_empty(key)
842 .ok()
843 .and_then(std::convert::identity)
844 .is_some()
845}
846
847#[must_use]
848pub fn read_base_url(config: OpenAiCompatConfig) -> String {
849 std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string())
850}
851
852fn chat_completions_endpoint(base_url: &str) -> String {
853 let trimmed = base_url.trim_end_matches('/');
854 if trimmed.ends_with("/chat/completions") {
855 trimmed.to_string()
856 } else {
857 format!("{trimmed}/chat/completions")
858 }
859}
860
861fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
862 headers
863 .get(REQUEST_ID_HEADER)
864 .or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
865 .and_then(|value| value.to_str().ok())
866 .map(ToOwned::to_owned)
867}
868
869async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
870 let status = response.status();
871 if status.is_success() {
872 return Ok(response);
873 }
874
875 let body = response.text().await.unwrap_or_default();
876 let parsed_error = serde_json::from_str::<ErrorEnvelope>(&body).ok();
877 let retryable = is_retryable_status(status);
878
879 Err(ApiError::Api {
880 status,
881 error_type: parsed_error
882 .as_ref()
883 .and_then(|error| error.error.error_type.clone()),
884 message: parsed_error
885 .as_ref()
886 .and_then(|error| error.error.message.clone()),
887 body,
888 retryable,
889 })
890}
891
892const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
893 matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
894}
895
896fn normalize_finish_reason(value: &str) -> String {
897 match value {
898 "stop" => "end_turn",
899 "tool_calls" => "tool_use",
900 other => other,
901 }
902 .to_string()
903}
904
905trait StringExt {
906 fn if_empty_then(self, fallback: String) -> String;
907}
908
909impl StringExt for String {
910 fn if_empty_then(self, fallback: String) -> String {
911 if self.is_empty() {
912 fallback
913 } else {
914 self
915 }
916 }
917}
918
919#[cfg(test)]
920#[path = "openai_compat_tests.rs"]
921mod tests;