Skip to main content

rmux_sdk/
wait.rs

1//! Daemon-backed byte waits and snapshot-polled text wait helpers.
2
3#[path = "wait/visible.rs"]
4mod visible;
5
6use std::future::Future;
7use std::io;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use std::time::Duration;
11
12use rmux_proto::{
13    CancelSdkWaitRequest, PaneOutputSubscriptionStart, Request, Response, RmuxError as ProtoError,
14    SdkWaitForOutputRefRequest, SdkWaitForOutputRequest, SdkWaitId, SdkWaitOutcome,
15    CAPABILITY_SDK_PANE_BY_ID,
16};
17
18use crate::handles::{connect_transport_to_endpoint, Pane};
19use crate::transport::{DropGuard, PendingResponse};
20use crate::{Result, RmuxError};
21
22pub use visible::{VisibleTextExpectation, VisibleTextWait, WaitTimeoutError};
23
24const WAIT_FOR_BYTES_OPERATION: &str = "wait for pane output bytes";
25const WAIT_FOR_TEXT_OPERATION: &str = "wait for pane snapshot text";
26const WAIT_FOR_NEXT_BYTES_OPERATION: &str = "wait for next pane output bytes";
27const WAIT_FOR_TEXT_NEXT_OPERATION: &str = "wait for next pane output text";
28const WAIT_FOR_EXIT_OPERATION: &str = "wait for pane process exit";
29pub(crate) const TEXT_POLL_INTERVAL: Duration = Duration::from_millis(25);
30
31/// A daemon-armed wait for future pane output.
32///
33/// Values are returned by [`Pane::wait_for_next`](crate::Pane::wait_for_next)
34/// and [`Pane::wait_for_text_next`](crate::Pane::wait_for_text_next) after the
35/// SDK has written the daemon wait request. Awaiting the value completes when
36/// that daemon wait reports a match. Dropping it before a match sends a
37/// best-effort SDK wait cancellation request; cancellation never closes panes,
38/// sessions, child processes, or the daemon.
39#[must_use = "armed waits do nothing useful unless awaited or explicitly dropped"]
40pub struct ArmedWait {
41    response: PendingResponse,
42    wait_id: SdkWaitId,
43    cancel_guard: DropGuard,
44    timeout: Option<Pin<Box<tokio::time::Sleep>>>,
45    timeout_duration: Option<Duration>,
46    operation: &'static str,
47}
48
49impl ArmedWait {
50    fn new(
51        response: PendingResponse,
52        wait_id: SdkWaitId,
53        cancel_guard: DropGuard,
54        operation: &'static str,
55        timeout: Option<Duration>,
56    ) -> Self {
57        Self {
58            response,
59            wait_id,
60            cancel_guard,
61            timeout: timeout.map(|duration| Box::pin(tokio::time::sleep(duration))),
62            timeout_duration: timeout,
63            operation,
64        }
65    }
66}
67
68impl Future for ArmedWait {
69    type Output = Result<()>;
70
71    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
72        match Pin::new(&mut self.response).poll(cx) {
73            Poll::Ready(Ok(response)) => {
74                if sdk_wait_response_disarms_cancel(&response, self.wait_id) {
75                    self.cancel_guard.disarm();
76                }
77                let result = sdk_wait_response_to_result(response, self.wait_id);
78                return Poll::Ready(result);
79            }
80            Poll::Ready(Err(error)) => {
81                if sdk_wait_error_disarms_cancel(&error) {
82                    self.cancel_guard.disarm();
83                }
84                return Poll::Ready(Err(error));
85            }
86            Poll::Pending => {}
87        }
88
89        if let Some(duration) = self.timeout_duration {
90            if let Some(timeout) = self.timeout.as_mut() {
91                if timeout.as_mut().poll(cx).is_ready() {
92                    self.cancel_guard.trigger();
93                    return Poll::Ready(Err(wait_timeout_error(self.operation, duration)));
94                }
95            }
96        }
97
98        Poll::Pending
99    }
100}
101
102impl std::fmt::Debug for ArmedWait {
103    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        formatter
105            .debug_struct("ArmedWait")
106            .field("wait_id", &self.wait_id)
107            .field("operation", &self.operation)
108            .finish_non_exhaustive()
109    }
110}
111
112pub(crate) async fn wait_for_bytes(pane: &Pane, bytes: Vec<u8>) -> Result<()> {
113    if bytes.is_empty() {
114        return Err(RmuxError::protocol(ProtoError::Server(
115            "SDK wait bytes must not be empty".to_owned(),
116        )));
117    }
118
119    let timeout = resolved_wait_timeout(pane.configured_default_timeout());
120    with_wait_timeout(
121        WAIT_FOR_BYTES_OPERATION,
122        timeout,
123        wait_for_bytes_without_timeout(pane, bytes, timeout),
124    )
125    .await
126}
127
128pub(crate) async fn wait_for_next_bytes(pane: &Pane, bytes: Vec<u8>) -> Result<ArmedWait> {
129    if bytes.is_empty() {
130        return Err(RmuxError::protocol(ProtoError::Server(
131            "SDK wait bytes must not be empty".to_owned(),
132        )));
133    }
134
135    let timeout = resolved_wait_timeout(pane.configured_default_timeout());
136    arm_sdk_wait(pane, bytes, WAIT_FOR_NEXT_BYTES_OPERATION, timeout).await
137}
138
139pub(crate) async fn wait_for_text(pane: &Pane, text: String) -> Result<()> {
140    if text.is_empty() {
141        return Err(RmuxError::protocol(ProtoError::Server(
142            "SDK wait text must not be empty".to_owned(),
143        )));
144    }
145
146    let timeout = resolved_wait_timeout(pane.configured_default_timeout());
147    with_wait_timeout(
148        WAIT_FOR_TEXT_OPERATION,
149        timeout,
150        wait_for_text_without_timeout(pane, text),
151    )
152    .await
153}
154
155pub(crate) async fn wait_for_text_next(pane: &Pane, text: String) -> Result<ArmedWait> {
156    if text.is_empty() {
157        return Err(RmuxError::protocol(ProtoError::Server(
158            "SDK wait text must not be empty".to_owned(),
159        )));
160    }
161
162    let timeout = resolved_wait_timeout(pane.configured_default_timeout());
163    arm_sdk_wait(
164        pane,
165        text.into_bytes(),
166        WAIT_FOR_TEXT_NEXT_OPERATION,
167        timeout,
168    )
169    .await
170}
171
172pub(crate) async fn wait_exit(pane: &Pane) -> Result<Option<crate::PaneExitState>> {
173    let timeout = resolved_wait_timeout(pane.configured_default_timeout());
174    with_wait_timeout(
175        WAIT_FOR_EXIT_OPERATION,
176        timeout,
177        wait_exit_without_timeout(pane),
178    )
179    .await
180}
181
182async fn wait_for_bytes_without_timeout(
183    pane: &Pane,
184    bytes: Vec<u8>,
185    timeout: Option<Duration>,
186) -> Result<()> {
187    let owner_id = pane.transport().sdk_wait_owner_id();
188    let wait_id = pane.transport().allocate_sdk_wait_id();
189    let cancel_request = Request::CancelSdkWait(CancelSdkWaitRequest { owner_id, wait_id });
190    let cancel_client = connect_transport_to_endpoint(pane.endpoint(), timeout).await?;
191    let mut cancel_guard = DropGuard::best_effort(cancel_client, cancel_request);
192
193    let response = if pane.is_stable_id() {
194        crate::capabilities::require(pane.transport(), &[CAPABILITY_SDK_PANE_BY_ID]).await?;
195        pane.transport()
196            .request(Request::SdkWaitForOutputRef(SdkWaitForOutputRefRequest {
197                owner_id,
198                wait_id,
199                target: pane.proto_target_ref(),
200                bytes,
201                start: PaneOutputSubscriptionStart::Now,
202            }))
203            .await
204    } else {
205        pane.transport()
206            .request(Request::SdkWaitForOutput(SdkWaitForOutputRequest {
207                owner_id,
208                wait_id,
209                target: pane.target().into(),
210                bytes,
211                start: PaneOutputSubscriptionStart::Now,
212            }))
213            .await
214    };
215
216    let response = match response {
217        Ok(response) => response,
218        Err(error) => {
219            if sdk_wait_error_disarms_cancel(&error) {
220                cancel_guard.disarm();
221            }
222            return Err(error);
223        }
224    };
225
226    if sdk_wait_response_disarms_cancel(&response, wait_id) {
227        cancel_guard.disarm();
228    }
229    sdk_wait_response_to_result(response, wait_id)
230}
231
232async fn arm_sdk_wait(
233    pane: &Pane,
234    bytes: Vec<u8>,
235    operation: &'static str,
236    timeout: Option<Duration>,
237) -> Result<ArmedWait> {
238    let wait_client = connect_transport_to_endpoint(pane.endpoint(), timeout).await?;
239    let cancel_client = connect_transport_to_endpoint(pane.endpoint(), timeout).await?;
240    let owner_id = wait_client.sdk_wait_owner_id();
241    let wait_id = wait_client.allocate_sdk_wait_id();
242    let cancel_request = Request::CancelSdkWait(CancelSdkWaitRequest { owner_id, wait_id });
243    let cancel_guard = DropGuard::best_effort(cancel_client, cancel_request);
244
245    let response = with_wait_timeout(
246        operation,
247        timeout,
248        wait_client.armed_request(sdk_wait_request_for_pane(pane, owner_id, wait_id, bytes).await?),
249    )
250    .await?;
251
252    Ok(ArmedWait::new(
253        response,
254        wait_id,
255        cancel_guard,
256        operation,
257        timeout,
258    ))
259}
260
261async fn sdk_wait_request_for_pane(
262    pane: &Pane,
263    owner_id: rmux_proto::SdkWaitOwnerId,
264    wait_id: SdkWaitId,
265    bytes: Vec<u8>,
266) -> Result<Request> {
267    if pane.is_stable_id() {
268        crate::capabilities::require(pane.transport(), &[CAPABILITY_SDK_PANE_BY_ID]).await?;
269        return Ok(Request::SdkWaitForOutputRef(SdkWaitForOutputRefRequest {
270            owner_id,
271            wait_id,
272            target: pane.proto_target_ref(),
273            bytes,
274            start: PaneOutputSubscriptionStart::Now,
275        }));
276    }
277
278    Ok(Request::SdkWaitForOutput(SdkWaitForOutputRequest {
279        owner_id,
280        wait_id,
281        target: pane.target().into(),
282        bytes,
283        start: PaneOutputSubscriptionStart::Now,
284    }))
285}
286
287async fn wait_for_text_without_timeout(pane: &Pane, text: String) -> Result<()> {
288    loop {
289        let snapshot = pane.snapshot().await?;
290        if snapshot.visible_text().contains(&text) {
291            return Ok(());
292        }
293        tokio::time::sleep(TEXT_POLL_INTERVAL).await;
294    }
295}
296
297async fn wait_exit_without_timeout(pane: &Pane) -> Result<Option<crate::PaneExitState>> {
298    loop {
299        match pane_exit_observation(pane).await? {
300            PaneExitObservation::Running => {}
301            PaneExitObservation::Exited(exit_state) => return Ok(exit_state),
302        }
303        tokio::time::sleep(TEXT_POLL_INTERVAL).await;
304    }
305}
306
307pub(crate) async fn pane_exit_observation(pane: &Pane) -> Result<PaneExitObservation> {
308    let info = pane.info().await?;
309    let Some(pane) = info.panes.first() else {
310        return Ok(PaneExitObservation::Exited(None));
311    };
312
313    if matches!(pane.process, crate::PaneProcessState::Exited) || pane.exit_state.is_some() {
314        return Ok(PaneExitObservation::Exited(pane.exit_state.clone()));
315    }
316
317    Ok(PaneExitObservation::Running)
318}
319
320pub(crate) enum PaneExitObservation {
321    Running,
322    Exited(Option<crate::PaneExitState>),
323}
324
325pub(crate) async fn with_wait_timeout<F, T>(
326    operation: &'static str,
327    timeout: Option<Duration>,
328    future: F,
329) -> Result<T>
330where
331    F: Future<Output = Result<T>>,
332{
333    match timeout {
334        Some(timeout) => tokio::time::timeout(timeout, future)
335            .await
336            .map_err(|_| wait_timeout_error(operation, timeout))?,
337        None => future.await,
338    }
339}
340
341pub(crate) fn resolved_wait_timeout(default_timeout: Option<Duration>) -> Option<Duration> {
342    crate::bootstrap::discovery::resolve_timeout(None, default_timeout)
343}
344
345pub(crate) fn wait_timeout_error(operation: &'static str, timeout: Duration) -> RmuxError {
346    RmuxError::transport(
347        operation,
348        io::Error::new(
349            io::ErrorKind::TimedOut,
350            format!(
351                "timed out after {}s while {operation}",
352                timeout.as_secs_f32()
353            ),
354        ),
355    )
356}
357
358fn sdk_wait_response_disarms_cancel(response: &Response, expected_wait_id: SdkWaitId) -> bool {
359    matches!(
360        response,
361        Response::SdkWaitForOutput(response) if response.wait_id == expected_wait_id
362    )
363}
364
365fn sdk_wait_error_disarms_cancel(error: &RmuxError) -> bool {
366    matches!(
367        error,
368        RmuxError::Protocol { .. } | RmuxError::Unsupported { .. }
369    )
370}
371
372fn sdk_wait_response_to_result(response: Response, expected_wait_id: SdkWaitId) -> Result<()> {
373    match response {
374        Response::SdkWaitForOutput(response)
375            if response.wait_id == expected_wait_id
376                && response.outcome == SdkWaitOutcome::Matched =>
377        {
378            Ok(())
379        }
380        Response::SdkWaitForOutput(response)
381            if response.wait_id == expected_wait_id
382                && response.outcome == SdkWaitOutcome::Cancelled =>
383        {
384            Err(RmuxError::protocol(ProtoError::Server(format!(
385                "SDK wait {} was cancelled",
386                response.wait_id.as_u64()
387            ))))
388        }
389        Response::SdkWaitForOutput(response) => {
390            if response.wait_id != expected_wait_id {
391                return Err(RmuxError::protocol(ProtoError::Server(format!(
392                    "SDK wait response id {} did not match request id {}",
393                    response.wait_id.as_u64(),
394                    expected_wait_id.as_u64()
395                ))));
396            }
397
398            Err(RmuxError::protocol(ProtoError::Server(format!(
399                "SDK wait {} completed with unexpected outcome {:?}",
400                response.wait_id.as_u64(),
401                response.outcome
402            ))))
403        }
404        response => Err(crate::handles::session::unexpected_response(
405            "sdk-wait-output",
406            response,
407        )),
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414    use crate::transport::TransportClient;
415    use rmux_proto::{encode_frame, CancelSdkWaitResponse, FrameDecoder, SdkWaitForOutputResponse};
416    use tokio::io::{AsyncReadExt, AsyncWriteExt};
417
418    async fn read_request(stream: &mut tokio::io::DuplexStream) -> Request {
419        let mut decoder = FrameDecoder::new();
420        let mut buffer = [0_u8; 512];
421
422        loop {
423            if let Some(request) = decoder
424                .next_frame::<Request>()
425                .expect("request frame decodes")
426            {
427                return request;
428            }
429
430            let read = stream.read(&mut buffer).await.expect("read request");
431            assert_ne!(read, 0, "stream closed before request");
432            decoder.push_bytes(&buffer[..read]);
433        }
434    }
435
436    async fn write_response(stream: &mut tokio::io::DuplexStream, response: Response) {
437        let frame = encode_frame(&response).expect("response encodes");
438        stream.write_all(&frame).await.expect("write response");
439        stream.flush().await.expect("flush response");
440    }
441
442    #[tokio::test]
443    async fn drop_guard_sends_cancel_request_once_when_wait_future_is_dropped() {
444        let (client_stream, mut server_stream) = tokio::io::duplex(4096);
445        let client = TransportClient::spawn(client_stream);
446        let owner_id = client.sdk_wait_owner_id();
447        let wait_id = client.allocate_sdk_wait_id();
448        let guard = DropGuard::best_effort(
449            client,
450            Request::CancelSdkWait(CancelSdkWaitRequest { owner_id, wait_id }),
451        );
452
453        drop(guard);
454
455        assert_eq!(
456            read_request(&mut server_stream).await,
457            Request::CancelSdkWait(CancelSdkWaitRequest { owner_id, wait_id })
458        );
459        write_response(
460            &mut server_stream,
461            Response::CancelSdkWait(CancelSdkWaitResponse {
462                wait_id,
463                removed: true,
464            }),
465        )
466        .await;
467    }
468
469    #[tokio::test]
470    async fn disarmed_drop_guard_does_not_send_stale_cancel() {
471        let (client_stream, mut server_stream) = tokio::io::duplex(4096);
472        let client = TransportClient::spawn(client_stream);
473        let owner_id = client.sdk_wait_owner_id();
474        let mut guard = DropGuard::best_effort(
475            client,
476            Request::CancelSdkWait(CancelSdkWaitRequest {
477                owner_id,
478                wait_id: SdkWaitId::new(9),
479            }),
480        );
481        guard.disarm();
482        drop(guard);
483
484        let mut buffer = [0_u8; 1];
485        let read = tokio::time::timeout(
486            std::time::Duration::from_millis(50),
487            server_stream.read(&mut buffer),
488        )
489        .await;
490        match read {
491            Err(_) => {}
492            Ok(Ok(0)) => {}
493            Ok(other) => panic!("disarmed guard must not write cancel, got {other:?}"),
494        }
495    }
496
497    #[test]
498    fn sdk_wait_response_rejects_mismatched_wait_id() {
499        let result = sdk_wait_response_to_result(
500            Response::SdkWaitForOutput(SdkWaitForOutputResponse {
501                wait_id: SdkWaitId::new(10),
502                outcome: SdkWaitOutcome::Matched,
503            }),
504            SdkWaitId::new(9),
505        );
506
507        match result.expect_err("mismatched wait id must fail") {
508            RmuxError::Protocol {
509                source: ProtoError::Server(message),
510                ..
511            } => assert!(message.contains("did not match request id 9")),
512            error => panic!("expected protocol mismatch, got {error:?}"),
513        }
514    }
515
516    #[test]
517    fn duration_max_resolves_to_no_timeout_for_wait_operations() {
518        assert_eq!(resolved_wait_timeout(Some(Duration::MAX)), None);
519    }
520
521    #[tokio::test]
522    async fn finite_wait_timeout_surfaces_typed_timeout_error() {
523        let error = with_wait_timeout(
524            "test wait operation",
525            Some(Duration::from_millis(1)),
526            std::future::pending::<Result<()>>(),
527        )
528        .await
529        .expect_err("pending wait must time out");
530
531        match error {
532            RmuxError::Transport { operation, source } => {
533                assert_eq!(operation, "test wait operation");
534                assert_eq!(source.kind(), io::ErrorKind::TimedOut);
535            }
536            other => panic!("expected typed transport timeout, got {other:?}"),
537        }
538    }
539
540    #[tokio::test]
541    async fn no_timeout_branch_awaits_future_directly() {
542        let value = with_wait_timeout("test no timeout", None, async { Ok(7_u8) })
543            .await
544            .expect("untimed ready future completes");
545
546        assert_eq!(value, 7);
547    }
548}