use std::sync::Arc;
use futures::StreamExt;
use tokio::sync::{mpsc, oneshot};
use tokio::task::JoinHandle;
use crate::client::Inner;
use crate::error::{AkribesError, Result};
use crate::models::*;
pub struct EventSubscription {
handles: Vec<JoinHandle<()>>,
}
impl EventSubscription {
pub fn cancel(self) {
for h in &self.handles {
h.abort();
}
}
pub(crate) fn from_handle(handle: JoinHandle<()>) -> Self {
Self {
handles: vec![handle],
}
}
pub(crate) fn from_handles(handles: Vec<JoinHandle<()>>) -> Self {
Self { handles }
}
}
impl Drop for EventSubscription {
fn drop(&mut self) {
for h in &self.handles {
h.abort();
}
}
}
#[derive(Clone, Debug)]
pub struct EventsClient {
pub(crate) inner: Arc<Inner>,
pub(crate) project_id: i64,
}
impl EventsClient {
pub(crate) fn new(inner: Arc<Inner>, project_id: i64) -> Self {
Self { inner, project_id }
}
pub async fn event_stream(
&self,
script_name: Option<&str>,
) -> Result<(mpsc::UnboundedReceiver<HubEvent>, EventSubscription)> {
let base_url = self.inner.base_url.clone();
let project_id = self.project_id;
let script_name = script_name.map(|s| s.to_string());
let (tx, rx) = mpsc::unbounded_channel();
let http = self.inner.http.clone();
let token = self.inner.token.clone();
let handle = tokio::spawn(async move {
let _ = stream_sse_with_retry(http, token, base_url, project_id, script_name, tx, None)
.await;
});
Ok((
rx,
EventSubscription {
handles: vec![handle],
},
))
}
pub async fn event_stream_bounded(
&self,
script_name: Option<&str>,
buffer: usize,
) -> Result<(mpsc::Receiver<HubEvent>, EventSubscription)> {
let base_url = self.inner.base_url.clone();
let project_id = self.project_id;
let script_name = script_name.map(|s| s.to_string());
let (tx_bounded, rx_bounded) = mpsc::channel::<HubEvent>(buffer.max(1));
let (tx_inner, mut rx_inner) = mpsc::unbounded_channel::<HubEvent>();
let http = self.inner.http.clone();
let token = self.inner.token.clone();
let sse_handle = tokio::spawn(async move {
let _ = stream_sse_with_retry(
http,
token,
base_url,
project_id,
script_name,
tx_inner,
None,
)
.await;
});
let forward_handle = tokio::spawn(async move {
while let Some(evt) = rx_inner.recv().await {
if tx_bounded.send(evt).await.is_err() {
break;
}
}
});
Ok((
rx_bounded,
EventSubscription {
handles: vec![sse_handle, forward_handle],
},
))
}
pub async fn execution_stream(
&self,
script_name: &str,
) -> Result<(mpsc::UnboundedReceiver<EngineEvent>, EventSubscription)> {
let (mut hub_rx, sub) = self.event_stream(Some(script_name)).await?;
let (tx, rx) = mpsc::unbounded_channel();
let outer_handle = tokio::spawn(async move {
while let Some(evt) = hub_rx.recv().await {
if let HubEvent::Execution { event, .. } = evt {
if tx.send(event).is_err() {
break;
}
}
}
});
let combined = EventSubscription {
handles: vec![tokio::spawn(async move {
let _sub = sub;
outer_handle.await.ok();
})],
};
Ok((rx, combined))
}
pub async fn typed_execution_stream(
&self,
script_name: &str,
) -> Result<(
mpsc::UnboundedReceiver<crate::events::WorkflowEvent>,
EventSubscription,
)> {
let (mut raw_rx, sub) = self.execution_stream(script_name).await?;
let (tx, rx) = mpsc::unbounded_channel();
let outer_handle = tokio::spawn(async move {
while let Some(evt) = raw_rx.recv().await {
let typed: crate::events::WorkflowEvent = evt.into();
if tx.send(typed).is_err() {
break;
}
}
});
let combined = EventSubscription {
handles: vec![tokio::spawn(async move {
let _sub = sub;
outer_handle.await.ok();
})],
};
Ok((rx, combined))
}
pub async fn on_events<F>(
&self,
script_name: Option<&str>,
mut callback: F,
) -> Result<EventSubscription>
where
F: FnMut(HubEvent) + Send + 'static,
{
let (mut rx, sub) = self.event_stream(script_name).await?;
let handle = tokio::spawn(async move {
let _sub = sub;
while let Some(evt) = rx.recv().await {
callback(evt);
}
});
Ok(EventSubscription {
handles: vec![handle],
})
}
pub async fn on_script_execution<F>(
&self,
script_name: &str,
mut callback: F,
) -> Result<EventSubscription>
where
F: FnMut(EngineEvent) + Send + 'static,
{
let (mut rx, sub) = self.execution_stream(script_name).await?;
let handle = tokio::spawn(async move {
let _sub = sub;
while let Some(evt) = rx.recv().await {
callback(evt);
}
});
Ok(EventSubscription {
handles: vec![handle],
})
}
pub async fn on_script_change<F>(
&self,
script_name: &str,
mut callback: F,
) -> Result<EventSubscription>
where
F: FnMut(i64, Option<String>) + Send + 'static,
{
let name = script_name.to_string();
self.on_events(Some(script_name), move |hub_evt| {
if let HubEvent::Registry(RegistryEvent::ScriptUpdated {
script_name: ref evt_name,
version_id,
ref channel,
..
}) = hub_evt
{
if *evt_name == name {
callback(version_id, channel.clone());
}
}
})
.await
}
pub async fn on_script_schema_change<F>(
&self,
script_name: &str,
mut callback: F,
) -> Result<EventSubscription>
where
F: FnMut(i64, Option<String>) + Send + 'static,
{
let name = script_name.to_string();
let inner = Arc::clone(&self.inner);
self.on_events(Some(script_name), move |hub_evt| {
if let HubEvent::Registry(RegistryEvent::ScriptUpdated {
script_name: ref evt_name,
version_id,
ref channel,
..
}) = hub_evt
{
if *evt_name == name {
inner.broken_scripts.lock().unwrap().insert(name.clone());
callback(version_id, channel.clone());
}
}
})
.await
}
}
async fn build_events_url(base_url: &str, project_id: i64, script_name: Option<&str>) -> String {
let mut url = format!("{}/events?project_id={}", base_url, project_id);
if let Some(name) = script_name {
url.push_str(&format!("&script_name={}", urlencoding::encode(name)));
}
url
}
pub(crate) async fn stream_sse_with_retry(
http: reqwest::Client,
token: Arc<tokio::sync::RwLock<Option<String>>>,
base_url: String,
project_id: i64,
script_name: Option<String>,
tx: mpsc::UnboundedSender<HubEvent>,
mut ready_tx: Option<oneshot::Sender<Result<()>>>,
) -> Result<()> {
let max_retries = 5u32;
let mut attempt = 0;
let last_event_id: Arc<std::sync::Mutex<Option<i64>>> = Arc::new(std::sync::Mutex::new(None));
loop {
let url = build_events_url(&base_url, project_id, script_name.as_deref()).await;
let cursor = *last_event_id.lock().unwrap();
match stream_sse(
http.clone(),
token.clone(),
&url,
tx.clone(),
&mut ready_tx,
cursor,
Arc::clone(&last_event_id),
)
.await
{
Ok(()) => return Ok(()),
Err(e) => {
attempt += 1;
if attempt > max_retries || tx.is_closed() {
if let Some(rt) = ready_tx.take() {
let _ = rt.send(Err(AkribesError::Other(format!(
"SSE subscribe failed after {} attempts: {}",
attempt, e
))));
}
return Err(e);
}
let delay = retry_backoff(attempt);
tracing::warn!(attempt, max_retries, ?delay, "SSE disconnected, retrying");
tokio::time::sleep(delay).await;
}
}
}
}
pub(crate) async fn stream_bench_run_events(
http: reqwest::Client,
token: Arc<tokio::sync::RwLock<Option<String>>>,
base_url: String,
run_id: i64,
tx: mpsc::UnboundedSender<BenchRunEvent>,
mut ready_tx: Option<oneshot::Sender<Result<()>>>,
) -> Result<()> {
let url = format!("{}/bench-runs/{}/events", base_url, run_id);
let mut req = http.get(&url).header("Accept", "text/event-stream");
if let Some(ref t) = *token.read().await {
req = req.bearer_auth(t);
}
let res = match req.send().await.map_err(AkribesError::Http) {
Ok(r) => r,
Err(e) => {
if let Some(rt) = ready_tx.take() {
let _ = rt.send(Err(AkribesError::Other(format!(
"bench SSE subscribe failed: {e}"
))));
}
return Err(e);
}
};
if !res.status().is_success() {
let status = res.status().as_u16();
let err = AkribesError::HttpStatus {
status,
message: format!("bench SSE subscribe failed: {}", res.status()),
};
if let Some(rt) = ready_tx.take() {
let _ = rt.send(Err(AkribesError::HttpStatus {
status,
message: format!("bench SSE subscribe failed: {}", res.status()),
}));
}
return Err(err);
}
if let Some(rt) = ready_tx.take() {
let _ = rt.send(Ok(()));
}
let mut stream = res.bytes_stream();
let mut buf: Vec<u8> = Vec::new();
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(AkribesError::Http)?;
buf.extend_from_slice(&chunk);
while let Some((msg_bytes, delim_len)) = split_sse_message_bytes(&buf) {
let message = String::from_utf8_lossy(&buf[..msg_bytes]).into_owned();
buf.drain(..msg_bytes + delim_len);
let Some(frame) = parse_sse_message(&message) else {
continue;
};
match frame.event_type.as_str() {
"result" => match serde_json::from_str::<BenchResult>(&frame.data) {
Ok(r) => {
if tx.send(BenchRunEvent::Result(Box::new(r))).is_err() {
return Ok(());
}
}
Err(e) => {
tracing::warn!(error = %e, "bench SSE result parse error");
}
},
"lagged" => {
let dropped = serde_json::from_str::<serde_json::Value>(&frame.data)
.ok()
.and_then(|v| v.get("dropped").and_then(|d| d.as_u64()))
.unwrap_or(0);
if tx.send(BenchRunEvent::Lagged { dropped }).is_err() {
return Ok(());
}
}
"terminal" => {
let status = serde_json::from_str::<serde_json::Value>(&frame.data)
.ok()
.and_then(|v| {
v.get("status")
.and_then(|s| s.as_str())
.map(|s| s.to_string())
})
.unwrap_or_else(|| "unknown".to_string());
let _ = tx.send(BenchRunEvent::Terminal { status });
return Ok(());
}
other => {
tracing::warn!(event_type = other, "ignoring unknown bench SSE event type");
}
}
}
}
Ok(())
}
fn retry_backoff(attempt: u32) -> std::time::Duration {
if attempt == 0 {
return std::time::Duration::ZERO;
}
let base_ms: u64 = 1_000;
let cap_ms: u64 = 30_000;
let exponent = attempt.saturating_sub(1).min(20);
let exp_ms = base_ms.saturating_mul(1u64 << exponent).min(cap_ms);
let now_nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.subsec_nanos() as u64)
.unwrap_or(0);
let jitter_ms = if exp_ms == 0 { 0 } else { now_nanos % exp_ms };
std::time::Duration::from_millis(jitter_ms)
}
async fn stream_sse(
http: reqwest::Client,
token: Arc<tokio::sync::RwLock<Option<String>>>,
url: &str,
tx: mpsc::UnboundedSender<HubEvent>,
ready_tx: &mut Option<oneshot::Sender<Result<()>>>,
cursor: Option<i64>,
last_event_id_out: Arc<std::sync::Mutex<Option<i64>>>,
) -> Result<()> {
let mut req = http.get(url).header("Accept", "text/event-stream");
if let Some(ref t) = *token.read().await {
req = req.bearer_auth(t);
}
if let Some(seq) = cursor {
req = req.header("Last-Event-ID", seq.to_string());
}
let res = req.send().await.map_err(AkribesError::Http)?;
if !res.status().is_success() {
return Err(AkribesError::HttpStatus {
status: res.status().as_u16(),
message: format!("SSE subscribe failed: {}", res.status()),
});
}
if let Some(rt) = ready_tx.take() {
let _ = rt.send(Ok(()));
}
let mut stream = res.bytes_stream();
let mut buf: Vec<u8> = Vec::new();
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(AkribesError::Http)?;
buf.extend_from_slice(&chunk);
while let Some((msg_bytes, delim_len)) = split_sse_message_bytes(&buf) {
let message = String::from_utf8_lossy(&buf[..msg_bytes]).into_owned();
buf.drain(..msg_bytes + delim_len);
let Some(frame) = parse_sse_message(&message) else {
continue;
};
let SseFrame {
event_type,
data,
event_id,
} = frame;
if let Some(seq) = event_id {
*last_event_id_out.lock().unwrap() = Some(seq);
}
if event_type == "batch" || event_type.is_empty() {
match serde_json::from_str::<Vec<serde_json::Value>>(&data) {
Ok(raw_events) => {
for raw in raw_events {
match serde_json::from_value::<HubEvent>(raw) {
Ok(evt) => {
if tx.send(evt).is_err() {
return Ok(());
}
}
Err(e) => {
tracing::warn!(
error = %e,
"skipping unrecognised hub event in batch"
);
}
}
}
}
Err(e) => {
tracing::warn!(error = %e, "SSE JSON parse error");
}
}
} else {
tracing::warn!(event_type, "ignoring unknown SSE event type");
}
}
}
Ok(())
}
pub(crate) struct SseFrame {
pub event_type: String,
pub data: String,
pub event_id: Option<i64>,
}
pub(crate) fn parse_sse_message(message: &str) -> Option<SseFrame> {
let mut data_parts: Vec<&str> = Vec::new();
let mut event_type = String::new();
let mut event_id: Option<i64> = None;
for line in message.lines() {
if let Some(rest) = line.strip_prefix("data: ") {
data_parts.push(rest);
} else if let Some(rest) = line.strip_prefix("data:") {
data_parts.push(rest);
} else if let Some(rest) = line.strip_prefix("event: ") {
event_type = rest.to_string();
} else if let Some(rest) = line.strip_prefix("event:") {
event_type = rest.to_string();
} else if let Some(rest) = line.strip_prefix("id: ") {
event_id = rest.parse::<i64>().ok();
} else if let Some(rest) = line.strip_prefix("id:") {
event_id = rest.parse::<i64>().ok();
}
}
if data_parts.is_empty() {
return None;
}
Some(SseFrame {
event_type,
data: data_parts.join("\n"),
event_id,
})
}
pub(crate) fn split_sse_message_bytes(buf: &[u8]) -> Option<(usize, usize)> {
let mut best: Option<(usize, usize)> = None;
for delimiter in &[
b"\r\n\r\n".as_slice(),
b"\n\n".as_slice(),
b"\r\r".as_slice(),
] {
if let Some(pos) = find_bytes(buf, delimiter) {
match best {
Some((best_pos, _)) if pos >= best_pos => {}
_ => best = Some((pos, delimiter.len())),
}
}
}
best
}
fn find_bytes(haystack: &[u8], needle: &[u8]) -> Option<usize> {
if needle.is_empty() || haystack.len() < needle.len() {
return None;
}
haystack.windows(needle.len()).position(|w| w == needle)
}
#[cfg(test)]
mod sse_split_tests {
use super::split_sse_message_bytes;
#[test]
fn picks_lf_lf_when_alone() {
let buf = b"event: ping\ndata: 1\n\nrest";
let (msg_len, delim_len) = split_sse_message_bytes(buf).expect("delim found");
assert_eq!(&buf[..msg_len], b"event: ping\ndata: 1");
assert_eq!(delim_len, 2);
}
#[test]
fn picks_crlf_crlf_when_alone() {
let buf = b"event: ping\r\ndata: 1\r\n\r\nrest";
let (msg_len, delim_len) = split_sse_message_bytes(buf).expect("delim found");
assert_eq!(&buf[..msg_len], b"event: ping\r\ndata: 1");
assert_eq!(delim_len, 4);
}
#[test]
fn picks_earliest_delimiter_when_mixed() {
let buf = b"data: a\n\ndata: b\r\n\r\n";
let (msg_len, delim_len) = split_sse_message_bytes(buf).expect("delim found");
assert_eq!(&buf[..msg_len], b"data: a");
assert_eq!(delim_len, 2);
}
#[test]
fn picks_earliest_delimiter_crlf_first() {
let buf = b"data: a\r\n\r\ndata: b\n\n";
let (msg_len, delim_len) = split_sse_message_bytes(buf).expect("delim found");
assert_eq!(&buf[..msg_len], b"data: a");
assert_eq!(delim_len, 4);
}
#[test]
fn returns_none_without_delimiter() {
let buf = b"data: incomplete";
assert!(split_sse_message_bytes(buf).is_none());
}
}