use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::Mutex;
use super::hooks::{HookManager, InteractionEvent};
use super::mode::InteractionMode;
use super::terminal::TerminalSize;
use crate::error::{ExpectError, Result};
use crate::expect::Pattern;
#[derive(Debug, Clone)]
pub enum InteractAction {
Continue,
Send(Vec<u8>),
Stop,
Error(String),
}
impl InteractAction {
pub fn send(s: impl Into<String>) -> Self {
Self::Send(s.into().into_bytes())
}
pub fn send_bytes(data: impl Into<Vec<u8>>) -> Self {
Self::Send(data.into())
}
}
pub struct InteractContext<'a> {
pub matched: &'a str,
pub before: &'a str,
pub after: &'a str,
pub buffer: &'a str,
pub pattern_index: usize,
}
impl InteractContext<'_> {
pub fn send(&self, data: impl Into<String>) -> InteractAction {
InteractAction::send(data)
}
pub fn send_line(&self, data: impl Into<String>) -> InteractAction {
let mut s = data.into();
s.push('\n');
InteractAction::send(s)
}
}
pub type PatternHook = Box<dyn Fn(&InteractContext<'_>) -> InteractAction + Send + Sync>;
#[derive(Debug, Clone, Copy)]
pub struct ResizeContext {
pub size: TerminalSize,
pub previous: Option<TerminalSize>,
}
pub type ResizeHook = Box<dyn Fn(&ResizeContext) -> InteractAction + Send + Sync>;
struct OutputPatternHook {
pattern: Pattern,
callback: PatternHook,
}
struct InputPatternHook {
pattern: Pattern,
callback: PatternHook,
}
pub struct InteractBuilder<'a, T>
where
T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static,
{
transport: &'a Arc<Mutex<T>>,
output_hooks: Vec<OutputPatternHook>,
input_hooks: Vec<InputPatternHook>,
resize_hook: Option<ResizeHook>,
hook_manager: HookManager,
mode: InteractionMode,
buffer_size: usize,
escape_sequence: Option<Vec<u8>>,
timeout: Option<Duration>,
}
impl<'a, T> InteractBuilder<'a, T>
where
T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static,
{
pub(crate) fn new(transport: &'a Arc<Mutex<T>>) -> Self {
Self {
transport,
output_hooks: Vec::new(),
input_hooks: Vec::new(),
resize_hook: None,
hook_manager: HookManager::new(),
mode: InteractionMode::default(),
buffer_size: 8192,
escape_sequence: Some(vec![0x1d]), timeout: None,
}
}
#[must_use]
pub fn on_output<F>(mut self, pattern: impl Into<Pattern>, callback: F) -> Self
where
F: Fn(&InteractContext<'_>) -> InteractAction + Send + Sync + 'static,
{
self.output_hooks.push(OutputPatternHook {
pattern: pattern.into(),
callback: Box::new(callback),
});
self
}
#[must_use]
pub fn on_input<F>(mut self, pattern: impl Into<Pattern>, callback: F) -> Self
where
F: Fn(&InteractContext<'_>) -> InteractAction + Send + Sync + 'static,
{
self.input_hooks.push(InputPatternHook {
pattern: pattern.into(),
callback: Box::new(callback),
});
self
}
#[must_use]
pub fn on_resize<F>(mut self, callback: F) -> Self
where
F: Fn(&ResizeContext) -> InteractAction + Send + Sync + 'static,
{
self.resize_hook = Some(Box::new(callback));
self
}
#[must_use]
pub const fn with_mode(mut self, mode: InteractionMode) -> Self {
self.mode = mode;
self
}
#[must_use]
pub fn with_escape(mut self, escape: impl Into<Vec<u8>>) -> Self {
self.escape_sequence = Some(escape.into());
self
}
#[must_use]
pub fn no_escape(mut self) -> Self {
self.escape_sequence = None;
self
}
#[must_use]
pub const fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
#[must_use]
pub const fn with_buffer_size(mut self, size: usize) -> Self {
self.buffer_size = size;
self
}
#[must_use]
pub fn with_input_hook<F>(mut self, hook: F) -> Self
where
F: Fn(&[u8]) -> Vec<u8> + Send + Sync + 'static,
{
self.hook_manager.add_input_hook(hook);
self
}
#[must_use]
pub fn with_output_hook<F>(mut self, hook: F) -> Self
where
F: Fn(&[u8]) -> Vec<u8> + Send + Sync + 'static,
{
self.hook_manager.add_output_hook(hook);
self
}
pub async fn start(self) -> Result<InteractResult> {
let mut runner = InteractRunner::new(
Arc::clone(self.transport),
self.output_hooks,
self.input_hooks,
self.resize_hook,
self.hook_manager,
self.mode,
self.buffer_size,
self.escape_sequence,
self.timeout,
);
runner.run().await
}
}
#[derive(Debug, Clone)]
pub struct InteractResult {
pub reason: InteractEndReason,
pub buffer: String,
}
#[derive(Debug, Clone)]
pub enum InteractEndReason {
PatternStop {
pattern_index: usize,
},
Escape,
Timeout,
Eof,
Error(String),
}
struct InteractRunner<T>
where
T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static,
{
transport: Arc<Mutex<T>>,
output_hooks: Vec<OutputPatternHook>,
input_hooks: Vec<InputPatternHook>,
#[cfg_attr(windows, allow(dead_code))]
resize_hook: Option<ResizeHook>,
hook_manager: HookManager,
mode: InteractionMode,
buffer: String,
buffer_size: usize,
escape_sequence: Option<Vec<u8>>,
timeout: Option<Duration>,
#[cfg_attr(windows, allow(dead_code))]
current_size: Option<TerminalSize>,
}
impl<T> InteractRunner<T>
where
T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static,
{
#[allow(clippy::too_many_arguments)]
fn new(
transport: Arc<Mutex<T>>,
output_hooks: Vec<OutputPatternHook>,
input_hooks: Vec<InputPatternHook>,
resize_hook: Option<ResizeHook>,
hook_manager: HookManager,
mode: InteractionMode,
buffer_size: usize,
escape_sequence: Option<Vec<u8>>,
timeout: Option<Duration>,
) -> Self {
let current_size = super::terminal::Terminal::size().ok();
Self {
transport,
output_hooks,
input_hooks,
resize_hook,
hook_manager,
mode,
buffer: String::with_capacity(buffer_size),
buffer_size,
escape_sequence,
timeout,
current_size,
}
}
async fn run(&mut self) -> Result<InteractResult> {
#[cfg(unix)]
{
self.run_with_signals().await
}
#[cfg(not(unix))]
{
self.run_without_signals().await
}
}
#[cfg(unix)]
#[allow(clippy::significant_drop_tightening)]
async fn run_with_signals(&mut self) -> Result<InteractResult> {
use tokio::io::{BufReader, stdin, stdout};
self.hook_manager.notify(&InteractionEvent::Started);
let mut stdin = BufReader::new(stdin());
let mut input_buf = [0u8; 1024];
let mut output_buf = [0u8; 4096];
let mut escape_buf: Vec<u8> = Vec::new();
let deadline = self.timeout.map(|t| std::time::Instant::now() + t);
let mut sigwinch =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::window_change())
.map_err(ExpectError::Io)?;
loop {
if let Some(deadline) = deadline
&& std::time::Instant::now() >= deadline
{
self.hook_manager.notify(&InteractionEvent::Ended);
return Ok(InteractResult {
reason: InteractEndReason::Timeout,
buffer: self.buffer.clone(),
});
}
let read_timeout = self.mode.read_timeout;
let mut transport = self.transport.lock().await;
tokio::select! {
_ = sigwinch.recv() => {
drop(transport);
if let Some(result) = self.handle_resize().await? {
return Ok(result);
}
}
result = transport.read(&mut output_buf) => {
drop(transport); match result {
Ok(0) => {
self.hook_manager.notify(&InteractionEvent::Ended);
return Ok(InteractResult {
reason: InteractEndReason::Eof,
buffer: self.buffer.clone(),
});
}
Ok(n) => {
let data = &output_buf[..n];
let processed = self.hook_manager.process_output(data.to_vec());
self.hook_manager.notify(&InteractionEvent::Output(processed.clone()));
let mut stdout = stdout();
let _ = stdout.write_all(&processed).await;
let _ = stdout.flush().await;
if let Ok(s) = std::str::from_utf8(&processed) {
self.buffer.push_str(s);
if self.buffer.len() > self.buffer_size {
let start = self.buffer.len() - self.buffer_size;
self.buffer = self.buffer[start..].to_string();
}
}
if let Some(result) = self.check_output_patterns().await? {
return Ok(result);
}
}
Err(e) => {
self.hook_manager.notify(&InteractionEvent::Ended);
return Err(ExpectError::Io(e));
}
}
}
result = tokio::time::timeout(read_timeout, stdin.read(&mut input_buf)) => {
drop(transport);
if let Ok(Ok(n)) = result {
if n == 0 {
continue;
}
let data = &input_buf[..n];
if let Some(ref esc) = self.escape_sequence {
escape_buf.extend_from_slice(data);
if escape_buf.ends_with(esc) {
self.hook_manager.notify(&InteractionEvent::ExitRequested);
self.hook_manager.notify(&InteractionEvent::Ended);
return Ok(InteractResult {
reason: InteractEndReason::Escape,
buffer: self.buffer.clone(),
});
}
if escape_buf.len() > esc.len() {
escape_buf = escape_buf[escape_buf.len() - esc.len()..].to_vec();
}
}
let processed = self.hook_manager.process_input(data.to_vec());
self.hook_manager.notify(&InteractionEvent::Input(processed.clone()));
if let Some(result) = self.check_input_patterns(&processed).await? {
return Ok(result);
}
let mut transport = self.transport.lock().await;
transport.write_all(&processed).await.map_err(ExpectError::Io)?;
transport.flush().await.map_err(ExpectError::Io)?;
}
}
}
}
}
#[cfg(not(unix))]
#[allow(clippy::significant_drop_tightening)]
async fn run_without_signals(&mut self) -> Result<InteractResult> {
use tokio::io::{BufReader, stdin, stdout};
self.hook_manager.notify(&InteractionEvent::Started);
let mut stdin = BufReader::new(stdin());
let mut input_buf = [0u8; 1024];
let mut output_buf = [0u8; 4096];
let mut escape_buf: Vec<u8> = Vec::new();
let deadline = self.timeout.map(|t| std::time::Instant::now() + t);
loop {
if let Some(deadline) = deadline {
if std::time::Instant::now() >= deadline {
self.hook_manager.notify(&InteractionEvent::Ended);
return Ok(InteractResult {
reason: InteractEndReason::Timeout,
buffer: self.buffer.clone(),
});
}
}
let read_timeout = self.mode.read_timeout;
let mut transport = self.transport.lock().await;
tokio::select! {
result = transport.read(&mut output_buf) => {
drop(transport); match result {
Ok(0) => {
self.hook_manager.notify(&InteractionEvent::Ended);
return Ok(InteractResult {
reason: InteractEndReason::Eof,
buffer: self.buffer.clone(),
});
}
Ok(n) => {
let data = &output_buf[..n];
let processed = self.hook_manager.process_output(data.to_vec());
self.hook_manager.notify(&InteractionEvent::Output(processed.clone()));
let mut stdout = stdout();
let _ = stdout.write_all(&processed).await;
let _ = stdout.flush().await;
if let Ok(s) = std::str::from_utf8(&processed) {
self.buffer.push_str(s);
if self.buffer.len() > self.buffer_size {
let start = self.buffer.len() - self.buffer_size;
self.buffer = self.buffer[start..].to_string();
}
}
if let Some(result) = self.check_output_patterns().await? {
return Ok(result);
}
}
Err(e) => {
self.hook_manager.notify(&InteractionEvent::Ended);
return Err(ExpectError::Io(e));
}
}
}
result = tokio::time::timeout(read_timeout, stdin.read(&mut input_buf)) => {
drop(transport);
if let Ok(Ok(n)) = result {
if n == 0 {
continue;
}
let data = &input_buf[..n];
if let Some(ref esc) = self.escape_sequence {
escape_buf.extend_from_slice(data);
if escape_buf.ends_with(esc) {
self.hook_manager.notify(&InteractionEvent::ExitRequested);
self.hook_manager.notify(&InteractionEvent::Ended);
return Ok(InteractResult {
reason: InteractEndReason::Escape,
buffer: self.buffer.clone(),
});
}
if escape_buf.len() > esc.len() {
escape_buf = escape_buf[escape_buf.len() - esc.len()..].to_vec();
}
}
let processed = self.hook_manager.process_input(data.to_vec());
self.hook_manager.notify(&InteractionEvent::Input(processed.clone()));
if let Some(result) = self.check_input_patterns(&processed).await? {
return Ok(result);
}
let mut transport = self.transport.lock().await;
transport.write_all(&processed).await.map_err(ExpectError::Io)?;
transport.flush().await.map_err(ExpectError::Io)?;
}
}
}
}
}
#[allow(clippy::significant_drop_tightening)]
async fn check_output_patterns(&mut self) -> Result<Option<InteractResult>> {
for (index, hook) in self.output_hooks.iter().enumerate() {
if let Some(m) = hook.pattern.matches(&self.buffer) {
let matched = &self.buffer[m.start..m.end];
let before = &self.buffer[..m.start];
let after = &self.buffer[m.end..];
let ctx = InteractContext {
matched,
before,
after,
buffer: &self.buffer,
pattern_index: index,
};
match (hook.callback)(&ctx) {
InteractAction::Continue => {
self.buffer = after.to_string();
}
InteractAction::Send(data) => {
let mut transport = self.transport.lock().await;
transport.write_all(&data).await.map_err(ExpectError::Io)?;
transport.flush().await.map_err(ExpectError::Io)?;
self.buffer = after.to_string();
}
InteractAction::Stop => {
self.hook_manager.notify(&InteractionEvent::Ended);
return Ok(Some(InteractResult {
reason: InteractEndReason::PatternStop {
pattern_index: index,
},
buffer: self.buffer.clone(),
}));
}
InteractAction::Error(msg) => {
self.hook_manager.notify(&InteractionEvent::Ended);
return Ok(Some(InteractResult {
reason: InteractEndReason::Error(msg),
buffer: self.buffer.clone(),
}));
}
}
}
}
Ok(None)
}
#[allow(clippy::significant_drop_tightening)]
async fn check_input_patterns(&self, input: &[u8]) -> Result<Option<InteractResult>> {
let input_str = String::from_utf8_lossy(input);
for (index, hook) in self.input_hooks.iter().enumerate() {
if let Some(m) = hook.pattern.matches(&input_str) {
let matched = &input_str[m.start..m.end];
let before = &input_str[..m.start];
let after = &input_str[m.end..];
let ctx = InteractContext {
matched,
before,
after,
buffer: &input_str,
pattern_index: index,
};
match (hook.callback)(&ctx) {
InteractAction::Continue => {}
InteractAction::Send(data) => {
let mut transport = self.transport.lock().await;
transport.write_all(&data).await.map_err(ExpectError::Io)?;
transport.flush().await.map_err(ExpectError::Io)?;
}
InteractAction::Stop => {
return Ok(Some(InteractResult {
reason: InteractEndReason::PatternStop {
pattern_index: index,
},
buffer: self.buffer.clone(),
}));
}
InteractAction::Error(msg) => {
return Ok(Some(InteractResult {
reason: InteractEndReason::Error(msg),
buffer: self.buffer.clone(),
}));
}
}
}
}
Ok(None)
}
#[cfg_attr(windows, allow(dead_code))]
#[allow(clippy::significant_drop_tightening)]
async fn handle_resize(&mut self) -> Result<Option<InteractResult>> {
let Ok(new_size) = super::terminal::Terminal::size() else {
return Ok(None); };
let ctx = ResizeContext {
size: new_size,
previous: self.current_size,
};
self.hook_manager.notify(&InteractionEvent::Resize {
cols: new_size.cols,
rows: new_size.rows,
});
self.current_size = Some(new_size);
if let Some(ref hook) = self.resize_hook {
match hook(&ctx) {
InteractAction::Continue => {}
InteractAction::Send(data) => {
let mut transport = self.transport.lock().await;
transport.write_all(&data).await.map_err(ExpectError::Io)?;
transport.flush().await.map_err(ExpectError::Io)?;
}
InteractAction::Stop => {
self.hook_manager.notify(&InteractionEvent::Ended);
return Ok(Some(InteractResult {
reason: InteractEndReason::PatternStop { pattern_index: 0 },
buffer: self.buffer.clone(),
}));
}
InteractAction::Error(msg) => {
self.hook_manager.notify(&InteractionEvent::Ended);
return Ok(Some(InteractResult {
reason: InteractEndReason::Error(msg),
buffer: self.buffer.clone(),
}));
}
}
}
Ok(None)
}
}