use crate::{FabricClient, FabricError, Result};
use bytes::Bytes;
use futures_util::stream::{Stream, StreamExt};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SseEvent {
#[serde(skip_serializing_if = "Option::is_none")]
pub event: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
pub data: String,
}
pub fn parse_sse_stream<S>(stream: S) -> impl Stream<Item = Result<SseEvent>>
where
S: Stream<Item = std::result::Result<Bytes, reqwest::Error>>,
{
let mut buffer: Vec<u8> = Vec::with_capacity(4096);
let mut current = SseEvent::default();
let mut data_buf = String::new();
let mut has_content = false;
async_stream::stream! {
let mut stream = Box::pin(stream);
while let Some(chunk) = stream.next().await {
let chunk = match chunk {
Ok(b) => b,
Err(e) => {
yield Err(FabricError::Http(e));
return;
}
};
buffer.extend_from_slice(&chunk);
while let Some(nl) = buffer.iter().position(|&b| b == b'\n') {
let mut line = buffer.drain(..=nl).collect::<Vec<u8>>();
line.pop(); if line.last() == Some(&b'\r') {
line.pop();
}
let line = match std::str::from_utf8(&line) {
Ok(s) => s,
Err(e) => {
yield Err(FabricError::Other(format!("invalid UTF-8 in SSE stream: {e}")));
return;
}
};
if line.is_empty() {
if has_content {
current.data = std::mem::take(&mut data_buf);
yield Ok(std::mem::take(&mut current));
has_content = false;
}
continue;
}
if line.starts_with(':') {
continue;
}
let (field, value) = match line.find(':') {
Some(i) => {
let v = &line[i + 1..];
let v = v.strip_prefix(' ').unwrap_or(v);
(&line[..i], v)
}
None => (line, ""),
};
match field {
"event" => {
current.event = Some(value.to_string());
has_content = true;
}
"id" => {
current.id = Some(value.to_string());
has_content = true;
}
"data" => {
if !data_buf.is_empty() {
data_buf.push('\n');
}
data_buf.push_str(value);
has_content = true;
}
_ => {}
}
}
}
if has_content {
current.data = data_buf;
yield Ok(current);
}
}
}
fn with_internal_query(path: &str, include_internal: bool) -> String {
if !include_internal {
return path.to_string();
}
let sep = if path.contains('?') { '&' } else { '?' };
format!("{path}{sep}include_internal=true")
}
impl FabricClient {
pub async fn stream_workflow_run(
&self,
run_id: &str,
) -> Result<impl Stream<Item = Result<SseEvent>>> {
self.sse_get(&format!("/v1/workflow-runs/{run_id}/events"))
.await
}
pub async fn stream_workflow_run_with_internal(
&self,
run_id: &str,
include_internal: bool,
) -> Result<impl Stream<Item = Result<SseEvent>>> {
let path = with_internal_query(
&format!("/v1/workflow-runs/{run_id}/events"),
include_internal,
);
self.sse_get(&path).await
}
pub async fn stream_job(&self, job_id: &str) -> Result<impl Stream<Item = Result<SseEvent>>> {
self.sse_get(&format!("/v1/jobs/{job_id}/events")).await
}
pub async fn stream_job_with_internal(
&self,
job_id: &str,
include_internal: bool,
) -> Result<impl Stream<Item = Result<SseEvent>>> {
let path = with_internal_query(&format!("/v1/jobs/{job_id}/events"), include_internal);
self.sse_get(&path).await
}
pub async fn stream_events(&self) -> Result<impl Stream<Item = Result<SseEvent>>> {
self.sse_get("/v1/events/stream").await
}
pub async fn stream_events_with_internal(
&self,
include_internal: bool,
) -> Result<impl Stream<Item = Result<SseEvent>>> {
let path = with_internal_query("/v1/events/stream", include_internal);
self.sse_get(&path).await
}
pub async fn stream_provider_execute(
&self,
body: serde_json::Value,
) -> Result<impl Stream<Item = Result<SseEvent>>> {
self.sse_post("/v1/providers/execute/stream", body).await
}
async fn sse_get(&self, path: &str) -> Result<impl Stream<Item = Result<SseEvent>>> {
let url = format!("{}{path}", self.base_url);
let resp = self
.client
.get(&url)
.header(reqwest::header::ACCEPT, "text/event-stream")
.send()
.await?;
check_sse_status(&resp)?;
Ok(parse_sse_stream(resp.bytes_stream()))
}
async fn sse_post(
&self,
path: &str,
body: serde_json::Value,
) -> Result<impl Stream<Item = Result<SseEvent>>> {
let url = format!("{}{path}", self.base_url);
let resp = self
.client
.post(&url)
.header(reqwest::header::ACCEPT, "text/event-stream")
.json(&body)
.send()
.await?;
check_sse_status(&resp)?;
Ok(parse_sse_stream(resp.bytes_stream()))
}
}
fn check_sse_status(resp: &reqwest::Response) -> Result<()> {
let status = resp.status();
if !status.is_success() {
return Err(FabricError::Api {
code: status.as_u16().to_string(),
message: format!("SSE connection failed with HTTP {status}"),
});
}
Ok(())
}