use std::fmt;
use std::time::Duration;
use regex::Regex;
#[derive(Clone)]
pub enum Pattern {
Literal(String),
Regex(CompiledRegex),
Glob(String),
Eof,
Timeout(Duration),
Bytes(usize),
}
impl Pattern {
#[must_use]
pub fn literal(s: impl Into<String>) -> Self {
Self::Literal(s.into())
}
pub fn regex(pattern: &str) -> Result<Self, regex::Error> {
let regex = Regex::new(pattern)?;
Ok(Self::Regex(CompiledRegex::new(pattern.to_string(), regex)))
}
#[must_use]
pub fn glob(pattern: impl Into<String>) -> Self {
Self::Glob(pattern.into())
}
#[must_use]
pub const fn eof() -> Self {
Self::Eof
}
#[must_use]
pub const fn timeout(duration: Duration) -> Self {
Self::Timeout(duration)
}
#[must_use]
pub const fn bytes(n: usize) -> Self {
Self::Bytes(n)
}
#[must_use]
pub fn as_str(&self) -> &str {
match self {
Self::Literal(s) => s,
Self::Regex(r) => r.pattern(),
Self::Glob(s) => s,
Self::Eof => "<EOF>",
Self::Timeout(_) => "<TIMEOUT>",
Self::Bytes(_) => "<BYTES>",
}
}
#[must_use]
pub fn matches(&self, text: &str) -> Option<PatternMatch> {
match self {
Self::Literal(s) => text.find(s).map(|pos| PatternMatch {
start: pos,
end: pos + s.len(),
captures: Vec::new(),
}),
Self::Regex(r) => r.find(text).map(|m| PatternMatch {
start: m.start(),
end: m.end(),
captures: r.captures(text),
}),
Self::Glob(pattern) => glob_match(pattern, text).map(|pos| PatternMatch {
start: pos,
end: text.len(),
captures: Vec::new(),
}),
Self::Eof | Self::Timeout(_) | Self::Bytes(_) => None,
}
}
#[must_use]
pub const fn is_timeout(&self) -> bool {
matches!(self, Self::Timeout(_))
}
#[must_use]
pub const fn is_eof(&self) -> bool {
matches!(self, Self::Eof)
}
#[must_use]
pub const fn timeout_duration(&self) -> Option<Duration> {
match self {
Self::Timeout(d) => Some(*d),
_ => None,
}
}
#[must_use]
pub fn shell_prompt() -> Self {
Self::regex(r"[$#>%]\s*$").unwrap_or_else(|_| Self::Literal("$ ".to_string()))
}
#[must_use]
pub fn any_prompt() -> Self {
Self::Glob("*$*".to_string())
}
pub fn ipv4() -> Result<Self, regex::Error> {
Self::regex(
r"\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b",
)
}
pub fn email() -> Result<Self, regex::Error> {
Self::regex(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b")
}
pub fn timestamp_iso8601() -> Result<Self, regex::Error> {
Self::regex(r"\d{4}-\d{2}-\d{2}[T ]\d{2}:\d{2}:\d{2}")
}
#[must_use]
pub fn error_indicator() -> Self {
Self::regex(r"(?i)\b(?:error|failed|fatal)\b")
.unwrap_or_else(|_| Self::Glob("*[Ee]rror*".to_string()))
}
#[must_use]
pub fn success_indicator() -> Self {
Self::regex(r"(?i)\b(?:success|successful|passed|complete|ok)\b")
.unwrap_or_else(|_| Self::Glob("*[Ss]uccess*".to_string()))
}
#[must_use]
pub fn password_prompt() -> Self {
Self::regex(r"(?i)(?:password|passphrase)\s*:\s*$")
.unwrap_or_else(|_| Self::Literal("password:".to_string()))
}
#[must_use]
pub fn login_prompt() -> Self {
Self::regex(r"(?i)(?:login|username|user)\s*:\s*$")
.unwrap_or_else(|_| Self::Literal("login:".to_string()))
}
#[must_use]
pub fn confirmation_prompt() -> Self {
Self::regex(r"\[([yYnN])/([yYnN])\]|\(([yY]es)/([nN]o)\)")
.unwrap_or_else(|_| Self::Glob("*[y/n]*".to_string()))
}
#[must_use]
pub fn continue_prompt() -> Self {
Self::regex(r"(?i)(?:continue\s*\?|press any key|hit enter)")
.unwrap_or_else(|_| Self::Glob("*continue*".to_string()))
}
}
impl fmt::Debug for Pattern {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Literal(s) => write!(f, "Literal({s:?})"),
Self::Regex(r) => write!(f, "Regex({:?})", r.pattern()),
Self::Glob(s) => write!(f, "Glob({s:?})"),
Self::Eof => write!(f, "Eof"),
Self::Timeout(d) => write!(f, "Timeout({d:?})"),
Self::Bytes(n) => write!(f, "Bytes({n})"),
}
}
}
impl From<&str> for Pattern {
fn from(s: &str) -> Self {
Self::Literal(s.to_string())
}
}
impl From<String> for Pattern {
fn from(s: String) -> Self {
Self::Literal(s)
}
}
#[derive(Clone)]
pub struct CompiledRegex {
pattern: String,
regex: Regex,
}
impl CompiledRegex {
#[must_use]
pub const fn new(pattern: String, regex: Regex) -> Self {
Self { pattern, regex }
}
#[must_use]
pub fn pattern(&self) -> &str {
&self.pattern
}
#[must_use]
pub fn find<'a>(&self, text: &'a str) -> Option<regex::Match<'a>> {
self.regex.find(text)
}
#[must_use]
pub fn captures(&self, text: &str) -> Vec<String> {
self.regex
.captures(text)
.map(|caps| {
caps.iter()
.skip(1) .filter_map(|m| m.map(|m| m.as_str().to_string()))
.collect()
})
.unwrap_or_default()
}
}
#[derive(Debug, Clone)]
pub struct PatternMatch {
pub start: usize,
pub end: usize,
pub captures: Vec<String>,
}
impl PatternMatch {
#[must_use]
pub fn as_str<'a>(&self, text: &'a str) -> &'a str {
&text[self.start..self.end]
}
#[must_use]
pub const fn len(&self) -> usize {
self.end - self.start
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.start == self.end
}
}
#[derive(Debug, Clone, Default)]
pub struct PatternSet {
patterns: Vec<NamedPattern>,
}
#[derive(Clone)]
pub struct NamedPattern {
pub pattern: Pattern,
pub name: Option<String>,
pub index: usize,
}
impl fmt::Debug for NamedPattern {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("NamedPattern")
.field("pattern", &self.pattern)
.field("name", &self.name)
.field("index", &self.index)
.finish()
}
}
impl PatternSet {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn from_patterns(patterns: Vec<Pattern>) -> Self {
let patterns = patterns
.into_iter()
.enumerate()
.map(|(index, pattern)| NamedPattern {
pattern,
name: None,
index,
})
.collect();
Self { patterns }
}
pub fn add(&mut self, pattern: Pattern) -> &mut Self {
let index = self.patterns.len();
self.patterns.push(NamedPattern {
pattern,
name: None,
index,
});
self
}
pub fn add_named(&mut self, name: impl Into<String>, pattern: Pattern) -> &mut Self {
let index = self.patterns.len();
self.patterns.push(NamedPattern {
pattern,
name: Some(name.into()),
index,
});
self
}
#[must_use]
pub const fn len(&self) -> usize {
self.patterns.len()
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.patterns.is_empty()
}
#[must_use]
pub fn find_match(&self, text: &str) -> Option<(usize, PatternMatch)> {
let mut best_match: Option<(usize, PatternMatch)> = None;
for (idx, named) in self.patterns.iter().enumerate() {
if let Some(m) = named.pattern.matches(text) {
match &best_match {
None => best_match = Some((idx, m)),
Some((_, current)) if m.start < current.start => {
best_match = Some((idx, m));
}
_ => {}
}
}
}
best_match
}
#[must_use]
pub fn get(&self, index: usize) -> Option<&NamedPattern> {
self.patterns.get(index)
}
#[must_use]
pub fn min_timeout(&self) -> Option<Duration> {
self.patterns
.iter()
.filter_map(|p| p.pattern.timeout_duration())
.min()
}
#[must_use]
pub fn has_eof(&self) -> bool {
self.patterns.iter().any(|p| p.pattern.is_eof())
}
pub fn iter(&self) -> impl Iterator<Item = &NamedPattern> {
self.patterns.iter()
}
}
fn glob_match(pattern: &str, text: &str) -> Option<usize> {
let pattern_chars: Vec<char> = pattern.chars().collect();
let text_chars: Vec<char> = text.chars().collect();
(0..=text_chars.len()).find(|&start| glob_match_from(&pattern_chars, &text_chars[start..]))
}
const fn glob_match_from(pattern: &[char], text: &[char]) -> bool {
let mut p = 0;
let mut t = 0;
let mut star_p = None;
let mut star_t = 0;
while p < pattern.len() {
if pattern[p] == '*' {
star_p = Some(p);
star_t = t;
p += 1;
} else if t < text.len() && (pattern[p] == '?' || pattern[p] == text[t]) {
p += 1;
t += 1;
} else if let Some(sp) = star_p {
p = sp + 1;
star_t += 1;
if star_t > text.len() {
return false;
}
t = star_t;
} else {
return false;
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn literal_pattern_matches() {
let pattern = Pattern::literal("hello");
let result = pattern.matches("say hello world");
assert!(result.is_some());
let m = result.unwrap();
assert_eq!(m.start, 4);
assert_eq!(m.end, 9);
}
#[test]
fn regex_pattern_matches() {
let pattern = Pattern::regex(r"\d+").unwrap();
let result = pattern.matches("test 123 value");
assert!(result.is_some());
let m = result.unwrap();
assert_eq!(m.as_str("test 123 value"), "123");
}
#[test]
fn regex_pattern_captures() {
let pattern = Pattern::regex(r"(\w+)@(\w+)").unwrap();
let result = pattern.matches("email: user@domain here");
assert!(result.is_some());
let m = result.unwrap();
assert_eq!(m.captures, vec!["user", "domain"]);
}
#[test]
fn glob_pattern_matches() {
let pattern = Pattern::glob("hello*world");
let result = pattern.matches("say hello beautiful world!");
assert!(result.is_some());
}
#[test]
fn pattern_set_finds_first() {
let mut set = PatternSet::new();
set.add(Pattern::literal("world"))
.add(Pattern::literal("hello"));
let result = set.find_match("hello world");
assert!(result.is_some());
let (idx, _) = result.unwrap();
assert_eq!(idx, 1);
}
#[test]
fn pattern_set_min_timeout() {
let mut set = PatternSet::new();
set.add(Pattern::timeout(Duration::from_secs(10)))
.add(Pattern::timeout(Duration::from_secs(5)));
assert_eq!(set.min_timeout(), Some(Duration::from_secs(5)));
}
}