use bytes::Bytes;
use futures_util::{Stream, StreamExt};
use reqwest::header::{HeaderName, HeaderValue};
use serde::Serialize;
use crate::config::{SDK_LANGUAGE, SDK_VERSION};
use crate::error::{ApiError, Error};
use crate::types::common::RequestOptions;
#[derive(Debug, Clone, Default)]
pub struct SseEvent {
pub event: Option<String>,
pub data: String,
pub id: Option<String>,
pub retry: Option<u64>,
}
pub async fn stream<Body>(
client: &crate::ReasoningLayerClient,
method: reqwest::Method,
path: &str,
body: Option<&Body>,
options: Option<&RequestOptions>,
) -> Result<impl Stream<Item = Result<SseEvent, Error>>, Error>
where
Body: Serialize + ?Sized,
{
let http = client.http();
let config = &http.config;
let resolved_path = crate::http::resolve_path_public(path);
let url = format!("{}{}", config.base_url, resolved_path);
let mut req = http.inner.request(method, &url);
req = req
.timeout(options.and_then(|o| o.timeout).unwrap_or(config.timeout))
.header(reqwest::header::ACCEPT, "text/event-stream")
.header(reqwest::header::CONTENT_TYPE, "application/json")
.header(HeaderName::from_static("x-sdk-version"), SDK_VERSION)
.header(HeaderName::from_static("x-sdk-language"), SDK_LANGUAGE)
.header(
HeaderName::from_static("x-tenant-id"),
config.tenant_id.as_str(),
);
if let Some(user_id) = options
.and_then(|o| o.user_id.as_deref())
.or(config.user_id.as_deref())
{
req = req.header(HeaderName::from_static("x-user-id"), user_id);
}
if let Some(ns) = options
.and_then(|o| o.namespace_id.as_deref())
.or(config.namespace_id.as_deref())
{
req = req.header(HeaderName::from_static("x-namespace-id"), ns);
}
if let Some(user) = config.authenticated_user.as_deref() {
req = req.header(HeaderName::from_static("x-authenticated-user"), user);
}
if let crate::config::AuthConfig::Bearer(token) = &config.auth {
let value = HeaderValue::from_str(&format!("Bearer {token}"))
.map_err(|_| Error::validation("auth", "invalid characters in bearer token"))?;
req = req.header(reqwest::header::AUTHORIZATION, value);
}
if let Some(body) = body {
req = req.json(body);
}
let response = req
.send()
.await
.map_err(|e: reqwest::Error| Error::Network {
message: e.to_string(),
source: Some(Box::new(e)),
})?;
if !response.status().is_success() {
let status = response.status();
let headers = response.headers().clone();
let body_bytes: bytes::Bytes = response.bytes().await.unwrap_or_default();
let body_value: Option<serde_json::Value> = if body_bytes.is_empty() {
None
} else {
serde_json::from_slice(&body_bytes).ok()
};
return Err(ApiError::from_response(status, body_value, headers).into());
}
let byte_stream = response.bytes_stream();
Ok(parse_sse(byte_stream))
}
fn parse_sse<S>(mut byte_stream: S) -> impl Stream<Item = Result<SseEvent, Error>>
where
S: Stream<Item = Result<Bytes, reqwest::Error>> + Unpin,
{
async_stream::stream! {
let mut buffer = String::new();
while let Some(chunk) = byte_stream.next().await {
let chunk = match chunk {
Ok(b) => b,
Err(e) => {
yield Err(Error::Network {
message: e.to_string(),
source: Some(Box::new(e)),
});
return;
}
};
let s = match std::str::from_utf8(&chunk) {
Ok(s) => s,
Err(e) => {
yield Err(Error::Network {
message: format!("invalid utf-8 in sse stream: {e}"),
source: None,
});
return;
}
};
buffer.push_str(s);
while let Some((idx, sep_len)) = find_event_boundary(&buffer) {
let block = buffer[..idx].to_string();
buffer = buffer[idx + sep_len..].to_string();
let mut event = SseEvent::default();
for line in block.lines() {
if line.is_empty() || line.starts_with(':') {
continue;
}
let (field, value) = match line.split_once(':') {
Some((f, v)) => (f, v.strip_prefix(' ').unwrap_or(v)),
None => (line, ""),
};
match field {
"event" => event.event = Some(value.to_string()),
"id" => event.id = Some(value.to_string()),
"retry" => event.retry = value.parse().ok(),
"data" => {
if !event.data.is_empty() {
event.data.push('\n');
}
event.data.push_str(value);
}
_ => {}
}
}
if !event.data.is_empty() || event.event.is_some() {
yield Ok(event);
}
}
}
}
}
fn find_event_boundary(s: &str) -> Option<(usize, usize)> {
if let Some(idx) = s.find("\r\n\r\n") {
Some((idx, 4))
} else {
s.find("\n\n").map(|idx| (idx, 2))
}
}