1use std::collections::{BTreeMap, VecDeque};
2
3use serde::Deserialize;
4use serde_json::{json, Value};
5
6use crate::error::ApiError;
7use crate::providers::RetryPolicy;
8use crate::types::{
9 ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
10 InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest,
11 MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent,
12 ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
13};
14
15pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1";
16pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
17const REQUEST_ID_HEADER: &str = "request-id";
18const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub struct OpenAiCompatConfig {
22 pub provider_name: &'static str,
23 pub api_key_env: &'static str,
24 pub base_url_env: &'static str,
25 pub default_base_url: &'static str,
26}
27
28const XAI_ENV_VARS: &[&str] = &["XAI_API_KEY"];
29const OPENAI_ENV_VARS: &[&str] = &["OPENAI_API_KEY"];
30
31impl OpenAiCompatConfig {
32 #[must_use]
33 pub const fn xai() -> Self {
34 Self {
35 provider_name: "xAI",
36 api_key_env: "XAI_API_KEY",
37 base_url_env: "XAI_BASE_URL",
38 default_base_url: DEFAULT_XAI_BASE_URL,
39 }
40 }
41
42 #[must_use]
43 pub const fn openai() -> Self {
44 Self {
45 provider_name: "OpenAI",
46 api_key_env: "OPENAI_API_KEY",
47 base_url_env: "OPENAI_BASE_URL",
48 default_base_url: DEFAULT_OPENAI_BASE_URL,
49 }
50 }
51 #[must_use]
52 pub fn credential_env_vars(self) -> &'static [&'static str] {
53 match self.api_key_env {
54 "XAI_API_KEY" => XAI_ENV_VARS,
55 "OPENAI_API_KEY" => OPENAI_ENV_VARS,
56 _ => &[],
57 }
58 }
59}
60
61#[derive(Clone)]
62pub struct OpenAiCompatClient {
63 http: reqwest::Client,
64 api_key: String,
65 base_url: String,
66 retry: RetryPolicy,
67}
68
69impl std::fmt::Debug for OpenAiCompatClient {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 f.debug_struct("OpenAiCompatClient")
72 .field("base_url", &self.base_url)
73 .field("api_key", &"***")
74 .finish()
75 }
76}
77
78impl OpenAiCompatClient {
79 #[must_use]
80 pub fn new(api_key: impl Into<String>, config: OpenAiCompatConfig) -> Self {
81 Self {
82 http: crate::default_http_client(),
83 api_key: api_key.into(),
84 base_url: read_base_url(config),
85 retry: RetryPolicy::default(),
86 }
87 }
88
89 #[must_use]
92 pub fn new_custom(base_url: impl Into<String>, api_key: impl Into<String>) -> Self {
93 Self {
94 http: crate::default_http_client(),
95 api_key: api_key.into(),
96 base_url: base_url.into(),
97 retry: RetryPolicy::default(),
98 }
99 }
100
101 pub fn from_env(config: OpenAiCompatConfig) -> Result<Self, ApiError> {
102 let Some(api_key) = read_env_non_empty(config.api_key_env)? else {
103 return Err(ApiError::missing_credentials(
104 config.provider_name,
105 config.credential_env_vars(),
106 ));
107 };
108 Ok(Self::new(api_key, config))
109 }
110
111 #[must_use]
112 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
113 self.base_url = base_url.into();
114 self
115 }
116
117 #[must_use]
118 pub fn with_retry_policy(mut self, retry: RetryPolicy) -> Self {
119 self.retry = retry;
120 self
121 }
122
123 pub async fn send_message(
124 &self,
125 request: &MessageRequest,
126 ) -> Result<MessageResponse, ApiError> {
127 let request = MessageRequest {
128 stream: false,
129 ..request.clone()
130 };
131 let response = self.send_with_retry(&request).await?;
132 let request_id = request_id_from_headers(response.headers());
133 let payload = response.json::<ChatCompletionResponse>().await?;
134 let mut normalized = normalize_response(&request.model, payload)?;
135 if normalized.request_id.is_none() {
136 normalized.request_id = request_id;
137 }
138 Ok(normalized)
139 }
140
141 pub async fn stream_message(
142 &self,
143 request: &MessageRequest,
144 ) -> Result<MessageStream, ApiError> {
145 let response = self
146 .send_with_retry(&request.clone().with_streaming())
147 .await?;
148 Ok(MessageStream {
149 request_id: request_id_from_headers(response.headers()),
150 response,
151 parser: OpenAiSseParser::new(),
152 pending: VecDeque::new(),
153 done: false,
154 state: StreamState::new(request.model.clone()),
155 })
156 }
157
158 async fn send_with_retry(
159 &self,
160 request: &MessageRequest,
161 ) -> Result<reqwest::Response, ApiError> {
162 let mut attempts = 0;
163
164 let last_error = loop {
165 attempts += 1;
166 let retryable_error = match self.send_raw_request(request).await {
167 Ok(response) => match expect_success(response).await {
168 Ok(response) => return Ok(response),
169 Err(error)
170 if error.is_retryable() && attempts <= self.retry.max_retries + 1 =>
171 {
172 error
173 }
174 Err(error) => return Err(error),
175 },
176 Err(error) if error.is_retryable() && attempts <= self.retry.max_retries + 1 => {
177 error
178 }
179 Err(error) => return Err(error),
180 };
181
182 if attempts > self.retry.max_retries {
183 break retryable_error;
184 }
185
186 tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
187 };
188
189 Err(ApiError::RetriesExhausted {
190 attempts,
191 last_error: Box::new(last_error),
192 })
193 }
194
195 async fn send_raw_request(
196 &self,
197 request: &MessageRequest,
198 ) -> Result<reqwest::Response, ApiError> {
199 let request_url = chat_completions_endpoint(&self.base_url);
200 let mut req = self
201 .http
202 .post(&request_url)
203 .header("content-type", "application/json");
204 if !self.api_key.is_empty() {
205 req = req.bearer_auth(&self.api_key);
206 }
207 req.json(&build_chat_completion_request(request))
208 .send()
209 .await
210 .map_err(ApiError::from)
211 }
212
213 fn backoff_for_attempt(&self, attempt: u32) -> Result<std::time::Duration, ApiError> {
214 let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
215 return Err(ApiError::BackoffOverflow {
216 attempt,
217 base_delay: self.retry.initial_backoff,
218 });
219 };
220 Ok(self
221 .retry
222 .initial_backoff
223 .checked_mul(multiplier)
224 .map_or(self.retry.max_backoff, |delay| {
225 delay.min(self.retry.max_backoff)
226 }))
227 }
228}
229
230#[derive(Debug)]
231pub struct MessageStream {
232 request_id: Option<String>,
233 response: reqwest::Response,
234 parser: OpenAiSseParser,
235 pending: VecDeque<StreamEvent>,
236 done: bool,
237 state: StreamState,
238}
239
240impl MessageStream {
241 #[must_use]
242 pub fn request_id(&self) -> Option<&str> {
243 self.request_id.as_deref()
244 }
245
246 pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
247 loop {
248 if let Some(event) = self.pending.pop_front() {
249 return Ok(Some(event));
250 }
251
252 if self.done {
253 self.pending.extend(self.state.finish());
254 if let Some(event) = self.pending.pop_front() {
255 return Ok(Some(event));
256 }
257 return Ok(None);
258 }
259
260 match self.response.chunk().await? {
261 Some(chunk) => {
262 for parsed in self.parser.push(&chunk)? {
263 self.pending.extend(self.state.ingest_chunk(parsed));
264 }
265 }
266 None => {
267 self.done = true;
268 }
269 }
270 }
271 }
272}
273
274#[derive(Debug, Default)]
275struct OpenAiSseParser {
276 buffer: Vec<u8>,
277}
278
279impl OpenAiSseParser {
280 fn new() -> Self {
281 Self::default()
282 }
283
284 fn push(&mut self, chunk: &[u8]) -> Result<Vec<ChatCompletionChunk>, ApiError> {
285 self.buffer.extend_from_slice(chunk);
286 if self.buffer.len() > 16 * 1024 * 1024 {
287 return Err(ApiError::ResponsePayloadTooLarge {
288 limit: 16 * 1024 * 1024,
289 });
290 }
291 let mut events = Vec::new();
292
293 while let Some(frame) = next_sse_frame(&mut self.buffer) {
294 if let Some(event) = parse_sse_frame(&frame)? {
295 events.push(event);
296 }
297 }
298
299 Ok(events)
300 }
301}
302
303#[derive(Debug, Clone, Copy, PartialEq, Eq)]
304enum TextPhase {
305 Pending,
306 Active,
307 Done,
308}
309
310#[derive(Debug)]
311struct StreamState {
312 model: String,
313 message_started: bool,
314 text_phase: TextPhase,
315 finished: bool,
316 stop_reason: Option<String>,
317 usage: Option<Usage>,
318 tool_calls: BTreeMap<u32, ToolCallState>,
319}
320
321impl StreamState {
322 fn new(model: String) -> Self {
323 Self {
324 model,
325 message_started: false,
326 text_phase: TextPhase::Pending,
327 finished: false,
328 stop_reason: None,
329 usage: None,
330 tool_calls: BTreeMap::new(),
331 }
332 }
333
334 fn ingest_chunk(&mut self, chunk: ChatCompletionChunk) -> Vec<StreamEvent> {
335 let mut events = Vec::new();
336 if !self.message_started {
337 self.message_started = true;
338 events.push(StreamEvent::MessageStart(MessageStartEvent {
339 message: MessageResponse {
340 id: chunk.id.clone(),
341 kind: "message".to_string(),
342 role: "assistant".to_string(),
343 content: Vec::new(),
344 model: chunk.model.clone().unwrap_or_else(|| self.model.clone()),
345 stop_reason: None,
346 stop_sequence: None,
347 usage: Usage {
348 input_tokens: 0,
349 cache_creation_input_tokens: 0,
350 cache_read_input_tokens: 0,
351 output_tokens: 0,
352 },
353 request_id: None,
354 },
355 }));
356 }
357
358 if let Some(usage) = chunk.usage {
359 self.usage = Some(Usage {
360 input_tokens: usage.prompt_tokens,
361 cache_creation_input_tokens: 0,
362 cache_read_input_tokens: 0,
363 output_tokens: usage.completion_tokens,
364 });
365 }
366
367 for choice in chunk.choices {
368 if let Some(content) = choice.delta.content.filter(|value| !value.is_empty()) {
369 if self.text_phase == TextPhase::Pending {
370 self.text_phase = TextPhase::Active;
371 events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent {
372 index: 0,
373 content_block: OutputContentBlock::Text {
374 text: String::new(),
375 },
376 }));
377 }
378 events.push(StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
379 index: 0,
380 delta: ContentBlockDelta::TextDelta { text: content },
381 }));
382 }
383
384 for tool_call in choice.delta.tool_calls {
385 let state = self.tool_calls.entry(tool_call.index).or_default();
386 state.apply(tool_call);
387 let block_index = state.block_index();
388 if !state.started {
389 if let Some(start_event) = state.start_event() {
390 state.started = true;
391 events.push(StreamEvent::ContentBlockStart(start_event));
392 } else {
393 continue;
394 }
395 }
396 if let Some(delta_event) = state.delta_event() {
397 events.push(StreamEvent::ContentBlockDelta(delta_event));
398 }
399 if choice.finish_reason.as_deref() == Some("tool_calls") && !state.stopped {
400 state.stopped = true;
401 events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
402 index: block_index,
403 }));
404 }
405 }
406
407 if let Some(finish_reason) = choice.finish_reason {
408 self.stop_reason = Some(normalize_finish_reason(&finish_reason));
409 if finish_reason == "tool_calls" {
410 for state in self.tool_calls.values_mut() {
411 if state.started && !state.stopped {
412 state.stopped = true;
413 events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
414 index: state.block_index(),
415 }));
416 }
417 }
418 }
419 }
420 }
421
422 events
423 }
424
425 fn finish(&mut self) -> Vec<StreamEvent> {
426 if self.finished {
427 return Vec::new();
428 }
429 self.finished = true;
430
431 let mut events = Vec::new();
432 if self.text_phase == TextPhase::Active {
433 self.text_phase = TextPhase::Done;
434 events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
435 index: 0,
436 }));
437 }
438
439 for state in self.tool_calls.values_mut() {
440 if !state.started {
441 if let Some(start_event) = state.start_event() {
442 state.started = true;
443 events.push(StreamEvent::ContentBlockStart(start_event));
444 if let Some(delta_event) = state.delta_event() {
445 events.push(StreamEvent::ContentBlockDelta(delta_event));
446 }
447 }
448 }
449 if state.started && !state.stopped {
450 state.stopped = true;
451 events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
452 index: state.block_index(),
453 }));
454 }
455 }
456
457 if self.message_started {
458 events.push(StreamEvent::MessageDelta(MessageDeltaEvent {
459 delta: MessageDelta {
460 stop_reason: Some(
461 self.stop_reason
462 .clone()
463 .unwrap_or_else(|| "end_turn".to_string()),
464 ),
465 stop_sequence: None,
466 },
467 usage: self.usage.clone().unwrap_or(Usage {
468 input_tokens: 0,
469 cache_creation_input_tokens: 0,
470 cache_read_input_tokens: 0,
471 output_tokens: 0,
472 }),
473 }));
474 events.push(StreamEvent::MessageStop(MessageStopEvent {}));
475 }
476 events
477 }
478}
479
480#[derive(Debug, Default)]
481struct ToolCallState {
482 openai_index: u32,
483 id: Option<String>,
484 name: Option<String>,
485 arguments: String,
486 emitted_len: usize,
487 started: bool,
488 stopped: bool,
489}
490
491impl ToolCallState {
492 fn apply(&mut self, tool_call: DeltaToolCall) {
493 self.openai_index = tool_call.index;
494 if let Some(id) = tool_call.id {
495 self.id = Some(id);
496 }
497 if let Some(name) = tool_call.function.name {
498 self.name = Some(name);
499 }
500 if let Some(arguments) = tool_call.function.arguments {
501 self.arguments.push_str(&arguments);
502 }
503 }
504
505 const fn block_index(&self) -> u32 {
506 self.openai_index + 1
507 }
508
509 fn start_event(&self) -> Option<ContentBlockStartEvent> {
510 let name = self.name.clone()?;
511 let id = self
512 .id
513 .clone()
514 .unwrap_or_else(|| format!("tool_call_{}", self.openai_index));
515 Some(ContentBlockStartEvent {
516 index: self.block_index(),
517 content_block: OutputContentBlock::ToolUse {
518 id,
519 name,
520 input: json!({}),
521 },
522 })
523 }
524
525 fn delta_event(&mut self) -> Option<ContentBlockDeltaEvent> {
526 if self.emitted_len >= self.arguments.len() {
527 return None;
528 }
529 let delta = self.arguments[self.emitted_len..].to_string();
530 self.emitted_len = self.arguments.len();
531 Some(ContentBlockDeltaEvent {
532 index: self.block_index(),
533 delta: ContentBlockDelta::InputJsonDelta {
534 partial_json: delta,
535 },
536 })
537 }
538}
539
540#[derive(Debug, Deserialize)]
541struct ChatCompletionResponse {
542 id: String,
543 model: String,
544 choices: Vec<ChatChoice>,
545 #[serde(default)]
546 usage: Option<OpenAiUsage>,
547}
548
549#[derive(Debug, Deserialize)]
550struct ChatChoice {
551 message: ChatMessage,
552 #[serde(default)]
553 finish_reason: Option<String>,
554}
555
556#[derive(Debug, Deserialize)]
557struct ChatMessage {
558 role: String,
559 #[serde(default)]
560 content: Option<String>,
561 #[serde(default)]
562 tool_calls: Vec<ResponseToolCall>,
563}
564
565#[derive(Debug, Deserialize)]
566struct ResponseToolCall {
567 id: String,
568 function: ResponseToolFunction,
569}
570
571#[derive(Debug, Deserialize)]
572struct ResponseToolFunction {
573 name: String,
574 arguments: String,
575}
576
577#[derive(Debug, Deserialize)]
578struct OpenAiUsage {
579 #[serde(default)]
580 prompt_tokens: u32,
581 #[serde(default)]
582 completion_tokens: u32,
583}
584
585#[derive(Debug, Deserialize)]
586struct ChatCompletionChunk {
587 id: String,
588 #[serde(default)]
589 model: Option<String>,
590 #[serde(default)]
591 choices: Vec<ChunkChoice>,
592 #[serde(default)]
593 usage: Option<OpenAiUsage>,
594}
595
596#[derive(Debug, Deserialize)]
597struct ChunkChoice {
598 delta: ChunkDelta,
599 #[serde(default)]
600 finish_reason: Option<String>,
601}
602
603#[derive(Debug, Default, Deserialize)]
604struct ChunkDelta {
605 #[serde(default)]
606 content: Option<String>,
607 #[serde(default)]
608 tool_calls: Vec<DeltaToolCall>,
609}
610
611#[derive(Debug, Deserialize)]
612struct DeltaToolCall {
613 #[serde(default)]
614 index: u32,
615 #[serde(default)]
616 id: Option<String>,
617 #[serde(default)]
618 function: DeltaFunction,
619}
620
621#[derive(Debug, Default, Deserialize)]
622struct DeltaFunction {
623 #[serde(default)]
624 name: Option<String>,
625 #[serde(default)]
626 arguments: Option<String>,
627}
628
629#[derive(Debug, Deserialize)]
630struct ErrorEnvelope {
631 error: ErrorBody,
632}
633
634#[derive(Debug, Deserialize)]
635struct ErrorBody {
636 #[serde(rename = "type")]
637 error_type: Option<String>,
638 message: Option<String>,
639}
640
641fn build_chat_completion_request(request: &MessageRequest) -> Value {
642 let mut messages = Vec::new();
643 if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) {
644 messages.push(json!({
645 "role": "system",
646 "content": system,
647 }));
648 }
649 for message in &request.messages {
650 messages.extend(translate_message(message));
651 }
652
653 let mut payload = json!({
654 "model": request.model,
655 "max_tokens": request.max_tokens,
656 "messages": messages,
657 "stream": request.stream,
658 });
659
660 if let Some(tools) = &request.tools {
661 payload["tools"] =
662 Value::Array(tools.iter().map(openai_tool_definition).collect::<Vec<_>>());
663 }
664 if let Some(tool_choice) = &request.tool_choice {
665 payload["tool_choice"] = openai_tool_choice(tool_choice);
666 }
667
668 payload
669}
670
671fn translate_message(message: &InputMessage) -> Vec<Value> {
672 match message.role.as_str() {
673 "assistant" => {
674 let mut text = String::new();
675 let mut tool_calls = Vec::new();
676 for block in &message.content {
677 match block {
678 InputContentBlock::Text { text: value } => text.push_str(value),
679 InputContentBlock::ToolUse { id, name, input } => tool_calls.push(json!({
680 "id": id,
681 "type": "function",
682 "function": {
683 "name": name,
684 "arguments": serde_json::to_string(input).unwrap_or_default(),
685 }
686 })),
687 InputContentBlock::ToolResult { .. } => {}
688 }
689 }
690 if text.is_empty() && tool_calls.is_empty() {
691 Vec::new()
692 } else {
693 vec![json!({
694 "role": "assistant",
695 "content": (!text.is_empty()).then_some(text),
696 "tool_calls": tool_calls,
697 })]
698 }
699 }
700 _ => message
701 .content
702 .iter()
703 .filter_map(|block| match block {
704 InputContentBlock::Text { text } => Some(json!({
705 "role": "user",
706 "content": text,
707 })),
708 InputContentBlock::ToolResult {
709 tool_use_id,
710 content,
711 is_error,
712 } => Some(json!({
713 "role": "tool",
714 "tool_call_id": tool_use_id,
715 "content": flatten_tool_result_content(content),
716 "is_error": is_error,
717 })),
718 InputContentBlock::ToolUse { .. } => None,
719 })
720 .collect(),
721 }
722}
723
724fn flatten_tool_result_content(content: &[ToolResultContentBlock]) -> String {
725 content
726 .iter()
727 .map(|block| match block {
728 ToolResultContentBlock::Text { text } => text.clone(),
729 ToolResultContentBlock::Json { value } => value.to_string(),
730 })
731 .collect::<Vec<_>>()
732 .join("\n")
733}
734
735fn openai_tool_definition(tool: &ToolDefinition) -> Value {
736 json!({
737 "type": "function",
738 "function": {
739 "name": tool.name,
740 "description": tool.description,
741 "parameters": tool.input_schema,
742 }
743 })
744}
745
746fn openai_tool_choice(tool_choice: &ToolChoice) -> Value {
747 match tool_choice {
748 ToolChoice::Auto => Value::String("auto".to_string()),
749 ToolChoice::Any => Value::String("required".to_string()),
750 ToolChoice::Tool { name } => json!({
751 "type": "function",
752 "function": { "name": name },
753 }),
754 }
755}
756
757fn normalize_response(
758 model: &str,
759 response: ChatCompletionResponse,
760) -> Result<MessageResponse, ApiError> {
761 let choice = response
762 .choices
763 .into_iter()
764 .next()
765 .ok_or(ApiError::InvalidSseFrame(
766 "chat completion response missing choices",
767 ))?;
768 let mut content = Vec::new();
769 if let Some(text) = choice.message.content.filter(|value| !value.is_empty()) {
770 content.push(OutputContentBlock::Text { text });
771 }
772 for tool_call in choice.message.tool_calls {
773 content.push(OutputContentBlock::ToolUse {
774 id: tool_call.id,
775 name: tool_call.function.name,
776 input: parse_tool_arguments(&tool_call.function.arguments),
777 });
778 }
779
780 Ok(MessageResponse {
781 id: response.id,
782 kind: "message".to_string(),
783 role: choice.message.role,
784 content,
785 model: response.model.if_empty_then(model.to_string()),
786 stop_reason: choice
787 .finish_reason
788 .map(|value| normalize_finish_reason(&value)),
789 stop_sequence: None,
790 usage: Usage {
791 input_tokens: response
792 .usage
793 .as_ref()
794 .map_or(0, |usage| usage.prompt_tokens),
795 cache_creation_input_tokens: 0,
796 cache_read_input_tokens: 0,
797 output_tokens: response
798 .usage
799 .as_ref()
800 .map_or(0, |usage| usage.completion_tokens),
801 },
802 request_id: None,
803 })
804}
805
806fn parse_tool_arguments(arguments: &str) -> Value {
807 serde_json::from_str(arguments).unwrap_or_else(|_| json!({ "raw": arguments }))
808}
809
810fn next_sse_frame(buffer: &mut Vec<u8>) -> Option<String> {
811 let separator = buffer
812 .windows(2)
813 .position(|window| window == b"\n\n")
814 .map(|position| (position, 2))
815 .or_else(|| {
816 buffer
817 .windows(4)
818 .position(|window| window == b"\r\n\r\n")
819 .map(|position| (position, 4))
820 })?;
821
822 let (position, separator_len) = separator;
823 let frame = buffer.drain(..position + separator_len).collect::<Vec<_>>();
824 let frame_len = frame.len().saturating_sub(separator_len);
825 Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned())
826}
827
828fn parse_sse_frame(frame: &str) -> Result<Option<ChatCompletionChunk>, ApiError> {
829 let trimmed = frame.trim();
830 if trimmed.is_empty() {
831 return Ok(None);
832 }
833
834 let mut data_lines = Vec::new();
835 for line in trimmed.lines() {
836 if line.starts_with(':') {
837 continue;
838 }
839 if let Some(data) = line.strip_prefix("data:") {
840 data_lines.push(data.trim_start());
841 }
842 }
843 if data_lines.is_empty() {
844 return Ok(None);
845 }
846 let payload = data_lines.join("\n");
847 if payload == "[DONE]" {
848 return Ok(None);
849 }
850 serde_json::from_str(&payload)
851 .map(Some)
852 .map_err(ApiError::from)
853}
854
855fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
856 match std::env::var(key) {
857 Ok(value) if !value.is_empty() => Ok(Some(value)),
858 Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
859 Err(error) => Err(ApiError::from(error)),
860 }
861}
862
863#[must_use]
864pub fn has_api_key(key: &str) -> bool {
865 read_env_non_empty(key)
866 .ok()
867 .and_then(std::convert::identity)
868 .is_some()
869}
870
871#[must_use]
872pub fn read_base_url(config: OpenAiCompatConfig) -> String {
873 std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string())
874}
875
876fn chat_completions_endpoint(base_url: &str) -> String {
877 let trimmed = base_url.trim_end_matches('/');
878 if trimmed.ends_with("/chat/completions") {
879 trimmed.to_string()
880 } else {
881 format!("{trimmed}/chat/completions")
882 }
883}
884
885fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
886 headers
887 .get(REQUEST_ID_HEADER)
888 .or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
889 .and_then(|value| value.to_str().ok())
890 .map(ToOwned::to_owned)
891}
892
893async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
894 let status = response.status();
895 if status.is_success() {
896 return Ok(response);
897 }
898
899 let body = response.text().await.unwrap_or_default();
900 let parsed_error = serde_json::from_str::<ErrorEnvelope>(&body).ok();
901 let retryable = is_retryable_status(status);
902
903 Err(ApiError::Api {
904 status,
905 error_type: parsed_error
906 .as_ref()
907 .and_then(|error| error.error.error_type.clone()),
908 message: parsed_error
909 .as_ref()
910 .and_then(|error| error.error.message.clone()),
911 body,
912 retryable,
913 })
914}
915
916const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
917 matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
918}
919
920fn normalize_finish_reason(value: &str) -> String {
921 match value {
922 "stop" => "end_turn",
923 "tool_calls" => "tool_use",
924 other => other,
925 }
926 .to_string()
927}
928
929trait StringExt {
930 fn if_empty_then(self, fallback: String) -> String;
931}
932
933impl StringExt for String {
934 fn if_empty_then(self, fallback: String) -> String {
935 if self.is_empty() {
936 fallback
937 } else {
938 self
939 }
940 }
941}
942
943#[cfg(test)]
944#[path = "openai_compat_tests.rs"]
945mod tests;