use reqwest::{
Response,
header::{ACCEPT, AUTHORIZATION, HeaderMap, HeaderValue},
};
use serde::de::DeserializeOwned;
use crate::error::{ProviderError, TransportError};
pub(crate) const MAX_BODY_BYTES: usize = 8 * 1024 * 1024;
pub fn build_client(
builder: reqwest::ClientBuilder,
api_key: &str,
) -> Result<reqwest::Client, TransportError> {
let mut default_headers = HeaderMap::new();
let auth_value = HeaderValue::from_str(&format!("Bearer {api_key}"))
.map_err(|_| TransportError::InvalidConfig("api key contains invalid header characters"))?;
default_headers.insert(AUTHORIZATION, auth_value);
default_headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
builder
.default_headers(default_headers)
.build()
.map_err(TransportError::BuildClient)
}
fn append_capped(buf: &mut Vec<u8>, chunk: &[u8], limit: usize) -> Result<(), TransportError> {
if buf.len().saturating_add(chunk.len()) > limit {
return Err(TransportError::BodyTooLarge { limit });
}
buf.extend_from_slice(chunk);
Ok(())
}
async fn read_body_text(response: Response) -> Result<String, TransportError> {
use futures_util::StreamExt;
let mut buf = Vec::new();
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(TransportError::Transport)?;
append_capped(&mut buf, &chunk, MAX_BODY_BYTES)?;
}
String::from_utf8(buf).map_err(TransportError::Utf8)
}
pub async fn ensure_success(response: Response) -> Result<Response, TransportError> {
let status = response.status();
if status.is_success() {
return Ok(response);
}
let body = read_body_text(response).await?;
Err(TransportError::HttpStatus { status, body })
}
pub fn endpoint_url(base_url: &str, path: &str) -> Result<String, TransportError> {
let mut base = reqwest::Url::parse(base_url)
.map_err(|_| TransportError::InvalidConfig("invalid base URL"))?;
if !base.path().ends_with('/') {
let mut p = base.path().to_owned();
p.push('/');
base.set_path(&p);
}
let full = base
.join(path.trim_start_matches('/'))
.map_err(|_| TransportError::InvalidConfig("invalid endpoint path"))?;
Ok(full.into())
}
pub async fn parse_json<T: DeserializeOwned>(response: Response) -> Result<T, ProviderError> {
let response = ensure_success(response).await?;
let body = read_body_text(response).await?;
parse_body(body)
}
fn parse_body<T: DeserializeOwned>(body: String) -> Result<T, ProviderError> {
if body.trim().is_empty() {
return Err(TransportError::InvalidResponse("empty response body".to_owned()).into());
}
serde_json::from_str(&body).map_err(|source| ProviderError::Deserialize { source, body })
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn endpoint_url_resolves_paths_correctly() {
let cases = [
(
"https://api.example.com/v1",
"chat/completions",
"https://api.example.com/v1/chat/completions",
),
(
"https://api.example.com/v1/",
"chat/completions",
"https://api.example.com/v1/chat/completions",
),
(
"https://api.example.com/v1",
"/chat/completions",
"https://api.example.com/v1/chat/completions",
),
(
"https://api.deepseek.com",
"chat/completions",
"https://api.deepseek.com/chat/completions",
),
(
"https://proxy.example.com/openai/v1",
"chat/completions",
"https://proxy.example.com/openai/v1/chat/completions",
),
];
for (base, path, expected) in cases {
assert_eq!(
endpoint_url(base, path).unwrap(),
expected,
"endpoint_url({base:?}, {path:?})"
);
}
}
#[test]
fn endpoint_url_rejects_invalid_base() {
assert!(endpoint_url("not a url", "chat/completions").is_err());
}
#[test]
fn parse_body_rejects_empty_body() {
let err = parse_body::<serde_json::Value>(String::new()).unwrap_err();
assert!(
matches!(
err,
ProviderError::Transport(TransportError::InvalidResponse(_))
),
"empty body must surface as InvalidResponse, got {err:?}"
);
let err = parse_body::<serde_json::Value>(" \n\t ".to_owned()).unwrap_err();
assert!(
matches!(
err,
ProviderError::Transport(TransportError::InvalidResponse(_))
),
"whitespace-only body must surface as InvalidResponse, got {err:?}"
);
}
#[test]
fn parse_body_maps_malformed_to_deserialize() {
let err = parse_body::<serde_json::Value>("not valid json".to_owned()).unwrap_err();
assert!(
matches!(err, ProviderError::Deserialize { .. }),
"malformed body must surface as Deserialize, got {err:?}"
);
}
#[test]
fn parse_body_deserializes_valid_json() {
let value = parse_body::<serde_json::Value>(r#"{"x":7}"#.to_owned()).unwrap();
assert_eq!(value["x"], 7);
}
#[test]
fn append_capped_accepts_up_to_limit() {
let mut buf = Vec::new();
let fill = vec![0u8; 16];
append_capped(&mut buf, &fill, 16).unwrap();
assert_eq!(buf.len(), 16);
}
#[test]
fn append_capped_rejects_overflow_without_appending() {
let mut buf = Vec::new();
let limit = 16;
let fill = vec![0u8; limit];
append_capped(&mut buf, &fill, limit).unwrap();
let err = append_capped(&mut buf, &[1], limit).unwrap_err();
assert!(
matches!(err, TransportError::BodyTooLarge { limit: 16 }),
"overflow must surface as BodyTooLarge, got {err:?}"
);
assert_eq!(buf.len(), limit);
assert!(!buf.contains(&1));
}
#[test]
fn append_capped_rejects_single_oversized_chunk() {
let mut buf = Vec::new();
let oversized = vec![0u8; 17];
let err = append_capped(&mut buf, &oversized, 16).unwrap_err();
assert!(
matches!(err, TransportError::BodyTooLarge { limit: 16 }),
"first-chunk overflow must surface as BodyTooLarge, got {err:?}"
);
assert!(
buf.is_empty(),
"buffer must stay empty when the first chunk overflows"
);
}
}