1#[path = "wait/visible.rs"]
4mod visible;
5
6use std::future::Future;
7use std::io;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use std::time::Duration;
11
12use rmux_proto::{
13 CancelSdkWaitRequest, PaneOutputSubscriptionStart, Request, Response, RmuxError as ProtoError,
14 SdkWaitForOutputRefRequest, SdkWaitForOutputRequest, SdkWaitId, SdkWaitOutcome,
15 CAPABILITY_SDK_PANE_BY_ID,
16};
17
18use crate::handles::{connect_transport_to_endpoint, Pane};
19use crate::transport::{DropGuard, PendingResponse};
20use crate::{Result, RmuxError};
21
22pub use visible::{VisibleTextExpectation, VisibleTextWait, WaitTimeoutError};
23
24const WAIT_FOR_BYTES_OPERATION: &str = "wait for pane output bytes";
25const WAIT_FOR_TEXT_OPERATION: &str = "wait for pane snapshot text";
26const WAIT_FOR_NEXT_BYTES_OPERATION: &str = "wait for next pane output bytes";
27const WAIT_FOR_TEXT_NEXT_OPERATION: &str = "wait for next pane output text";
28const WAIT_FOR_EXIT_OPERATION: &str = "wait for pane process exit";
29pub(crate) const TEXT_POLL_INTERVAL: Duration = Duration::from_millis(25);
30
31#[must_use = "armed waits do nothing useful unless awaited or explicitly dropped"]
40pub struct ArmedWait {
41 response: PendingResponse,
42 wait_id: SdkWaitId,
43 cancel_guard: DropGuard,
44 timeout: Option<Pin<Box<tokio::time::Sleep>>>,
45 timeout_duration: Option<Duration>,
46 operation: &'static str,
47}
48
49impl ArmedWait {
50 fn new(
51 response: PendingResponse,
52 wait_id: SdkWaitId,
53 cancel_guard: DropGuard,
54 operation: &'static str,
55 timeout: Option<Duration>,
56 ) -> Self {
57 Self {
58 response,
59 wait_id,
60 cancel_guard,
61 timeout: timeout.map(|duration| Box::pin(tokio::time::sleep(duration))),
62 timeout_duration: timeout,
63 operation,
64 }
65 }
66}
67
68impl Future for ArmedWait {
69 type Output = Result<()>;
70
71 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
72 match Pin::new(&mut self.response).poll(cx) {
73 Poll::Ready(Ok(response)) => {
74 if sdk_wait_response_disarms_cancel(&response, self.wait_id) {
75 self.cancel_guard.disarm();
76 }
77 let result = sdk_wait_response_to_result(response, self.wait_id);
78 return Poll::Ready(result);
79 }
80 Poll::Ready(Err(error)) => {
81 if sdk_wait_error_disarms_cancel(&error) {
82 self.cancel_guard.disarm();
83 }
84 return Poll::Ready(Err(error));
85 }
86 Poll::Pending => {}
87 }
88
89 if let Some(duration) = self.timeout_duration {
90 if let Some(timeout) = self.timeout.as_mut() {
91 if timeout.as_mut().poll(cx).is_ready() {
92 self.cancel_guard.trigger();
93 return Poll::Ready(Err(wait_timeout_error(self.operation, duration)));
94 }
95 }
96 }
97
98 Poll::Pending
99 }
100}
101
102impl std::fmt::Debug for ArmedWait {
103 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 formatter
105 .debug_struct("ArmedWait")
106 .field("wait_id", &self.wait_id)
107 .field("operation", &self.operation)
108 .finish_non_exhaustive()
109 }
110}
111
112pub(crate) async fn wait_for_bytes(pane: &Pane, bytes: Vec<u8>) -> Result<()> {
113 if bytes.is_empty() {
114 return Err(RmuxError::protocol(ProtoError::Server(
115 "SDK wait bytes must not be empty".to_owned(),
116 )));
117 }
118
119 let timeout = resolved_wait_timeout(pane.configured_default_timeout());
120 with_wait_timeout(
121 WAIT_FOR_BYTES_OPERATION,
122 timeout,
123 wait_for_bytes_without_timeout(pane, bytes, timeout),
124 )
125 .await
126}
127
128pub(crate) async fn wait_for_next_bytes(pane: &Pane, bytes: Vec<u8>) -> Result<ArmedWait> {
129 if bytes.is_empty() {
130 return Err(RmuxError::protocol(ProtoError::Server(
131 "SDK wait bytes must not be empty".to_owned(),
132 )));
133 }
134
135 let timeout = resolved_wait_timeout(pane.configured_default_timeout());
136 arm_sdk_wait(pane, bytes, WAIT_FOR_NEXT_BYTES_OPERATION, timeout).await
137}
138
139pub(crate) async fn wait_for_text(pane: &Pane, text: String) -> Result<()> {
140 if text.is_empty() {
141 return Err(RmuxError::protocol(ProtoError::Server(
142 "SDK wait text must not be empty".to_owned(),
143 )));
144 }
145
146 let timeout = resolved_wait_timeout(pane.configured_default_timeout());
147 with_wait_timeout(
148 WAIT_FOR_TEXT_OPERATION,
149 timeout,
150 wait_for_text_without_timeout(pane, text),
151 )
152 .await
153}
154
155pub(crate) async fn wait_for_text_next(pane: &Pane, text: String) -> Result<ArmedWait> {
156 if text.is_empty() {
157 return Err(RmuxError::protocol(ProtoError::Server(
158 "SDK wait text must not be empty".to_owned(),
159 )));
160 }
161
162 let timeout = resolved_wait_timeout(pane.configured_default_timeout());
163 arm_sdk_wait(
164 pane,
165 text.into_bytes(),
166 WAIT_FOR_TEXT_NEXT_OPERATION,
167 timeout,
168 )
169 .await
170}
171
172pub(crate) async fn wait_exit(pane: &Pane) -> Result<Option<crate::PaneExitState>> {
173 let timeout = resolved_wait_timeout(pane.configured_default_timeout());
174 with_wait_timeout(
175 WAIT_FOR_EXIT_OPERATION,
176 timeout,
177 wait_exit_without_timeout(pane),
178 )
179 .await
180}
181
182async fn wait_for_bytes_without_timeout(
183 pane: &Pane,
184 bytes: Vec<u8>,
185 timeout: Option<Duration>,
186) -> Result<()> {
187 let owner_id = pane.transport().sdk_wait_owner_id();
188 let wait_id = pane.transport().allocate_sdk_wait_id();
189 let cancel_request = Request::CancelSdkWait(CancelSdkWaitRequest { owner_id, wait_id });
190 let cancel_client = connect_transport_to_endpoint(pane.endpoint(), timeout).await?;
191 let mut cancel_guard = DropGuard::best_effort(cancel_client, cancel_request);
192
193 let response = if pane.is_stable_id() {
194 crate::capabilities::require(pane.transport(), &[CAPABILITY_SDK_PANE_BY_ID]).await?;
195 pane.transport()
196 .request(Request::SdkWaitForOutputRef(SdkWaitForOutputRefRequest {
197 owner_id,
198 wait_id,
199 target: pane.proto_target_ref(),
200 bytes,
201 start: PaneOutputSubscriptionStart::Now,
202 }))
203 .await
204 } else {
205 pane.transport()
206 .request(Request::SdkWaitForOutput(SdkWaitForOutputRequest {
207 owner_id,
208 wait_id,
209 target: pane.target().into(),
210 bytes,
211 start: PaneOutputSubscriptionStart::Now,
212 }))
213 .await
214 };
215
216 let response = match response {
217 Ok(response) => response,
218 Err(error) => {
219 if sdk_wait_error_disarms_cancel(&error) {
220 cancel_guard.disarm();
221 }
222 return Err(error);
223 }
224 };
225
226 if sdk_wait_response_disarms_cancel(&response, wait_id) {
227 cancel_guard.disarm();
228 }
229 sdk_wait_response_to_result(response, wait_id)
230}
231
232async fn arm_sdk_wait(
233 pane: &Pane,
234 bytes: Vec<u8>,
235 operation: &'static str,
236 timeout: Option<Duration>,
237) -> Result<ArmedWait> {
238 let wait_client = connect_transport_to_endpoint(pane.endpoint(), timeout).await?;
239 let cancel_client = connect_transport_to_endpoint(pane.endpoint(), timeout).await?;
240 let owner_id = wait_client.sdk_wait_owner_id();
241 let wait_id = wait_client.allocate_sdk_wait_id();
242 let cancel_request = Request::CancelSdkWait(CancelSdkWaitRequest { owner_id, wait_id });
243 let cancel_guard = DropGuard::best_effort(cancel_client, cancel_request);
244
245 let response = with_wait_timeout(
246 operation,
247 timeout,
248 wait_client.armed_request(sdk_wait_request_for_pane(pane, owner_id, wait_id, bytes).await?),
249 )
250 .await?;
251
252 Ok(ArmedWait::new(
253 response,
254 wait_id,
255 cancel_guard,
256 operation,
257 timeout,
258 ))
259}
260
261async fn sdk_wait_request_for_pane(
262 pane: &Pane,
263 owner_id: rmux_proto::SdkWaitOwnerId,
264 wait_id: SdkWaitId,
265 bytes: Vec<u8>,
266) -> Result<Request> {
267 if pane.is_stable_id() {
268 crate::capabilities::require(pane.transport(), &[CAPABILITY_SDK_PANE_BY_ID]).await?;
269 return Ok(Request::SdkWaitForOutputRef(SdkWaitForOutputRefRequest {
270 owner_id,
271 wait_id,
272 target: pane.proto_target_ref(),
273 bytes,
274 start: PaneOutputSubscriptionStart::Now,
275 }));
276 }
277
278 Ok(Request::SdkWaitForOutput(SdkWaitForOutputRequest {
279 owner_id,
280 wait_id,
281 target: pane.target().into(),
282 bytes,
283 start: PaneOutputSubscriptionStart::Now,
284 }))
285}
286
287async fn wait_for_text_without_timeout(pane: &Pane, text: String) -> Result<()> {
288 loop {
289 let snapshot = pane.snapshot().await?;
290 if snapshot.visible_text().contains(&text) {
291 return Ok(());
292 }
293 tokio::time::sleep(TEXT_POLL_INTERVAL).await;
294 }
295}
296
297async fn wait_exit_without_timeout(pane: &Pane) -> Result<Option<crate::PaneExitState>> {
298 loop {
299 match pane_exit_observation(pane).await? {
300 PaneExitObservation::Running => {}
301 PaneExitObservation::Exited(exit_state) => return Ok(exit_state),
302 }
303 tokio::time::sleep(TEXT_POLL_INTERVAL).await;
304 }
305}
306
307pub(crate) async fn pane_exit_observation(pane: &Pane) -> Result<PaneExitObservation> {
308 let info = pane.info().await?;
309 let Some(pane) = info.panes.first() else {
310 return Ok(PaneExitObservation::Exited(None));
311 };
312
313 if matches!(pane.process, crate::PaneProcessState::Exited) || pane.exit_state.is_some() {
314 return Ok(PaneExitObservation::Exited(pane.exit_state.clone()));
315 }
316
317 Ok(PaneExitObservation::Running)
318}
319
320pub(crate) enum PaneExitObservation {
321 Running,
322 Exited(Option<crate::PaneExitState>),
323}
324
325pub(crate) async fn with_wait_timeout<F, T>(
326 operation: &'static str,
327 timeout: Option<Duration>,
328 future: F,
329) -> Result<T>
330where
331 F: Future<Output = Result<T>>,
332{
333 match timeout {
334 Some(timeout) => tokio::time::timeout(timeout, future)
335 .await
336 .map_err(|_| wait_timeout_error(operation, timeout))?,
337 None => future.await,
338 }
339}
340
341pub(crate) fn resolved_wait_timeout(default_timeout: Option<Duration>) -> Option<Duration> {
342 crate::bootstrap::discovery::resolve_timeout(None, default_timeout)
343}
344
345pub(crate) fn wait_timeout_error(operation: &'static str, timeout: Duration) -> RmuxError {
346 RmuxError::transport(
347 operation,
348 io::Error::new(
349 io::ErrorKind::TimedOut,
350 format!(
351 "timed out after {}s while {operation}",
352 timeout.as_secs_f32()
353 ),
354 ),
355 )
356}
357
358fn sdk_wait_response_disarms_cancel(response: &Response, expected_wait_id: SdkWaitId) -> bool {
359 matches!(
360 response,
361 Response::SdkWaitForOutput(response) if response.wait_id == expected_wait_id
362 )
363}
364
365fn sdk_wait_error_disarms_cancel(error: &RmuxError) -> bool {
366 matches!(
367 error,
368 RmuxError::Protocol { .. } | RmuxError::Unsupported { .. }
369 )
370}
371
372fn sdk_wait_response_to_result(response: Response, expected_wait_id: SdkWaitId) -> Result<()> {
373 match response {
374 Response::SdkWaitForOutput(response)
375 if response.wait_id == expected_wait_id
376 && response.outcome == SdkWaitOutcome::Matched =>
377 {
378 Ok(())
379 }
380 Response::SdkWaitForOutput(response)
381 if response.wait_id == expected_wait_id
382 && response.outcome == SdkWaitOutcome::Cancelled =>
383 {
384 Err(RmuxError::protocol(ProtoError::Server(format!(
385 "SDK wait {} was cancelled",
386 response.wait_id.as_u64()
387 ))))
388 }
389 Response::SdkWaitForOutput(response) => {
390 if response.wait_id != expected_wait_id {
391 return Err(RmuxError::protocol(ProtoError::Server(format!(
392 "SDK wait response id {} did not match request id {}",
393 response.wait_id.as_u64(),
394 expected_wait_id.as_u64()
395 ))));
396 }
397
398 Err(RmuxError::protocol(ProtoError::Server(format!(
399 "SDK wait {} completed with unexpected outcome {:?}",
400 response.wait_id.as_u64(),
401 response.outcome
402 ))))
403 }
404 response => Err(crate::handles::session::unexpected_response(
405 "sdk-wait-output",
406 response,
407 )),
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414 use crate::transport::TransportClient;
415 use rmux_proto::{encode_frame, CancelSdkWaitResponse, FrameDecoder, SdkWaitForOutputResponse};
416 use tokio::io::{AsyncReadExt, AsyncWriteExt};
417
418 async fn read_request(stream: &mut tokio::io::DuplexStream) -> Request {
419 let mut decoder = FrameDecoder::new();
420 let mut buffer = [0_u8; 512];
421
422 loop {
423 if let Some(request) = decoder
424 .next_frame::<Request>()
425 .expect("request frame decodes")
426 {
427 return request;
428 }
429
430 let read = stream.read(&mut buffer).await.expect("read request");
431 assert_ne!(read, 0, "stream closed before request");
432 decoder.push_bytes(&buffer[..read]);
433 }
434 }
435
436 async fn write_response(stream: &mut tokio::io::DuplexStream, response: Response) {
437 let frame = encode_frame(&response).expect("response encodes");
438 stream.write_all(&frame).await.expect("write response");
439 stream.flush().await.expect("flush response");
440 }
441
442 #[tokio::test]
443 async fn drop_guard_sends_cancel_request_once_when_wait_future_is_dropped() {
444 let (client_stream, mut server_stream) = tokio::io::duplex(4096);
445 let client = TransportClient::spawn(client_stream);
446 let owner_id = client.sdk_wait_owner_id();
447 let wait_id = client.allocate_sdk_wait_id();
448 let guard = DropGuard::best_effort(
449 client,
450 Request::CancelSdkWait(CancelSdkWaitRequest { owner_id, wait_id }),
451 );
452
453 drop(guard);
454
455 assert_eq!(
456 read_request(&mut server_stream).await,
457 Request::CancelSdkWait(CancelSdkWaitRequest { owner_id, wait_id })
458 );
459 write_response(
460 &mut server_stream,
461 Response::CancelSdkWait(CancelSdkWaitResponse {
462 wait_id,
463 removed: true,
464 }),
465 )
466 .await;
467 }
468
469 #[tokio::test]
470 async fn disarmed_drop_guard_does_not_send_stale_cancel() {
471 let (client_stream, mut server_stream) = tokio::io::duplex(4096);
472 let client = TransportClient::spawn(client_stream);
473 let owner_id = client.sdk_wait_owner_id();
474 let mut guard = DropGuard::best_effort(
475 client,
476 Request::CancelSdkWait(CancelSdkWaitRequest {
477 owner_id,
478 wait_id: SdkWaitId::new(9),
479 }),
480 );
481 guard.disarm();
482 drop(guard);
483
484 let mut buffer = [0_u8; 1];
485 let read = tokio::time::timeout(
486 std::time::Duration::from_millis(50),
487 server_stream.read(&mut buffer),
488 )
489 .await;
490 match read {
491 Err(_) => {}
492 Ok(Ok(0)) => {}
493 Ok(other) => panic!("disarmed guard must not write cancel, got {other:?}"),
494 }
495 }
496
497 #[test]
498 fn sdk_wait_response_rejects_mismatched_wait_id() {
499 let result = sdk_wait_response_to_result(
500 Response::SdkWaitForOutput(SdkWaitForOutputResponse {
501 wait_id: SdkWaitId::new(10),
502 outcome: SdkWaitOutcome::Matched,
503 }),
504 SdkWaitId::new(9),
505 );
506
507 match result.expect_err("mismatched wait id must fail") {
508 RmuxError::Protocol {
509 source: ProtoError::Server(message),
510 ..
511 } => assert!(message.contains("did not match request id 9")),
512 error => panic!("expected protocol mismatch, got {error:?}"),
513 }
514 }
515
516 #[test]
517 fn duration_max_resolves_to_no_timeout_for_wait_operations() {
518 assert_eq!(resolved_wait_timeout(Some(Duration::MAX)), None);
519 }
520
521 #[tokio::test]
522 async fn finite_wait_timeout_surfaces_typed_timeout_error() {
523 let error = with_wait_timeout(
524 "test wait operation",
525 Some(Duration::from_millis(1)),
526 std::future::pending::<Result<()>>(),
527 )
528 .await
529 .expect_err("pending wait must time out");
530
531 match error {
532 RmuxError::Transport { operation, source } => {
533 assert_eq!(operation, "test wait operation");
534 assert_eq!(source.kind(), io::ErrorKind::TimedOut);
535 }
536 other => panic!("expected typed transport timeout, got {other:?}"),
537 }
538 }
539
540 #[tokio::test]
541 async fn no_timeout_branch_awaits_future_directly() {
542 let value = with_wait_timeout("test no timeout", None, async { Ok(7_u8) })
543 .await
544 .expect("untimed ready future completes");
545
546 assert_eq!(value, 7);
547 }
548}