use std::time::Duration;
use regex::Regex;
use crate::tools::strip_ansi;
use super::config::ShellConfig;
#[derive(Debug, Clone)]
pub enum ReadinessStrategy {
Timeout,
Prompt,
Hybrid,
}
impl ReadinessStrategy {
pub fn from_str(s: &str) -> Self {
match s.to_lowercase().as_str() {
"timeout" => ReadinessStrategy::Timeout,
"prompt" => ReadinessStrategy::Prompt,
_ => ReadinessStrategy::Hybrid,
}
}
}
#[derive(Debug, PartialEq)]
pub enum ReadinessResult {
Ready,
Waiting,
SilenceTimeout,
MaxTimeout,
}
pub struct ReadinessDetector {
strategy: ReadinessStrategy,
patterns: Vec<Regex>,
silence_timeout: Duration,
max_timeout: Duration,
}
const PROMPT_TAIL_LEN: usize = 200;
impl ReadinessDetector {
pub fn new(
strategy: ReadinessStrategy,
patterns_str: &[String],
silence_timeout_ms: u64,
max_timeout_ms: u64,
) -> Self {
let patterns: Vec<Regex> = patterns_str
.iter()
.filter_map(|p| match Regex::new(p) {
Ok(re) => Some(re),
Err(e) => {
tracing::warn!(pattern = %p, error = %e, "skipping invalid prompt regex");
None
}
})
.collect();
let strategy = if patterns.is_empty() {
match strategy {
ReadinessStrategy::Prompt | ReadinessStrategy::Hybrid => {
tracing::warn!(
"no valid prompt patterns — falling back to Timeout strategy"
);
ReadinessStrategy::Timeout
}
other => other,
}
} else {
strategy
};
Self {
strategy,
patterns,
silence_timeout: Duration::from_millis(silence_timeout_ms),
max_timeout: Duration::from_millis(max_timeout_ms),
}
}
pub fn from_config(config: &ShellConfig) -> Self {
let strategy = crate::tools::shell::readiness::ReadinessStrategy::from_str(&config.readiness_strategy);
Self::new(
strategy,
&config.prompt_patterns,
config.readiness_timeout_ms,
config.max_readiness_timeout_ms,
)
}
pub fn check(
&self,
output: &str,
silence_elapsed: Duration,
total_elapsed: Duration,
) -> ReadinessResult {
if total_elapsed >= self.max_timeout {
return ReadinessResult::MaxTimeout;
}
match &self.strategy {
ReadinessStrategy::Timeout => {
if silence_elapsed >= self.silence_timeout {
ReadinessResult::SilenceTimeout
} else {
ReadinessResult::Waiting
}
}
ReadinessStrategy::Prompt => {
if self.matches_prompt(output) {
ReadinessResult::Ready
} else {
ReadinessResult::Waiting
}
}
ReadinessStrategy::Hybrid => {
if self.matches_prompt(output) {
ReadinessResult::Ready
} else if silence_elapsed >= self.silence_timeout {
ReadinessResult::SilenceTimeout
} else {
ReadinessResult::Waiting
}
}
}
}
pub fn matches_prompt(&self, output: &str) -> bool {
if output.is_empty() || self.patterns.is_empty() {
return false;
}
let tail = if output.len() > PROMPT_TAIL_LEN {
let mut start = output.len() - PROMPT_TAIL_LEN;
while start > 0 && !output.is_char_boundary(start) {
start -= 1;
}
&output[start..]
} else {
output
};
let clean = strip_ansi(tail);
let last_line = clean
.lines()
.rev()
.find(|l| !l.trim().is_empty())
.unwrap_or("");
if last_line.is_empty() {
return false;
}
self.patterns.iter().any(|re| re.is_match(last_line))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn timeout_detector(silence_ms: u64) -> ReadinessDetector {
ReadinessDetector::new(ReadinessStrategy::Timeout, &[], silence_ms, 10_000)
}
fn prompt_detector() -> ReadinessDetector {
let config = ShellConfig::default();
ReadinessDetector::new(
ReadinessStrategy::Prompt,
&config.prompt_patterns,
config.readiness_timeout_ms,
config.max_readiness_timeout_ms,
)
}
fn hybrid_detector() -> ReadinessDetector {
ReadinessDetector::from_config(&ShellConfig::default())
}
#[test]
fn timeout_strategy_silence_triggers() {
let det = timeout_detector(300);
let result = det.check(
"some output\n",
Duration::from_millis(301),
Duration::from_millis(500),
);
assert_eq!(result, ReadinessResult::SilenceTimeout);
}
#[test]
fn timeout_strategy_still_waiting() {
let det = timeout_detector(300);
let result = det.check(
"some output\n",
Duration::from_millis(100),
Duration::from_millis(200),
);
assert_eq!(result, ReadinessResult::Waiting);
}
#[test]
fn prompt_strategy_dollar_ready() {
let det = prompt_detector();
let result = det.check(
"user@host:~$ ",
Duration::from_millis(0),
Duration::from_millis(50),
);
assert_eq!(result, ReadinessResult::Ready);
}
#[test]
fn prompt_strategy_python_ready() {
let det = prompt_detector();
let result = det.check(
"Python 3.11.0\n>>> ",
Duration::from_millis(0),
Duration::from_millis(50),
);
assert_eq!(result, ReadinessResult::Ready);
}
#[test]
fn prompt_strategy_no_match_silence_fallback() {
let det = prompt_detector();
let result = det.check(
"compiling crate...\n",
Duration::from_millis(500),
Duration::from_millis(1000),
);
assert_eq!(result, ReadinessResult::Waiting);
}
#[test]
fn prompt_strategy_no_match_waiting() {
let det = prompt_detector();
let result = det.check(
"compiling crate...\n",
Duration::from_millis(100),
Duration::from_millis(200),
);
assert_eq!(result, ReadinessResult::Waiting);
}
#[test]
fn hybrid_prompt_match_before_silence() {
let det = hybrid_detector();
let result = det.check(
"welcome\nuser@host:~$ ",
Duration::from_millis(10), Duration::from_millis(50),
);
assert_eq!(result, ReadinessResult::Ready);
}
#[test]
fn hybrid_silence_fallback() {
let det = hybrid_detector();
let result = det.check(
"running long task...\n",
Duration::from_millis(500),
Duration::from_millis(1000),
);
assert_eq!(result, ReadinessResult::SilenceTimeout);
}
#[test]
fn max_timeout_always_wins() {
for det in [timeout_detector(300), prompt_detector(), hybrid_detector()] {
let result = det.check(
"user@host:~$ ",
Duration::from_millis(0),
Duration::from_millis(10_001),
);
assert_eq!(result, ReadinessResult::MaxTimeout);
}
}
#[test]
fn matches_prompt_common_patterns() {
let det = hybrid_detector();
let prompts = [
"user@host:~$ ",
"root@server:/var# ",
">>> ",
"(gdb) ",
"Password: ",
];
for prompt in &prompts {
assert!(
det.matches_prompt(prompt),
"expected match for prompt: {:?}",
prompt,
);
}
}
#[test]
fn invalid_regex_skipped_no_panic() {
let patterns = vec![
"[invalid(".into(), r"[$#] $".into(), ];
let det = ReadinessDetector::new(
ReadinessStrategy::Hybrid,
&patterns,
300,
10_000,
);
assert!(det.matches_prompt("user@host:~$ "));
}
#[test]
fn all_invalid_patterns_fallback_to_timeout() {
let patterns = vec!["[broken(".into(), "(also[bad".into()];
let det = ReadinessDetector::new(
ReadinessStrategy::Hybrid,
&patterns,
300,
10_000,
);
assert!(!det.matches_prompt("user@host:~$ "));
let result = det.check(
"user@host:~$ ",
Duration::from_millis(301),
Duration::from_millis(500),
);
assert_eq!(result, ReadinessResult::SilenceTimeout);
}
#[test]
fn empty_output_no_match() {
let det = hybrid_detector();
assert!(!det.matches_prompt(""));
}
#[test]
fn ansi_stripped_before_matching() {
let det = hybrid_detector();
let ansi_prompt = "\x1b[32muser@host\x1b[0m:\x1b[34m~\x1b[0m$ ";
assert!(det.matches_prompt(ansi_prompt));
}
#[test]
fn only_last_line_checked() {
let det = hybrid_detector();
assert!(!det.matches_prompt("user@host:~$ \nstill running..."));
assert!(det.matches_prompt("still running...\nuser@host:~$ "));
}
#[test]
fn matches_prompt_handles_multibyte_glyph_at_tail_boundary() {
let det = hybrid_detector();
let prefix_len = PROMPT_TAIL_LEN - 1;
let mut output = "x".repeat(prefix_len);
output.push('●'); output.push_str("\n~/repo on main\nuser@host:~$ ");
assert!(det.matches_prompt(&output));
}
#[test]
fn matches_prompt_handles_4byte_emoji_at_tail_boundary() {
let det = hybrid_detector();
for offset in 0..4 {
let prefix_len = PROMPT_TAIL_LEN - offset;
let mut output = "a".repeat(prefix_len);
output.push('🔥'); output.push_str("\nuser@host:~$ ");
assert!(det.matches_prompt(&output), "panicked at offset {}", offset);
}
}
}