use std::pin::Pin;
use futures_core::Stream;
use serde::Deserialize;
use crate::providers::AuthConfig;
use crate::wire_formats;
use crate::{Error, ProviderConfig, StreamEvent, WireFormat};
fn apply_auth(builder: reqwest::RequestBuilder, auth: &AuthConfig) -> reqwest::RequestBuilder {
match auth {
AuthConfig::None => builder,
AuthConfig::Bearer(token) => builder.bearer_auth(token),
AuthConfig::ApiKey { header, value } => builder.header(header.as_str(), value.as_str()),
}
}
fn apply_auth_blocking(
builder: reqwest::blocking::RequestBuilder,
auth: &AuthConfig,
) -> reqwest::blocking::RequestBuilder {
match auth {
AuthConfig::None => builder,
AuthConfig::Bearer(token) => builder.bearer_auth(token),
AuthConfig::ApiKey { header, value } => builder.header(header.as_str(), value.as_str()),
}
}
fn parse_sse_buffer(buffer: &str, format: WireFormat) -> (Vec<StreamEvent>, String) {
let mut events = Vec::new();
let mut remaining = String::new();
let parts: Vec<&str> = buffer.split("\n\n").collect();
let last_idx = parts.len().saturating_sub(1);
for (i, chunk) in parts.iter().enumerate() {
if chunk.is_empty() {
continue;
}
if i == last_idx && !buffer.ends_with("\n\n") {
remaining = chunk.to_string();
break;
}
for line in chunk.lines() {
let line = line.trim();
if let Some(data) = line.strip_prefix("data: ") {
let data = data.trim();
if data == "[DONE]" {
events.push(StreamEvent::Done);
} else if let Some(event) = wire_formats::parse_stream_chunk(data, format) {
events.push(event);
}
}
}
}
(events, remaining)
}
const MAX_RETRIES: u32 = 10;
const RETRY_BASE_MS: u64 = 500;
const REQUEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
pub async fn chat_async(
config: &ProviderConfig,
messages: &[crate::ChatMessage],
) -> Result<String, Error> {
let url = wire_formats::endpoint(&config.base_url, config.wire_format);
let body = wire_formats::build_request(messages, config);
let client = reqwest::Client::builder()
.timeout(REQUEST_TIMEOUT)
.build()
.map_err(|e| Error::Unavailable(e.to_string()))?;
let mut last_err = Error::Unavailable("no attempts made".to_string());
for attempt in 0..MAX_RETRIES {
if attempt > 0 {
let delay = std::time::Duration::from_millis(RETRY_BASE_MS * 2u64.pow(attempt - 1));
tokio::time::sleep(delay).await;
}
let resp = apply_auth(client.post(&url), &config.auth)
.json(&body)
.send()
.await
.map_err(|e| Error::Unavailable(e.to_string()))?;
if resp.status() == reqwest::StatusCode::SERVICE_UNAVAILABLE {
last_err = Error::RequestFailed("server busy (503), retrying...".to_string());
continue;
}
let bytes = resp
.error_for_status()
.map_err(|e| Error::RequestFailed(e.to_string()))?
.bytes()
.await
.map_err(|e| Error::RequestFailed(e.to_string()))?;
return wire_formats::parse_response(&bytes, config.wire_format);
}
Err(last_err)
}
pub fn chat(config: &ProviderConfig, messages: &[crate::ChatMessage]) -> Result<String, Error> {
let url = wire_formats::endpoint(&config.base_url, config.wire_format);
let body = wire_formats::build_request(messages, config);
let client = reqwest::blocking::Client::builder()
.timeout(REQUEST_TIMEOUT)
.build()
.map_err(|e| Error::Unavailable(e.to_string()))?;
let mut last_err = Error::Unavailable("no attempts made".to_string());
for attempt in 0..MAX_RETRIES {
if attempt > 0 {
let delay = RETRY_BASE_MS * 2u64.pow(attempt - 1);
std::thread::sleep(std::time::Duration::from_millis(delay));
}
let resp = apply_auth_blocking(client.post(&url), &config.auth)
.json(&body)
.send()
.map_err(|e| Error::Unavailable(e.to_string()))?;
if resp.status() == reqwest::StatusCode::SERVICE_UNAVAILABLE {
last_err = Error::RequestFailed("server busy (503), retrying...".to_string());
continue;
}
let bytes = resp
.error_for_status()
.map_err(|e| Error::RequestFailed(e.to_string()))?
.bytes()
.map_err(|e| Error::RequestFailed(e.to_string()))?;
return wire_formats::parse_response(&bytes, config.wire_format);
}
Err(last_err)
}
pub async fn chat_stream(
config: &ProviderConfig,
messages: &[crate::ChatMessage],
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, Error>> + Send>>, Error> {
let url = wire_formats::endpoint(&config.base_url, config.wire_format);
let mut body = wire_formats::build_request(messages, config);
body.as_object_mut()
.ok_or_else(|| Error::RequestFailed("request body is not a JSON object".to_string()))?
.insert("stream".into(), true.into());
let client = reqwest::Client::builder()
.build()
.map_err(|e| Error::Unavailable(e.to_string()))?;
let resp = apply_auth(client.post(&url), &config.auth)
.json(&body)
.send()
.await
.map_err(|e| Error::Unavailable(e.to_string()))?;
if !resp.status().is_success() {
return Err(Error::RequestFailed(format!("HTTP {}", resp.status())));
}
let format = config.wire_format;
let byte_stream = resp.bytes_stream();
use futures_util::StreamExt;
let event_stream = futures_util::stream::unfold(
(byte_stream, String::new()),
move |(mut byte_stream, mut buffer)| async move {
type Items = Vec<Result<StreamEvent, Error>>;
loop {
match byte_stream.next().await {
Some(Ok(bytes)) => {
buffer.push_str(&String::from_utf8_lossy(&bytes));
let (events, remaining) = parse_sse_buffer(&buffer, format);
buffer = remaining;
if !events.is_empty() {
let is_done = events.iter().any(|e| matches!(e, StreamEvent::Done));
let items: Items = events.into_iter().map(Ok).collect();
let stream = futures_util::stream::iter(items);
if is_done {
return Some((stream, (byte_stream, String::new())));
}
return Some((stream, (byte_stream, buffer)));
}
}
Some(Err(e)) => {
let items: Items = vec![Err(Error::RequestFailed(e.to_string()))];
let stream = futures_util::stream::iter(items);
return Some((stream, (byte_stream, String::new())));
}
None => {
if !buffer.is_empty() {
let (events, _) = parse_sse_buffer(&buffer, format);
if !events.is_empty() {
let items: Items = events.into_iter().map(Ok).collect();
let stream = futures_util::stream::iter(items);
return Some((stream, (byte_stream, String::new())));
}
}
return None;
}
}
}
},
)
.flatten();
Ok(Box::pin(event_stream))
}
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
pub struct HealthResponse {
pub status: String,
}
pub async fn health_check(base_url: &str) -> Result<HealthResponse, Error> {
let url = format!("{}/health", base_url);
let resp = reqwest::Client::new()
.get(&url)
.send()
.await
.map_err(|e| Error::Unavailable(e.to_string()))?
.error_for_status()
.map_err(|e| Error::RequestFailed(e.to_string()))?
.json()
.await
.map_err(|e| Error::RequestFailed(e.to_string()))?;
Ok(resp)
}
pub fn health_check_blocking(base_url: &str) -> Result<HealthResponse, Error> {
let url = format!("{}/health", base_url);
let resp = reqwest::blocking::get(&url)
.map_err(|e| Error::Unavailable(e.to_string()))?
.error_for_status()
.map_err(|e| Error::RequestFailed(e.to_string()))?
.json()
.map_err(|e| Error::RequestFailed(e.to_string()))?;
Ok(resp)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_sse_buffer_single_event() {
let buffer = "data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\n";
let (events, remaining) = parse_sse_buffer(buffer, WireFormat::Completions);
assert_eq!(events.len(), 1);
assert!(matches!(&events[0], StreamEvent::Delta(s) if s == "hi"));
assert!(remaining.is_empty());
}
#[test]
fn parse_sse_buffer_done_signal() {
let buffer = "data: [DONE]\n\n";
let (events, remaining) = parse_sse_buffer(buffer, WireFormat::Completions);
assert_eq!(events.len(), 1);
assert!(matches!(&events[0], StreamEvent::Done));
assert!(remaining.is_empty());
}
#[test]
fn parse_sse_buffer_multiple_events() {
let buffer = "data: {\"choices\":[{\"delta\":{\"content\":\"a\"}}]}\n\ndata: {\"choices\":[{\"delta\":{\"content\":\"b\"}}]}\n\n";
let (events, _) = parse_sse_buffer(buffer, WireFormat::Completions);
assert_eq!(events.len(), 2);
}
#[test]
fn parse_sse_buffer_incomplete_chunk() {
let buffer = "data: {\"choices\":[{\"delta\":{\"content\":\"a\"}}]}\n\ndata: partial";
let (events, remaining) = parse_sse_buffer(buffer, WireFormat::Completions);
assert_eq!(events.len(), 1);
assert_eq!(remaining, "data: partial");
}
#[test]
fn parse_sse_buffer_with_event_prefix() {
let buffer = "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\"hi\"}}\n\n";
let (events, _) = parse_sse_buffer(buffer, WireFormat::Messages);
assert_eq!(events.len(), 1);
assert!(matches!(&events[0], StreamEvent::Delta(s) if s == "hi"));
}
}