1use std::sync::Arc;
2
3use futures::StreamExt;
4use tokio::sync::{mpsc, oneshot};
5use tokio::task::JoinHandle;
6
7use crate::client::Inner;
8use crate::error::{AkribesError, Result};
9use crate::models::*;
10
11pub struct EventSubscription {
14 handles: Vec<JoinHandle<()>>,
15}
16
17impl EventSubscription {
18 pub fn cancel(self) {
20 for h in &self.handles {
21 h.abort();
22 }
23 }
24
25 pub(crate) fn from_handle(handle: JoinHandle<()>) -> Self {
26 Self {
27 handles: vec![handle],
28 }
29 }
30
31 pub(crate) fn from_handles(handles: Vec<JoinHandle<()>>) -> Self {
32 Self { handles }
33 }
34}
35
36impl Drop for EventSubscription {
37 fn drop(&mut self) {
38 for h in &self.handles {
39 h.abort();
40 }
41 }
42}
43
44#[derive(Clone, Debug)]
46pub struct EventsClient {
47 pub(crate) inner: Arc<Inner>,
48 pub(crate) project_id: i64,
49}
50
51impl EventsClient {
52 pub(crate) fn new(inner: Arc<Inner>, project_id: i64) -> Self {
53 Self { inner, project_id }
54 }
55
56 pub async fn event_stream(
67 &self,
68 script_name: Option<&str>,
69 ) -> Result<(mpsc::UnboundedReceiver<HubEvent>, EventSubscription)> {
70 let base_url = self.inner.base_url.clone();
71 let project_id = self.project_id;
72 let script_name = script_name.map(|s| s.to_string());
73 let (tx, rx) = mpsc::unbounded_channel();
74 let http = self.inner.http.clone();
75 let token = self.inner.token.clone();
76
77 let handle = tokio::spawn(async move {
78 let _ = stream_sse_with_retry(http, token, base_url, project_id, script_name, tx, None)
79 .await;
80 });
81
82 Ok((
83 rx,
84 EventSubscription {
85 handles: vec![handle],
86 },
87 ))
88 }
89
90 pub async fn event_stream_bounded(
105 &self,
106 script_name: Option<&str>,
107 buffer: usize,
108 ) -> Result<(mpsc::Receiver<HubEvent>, EventSubscription)> {
109 let base_url = self.inner.base_url.clone();
110 let project_id = self.project_id;
111 let script_name = script_name.map(|s| s.to_string());
112 let (tx_bounded, rx_bounded) = mpsc::channel::<HubEvent>(buffer.max(1));
113 let (tx_inner, mut rx_inner) = mpsc::unbounded_channel::<HubEvent>();
118 let http = self.inner.http.clone();
119 let token = self.inner.token.clone();
120
121 let sse_handle = tokio::spawn(async move {
122 let _ = stream_sse_with_retry(
123 http,
124 token,
125 base_url,
126 project_id,
127 script_name,
128 tx_inner,
129 None,
130 )
131 .await;
132 });
133 let forward_handle = tokio::spawn(async move {
134 while let Some(evt) = rx_inner.recv().await {
135 if tx_bounded.send(evt).await.is_err() {
136 break;
137 }
138 }
139 });
140
141 Ok((
142 rx_bounded,
143 EventSubscription {
144 handles: vec![sse_handle, forward_handle],
145 },
146 ))
147 }
148
149 pub async fn execution_stream(
151 &self,
152 script_name: &str,
153 ) -> Result<(mpsc::UnboundedReceiver<EngineEvent>, EventSubscription)> {
154 let (mut hub_rx, sub) = self.event_stream(Some(script_name)).await?;
155 let (tx, rx) = mpsc::unbounded_channel();
156
157 let outer_handle = tokio::spawn(async move {
158 while let Some(evt) = hub_rx.recv().await {
159 if let HubEvent::Execution { event, .. } = evt {
160 if tx.send(event).is_err() {
161 break;
162 }
163 }
164 }
165 });
166
167 let combined = EventSubscription {
168 handles: vec![tokio::spawn(async move {
169 let _sub = sub;
170 outer_handle.await.ok();
171 })],
172 };
173
174 Ok((rx, combined))
175 }
176
177 pub async fn typed_execution_stream(
188 &self,
189 script_name: &str,
190 ) -> Result<(
191 mpsc::UnboundedReceiver<crate::events::WorkflowEvent>,
192 EventSubscription,
193 )> {
194 let (mut raw_rx, sub) = self.execution_stream(script_name).await?;
195 let (tx, rx) = mpsc::unbounded_channel();
196 let outer_handle = tokio::spawn(async move {
197 while let Some(evt) = raw_rx.recv().await {
198 let typed: crate::events::WorkflowEvent = evt.into();
199 if tx.send(typed).is_err() {
200 break;
201 }
202 }
203 });
204 let combined = EventSubscription {
205 handles: vec![tokio::spawn(async move {
206 let _sub = sub;
207 outer_handle.await.ok();
208 })],
209 };
210 Ok((rx, combined))
211 }
212
213 pub async fn on_events<F>(
215 &self,
216 script_name: Option<&str>,
217 mut callback: F,
218 ) -> Result<EventSubscription>
219 where
220 F: FnMut(HubEvent) + Send + 'static,
221 {
222 let (mut rx, sub) = self.event_stream(script_name).await?;
223 let handle = tokio::spawn(async move {
224 let _sub = sub;
225 while let Some(evt) = rx.recv().await {
226 callback(evt);
227 }
228 });
229 Ok(EventSubscription {
230 handles: vec![handle],
231 })
232 }
233
234 pub async fn on_script_execution<F>(
236 &self,
237 script_name: &str,
238 mut callback: F,
239 ) -> Result<EventSubscription>
240 where
241 F: FnMut(EngineEvent) + Send + 'static,
242 {
243 let (mut rx, sub) = self.execution_stream(script_name).await?;
244 let handle = tokio::spawn(async move {
245 let _sub = sub;
246 while let Some(evt) = rx.recv().await {
247 callback(evt);
248 }
249 });
250 Ok(EventSubscription {
251 handles: vec![handle],
252 })
253 }
254
255 pub async fn on_script_change<F>(
257 &self,
258 script_name: &str,
259 mut callback: F,
260 ) -> Result<EventSubscription>
261 where
262 F: FnMut(i64, Option<String>) + Send + 'static,
263 {
264 let name = script_name.to_string();
265 self.on_events(Some(script_name), move |hub_evt| {
266 if let HubEvent::Registry(RegistryEvent::ScriptUpdated {
267 script_name: ref evt_name,
268 version_id,
269 ref channel,
270 ..
271 }) = hub_evt
272 {
273 if *evt_name == name {
274 callback(version_id, channel.clone());
275 }
276 }
277 })
278 .await
279 }
280
281 pub async fn on_script_schema_change<F>(
285 &self,
286 script_name: &str,
287 mut callback: F,
288 ) -> Result<EventSubscription>
289 where
290 F: FnMut(i64, Option<String>) + Send + 'static,
291 {
292 let name = script_name.to_string();
293 let inner = Arc::clone(&self.inner);
294 self.on_events(Some(script_name), move |hub_evt| {
295 if let HubEvent::Registry(RegistryEvent::ScriptUpdated {
296 script_name: ref evt_name,
297 version_id,
298 ref channel,
299 ..
300 }) = hub_evt
301 {
302 if *evt_name == name {
303 inner.broken_scripts.lock().unwrap().insert(name.clone());
304 callback(version_id, channel.clone());
305 }
306 }
307 })
308 .await
309 }
310}
311
312async fn build_events_url(base_url: &str, project_id: i64, script_name: Option<&str>) -> String {
319 let mut url = format!("{}/events?project_id={}", base_url, project_id);
320 if let Some(name) = script_name {
321 url.push_str(&format!("&script_name={}", urlencoding::encode(name)));
322 }
323 url
324}
325
326pub(crate) async fn stream_sse_with_retry(
342 http: reqwest::Client,
343 token: Arc<tokio::sync::RwLock<Option<String>>>,
344 base_url: String,
345 project_id: i64,
346 script_name: Option<String>,
347 tx: mpsc::UnboundedSender<HubEvent>,
348 mut ready_tx: Option<oneshot::Sender<Result<()>>>,
349) -> Result<()> {
350 let max_retries = 5u32;
351 let mut attempt = 0;
352 let last_event_id: Arc<std::sync::Mutex<Option<i64>>> = Arc::new(std::sync::Mutex::new(None));
356 loop {
357 let url = build_events_url(&base_url, project_id, script_name.as_deref()).await;
358 let cursor = *last_event_id.lock().unwrap();
359 match stream_sse(
360 http.clone(),
361 token.clone(),
362 &url,
363 tx.clone(),
364 &mut ready_tx,
365 cursor,
366 Arc::clone(&last_event_id),
367 )
368 .await
369 {
370 Ok(()) => return Ok(()),
371 Err(e) => {
372 attempt += 1;
373 if attempt > max_retries || tx.is_closed() {
374 if let Some(rt) = ready_tx.take() {
375 let _ = rt.send(Err(AkribesError::Other(format!(
376 "SSE subscribe failed after {} attempts: {}",
377 attempt, e
378 ))));
379 }
380 return Err(e);
381 }
382 let delay = retry_backoff(attempt);
383 tracing::warn!(attempt, max_retries, ?delay, "SSE disconnected, retrying");
384 tokio::time::sleep(delay).await;
385 }
386 }
387 }
388}
389
390pub(crate) async fn stream_bench_run_events(
404 http: reqwest::Client,
405 token: Arc<tokio::sync::RwLock<Option<String>>>,
406 base_url: String,
407 run_id: i64,
408 tx: mpsc::UnboundedSender<BenchRunEvent>,
409 mut ready_tx: Option<oneshot::Sender<Result<()>>>,
410) -> Result<()> {
411 let url = format!("{}/bench-runs/{}/events", base_url, run_id);
412 let mut req = http.get(&url).header("Accept", "text/event-stream");
413 if let Some(ref t) = *token.read().await {
414 req = req.bearer_auth(t);
415 }
416 let res = match req.send().await.map_err(AkribesError::Http) {
417 Ok(r) => r,
418 Err(e) => {
419 if let Some(rt) = ready_tx.take() {
420 let _ = rt.send(Err(AkribesError::Other(format!(
421 "bench SSE subscribe failed: {e}"
422 ))));
423 }
424 return Err(e);
425 }
426 };
427 if !res.status().is_success() {
428 let status = res.status().as_u16();
429 let err = AkribesError::HttpStatus {
430 status,
431 message: format!("bench SSE subscribe failed: {}", res.status()),
432 };
433 if let Some(rt) = ready_tx.take() {
434 let _ = rt.send(Err(AkribesError::HttpStatus {
435 status,
436 message: format!("bench SSE subscribe failed: {}", res.status()),
437 }));
438 }
439 return Err(err);
440 }
441 if let Some(rt) = ready_tx.take() {
442 let _ = rt.send(Ok(()));
443 }
444
445 let mut stream = res.bytes_stream();
446 let mut buf: Vec<u8> = Vec::new();
447 while let Some(chunk) = stream.next().await {
448 let chunk = chunk.map_err(AkribesError::Http)?;
449 buf.extend_from_slice(&chunk);
450 while let Some((msg_bytes, delim_len)) = split_sse_message_bytes(&buf) {
451 let message = String::from_utf8_lossy(&buf[..msg_bytes]).into_owned();
452 buf.drain(..msg_bytes + delim_len);
453 let Some(frame) = parse_sse_message(&message) else {
454 continue;
455 };
456 match frame.event_type.as_str() {
457 "result" => match serde_json::from_str::<BenchResult>(&frame.data) {
458 Ok(r) => {
459 if tx.send(BenchRunEvent::Result(Box::new(r))).is_err() {
460 return Ok(());
461 }
462 }
463 Err(e) => {
464 tracing::warn!(error = %e, "bench SSE result parse error");
465 }
466 },
467 "lagged" => {
468 let dropped = serde_json::from_str::<serde_json::Value>(&frame.data)
469 .ok()
470 .and_then(|v| v.get("dropped").and_then(|d| d.as_u64()))
471 .unwrap_or(0);
472 if tx.send(BenchRunEvent::Lagged { dropped }).is_err() {
473 return Ok(());
474 }
475 }
476 "terminal" => {
477 let status = serde_json::from_str::<serde_json::Value>(&frame.data)
478 .ok()
479 .and_then(|v| {
480 v.get("status")
481 .and_then(|s| s.as_str())
482 .map(|s| s.to_string())
483 })
484 .unwrap_or_else(|| "unknown".to_string());
485 let _ = tx.send(BenchRunEvent::Terminal { status });
487 return Ok(());
488 }
489 other => {
490 tracing::warn!(event_type = other, "ignoring unknown bench SSE event type");
491 }
492 }
493 }
494 }
495 Ok(())
496}
497
498fn retry_backoff(attempt: u32) -> std::time::Duration {
501 if attempt == 0 {
502 return std::time::Duration::ZERO;
503 }
504 let base_ms: u64 = 1_000;
505 let cap_ms: u64 = 30_000;
506 let exponent = attempt.saturating_sub(1).min(20);
507 let exp_ms = base_ms.saturating_mul(1u64 << exponent).min(cap_ms);
508 let now_nanos = std::time::SystemTime::now()
509 .duration_since(std::time::UNIX_EPOCH)
510 .map(|d| d.subsec_nanos() as u64)
511 .unwrap_or(0);
512 let jitter_ms = if exp_ms == 0 { 0 } else { now_nanos % exp_ms };
513 std::time::Duration::from_millis(jitter_ms)
514}
515
516async fn stream_sse(
528 http: reqwest::Client,
529 token: Arc<tokio::sync::RwLock<Option<String>>>,
530 url: &str,
531 tx: mpsc::UnboundedSender<HubEvent>,
532 ready_tx: &mut Option<oneshot::Sender<Result<()>>>,
533 cursor: Option<i64>,
534 last_event_id_out: Arc<std::sync::Mutex<Option<i64>>>,
535) -> Result<()> {
536 let mut req = http.get(url).header("Accept", "text/event-stream");
537 if let Some(ref t) = *token.read().await {
538 req = req.bearer_auth(t);
539 }
540 if let Some(seq) = cursor {
541 req = req.header("Last-Event-ID", seq.to_string());
542 }
543 let res = req.send().await.map_err(AkribesError::Http)?;
544 if !res.status().is_success() {
545 return Err(AkribesError::HttpStatus {
546 status: res.status().as_u16(),
547 message: format!("SSE subscribe failed: {}", res.status()),
548 });
549 }
550 if let Some(rt) = ready_tx.take() {
551 let _ = rt.send(Ok(()));
552 }
553 let mut stream = res.bytes_stream();
554 let mut buf: Vec<u8> = Vec::new();
558
559 while let Some(chunk) = stream.next().await {
560 let chunk = chunk.map_err(AkribesError::Http)?;
561 buf.extend_from_slice(&chunk);
562
563 while let Some((msg_bytes, delim_len)) = split_sse_message_bytes(&buf) {
565 let message = String::from_utf8_lossy(&buf[..msg_bytes]).into_owned();
569 buf.drain(..msg_bytes + delim_len);
570
571 let Some(frame) = parse_sse_message(&message) else {
572 continue;
573 };
574 let SseFrame {
575 event_type,
576 data,
577 event_id,
578 } = frame;
579 if let Some(seq) = event_id {
582 *last_event_id_out.lock().unwrap() = Some(seq);
583 }
584
585 if event_type == "batch" || event_type.is_empty() {
586 match serde_json::from_str::<Vec<serde_json::Value>>(&data) {
595 Ok(raw_events) => {
596 for raw in raw_events {
597 match serde_json::from_value::<HubEvent>(raw) {
598 Ok(evt) => {
599 if tx.send(evt).is_err() {
600 return Ok(());
601 }
602 }
603 Err(e) => {
604 tracing::warn!(
605 error = %e,
606 "skipping unrecognised hub event in batch"
607 );
608 }
609 }
610 }
611 }
612 Err(e) => {
613 tracing::warn!(error = %e, "SSE JSON parse error");
614 }
615 }
616 } else {
617 tracing::warn!(event_type, "ignoring unknown SSE event type");
618 }
619 }
620 }
621
622 Ok(())
623}
624
625pub(crate) struct SseFrame {
631 pub event_type: String,
632 pub data: String,
633 pub event_id: Option<i64>,
634}
635
636pub(crate) fn parse_sse_message(message: &str) -> Option<SseFrame> {
642 let mut data_parts: Vec<&str> = Vec::new();
643 let mut event_type = String::new();
644 let mut event_id: Option<i64> = None;
645 for line in message.lines() {
646 if let Some(rest) = line.strip_prefix("data: ") {
647 data_parts.push(rest);
648 } else if let Some(rest) = line.strip_prefix("data:") {
649 data_parts.push(rest);
650 } else if let Some(rest) = line.strip_prefix("event: ") {
651 event_type = rest.to_string();
652 } else if let Some(rest) = line.strip_prefix("event:") {
653 event_type = rest.to_string();
654 } else if let Some(rest) = line.strip_prefix("id: ") {
655 event_id = rest.parse::<i64>().ok();
656 } else if let Some(rest) = line.strip_prefix("id:") {
657 event_id = rest.parse::<i64>().ok();
658 }
659 }
660 if data_parts.is_empty() {
661 return None;
662 }
663 Some(SseFrame {
664 event_type,
665 data: data_parts.join("\n"),
666 event_id,
667 })
668}
669
670pub(crate) fn split_sse_message_bytes(buf: &[u8]) -> Option<(usize, usize)> {
682 let mut best: Option<(usize, usize)> = None;
683 for delimiter in &[
684 b"\r\n\r\n".as_slice(),
685 b"\n\n".as_slice(),
686 b"\r\r".as_slice(),
687 ] {
688 if let Some(pos) = find_bytes(buf, delimiter) {
689 match best {
690 Some((best_pos, _)) if pos >= best_pos => {}
691 _ => best = Some((pos, delimiter.len())),
692 }
693 }
694 }
695 best
696}
697
698fn find_bytes(haystack: &[u8], needle: &[u8]) -> Option<usize> {
699 if needle.is_empty() || haystack.len() < needle.len() {
700 return None;
701 }
702 haystack.windows(needle.len()).position(|w| w == needle)
703}
704
705#[cfg(test)]
706mod sse_split_tests {
707 use super::split_sse_message_bytes;
708
709 #[test]
710 fn picks_lf_lf_when_alone() {
711 let buf = b"event: ping\ndata: 1\n\nrest";
712 let (msg_len, delim_len) = split_sse_message_bytes(buf).expect("delim found");
713 assert_eq!(&buf[..msg_len], b"event: ping\ndata: 1");
714 assert_eq!(delim_len, 2);
715 }
716
717 #[test]
718 fn picks_crlf_crlf_when_alone() {
719 let buf = b"event: ping\r\ndata: 1\r\n\r\nrest";
720 let (msg_len, delim_len) = split_sse_message_bytes(buf).expect("delim found");
721 assert_eq!(&buf[..msg_len], b"event: ping\r\ndata: 1");
722 assert_eq!(delim_len, 4);
723 }
724
725 #[test]
726 fn picks_earliest_delimiter_when_mixed() {
727 let buf = b"data: a\n\ndata: b\r\n\r\n";
732 let (msg_len, delim_len) = split_sse_message_bytes(buf).expect("delim found");
733 assert_eq!(&buf[..msg_len], b"data: a");
734 assert_eq!(delim_len, 2);
735 }
736
737 #[test]
738 fn picks_earliest_delimiter_crlf_first() {
739 let buf = b"data: a\r\n\r\ndata: b\n\n";
741 let (msg_len, delim_len) = split_sse_message_bytes(buf).expect("delim found");
742 assert_eq!(&buf[..msg_len], b"data: a");
743 assert_eq!(delim_len, 4);
744 }
745
746 #[test]
747 fn returns_none_without_delimiter() {
748 let buf = b"data: incomplete";
749 assert!(split_sse_message_bytes(buf).is_none());
750 }
751}