1use std::sync::Arc;
31use std::time::Duration;
32
33use tokio::io::{AsyncReadExt, AsyncWriteExt};
34use tokio::sync::Mutex;
35
36use super::hooks::{HookManager, InteractionEvent};
37use super::mode::InteractionMode;
38use super::terminal::TerminalSize;
39use crate::error::{ExpectError, Result};
40use crate::expect::Pattern;
41
42#[derive(Debug, Clone)]
44pub enum InteractAction {
45 Continue,
47 Send(Vec<u8>),
49 Stop,
51 Error(String),
53}
54
55impl InteractAction {
56 pub fn send(s: impl Into<String>) -> Self {
58 Self::Send(s.into().into_bytes())
59 }
60
61 pub fn send_bytes(data: impl Into<Vec<u8>>) -> Self {
63 Self::Send(data.into())
64 }
65}
66
67pub struct InteractContext<'a> {
69 pub matched: &'a str,
71 pub before: &'a str,
73 pub after: &'a str,
75 pub buffer: &'a str,
77 pub pattern_index: usize,
79}
80
81impl InteractContext<'_> {
82 pub fn send(&self, data: impl Into<String>) -> InteractAction {
84 InteractAction::send(data)
85 }
86
87 pub fn send_line(&self, data: impl Into<String>) -> InteractAction {
89 let mut s = data.into();
90 s.push('\n');
91 InteractAction::send(s)
92 }
93}
94
95pub type PatternHook = Box<dyn Fn(&InteractContext<'_>) -> InteractAction + Send + Sync>;
97
98#[derive(Debug, Clone, Copy)]
100pub struct ResizeContext {
101 pub size: TerminalSize,
103 pub previous: Option<TerminalSize>,
105}
106
107pub type ResizeHook = Box<dyn Fn(&ResizeContext) -> InteractAction + Send + Sync>;
109
110struct OutputPatternHook {
112 pattern: Pattern,
113 callback: PatternHook,
114}
115
116struct InputPatternHook {
118 pattern: Pattern,
119 callback: PatternHook,
120}
121
122pub struct InteractBuilder<'a, T>
124where
125 T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static,
126{
127 transport: &'a Arc<Mutex<T>>,
129 output_hooks: Vec<OutputPatternHook>,
131 input_hooks: Vec<InputPatternHook>,
133 resize_hook: Option<ResizeHook>,
135 hook_manager: HookManager,
137 mode: InteractionMode,
139 buffer_size: usize,
141 escape_sequence: Option<Vec<u8>>,
143 timeout: Option<Duration>,
145 output_taps: Vec<crate::session::OutputTap>,
150}
151
152impl<'a, T> InteractBuilder<'a, T>
153where
154 T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static,
155{
156 pub(crate) fn new(
158 transport: &'a Arc<Mutex<T>>,
159 output_taps: Vec<crate::session::OutputTap>,
160 ) -> Self {
161 Self {
162 transport,
163 output_hooks: Vec::new(),
164 input_hooks: Vec::new(),
165 resize_hook: None,
166 hook_manager: HookManager::new(),
167 mode: InteractionMode::default(),
168 buffer_size: 8192,
169 escape_sequence: Some(vec![0x1d]), timeout: None,
171 output_taps,
172 }
173 }
174
175 #[must_use]
190 pub fn on_output<F>(mut self, pattern: impl Into<Pattern>, callback: F) -> Self
191 where
192 F: Fn(&InteractContext<'_>) -> InteractAction + Send + Sync + 'static,
193 {
194 self.output_hooks.push(OutputPatternHook {
195 pattern: pattern.into(),
196 callback: Box::new(callback),
197 });
198 self
199 }
200
201 #[must_use]
205 pub fn on_input<F>(mut self, pattern: impl Into<Pattern>, callback: F) -> Self
206 where
207 F: Fn(&InteractContext<'_>) -> InteractAction + Send + Sync + 'static,
208 {
209 self.input_hooks.push(InputPatternHook {
210 pattern: pattern.into(),
211 callback: Box::new(callback),
212 });
213 self
214 }
215
216 #[must_use]
239 pub fn on_resize<F>(mut self, callback: F) -> Self
240 where
241 F: Fn(&ResizeContext) -> InteractAction + Send + Sync + 'static,
242 {
243 self.resize_hook = Some(Box::new(callback));
244 self
245 }
246
247 #[must_use]
249 pub const fn with_mode(mut self, mode: InteractionMode) -> Self {
250 self.mode = mode;
251 self
252 }
253
254 #[must_use]
258 pub fn with_escape(mut self, escape: impl Into<Vec<u8>>) -> Self {
259 self.escape_sequence = Some(escape.into());
260 self
261 }
262
263 #[must_use]
265 pub fn no_escape(mut self) -> Self {
266 self.escape_sequence = None;
267 self
268 }
269
270 #[must_use]
272 pub const fn with_timeout(mut self, timeout: Duration) -> Self {
273 self.timeout = Some(timeout);
274 self
275 }
276
277 #[must_use]
279 pub const fn with_buffer_size(mut self, size: usize) -> Self {
280 self.buffer_size = size;
281 self
282 }
283
284 #[must_use]
286 pub fn with_input_hook<F>(mut self, hook: F) -> Self
287 where
288 F: Fn(&[u8]) -> Vec<u8> + Send + Sync + 'static,
289 {
290 self.hook_manager.add_input_hook(hook);
291 self
292 }
293
294 #[must_use]
296 pub fn with_output_hook<F>(mut self, hook: F) -> Self
297 where
298 F: Fn(&[u8]) -> Vec<u8> + Send + Sync + 'static,
299 {
300 self.hook_manager.add_output_hook(hook);
301 self
302 }
303
304 pub async fn start(self) -> Result<InteractResult> {
319 let mut runner = InteractRunner::new(
320 Arc::clone(self.transport),
321 self.output_hooks,
322 self.input_hooks,
323 self.resize_hook,
324 self.hook_manager,
325 self.mode,
326 self.buffer_size,
327 self.escape_sequence,
328 self.timeout,
329 self.output_taps,
330 );
331 runner.run().await
332 }
333}
334
335#[derive(Debug, Clone)]
337pub struct InteractResult {
338 pub reason: InteractEndReason,
340 pub buffer: String,
342}
343
344#[derive(Debug, Clone)]
346pub enum InteractEndReason {
347 PatternStop {
349 pattern_index: usize,
351 },
352 Escape,
354 Timeout,
356 Eof,
358 Error(String),
360}
361
362struct InteractRunner<T>
364where
365 T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static,
366{
367 transport: Arc<Mutex<T>>,
368 output_hooks: Vec<OutputPatternHook>,
369 input_hooks: Vec<InputPatternHook>,
370 #[cfg_attr(windows, allow(dead_code))]
373 resize_hook: Option<ResizeHook>,
374 hook_manager: HookManager,
375 mode: InteractionMode,
376 buffer: String,
377 buffer_size: usize,
378 escape_sequence: Option<Vec<u8>>,
379 output_taps: Vec<crate::session::OutputTap>,
382 timeout: Option<Duration>,
383 #[cfg_attr(windows, allow(dead_code))]
386 current_size: Option<TerminalSize>,
387}
388
389impl<T> InteractRunner<T>
390where
391 T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static,
392{
393 #[allow(clippy::too_many_arguments)]
394 fn new(
395 transport: Arc<Mutex<T>>,
396 output_hooks: Vec<OutputPatternHook>,
397 input_hooks: Vec<InputPatternHook>,
398 resize_hook: Option<ResizeHook>,
399 hook_manager: HookManager,
400 mode: InteractionMode,
401 buffer_size: usize,
402 escape_sequence: Option<Vec<u8>>,
403 timeout: Option<Duration>,
404 output_taps: Vec<crate::session::OutputTap>,
405 ) -> Self {
406 let current_size = super::terminal::Terminal::size().ok();
408
409 Self {
410 transport,
411 output_hooks,
412 input_hooks,
413 resize_hook,
414 hook_manager,
415 mode,
416 buffer: String::with_capacity(buffer_size),
417 buffer_size,
418 escape_sequence,
419 timeout,
420 current_size,
421 output_taps,
422 }
423 }
424
425 fn fire_taps(&self, chunk: &[u8]) {
429 for tap in &self.output_taps {
430 let tap_clone = tap.clone();
431 let chunk_ref = chunk;
432 let result =
433 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| tap_clone(chunk_ref)));
434 if result.is_err() {
435 tracing::warn!("output tap panicked during interact; caught and continuing");
436 }
437 }
438 }
439
440 async fn run(&mut self) -> Result<InteractResult> {
441 #[cfg(unix)]
442 {
443 self.run_with_signals().await
444 }
445 #[cfg(not(unix))]
446 {
447 self.run_without_signals().await
448 }
449 }
450
451 #[cfg(unix)]
453 #[allow(clippy::significant_drop_tightening)]
454 async fn run_with_signals(&mut self) -> Result<InteractResult> {
455 use tokio::io::{BufReader, stdin, stdout};
456
457 self.hook_manager.notify(&InteractionEvent::Started);
458
459 let mut stdin = BufReader::new(stdin());
460 let mut input_buf = [0u8; 1024];
461 let mut output_buf = [0u8; 4096];
462 let mut escape_buf: Vec<u8> = Vec::new();
463
464 let deadline = self.timeout.map(|t| std::time::Instant::now() + t);
465
466 let mut sigwinch =
468 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::window_change())
469 .map_err(ExpectError::Io)?;
470
471 loop {
472 if let Some(deadline) = deadline
474 && std::time::Instant::now() >= deadline
475 {
476 self.hook_manager.notify(&InteractionEvent::Ended);
477 return Ok(InteractResult {
478 reason: InteractEndReason::Timeout,
479 buffer: self.buffer.clone(),
480 });
481 }
482
483 let read_timeout = self.mode.read_timeout;
484 let mut transport = self.transport.lock().await;
485
486 tokio::select! {
487 _ = sigwinch.recv() => {
489 drop(transport); if let Some(result) = self.handle_resize().await? {
492 return Ok(result);
493 }
494 }
495
496 result = transport.read(&mut output_buf) => {
498 drop(transport); match result {
500 Ok(0) => {
501 self.hook_manager.notify(&InteractionEvent::Ended);
502 return Ok(InteractResult {
503 reason: InteractEndReason::Eof,
504 buffer: self.buffer.clone(),
505 });
506 }
507 Ok(n) => {
508 let data = &output_buf[..n];
509 self.fire_taps(data);
513 let processed = self.hook_manager.process_output(data.to_vec());
514
515 self.hook_manager.notify(&InteractionEvent::Output(processed.clone()));
516
517 let mut stdout = stdout();
519 let _ = stdout.write_all(&processed).await;
520 let _ = stdout.flush().await;
521
522 if let Ok(s) = std::str::from_utf8(&processed) {
524 self.buffer.push_str(s);
525 if self.buffer.len() > self.buffer_size {
527 let start = self.buffer.len() - self.buffer_size;
528 self.buffer = self.buffer[start..].to_string();
529 }
530 }
531
532 if let Some(result) = self.check_output_patterns().await? {
534 return Ok(result);
535 }
536 }
537 Err(e) => {
538 self.hook_manager.notify(&InteractionEvent::Ended);
539 return Err(ExpectError::Io(e));
540 }
541 }
542 }
543
544 result = tokio::time::timeout(read_timeout, stdin.read(&mut input_buf)) => {
546 drop(transport); if let Ok(Ok(n)) = result {
549 if n == 0 {
550 continue;
551 }
552
553 let data = &input_buf[..n];
554
555 if let Some(ref esc) = self.escape_sequence {
557 escape_buf.extend_from_slice(data);
558 if escape_buf.ends_with(esc) {
559 self.hook_manager.notify(&InteractionEvent::ExitRequested);
560 self.hook_manager.notify(&InteractionEvent::Ended);
561 return Ok(InteractResult {
562 reason: InteractEndReason::Escape,
563 buffer: self.buffer.clone(),
564 });
565 }
566 if escape_buf.len() > esc.len() {
568 escape_buf = escape_buf[escape_buf.len() - esc.len()..].to_vec();
569 }
570 }
571
572 let processed = self.hook_manager.process_input(data.to_vec());
574
575 self.hook_manager.notify(&InteractionEvent::Input(processed.clone()));
576
577 if let Some(result) = self.check_input_patterns(&processed).await? {
579 return Ok(result);
580 }
581
582 let mut transport = self.transport.lock().await;
584 transport.write_all(&processed).await.map_err(ExpectError::Io)?;
585 transport.flush().await.map_err(ExpectError::Io)?;
586 }
587 }
588 }
589 }
590 }
591
592 #[cfg(not(unix))]
594 #[allow(clippy::significant_drop_tightening)]
595 async fn run_without_signals(&mut self) -> Result<InteractResult> {
596 use tokio::io::{BufReader, stdin, stdout};
597
598 self.hook_manager.notify(&InteractionEvent::Started);
599
600 let mut stdin = BufReader::new(stdin());
601 let mut input_buf = [0u8; 1024];
602 let mut output_buf = [0u8; 4096];
603 let mut escape_buf: Vec<u8> = Vec::new();
604
605 let deadline = self.timeout.map(|t| std::time::Instant::now() + t);
606
607 loop {
608 if let Some(deadline) = deadline {
610 if std::time::Instant::now() >= deadline {
611 self.hook_manager.notify(&InteractionEvent::Ended);
612 return Ok(InteractResult {
613 reason: InteractEndReason::Timeout,
614 buffer: self.buffer.clone(),
615 });
616 }
617 }
618
619 let read_timeout = self.mode.read_timeout;
620 let mut transport = self.transport.lock().await;
621
622 tokio::select! {
623 result = transport.read(&mut output_buf) => {
625 drop(transport); match result {
627 Ok(0) => {
628 self.hook_manager.notify(&InteractionEvent::Ended);
629 return Ok(InteractResult {
630 reason: InteractEndReason::Eof,
631 buffer: self.buffer.clone(),
632 });
633 }
634 Ok(n) => {
635 let data = &output_buf[..n];
636 self.fire_taps(data);
637 let processed = self.hook_manager.process_output(data.to_vec());
638
639 self.hook_manager.notify(&InteractionEvent::Output(processed.clone()));
640
641 let mut stdout = stdout();
643 let _ = stdout.write_all(&processed).await;
644 let _ = stdout.flush().await;
645
646 if let Ok(s) = std::str::from_utf8(&processed) {
648 self.buffer.push_str(s);
649 if self.buffer.len() > self.buffer_size {
651 let start = self.buffer.len() - self.buffer_size;
652 self.buffer = self.buffer[start..].to_string();
653 }
654 }
655
656 if let Some(result) = self.check_output_patterns().await? {
658 return Ok(result);
659 }
660 }
661 Err(e) => {
662 self.hook_manager.notify(&InteractionEvent::Ended);
663 return Err(ExpectError::Io(e));
664 }
665 }
666 }
667
668 result = tokio::time::timeout(read_timeout, stdin.read(&mut input_buf)) => {
670 drop(transport); if let Ok(Ok(n)) = result {
673 if n == 0 {
674 continue;
675 }
676
677 let data = &input_buf[..n];
678
679 if let Some(ref esc) = self.escape_sequence {
681 escape_buf.extend_from_slice(data);
682 if escape_buf.ends_with(esc) {
683 self.hook_manager.notify(&InteractionEvent::ExitRequested);
684 self.hook_manager.notify(&InteractionEvent::Ended);
685 return Ok(InteractResult {
686 reason: InteractEndReason::Escape,
687 buffer: self.buffer.clone(),
688 });
689 }
690 if escape_buf.len() > esc.len() {
692 escape_buf = escape_buf[escape_buf.len() - esc.len()..].to_vec();
693 }
694 }
695
696 let processed = self.hook_manager.process_input(data.to_vec());
698
699 self.hook_manager.notify(&InteractionEvent::Input(processed.clone()));
700
701 if let Some(result) = self.check_input_patterns(&processed).await? {
703 return Ok(result);
704 }
705
706 let mut transport = self.transport.lock().await;
708 transport.write_all(&processed).await.map_err(ExpectError::Io)?;
709 transport.flush().await.map_err(ExpectError::Io)?;
710 }
711 }
712 }
713 }
714 }
715
716 #[allow(clippy::significant_drop_tightening)]
717 async fn check_output_patterns(&mut self) -> Result<Option<InteractResult>> {
718 for (index, hook) in self.output_hooks.iter().enumerate() {
719 if let Some(m) = hook.pattern.matches(&self.buffer) {
720 let matched = &self.buffer[m.start..m.end];
721 let before = &self.buffer[..m.start];
722 let after = &self.buffer[m.end..];
723
724 let ctx = InteractContext {
725 matched,
726 before,
727 after,
728 buffer: &self.buffer,
729 pattern_index: index,
730 };
731
732 match (hook.callback)(&ctx) {
733 InteractAction::Continue => {
734 self.buffer = after.to_string();
736 }
737 InteractAction::Send(data) => {
738 let mut transport = self.transport.lock().await;
739 transport.write_all(&data).await.map_err(ExpectError::Io)?;
740 transport.flush().await.map_err(ExpectError::Io)?;
741 self.buffer = after.to_string();
743 }
744 InteractAction::Stop => {
745 self.hook_manager.notify(&InteractionEvent::Ended);
746 return Ok(Some(InteractResult {
747 reason: InteractEndReason::PatternStop {
748 pattern_index: index,
749 },
750 buffer: self.buffer.clone(),
751 }));
752 }
753 InteractAction::Error(msg) => {
754 self.hook_manager.notify(&InteractionEvent::Ended);
755 return Ok(Some(InteractResult {
756 reason: InteractEndReason::Error(msg),
757 buffer: self.buffer.clone(),
758 }));
759 }
760 }
761 }
762 }
763 Ok(None)
764 }
765
766 #[allow(clippy::significant_drop_tightening)]
767 async fn check_input_patterns(&self, input: &[u8]) -> Result<Option<InteractResult>> {
768 let input_str = String::from_utf8_lossy(input);
769
770 for (index, hook) in self.input_hooks.iter().enumerate() {
771 if let Some(m) = hook.pattern.matches(&input_str) {
772 let matched = &input_str[m.start..m.end];
773 let before = &input_str[..m.start];
774 let after = &input_str[m.end..];
775
776 let ctx = InteractContext {
777 matched,
778 before,
779 after,
780 buffer: &input_str,
781 pattern_index: index,
782 };
783
784 match (hook.callback)(&ctx) {
785 InteractAction::Continue => {}
786 InteractAction::Send(data) => {
787 let mut transport = self.transport.lock().await;
788 transport.write_all(&data).await.map_err(ExpectError::Io)?;
789 transport.flush().await.map_err(ExpectError::Io)?;
790 }
791 InteractAction::Stop => {
792 return Ok(Some(InteractResult {
793 reason: InteractEndReason::PatternStop {
794 pattern_index: index,
795 },
796 buffer: self.buffer.clone(),
797 }));
798 }
799 InteractAction::Error(msg) => {
800 return Ok(Some(InteractResult {
801 reason: InteractEndReason::Error(msg),
802 buffer: self.buffer.clone(),
803 }));
804 }
805 }
806 }
807 }
808 Ok(None)
809 }
810
811 #[cfg_attr(windows, allow(dead_code))]
816 #[allow(clippy::significant_drop_tightening)]
817 async fn handle_resize(&mut self) -> Result<Option<InteractResult>> {
818 let Ok(new_size) = super::terminal::Terminal::size() else {
820 return Ok(None); };
822
823 let ctx = ResizeContext {
825 size: new_size,
826 previous: self.current_size,
827 };
828
829 self.hook_manager.notify(&InteractionEvent::Resize {
831 cols: new_size.cols,
832 rows: new_size.rows,
833 });
834
835 self.current_size = Some(new_size);
837
838 if let Some(ref hook) = self.resize_hook {
840 match hook(&ctx) {
841 InteractAction::Continue => {}
842 InteractAction::Send(data) => {
843 let mut transport = self.transport.lock().await;
844 transport.write_all(&data).await.map_err(ExpectError::Io)?;
845 transport.flush().await.map_err(ExpectError::Io)?;
846 }
847 InteractAction::Stop => {
848 self.hook_manager.notify(&InteractionEvent::Ended);
849 return Ok(Some(InteractResult {
850 reason: InteractEndReason::PatternStop { pattern_index: 0 },
851 buffer: self.buffer.clone(),
852 }));
853 }
854 InteractAction::Error(msg) => {
855 self.hook_manager.notify(&InteractionEvent::Ended);
856 return Ok(Some(InteractResult {
857 reason: InteractEndReason::Error(msg),
858 buffer: self.buffer.clone(),
859 }));
860 }
861 }
862 }
863
864 Ok(None)
865 }
866}