1#[path = "openai_compat_sse.rs"]
2mod openai_compat_sse;
3#[path = "openai_compat_stream.rs"]
4mod openai_compat_stream;
5
6use std::collections::VecDeque;
7
8use serde::{Deserialize, Deserializer};
9use serde_json::{json, Value};
10
11use crate::error::ApiError;
12use crate::providers::{parse_custom_provider_prefix, RetryPolicy};
13use crate::types::{
14 InputContentBlock, InputMessage, MessageRequest, MessageResponse, OutputContentBlock,
15 ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
16};
17
18use openai_compat_sse::{first_non_empty_field, OpenAiSseParser};
19use openai_compat_stream::StreamState;
20
21pub use openai_compat_stream::MessageStream;
22
23pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1";
24pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
25const REQUEST_ID_HEADER: &str = "request-id";
26const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub struct OpenAiCompatConfig {
30 pub provider_name: &'static str,
31 pub api_key_env: &'static str,
32 pub base_url_env: &'static str,
33 pub default_base_url: &'static str,
34}
35
36const XAI_ENV_VARS: &[&str] = &["XAI_API_KEY"];
37const OPENAI_ENV_VARS: &[&str] = &["OPENAI_API_KEY"];
38
39impl OpenAiCompatConfig {
40 #[must_use]
41 pub const fn xai() -> Self {
42 Self {
43 provider_name: "xAI",
44 api_key_env: "XAI_API_KEY",
45 base_url_env: "XAI_BASE_URL",
46 default_base_url: DEFAULT_XAI_BASE_URL,
47 }
48 }
49
50 #[must_use]
51 pub const fn openai() -> Self {
52 Self {
53 provider_name: "OpenAI",
54 api_key_env: "OPENAI_API_KEY",
55 base_url_env: "OPENAI_BASE_URL",
56 default_base_url: DEFAULT_OPENAI_BASE_URL,
57 }
58 }
59 #[must_use]
60 pub fn credential_env_vars(self) -> &'static [&'static str] {
61 match self.api_key_env {
62 "XAI_API_KEY" => XAI_ENV_VARS,
63 "OPENAI_API_KEY" => OPENAI_ENV_VARS,
64 _ => &[],
65 }
66 }
67}
68
69#[derive(Clone)]
70pub struct OpenAiCompatClient {
71 http: reqwest::Client,
72 api_key: String,
73 base_url: String,
74 endpoint_query: Option<String>,
75 retry: RetryPolicy,
76}
77
78impl std::fmt::Debug for OpenAiCompatClient {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 f.debug_struct("OpenAiCompatClient")
81 .field("base_url", &self.base_url)
82 .field("endpoint_query", &self.endpoint_query)
83 .field("api_key", &"***")
84 .finish()
85 }
86}
87
88impl OpenAiCompatClient {
89 #[must_use]
90 pub fn new(api_key: impl Into<String>, config: OpenAiCompatConfig) -> Self {
91 Self {
92 http: crate::default_http_client(),
93 api_key: api_key.into(),
94 base_url: read_base_url(config),
95 endpoint_query: None,
96 retry: RetryPolicy::default(),
97 }
98 }
99
100 #[must_use]
101 pub fn new_custom(base_url: impl Into<String>, api_key: impl Into<String>) -> Self {
102 Self {
103 http: crate::default_http_client(),
104 api_key: api_key.into(),
105 base_url: base_url.into(),
106 endpoint_query: None,
107 retry: RetryPolicy::default(),
108 }
109 }
110
111 #[must_use]
112 pub fn with_endpoint_query(mut self, endpoint_query: Option<String>) -> Self {
113 self.endpoint_query = endpoint_query
114 .map(|s| s.trim().to_string())
115 .filter(|s| !s.is_empty());
116 self
117 }
118
119 pub fn from_env(config: OpenAiCompatConfig) -> Result<Self, ApiError> {
120 let Some(api_key) = read_env_non_empty(config.api_key_env)? else {
121 return Err(ApiError::missing_credentials(
122 config.provider_name,
123 config.credential_env_vars(),
124 ));
125 };
126 Ok(Self::new(api_key, config))
127 }
128
129 #[must_use]
130 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
131 self.base_url = base_url.into();
132 self
133 }
134
135 #[must_use]
136 pub fn with_retry_policy(mut self, retry: RetryPolicy) -> Self {
137 self.retry = retry;
138 self
139 }
140
141 pub async fn send_message(
142 &self,
143 request: &MessageRequest,
144 ) -> Result<MessageResponse, ApiError> {
145 let request = MessageRequest {
146 stream: false,
147 ..request.clone()
148 };
149 let response = self.send_with_retry(&request).await?;
150 let request_id = request_id_from_headers(response.headers());
151 let payload = response.json::<ChatCompletionResponse>().await?;
152 let mut normalized = normalize_response(&request.model, payload)?;
153 if normalized.request_id.is_none() {
154 normalized.request_id = request_id;
155 }
156 Ok(normalized)
157 }
158
159 pub async fn stream_message(
160 &self,
161 request: &MessageRequest,
162 ) -> Result<MessageStream, ApiError> {
163 let response = self
164 .send_with_retry(&request.clone().with_streaming())
165 .await?;
166 Ok(MessageStream {
167 request_id: request_id_from_headers(response.headers()),
168 response,
169 parser: OpenAiSseParser::new(),
170 pending: VecDeque::new(),
171 done: false,
172 state: StreamState::new(request.model.clone()),
173 })
174 }
175
176 async fn send_with_retry(
177 &self,
178 request: &MessageRequest,
179 ) -> Result<reqwest::Response, ApiError> {
180 let mut attempts = 0;
181
182 let last_error = loop {
183 attempts += 1;
184 let retryable_error = match self.send_raw_request(request).await {
185 Ok(response) => match expect_success(response).await {
186 Ok(response) => return Ok(response),
187 Err(error)
188 if error.is_retryable() && attempts <= self.retry.max_retries + 1 =>
189 {
190 error
191 }
192 Err(error) => return Err(error),
193 },
194 Err(error) if error.is_retryable() && attempts <= self.retry.max_retries + 1 => {
195 error
196 }
197 Err(error) => return Err(error),
198 };
199
200 if attempts > self.retry.max_retries {
201 break retryable_error;
202 }
203
204 tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
205 };
206
207 Err(ApiError::RetriesExhausted {
208 attempts,
209 last_error: Box::new(last_error),
210 })
211 }
212
213 async fn send_raw_request(
214 &self,
215 request: &MessageRequest,
216 ) -> Result<reqwest::Response, ApiError> {
217 let request_url = chat_completions_endpoint(&self.base_url, self.endpoint_query.as_deref());
218 let mut req = self
219 .http
220 .post(&request_url)
221 .header("content-type", "application/json");
222 if !self.api_key.is_empty() {
223 req = req.bearer_auth(&self.api_key);
224 }
225 req.json(&build_chat_completion_request(request))
226 .send()
227 .await
228 .map_err(ApiError::from)
229 }
230
231 fn backoff_for_attempt(&self, attempt: u32) -> Result<std::time::Duration, ApiError> {
232 let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
233 return Err(ApiError::BackoffOverflow {
234 attempt,
235 base_delay: self.retry.initial_backoff,
236 });
237 };
238 Ok(self
239 .retry
240 .initial_backoff
241 .checked_mul(multiplier)
242 .map_or(self.retry.max_backoff, |delay| {
243 delay.min(self.retry.max_backoff)
244 }))
245 }
246}
247
248#[derive(Debug, Deserialize)]
253struct ChatCompletionResponse {
254 id: String,
255 model: String,
256 choices: Vec<ChatChoice>,
257 #[serde(default)]
258 usage: Option<OpenAiUsage>,
259}
260
261#[derive(Debug, Deserialize)]
262struct ChatChoice {
263 message: ChatMessage,
264 #[serde(default)]
265 finish_reason: Option<String>,
266}
267
268#[derive(Debug, Deserialize)]
269struct ChatMessage {
270 role: String,
271 #[serde(default, deserialize_with = "deserialize_openai_text_content")]
272 content: Option<String>,
273 #[serde(default)]
274 reasoning_content: Option<String>,
275 #[serde(default)]
276 reasoning: Option<String>,
277 #[serde(default)]
278 thought: Option<String>,
279 #[serde(default)]
280 thinking: Option<String>,
281 #[serde(default)]
282 tool_calls: Vec<ResponseToolCall>,
283}
284
285impl ChatMessage {
286 fn assistant_visible_text(&self) -> Option<String> {
287 first_non_empty_field(&[
288 &self.content,
289 &self.reasoning_content,
290 &self.reasoning,
291 &self.thought,
292 &self.thinking,
293 ])
294 }
295}
296
297#[derive(Debug, Deserialize)]
298struct ResponseToolCall {
299 id: String,
300 function: ResponseToolFunction,
301}
302
303#[derive(Debug, Deserialize)]
304struct ResponseToolFunction {
305 name: String,
306 arguments: String,
307}
308
309#[derive(Debug, Deserialize)]
310pub(super) struct OpenAiUsage {
311 #[serde(default)]
312 pub prompt_tokens: u32,
313 #[serde(default)]
314 pub completion_tokens: u32,
315}
316
317#[derive(Debug, Deserialize)]
318struct ErrorEnvelope {
319 error: ErrorBody,
320}
321
322#[derive(Debug, Deserialize)]
323struct ErrorBody {
324 #[serde(rename = "type")]
325 error_type: Option<String>,
326 message: Option<String>,
327}
328
329fn upstream_openai_model(model: &str) -> String {
334 parse_custom_provider_prefix(model)
335 .map(|(_, rest)| rest.to_string())
336 .unwrap_or_else(|| model.to_string())
337}
338
339fn build_chat_completion_request(request: &MessageRequest) -> Value {
340 let mut messages = Vec::new();
341 if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) {
342 messages.push(json!({
343 "role": "system",
344 "content": system,
345 }));
346 }
347 for message in &request.messages {
348 messages.extend(translate_message(message));
349 }
350
351 let upstream_model = upstream_openai_model(&request.model);
352 const MAX_TOKENS_OPENAI_COMPAT_CAP: u32 = 32_768;
353 let max_tokens = request.max_tokens.clamp(1, MAX_TOKENS_OPENAI_COMPAT_CAP);
354 let mut payload = json!({
355 "model": upstream_model,
356 "max_tokens": max_tokens,
357 "messages": messages,
358 "stream": request.stream,
359 });
360
361 if let Some(tools) = &request.tools {
362 payload["tools"] =
363 Value::Array(tools.iter().map(openai_tool_definition).collect::<Vec<_>>());
364 }
365 if let Some(tool_choice) = &request.tool_choice {
366 payload["tool_choice"] = openai_tool_choice(tool_choice);
367 }
368
369 payload
370}
371
372fn translate_message(message: &InputMessage) -> Vec<Value> {
373 match message.role.as_str() {
374 "assistant" => {
375 let mut text = String::new();
376 let mut tool_calls = Vec::new();
377 for block in &message.content {
378 match block {
379 InputContentBlock::Text { text: value } => text.push_str(value),
380 InputContentBlock::ToolUse { id, name, input } => tool_calls.push(json!({
381 "id": id,
382 "type": "function",
383 "function": {
384 "name": name,
385 "arguments": serde_json::to_string(input).unwrap_or_default(),
386 }
387 })),
388 InputContentBlock::ToolResult { .. } => {}
389 }
390 }
391 if text.is_empty() && tool_calls.is_empty() {
392 Vec::new()
393 } else {
394 let mut msg = json!({
395 "role": "assistant",
396 "content": (!text.is_empty()).then_some(text),
397 });
398 if !tool_calls.is_empty() {
401 msg["tool_calls"] = json!(tool_calls);
402 }
403 vec![msg]
404 }
405 }
406 _ => message
407 .content
408 .iter()
409 .filter_map(|block| match block {
410 InputContentBlock::Text { text } => Some(json!({
411 "role": "user",
412 "content": text,
413 })),
414 InputContentBlock::ToolResult {
415 tool_use_id,
416 content,
417 is_error,
418 } => Some(json!({
419 "role": "tool",
420 "tool_call_id": tool_use_id,
421 "content": flatten_tool_result_content(content),
422 "is_error": is_error,
423 })),
424 InputContentBlock::ToolUse { .. } => None,
425 })
426 .collect(),
427 }
428}
429
430fn flatten_tool_result_content(content: &[ToolResultContentBlock]) -> String {
431 content
432 .iter()
433 .map(|block| match block {
434 ToolResultContentBlock::Text { text } => text.clone(),
435 ToolResultContentBlock::Json { value } => value.to_string(),
436 })
437 .collect::<Vec<_>>()
438 .join("\n")
439}
440
441fn openai_tool_definition(tool: &ToolDefinition) -> Value {
442 json!({
443 "type": "function",
444 "function": {
445 "name": tool.name,
446 "description": tool.description,
447 "parameters": tool.input_schema,
448 }
449 })
450}
451
452fn openai_tool_choice(tool_choice: &ToolChoice) -> Value {
453 match tool_choice {
454 ToolChoice::Auto => Value::String("auto".to_string()),
455 ToolChoice::Any => Value::String("required".to_string()),
456 ToolChoice::Tool { name } => json!({
457 "type": "function",
458 "function": { "name": name },
459 }),
460 }
461}
462
463fn normalize_response(
464 model: &str,
465 response: ChatCompletionResponse,
466) -> Result<MessageResponse, ApiError> {
467 let choice = response
468 .choices
469 .into_iter()
470 .next()
471 .ok_or(ApiError::InvalidSseFrame(
472 "chat completion response missing choices",
473 ))?;
474 let mut content = Vec::new();
475 if let Some(text) = choice.message.assistant_visible_text() {
476 content.push(OutputContentBlock::Text { text });
477 }
478 for tool_call in choice.message.tool_calls {
479 content.push(OutputContentBlock::ToolUse {
480 id: tool_call.id,
481 name: tool_call.function.name,
482 input: parse_tool_arguments(&tool_call.function.arguments),
483 });
484 }
485
486 Ok(MessageResponse {
487 id: response.id,
488 kind: "message".to_string(),
489 role: choice.message.role,
490 content,
491 model: response.model.if_empty_then(model.to_string()),
492 stop_reason: choice
493 .finish_reason
494 .map(|value| normalize_finish_reason(&value)),
495 stop_sequence: None,
496 usage: Usage {
497 input_tokens: response
498 .usage
499 .as_ref()
500 .map_or(0, |usage| usage.prompt_tokens),
501 cache_creation_input_tokens: 0,
502 cache_read_input_tokens: 0,
503 output_tokens: response
504 .usage
505 .as_ref()
506 .map_or(0, |usage| usage.completion_tokens),
507 },
508 request_id: None,
509 })
510}
511
512fn parse_tool_arguments(arguments: &str) -> Value {
513 serde_json::from_str(arguments).unwrap_or_else(|_| json!({ "raw": arguments }))
514}
515
516fn deserialize_openai_text_content<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
522where
523 D: Deserializer<'de>,
524{
525 #[derive(Deserialize)]
526 #[serde(untagged)]
527 enum Raw {
528 Str(String),
529 Arr(Vec<Value>),
530 }
531 match Option::<Raw>::deserialize(deserializer)? {
532 None => Ok(None),
533 Some(Raw::Str(s)) if s.is_empty() => Ok(None),
534 Some(Raw::Str(s)) => Ok(Some(s)),
535 Some(Raw::Arr(parts)) => {
536 let mut joined = String::new();
537 for part in parts {
538 match part {
539 Value::Object(map) => {
540 if let Some(text) = map.get("text").and_then(Value::as_str) {
541 joined.push_str(text);
542 } else if let Some(text) = map.get("content").and_then(Value::as_str) {
543 joined.push_str(text);
544 }
545 }
546 Value::String(s) => joined.push_str(&s),
547 _ => {}
548 }
549 }
550 Ok((!joined.is_empty()).then_some(joined))
551 }
552 }
553}
554
555fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
560 match std::env::var(key) {
561 Ok(value) if !value.is_empty() => Ok(Some(value)),
562 Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
563 Err(error) => Err(ApiError::from(error)),
564 }
565}
566
567#[must_use]
568pub fn has_api_key(key: &str) -> bool {
569 read_env_non_empty(key)
570 .ok()
571 .and_then(std::convert::identity)
572 .is_some()
573}
574
575#[must_use]
576pub fn read_base_url(config: OpenAiCompatConfig) -> String {
577 std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string())
578}
579
580fn chat_completions_endpoint(base_url: &str, extra_query: Option<&str>) -> String {
581 let trimmed = base_url.trim();
582 let (path_part, base_query) = match trimmed.split_once('?') {
583 Some((p, q)) => (p.trim_end_matches('/'), Some(q)),
584 None => (trimmed.trim_end_matches('/'), None),
585 };
586 let path = if path_part.ends_with("/chat/completions") {
587 path_part.to_string()
588 } else {
589 format!("{path_part}/chat/completions")
590 };
591 merge_url_query(&path, base_query, extra_query)
592}
593
594fn merge_url_query(path: &str, base_query: Option<&str>, extra_query: Option<&str>) -> String {
595 let mut segments: Vec<&str> = Vec::new();
596 if let Some(q) = base_query.map(str::trim).filter(|q| !q.is_empty()) {
597 segments.push(q);
598 }
599 if let Some(q) = extra_query.map(str::trim).filter(|q| !q.is_empty()) {
600 segments.push(q);
601 }
602 if segments.is_empty() {
603 path.to_string()
604 } else {
605 format!("{path}?{}", segments.join("&"))
606 }
607}
608
609fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
610 headers
611 .get(REQUEST_ID_HEADER)
612 .or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
613 .and_then(|value| value.to_str().ok())
614 .map(ToOwned::to_owned)
615}
616
617async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
618 let status = response.status();
619 if status.is_success() {
620 return Ok(response);
621 }
622
623 let body = response.text().await.unwrap_or_default();
624 let parsed_error = serde_json::from_str::<ErrorEnvelope>(&body).ok();
625 let retryable = is_retryable_status(status);
626
627 Err(ApiError::Api {
628 status,
629 error_type: parsed_error
630 .as_ref()
631 .and_then(|error| error.error.error_type.clone()),
632 message: parsed_error
633 .as_ref()
634 .and_then(|error| error.error.message.clone()),
635 body,
636 retryable,
637 })
638}
639
640const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
641 matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
642}
643
644fn normalize_finish_reason(value: &str) -> String {
645 match value {
646 "stop" => "end_turn",
647 "tool_calls" => "tool_use",
648 other => other,
649 }
650 .to_string()
651}
652
653trait StringExt {
654 fn if_empty_then(self, fallback: String) -> String;
655}
656
657impl StringExt for String {
658 fn if_empty_then(self, fallback: String) -> String {
659 if self.is_empty() {
660 fallback
661 } else {
662 self
663 }
664 }
665}
666
667#[cfg(test)]
668mod openai_compat_inner_tests {
669 use super::*;
670 use crate::types::OutputContentBlock;
671
672 #[test]
673 fn chat_completions_url_appends_api_version() {
674 assert_eq!(
675 chat_completions_endpoint(
676 "https://my.openai.azure.com/openai/deployments/gpt4",
677 Some("api-version=2024-02-15-preview"),
678 ),
679 "https://my.openai.azure.com/openai/deployments/gpt4/chat/completions?api-version=2024-02-15-preview"
680 );
681 }
682
683 #[test]
684 fn chat_completions_url_merges_base_query_and_api_version() {
685 assert_eq!(
686 chat_completions_endpoint(
687 "https://x/v1/chat/completions?existing=1",
688 Some("api-version=2024-02-15-preview"),
689 ),
690 "https://x/v1/chat/completions?existing=1&api-version=2024-02-15-preview"
691 );
692 }
693
694 #[test]
695 fn non_streaming_message_parses_content_array() {
696 let json = r#"{
697 "id":"1",
698 "model":"qwen",
699 "choices":[{
700 "message":{"role":"assistant","content":[{"type":"text","text":"hello"}]},
701 "finish_reason":"stop"
702 }],
703 "usage":{"prompt_tokens":1,"completion_tokens":1}
704 }"#;
705 let resp: ChatCompletionResponse = serde_json::from_str(json).unwrap();
706 let msg = normalize_response("qwen", resp).expect("normalize");
707 assert_eq!(
708 msg.content,
709 vec![OutputContentBlock::Text {
710 text: "hello".to_string()
711 }]
712 );
713 }
714
715 #[test]
716 fn non_streaming_reasoning_only_message() {
717 let json = r#"{
718 "id":"1",
719 "model":"qwen",
720 "choices":[{
721 "message":{"role":"assistant","content":null,"reasoning_content":"think"},
722 "finish_reason":"stop"
723 }],
724 "usage":{"prompt_tokens":1,"completion_tokens":1}
725 }"#;
726 let resp: ChatCompletionResponse = serde_json::from_str(json).unwrap();
727 let msg = normalize_response("qwen", resp).expect("normalize");
728 assert_eq!(
729 msg.content,
730 vec![OutputContentBlock::Text {
731 text: "think".to_string()
732 }]
733 );
734 }
735}
736
737#[cfg(test)]
738#[path = "openai_compat_tests.rs"]
739mod tests;