use std::{pin::Pin, sync::Arc, time::Duration};
use futures_util::{Stream, StreamExt};
use reqwest::{StatusCode, header::CONTENT_TYPE};
use thiserror::Error;
use tokio::{
sync::oneshot,
task::JoinHandle,
time::{Instant, MissedTickBehavior},
};
use tracing::{debug, error, info, warn};
use crate::api_client::{
DeviceToken, FoxgloveApiClient, FoxgloveApiClientError, WatchHelloEvent, WatchQuery,
WatchWakeEvent,
};
use super::sse::{SseFrame, sse_event_stream};
const HELLO_TIMEOUT: Duration = Duration::from_secs(30);
const MAX_MISSED_HEARTBEAT_INTERVALS: u32 = 3;
const MIN_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(1);
const MIN_DEVICE_WAIT_FOR_VIEWER: Duration = Duration::from_secs(5);
#[derive(Error, Debug)]
#[non_exhaustive]
pub(super) enum WatchError {
#[error("watch stream conflict: another gateway holds the lease")]
Conflict,
#[error("device token unauthorized")]
Unauthorized,
#[error(transparent)]
Api(#[from] FoxgloveApiClientError),
#[error("stream transport error: {0}")]
Transport(#[source] reqwest::Error),
#[error("watch stream ended before `hello` event")]
UnexpectedEof,
#[error("timed out waiting for `hello` event")]
HelloTimeout,
#[error("unexpected response content-type: {content_type:?}")]
UnexpectedContentType { content_type: Option<String> },
#[error("malformed `{event}` event: {source}")]
MalformedEvent {
event: &'static str,
#[source]
source: serde_json::Error,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(super) enum HeartbeatExit {
Conflict,
Gone,
Unauthorized,
Failed,
Cancelled,
}
#[derive(Debug)]
pub(super) enum WatchOutcome {
Wake(WatchWakeEvent),
StreamEnded,
ReadTimeout,
StreamError(WatchError),
HeartbeatLost(HeartbeatExit),
}
pub(super) struct Watch {
lease_id: String,
heartbeat_interval: Duration,
device_wait_for_viewer: Duration,
events: Pin<Box<dyn Stream<Item = Result<SseFrame, reqwest::Error>> + Send>>,
heartbeat_handle: Option<JoinHandle<()>>,
heartbeat_exit: oneshot::Receiver<HeartbeatExit>,
}
impl Watch {
pub async fn connect(
client: Arc<FoxgloveApiClient<DeviceToken>>,
query: WatchQuery,
) -> Result<Self, WatchError> {
Watch::connect_inner(client, query)
.await
.inspect_err(|e| match e {
WatchError::UnexpectedContentType { content_type } => info!(
content_type = content_type.as_deref(),
"watch endpoint returned non-SSE response; backing off"
),
_ => warn!(error=%e, "watch stream connect failed"),
})
}
async fn connect_inner(
client: Arc<FoxgloveApiClient<DeviceToken>>,
query: WatchQuery,
) -> Result<Self, WatchError> {
let hello_deadline = Instant::now() + HELLO_TIMEOUT;
let Ok(response) =
tokio::time::timeout_at(hello_deadline, client.open_watch_stream(&query)).await
else {
return Err(WatchError::HelloTimeout);
};
let response = match response {
Ok(response) => response,
Err(e) => {
if e.status_code() == Some(StatusCode::CONFLICT) {
return Err(WatchError::Conflict);
}
if e.status_code() == Some(StatusCode::UNAUTHORIZED) {
return Err(WatchError::Unauthorized);
}
return Err(WatchError::Api(e));
}
};
let content_type = response
.headers()
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(|v| {
v.split_once(';')
.map_or(v, |(a, _)| a)
.trim()
.to_ascii_lowercase()
});
if content_type.as_deref() != Some("text/event-stream") {
return Err(WatchError::UnexpectedContentType { content_type });
}
let mut events = sse_event_stream(response.bytes_stream());
let hello = loop {
let frame = match tokio::time::timeout_at(hello_deadline, events.next()).await {
Err(_) => return Err(WatchError::HelloTimeout),
Ok(None) => return Err(WatchError::UnexpectedEof),
Ok(Some(Err(e))) => return Err(WatchError::Transport(e)),
Ok(Some(Ok(frame))) => frame,
};
let event = match frame {
SseFrame::Comment => continue,
SseFrame::Event(event) => event,
};
if event.event == "hello" {
let hello: WatchHelloEvent =
serde_json::from_str(&event.data).map_err(|e| WatchError::MalformedEvent {
event: "hello",
source: e,
})?;
break hello;
}
debug!(event = %event.event, "ignoring unexpected event before hello");
};
let heartbeat_interval =
Duration::from_millis(hello.heartbeat_interval_ms).max(MIN_HEARTBEAT_INTERVAL);
let device_wait_for_viewer =
Duration::from_millis(hello.device_wait_for_viewer_ms).max(MIN_DEVICE_WAIT_FOR_VIEWER);
info!(
watch_lease_id = &hello.watch_lease_id,
?heartbeat_interval,
?device_wait_for_viewer,
"watch stream established"
);
let (exit_tx, exit_rx) = oneshot::channel::<HeartbeatExit>();
let heartbeat_handle = tokio::spawn(heartbeat_task(
client.clone(),
hello.watch_lease_id.clone(),
heartbeat_interval,
exit_tx,
));
Ok(Self {
lease_id: hello.watch_lease_id,
heartbeat_interval,
device_wait_for_viewer,
events,
heartbeat_handle: Some(heartbeat_handle),
heartbeat_exit: exit_rx,
})
}
pub fn lease_id(&self) -> &str {
&self.lease_id
}
pub fn device_wait_for_viewer(&self) -> Duration {
self.device_wait_for_viewer
}
pub fn heartbeat_interval(&self) -> Duration {
self.heartbeat_interval
}
pub async fn run(mut self) -> (WatchOutcome, Duration) {
let started_at = Instant::now();
let outcome = self.run_inner().await;
let duration = started_at.elapsed();
log_watch_outcome(&outcome, &self.lease_id);
self.close().await;
(outcome, duration)
}
async fn run_inner(&mut self) -> WatchOutcome {
let read_timeout = self
.heartbeat_interval
.saturating_mul(2)
.saturating_add(self.heartbeat_interval / 2);
let events = &mut self.events;
let heartbeat_exit = &mut self.heartbeat_exit;
loop {
tokio::select! {
biased;
hb_exit = &mut *heartbeat_exit => {
return match hb_exit {
Ok(reason) => WatchOutcome::HeartbeatLost(reason),
Err(_) => WatchOutcome::HeartbeatLost(HeartbeatExit::Cancelled),
};
},
ev = tokio::time::timeout(read_timeout, events.next()) => match ev {
Err(_) => return WatchOutcome::ReadTimeout,
Ok(None) => return WatchOutcome::StreamEnded,
Ok(Some(Err(e))) => {
return WatchOutcome::StreamError(WatchError::Transport(e));
}
Ok(Some(Ok(SseFrame::Comment))) => {
continue;
}
Ok(Some(Ok(SseFrame::Event(event)))) => match event.event.as_str() {
"wake" => match serde_json::from_str::<WatchWakeEvent>(&event.data) {
Ok(wake) => return WatchOutcome::Wake(wake),
Err(e) => return WatchOutcome::StreamError(WatchError::MalformedEvent {
event: "wake",
source: e,
}),
},
"hello" => {
warn!("received unexpected `hello` event on open stream; ignoring");
continue;
}
other => {
debug!(event = %other, "ignoring unknown SSE event");
continue;
}
},
},
};
}
}
async fn close(mut self) {
if let Some(handle) = self.heartbeat_handle.take() {
handle.abort();
let _ = handle.await;
}
}
}
impl Drop for Watch {
fn drop(&mut self) {
if let Some(handle) = self.heartbeat_handle.take() {
handle.abort();
}
}
}
async fn heartbeat_task(
client: Arc<FoxgloveApiClient<DeviceToken>>,
lease_id: String,
interval: Duration,
exit_tx: oneshot::Sender<HeartbeatExit>,
) {
let max_heartbeat_silence = interval.saturating_mul(MAX_MISSED_HEARTBEAT_INTERVALS);
let mut last_success = Instant::now();
let mut ticker = tokio::time::interval_at(Instant::now() + interval, interval);
ticker.set_missed_tick_behavior(MissedTickBehavior::Skip);
loop {
ticker.tick().await;
let result = client.post_watch_heartbeat(&lease_id).await;
match result {
Ok(()) => {
last_success = Instant::now();
}
Err(e) => match e.status_code() {
Some(StatusCode::CONFLICT) => {
let _ = exit_tx.send(HeartbeatExit::Conflict);
return;
}
Some(StatusCode::GONE) => {
let _ = exit_tx.send(HeartbeatExit::Gone);
return;
}
Some(StatusCode::UNAUTHORIZED) => {
let _ = exit_tx.send(HeartbeatExit::Unauthorized);
return;
}
_ => {
if last_success.elapsed() >= max_heartbeat_silence {
warn!(
error = %e,
stale_for_ms = last_success.elapsed().as_millis(),
"heartbeat failed for too long"
);
let _ = exit_tx.send(HeartbeatExit::Failed);
return;
}
debug!(error = %e, "heartbeat failed; will retry");
}
},
}
}
}
fn log_watch_outcome(outcome: &WatchOutcome, watch_lease_id: &str) {
match outcome {
WatchOutcome::Wake(wake) => info!(
watch_lease_id,
remote_access_session_id = wake.remote_access_session_id.as_deref(),
"received wake"
),
WatchOutcome::ReadTimeout => {
warn!(watch_lease_id, "watch stream read-timeout; reconnecting")
}
WatchOutcome::StreamEnded => warn!(
watch_lease_id,
"watch stream ended before wake; reconnecting"
),
WatchOutcome::StreamError(e) => warn!(
watch_lease_id,
error = %e,
"watch stream error; reconnecting"
),
WatchOutcome::HeartbeatLost(reason) => match reason {
HeartbeatExit::Conflict => warn!(
watch_lease_id,
"another gateway holds the watch lease; backing off"
),
HeartbeatExit::Gone => warn!(
watch_lease_id,
"watch lease gone; reconnecting to acquire a fresh lease"
),
HeartbeatExit::Unauthorized => error!(
watch_lease_id,
"device token unauthorized; stopping remote access gateway"
),
HeartbeatExit::Failed => warn!(
watch_lease_id,
"watch heartbeat failed for too long; reconnecting"
),
HeartbeatExit::Cancelled => error!(
watch_lease_id,
"heartbeat task exited without a terminal reason; check for panics"
),
},
}
}
#[cfg(test)]
mod tests {
use std::future;
use std::task::{Context, Poll};
use assert_matches::assert_matches;
use axum::extract::State;
use axum::http::{StatusCode as AxumStatusCode, header};
use axum::response::IntoResponse;
use axum::{Router, routing::get};
use futures_util::stream;
use tokio::net::TcpListener;
use crate::api_client::{FoxgloveApiClientBuilder, WatchQuery};
use crate::remote_access::sse::{SseEvent, SseFrame};
use super::*;
const TEST_TOKEN: &str = "fox_dt_testtoken";
#[derive(Clone)]
struct WatchResponse {
status: AxumStatusCode,
body: &'static str,
content_type: &'static str,
}
impl WatchResponse {
fn ok_sse(body: &'static str) -> Self {
Self {
status: AxumStatusCode::OK,
body,
content_type: "text/event-stream",
}
}
}
struct WatchServer {
url: String,
join_handle: JoinHandle<()>,
}
impl WatchServer {
fn url(&self) -> &str {
&self.url
}
}
impl Drop for WatchServer {
fn drop(&mut self) {
self.join_handle.abort();
}
}
struct NotifyOnDrop(Option<oneshot::Sender<()>>);
impl Drop for NotifyOnDrop {
fn drop(&mut self) {
if let Some(tx) = self.0.take() {
let _ = tx.send(());
}
}
}
impl future::Future for NotifyOnDrop {
type Output = ();
fn poll(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
Poll::Pending
}
}
async fn watch_handler(State(response): State<WatchResponse>) -> impl IntoResponse {
(
response.status,
[(header::CONTENT_TYPE, response.content_type)],
response.body,
)
}
async fn watch_server(response: WatchResponse) -> WatchServer {
let app = Router::new()
.route(
"/internal/platform/v1/remote-sessions/watch",
get(watch_handler),
)
.with_state(response);
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let join_handle = tokio::spawn(async move { axum::serve(listener, app).await.unwrap() });
WatchServer {
url: format!("http://{addr}"),
join_handle,
}
}
fn test_client(server: &WatchServer) -> Arc<FoxgloveApiClient<DeviceToken>> {
Arc::new(
FoxgloveApiClientBuilder::new(DeviceToken::new(TEST_TOKEN))
.base_url(server.url())
.timeout(Duration::from_millis(100))
.build()
.unwrap(),
)
}
fn event(event: &str, data: &str) -> SseFrame {
SseFrame::Event(SseEvent {
event: event.to_string(),
data: data.to_string(),
})
}
fn watch_from_stream<S>(
events: S,
heartbeat_interval_ms: u64,
) -> (Watch, oneshot::Sender<HeartbeatExit>)
where
S: Stream<Item = Result<SseFrame, reqwest::Error>> + Send + 'static,
{
let (exit_tx, exit_rx) = oneshot::channel();
let watch = Watch {
lease_id: "lease-1".to_string(),
device_wait_for_viewer: Duration::from_secs(45),
heartbeat_interval: Duration::from_millis(heartbeat_interval_ms),
events: Box::pin(events),
heartbeat_handle: None,
heartbeat_exit: exit_rx,
};
(watch, exit_tx)
}
fn watch_from_frames(
frames: Vec<SseFrame>,
heartbeat_interval_ms: u64,
) -> (Watch, oneshot::Sender<HeartbeatExit>) {
watch_from_stream(
stream::iter(frames.into_iter().map(Ok::<SseFrame, reqwest::Error>)),
heartbeat_interval_ms,
)
}
#[tokio::test]
async fn connect_reads_hello_after_non_hello_frames() {
let server = watch_server(WatchResponse::ok_sse(concat!(
": keepalive\n",
"event: ignored\n",
"data: {}\n\n",
"event: hello\n",
"data: {\"watchLeaseId\":\"lease-1\",\"deviceWaitForViewerMs\":45000,\"heartbeatIntervalMs\":60000}\n\n",
)))
.await;
let client = test_client(&server);
let watch = Watch::connect(client, WatchQuery::default())
.await
.expect("watch should connect");
assert_eq!(watch.lease_id(), "lease-1");
assert_eq!(watch.device_wait_for_viewer, Duration::from_secs(45));
assert_eq!(watch.heartbeat_interval, Duration::from_secs(60));
watch.close().await;
}
#[tokio::test]
async fn connect_maps_conflict_status() {
let server = watch_server(WatchResponse {
status: AxumStatusCode::CONFLICT,
body: "{\"error\":\"lease already held\"}",
content_type: "application/json",
})
.await;
let client = test_client(&server);
let result = Watch::connect(client, WatchQuery::default()).await;
assert!(matches!(result, Err(WatchError::Conflict)));
}
#[tokio::test]
async fn connect_maps_unauthorized_status() {
let server = watch_server(WatchResponse {
status: AxumStatusCode::UNAUTHORIZED,
body: "{\"error\":\"invalid device token\"}",
content_type: "application/json",
})
.await;
let client = test_client(&server);
let result = Watch::connect(client, WatchQuery::default()).await;
assert!(matches!(result, Err(WatchError::Unauthorized)));
}
#[tokio::test]
async fn connect_rejects_non_sse_content_type() {
let server = watch_server(WatchResponse {
status: AxumStatusCode::OK,
body: "<html><body>Down for maintenance</body></html>",
content_type: "text/html; charset=utf-8",
})
.await;
let client = test_client(&server);
let result = Watch::connect(client, WatchQuery::default()).await;
let Err(WatchError::UnexpectedContentType { content_type }) = result else {
panic!("expected UnexpectedContentType");
};
assert_eq!(content_type.as_deref(), Some("text/html"));
}
#[tokio::test]
async fn connect_clamps_durations() {
let server = watch_server(WatchResponse::ok_sse(concat!(
"event: hello\n",
"data: {\"watchLeaseId\":\"lease-1\",\"deviceWaitForViewerMs\":1,\"heartbeatIntervalMs\":2}\n\n",
)))
.await;
let client = test_client(&server);
let watch = Watch::connect(client, WatchQuery::default())
.await
.expect("watch should connect");
assert_eq!(watch.device_wait_for_viewer, MIN_DEVICE_WAIT_FOR_VIEWER);
assert_eq!(watch.heartbeat_interval, MIN_HEARTBEAT_INTERVAL);
watch.close().await;
}
#[tokio::test]
async fn watch_ignores_non_terminal_frames_until_wake() {
let wake_json = "{\"remoteAccessSessionId\":\"ras-1\",\"url\":\"wss://example.test\",\"token\":\"token-1\"}";
let (watch, _exit_tx) = watch_from_frames(
vec![
SseFrame::Comment,
event("hello", "{}"),
event("ignored", "{}"),
event("wake", wake_json),
],
60_000,
);
let (outcome, _duration) = watch.run().await;
let WatchOutcome::Wake(wake) = outcome else {
panic!("expected wake outcome");
};
assert_eq!(wake.remote_access_session_id.as_deref(), Some("ras-1"));
assert_eq!(wake.url, "wss://example.test");
assert_eq!(wake.token, "token-1");
}
#[tokio::test]
async fn watch_reports_malformed_wake() {
let (watch, _exit_tx) = watch_from_frames(vec![event("wake", "{")], 60_000);
let (outcome, _duration) = watch.run().await;
assert_matches!(
outcome,
WatchOutcome::StreamError(WatchError::MalformedEvent { event: "wake", .. })
);
}
#[tokio::test]
async fn watch_reports_stream_end() {
let (watch, _exit_tx) = watch_from_frames(Vec::new(), 60_000);
let (outcome, _duration) = watch.run().await;
assert_matches!(outcome, WatchOutcome::StreamEnded);
}
#[tokio::test]
async fn watch_reports_read_timeout() {
let (watch, _exit_tx) =
watch_from_stream(stream::pending::<Result<SseFrame, reqwest::Error>>(), 1);
let (outcome, _duration) = tokio::time::timeout(Duration::from_secs(1), watch.run())
.await
.expect("watch read timeout should fire");
assert_matches!(outcome, WatchOutcome::ReadTimeout);
}
#[tokio::test]
async fn watch_prefers_heartbeat_exit() {
let wake_json = "{\"url\":\"wss://example.test\",\"token\":\"token-1\"}";
let (watch, exit_tx) = watch_from_frames(vec![event("wake", wake_json)], 60_000);
exit_tx.send(HeartbeatExit::Gone).unwrap();
let (outcome, _duration) = watch.run().await;
assert_matches!(outcome, WatchOutcome::HeartbeatLost(HeartbeatExit::Gone));
}
#[tokio::test]
async fn close_aborts_heartbeat_task() {
let (dropped_tx, dropped_rx) = oneshot::channel();
let heartbeat_handle = tokio::spawn(NotifyOnDrop(Some(dropped_tx)));
let (_exit_tx, exit_rx) = oneshot::channel();
let watch = Watch {
lease_id: "lease-1".to_string(),
device_wait_for_viewer: Duration::from_secs(45),
heartbeat_interval: Duration::from_secs(60),
events: Box::pin(stream::pending::<Result<SseFrame, reqwest::Error>>()),
heartbeat_handle: Some(heartbeat_handle),
heartbeat_exit: exit_rx,
};
watch.close().await;
dropped_rx
.await
.expect("heartbeat future should be dropped after close");
}
}