use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::Mutex;
#[cfg(unix)]
use crate::backend::{AsyncPty, PtyConfig, PtySpawner};
#[cfg(windows)]
use crate::backend::{PtyConfig, PtySpawner, WindowsAsyncPty};
use crate::config::SessionConfig;
use crate::dialog::{Dialog, DialogExecutor, DialogResult};
use crate::error::{ExpectError, Result};
use crate::expect::{ExpectState, MatchResult, Matcher, Pattern, PatternManager, PatternSet};
use crate::interact::InteractBuilder;
use crate::types::{ControlChar, Dimensions, Match, ProcessExitStatus, SessionId, SessionState};
pub struct Session<T: AsyncReadExt + AsyncWriteExt + Unpin + Send> {
transport: Arc<Mutex<T>>,
config: SessionConfig,
matcher: Matcher,
pattern_manager: PatternManager,
state: SessionState,
id: SessionId,
eof: bool,
}
impl<T: AsyncReadExt + AsyncWriteExt + Unpin + Send> Session<T> {
pub fn new(transport: T, config: SessionConfig) -> Self {
let buffer_size = config.buffer.max_size;
let mut matcher = Matcher::new(buffer_size);
matcher.set_default_timeout(config.timeout.default);
Self {
transport: Arc::new(Mutex::new(transport)),
config,
matcher,
pattern_manager: PatternManager::new(),
state: SessionState::Starting,
id: SessionId::new(),
eof: false,
}
}
#[must_use]
pub const fn id(&self) -> &SessionId {
&self.id
}
#[must_use]
pub const fn state(&self) -> SessionState {
self.state
}
#[must_use]
pub const fn config(&self) -> &SessionConfig {
&self.config
}
#[must_use]
pub const fn is_eof(&self) -> bool {
self.eof
}
#[must_use]
pub fn buffer(&mut self) -> String {
self.matcher.buffer_str()
}
pub fn clear_buffer(&mut self) {
self.matcher.clear();
}
#[must_use]
pub const fn pattern_manager(&self) -> &PatternManager {
&self.pattern_manager
}
pub const fn pattern_manager_mut(&mut self) -> &mut PatternManager {
&mut self.pattern_manager
}
pub const fn set_state(&mut self, state: SessionState) {
self.state = state;
}
#[allow(clippy::significant_drop_tightening)]
pub async fn send(&mut self, data: &[u8]) -> Result<()> {
if matches!(self.state, SessionState::Closed | SessionState::Exited(_)) {
return Err(ExpectError::SessionClosed);
}
let mut transport = self.transport.lock().await;
transport
.write_all(data)
.await
.map_err(|e| ExpectError::io_context("writing to process", e))?;
transport
.flush()
.await
.map_err(|e| ExpectError::io_context("flushing process output", e))?;
Ok(())
}
pub async fn send_str(&mut self, s: &str) -> Result<()> {
self.send(s.as_bytes()).await
}
pub async fn send_line(&mut self, line: &str) -> Result<()> {
let line_ending = self.config.line_ending.as_str();
let data = format!("{line}{line_ending}");
self.send(data.as_bytes()).await
}
pub async fn send_control(&mut self, ctrl: ControlChar) -> Result<()> {
self.send(&[ctrl.as_byte()]).await
}
pub async fn expect(&mut self, pattern: impl Into<Pattern>) -> Result<Match> {
let patterns = PatternSet::from_patterns(vec![pattern.into()]);
self.expect_any(&patterns).await
}
pub async fn expect_any(&mut self, patterns: &PatternSet) -> Result<Match> {
let timeout = self.matcher.get_timeout(patterns);
let state = ExpectState::new(patterns.clone(), timeout);
loop {
if let Some((_, action)) = self
.pattern_manager
.check_before(&self.matcher.buffer_str())
{
match action {
crate::expect::HandlerAction::Continue => {}
crate::expect::HandlerAction::Return(s) => {
return Ok(Match::new(0, s, String::new(), self.matcher.buffer_str()));
}
crate::expect::HandlerAction::Abort(msg) => {
return Err(ExpectError::PatternNotFound {
pattern: msg,
buffer: self.matcher.buffer_str(),
});
}
crate::expect::HandlerAction::Respond(s) => {
self.send_str(&s).await?;
}
}
}
if let Some(result) = self.matcher.try_match_any(patterns) {
return Ok(self.matcher.consume_match(&result));
}
if state.is_timed_out() {
return Err(ExpectError::Timeout {
duration: timeout,
pattern: patterns
.iter()
.next()
.map(|p| p.pattern.as_str().to_string())
.unwrap_or_default(),
buffer: self.matcher.buffer_str(),
});
}
if self.eof {
if state.expects_eof() {
return Ok(Match::new(
0,
String::new(),
self.matcher.buffer_str(),
String::new(),
));
}
return Err(ExpectError::Eof {
buffer: self.matcher.buffer_str(),
});
}
self.read_with_timeout(state.remaining_time()).await?;
}
}
pub async fn expect_timeout(
&mut self,
pattern: impl Into<Pattern>,
timeout: Duration,
) -> Result<Match> {
let pattern = pattern.into();
let mut patterns = PatternSet::new();
patterns.add(pattern).add(Pattern::timeout(timeout));
self.expect_any(&patterns).await
}
async fn read_with_timeout(&mut self, timeout: Duration) -> Result<usize> {
let mut buf = [0u8; 4096];
let mut transport = self.transport.lock().await;
match tokio::time::timeout(timeout, transport.read(&mut buf)).await {
Ok(Ok(0)) => {
self.eof = true;
Ok(0)
}
Ok(Ok(n)) => {
self.matcher.append(&buf[..n]);
Ok(n)
}
Ok(Err(e)) => {
if is_pty_eof_error(&e) {
self.eof = true;
Ok(0)
} else {
Err(ExpectError::io_context("reading from process", e))
}
}
Err(_) => {
Ok(0)
}
}
}
pub async fn wait(&mut self) -> Result<ProcessExitStatus> {
while !self.eof {
if self.read_with_timeout(Duration::from_millis(100)).await? == 0 && !self.eof {
tokio::time::sleep(Duration::from_millis(10)).await;
}
}
self.state = SessionState::Exited(ProcessExitStatus::Unknown);
Ok(ProcessExitStatus::Unknown)
}
pub async fn wait_timeout(&mut self, timeout: Duration) -> Result<ProcessExitStatus> {
let deadline = tokio::time::Instant::now() + timeout;
while !self.eof {
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
if remaining.is_zero() {
return Err(ExpectError::timeout(
timeout,
"<EOF>",
self.matcher.buffer_str(),
));
}
let poll_timeout = remaining.min(Duration::from_millis(100));
if self.read_with_timeout(poll_timeout).await? == 0 && !self.eof {
tokio::time::sleep(Duration::from_millis(10)).await;
}
}
self.state = SessionState::Exited(ProcessExitStatus::Unknown);
Ok(ProcessExitStatus::Unknown)
}
#[must_use]
pub fn check(&mut self, pattern: &Pattern) -> Option<MatchResult> {
self.matcher.try_match(pattern)
}
#[must_use]
pub const fn transport(&self) -> &Arc<Mutex<T>> {
&self.transport
}
#[must_use]
pub fn interact(&self) -> InteractBuilder<'_, T>
where
T: 'static,
{
InteractBuilder::new(&self.transport)
}
pub async fn run_dialog(&mut self, dialog: &Dialog) -> Result<DialogResult> {
let executor = DialogExecutor::default();
executor.execute(self, dialog).await
}
pub async fn run_dialog_with(
&mut self,
dialog: &Dialog,
executor: &DialogExecutor,
) -> Result<DialogResult> {
executor.execute(self, dialog).await
}
pub async fn expect_eof(&mut self) -> Result<Match> {
self.expect(Pattern::eof()).await
}
pub async fn expect_eof_timeout(&mut self, timeout: Duration) -> Result<Match> {
let mut patterns = PatternSet::new();
patterns.add(Pattern::eof()).add(Pattern::timeout(timeout));
self.expect_any(&patterns).await
}
pub async fn run_script<I, S>(&mut self, commands: I, prompt: Pattern) -> Result<Vec<Match>>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let mut results = Vec::new();
for cmd in commands {
self.send_line(cmd.as_ref()).await?;
let result = self.expect(prompt.clone()).await?;
results.push(result);
}
Ok(results)
}
pub async fn run_script_timeout<I, S>(
&mut self,
commands: I,
prompt: Pattern,
timeout: Duration,
) -> Result<Vec<Match>>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let mut results = Vec::new();
for cmd in commands {
self.send_line(cmd.as_ref()).await?;
let result = self.expect_timeout(prompt.clone(), timeout).await?;
results.push(result);
}
Ok(results)
}
pub async fn run_script_with_results<I, S>(
&mut self,
commands: I,
prompt: Pattern,
) -> (Vec<Match>, Option<ExpectError>)
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let mut results = Vec::new();
for cmd in commands {
match self.send_line(cmd.as_ref()).await {
Ok(()) => {}
Err(e) => return (results, Some(e)),
}
match self.expect(prompt.clone()).await {
Ok(result) => results.push(result),
Err(e) => return (results, Some(e)),
}
}
(results, None)
}
}
impl<T: AsyncReadExt + AsyncWriteExt + Unpin + Send> std::fmt::Debug for Session<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Session")
.field("id", &self.id)
.field("state", &self.state)
.field("eof", &self.eof)
.finish_non_exhaustive()
}
}
#[cfg(unix)]
impl Session<AsyncPty> {
pub async fn spawn(command: &str, args: &[&str]) -> Result<Self> {
Self::spawn_with_config(command, args, SessionConfig::default()).await
}
pub async fn spawn_with_config(
command: &str,
args: &[&str],
config: SessionConfig,
) -> Result<Self> {
let pty_config = PtyConfig::from(&config);
let spawner = PtySpawner::with_config(pty_config);
let args_owned: Vec<String> = args.iter().map(|s| (*s).to_string()).collect();
let handle = spawner.spawn(command, &args_owned).await?;
let async_pty = AsyncPty::from_handle(handle)
.map_err(|e| ExpectError::io_context("creating async PTY wrapper", e))?;
let mut session = Self::new(async_pty, config);
session.state = SessionState::Running;
Ok(session)
}
#[must_use]
pub fn pid(&self) -> u32 {
if let Ok(transport) = self.transport.try_lock() {
transport.pid()
} else {
0
}
}
pub async fn resize_pty(&mut self, cols: u16, rows: u16) -> Result<()> {
let mut transport = self.transport.lock().await;
transport.resize(cols, rows)
}
pub fn signal(&self, signal: i32) -> Result<()> {
if let Ok(transport) = self.transport.try_lock() {
transport.signal(signal)
} else {
Err(ExpectError::io_context(
"sending signal to process",
std::io::Error::new(std::io::ErrorKind::WouldBlock, "transport is locked"),
))
}
}
pub fn kill(&self) -> Result<()> {
if let Ok(transport) = self.transport.try_lock() {
transport.kill()
} else {
Err(ExpectError::io_context(
"killing process",
std::io::Error::new(std::io::ErrorKind::WouldBlock, "transport is locked"),
))
}
}
}
#[cfg(windows)]
impl Session<WindowsAsyncPty> {
pub async fn spawn(command: &str, args: &[&str]) -> Result<Self> {
Self::spawn_with_config(command, args, SessionConfig::default()).await
}
pub async fn spawn_with_config(
command: &str,
args: &[&str],
config: SessionConfig,
) -> Result<Self> {
let pty_config = PtyConfig::from(&config);
let spawner = PtySpawner::with_config(pty_config);
let args_owned: Vec<String> = args.iter().map(|s| s.to_string()).collect();
let handle = spawner.spawn(command, &args_owned).await?;
let async_pty = WindowsAsyncPty::from_handle(handle);
let mut session = Session::new(async_pty, config);
session.state = SessionState::Running;
Ok(session)
}
#[must_use]
pub fn pid(&self) -> u32 {
if let Ok(transport) = self.transport.try_lock() {
transport.pid()
} else {
0
}
}
pub async fn resize_pty(&mut self, cols: u16, rows: u16) -> Result<()> {
let mut transport = self.transport.lock().await;
transport.resize(cols, rows)
}
#[must_use]
pub fn is_running(&self) -> bool {
if let Ok(transport) = self.transport.try_lock() {
transport.is_running()
} else {
true }
}
pub fn kill(&self) -> Result<()> {
if let Ok(mut transport) = self.transport.try_lock() {
transport.kill()
} else {
Err(ExpectError::io_context(
"killing process",
std::io::Error::new(std::io::ErrorKind::WouldBlock, "transport is locked"),
))
}
}
}
pub trait SessionExt {
fn send_expect(
&mut self,
send: &str,
expect: impl Into<Pattern>,
) -> impl std::future::Future<Output = Result<Match>> + Send;
fn resize(
&mut self,
dimensions: Dimensions,
) -> impl std::future::Future<Output = Result<()>> + Send;
}
fn is_pty_eof_error(e: &std::io::Error) -> bool {
use std::io::ErrorKind;
if e.kind() == ErrorKind::BrokenPipe {
return true;
}
#[cfg(unix)]
{
if let Some(errno) = e.raw_os_error() {
if errno == libc::EIO {
return true;
}
}
}
false
}