use crate::error::Error;
pub use regex::Regex;
use std::io::prelude::*;
use std::io::{self, BufReader};
use std::sync::mpsc::{Receiver, channel};
use std::thread;
use std::{fmt, time};
#[derive(Default)]
pub struct Options {
pub(crate) timeout_ms: Option<u64>,
pub(crate) strip_ansi_escape_codes: bool,
}
impl Options {
pub fn new() -> Self {
Default::default()
}
pub fn timeout_ms(mut self, timeout_ms: Option<u64>) -> Self {
self.timeout_ms = timeout_ms;
self
}
pub fn strip_ansi_escape_codes(mut self, yes: bool) -> Self {
self.strip_ansi_escape_codes = yes;
self
}
}
pub struct NBReader {
reader: Receiver<Result<PipedChar, PipeError>>,
buffer: String,
eof: bool,
timeout: Option<time::Duration>,
}
impl NBReader {
pub fn new<R: Read + Send + 'static>(f: R, options: Options) -> NBReader {
let (tx, rx) = channel();
thread::spawn(move || -> Result<(), Error> {
let mut reader = BufReader::new(f);
let mut byte = [0u8];
let mut in_escape_code = false;
loop {
match reader.read(&mut byte) {
Ok(0) => {
tx.send(Ok(PipedChar::EOF))
.map_err(|_| Error::MpscSendError)?;
break;
}
Ok(_) => {
if options.strip_ansi_escape_codes && byte[0] == 27 {
in_escape_code = true;
} else if options.strip_ansi_escape_codes && in_escape_code {
if char::from(byte[0]).is_alphabetic() {
in_escape_code = false;
}
} else {
tx.send(Ok(PipedChar::Char(byte[0])))
.map_err(|_| Error::MpscSendError)?;
}
}
Err(error) => {
tx.send(Err(PipeError::IO(error)))
.map_err(|_| Error::MpscSendError)?;
}
}
}
Ok(())
});
NBReader {
reader: rx,
buffer: String::with_capacity(1024),
eof: false,
timeout: options.timeout_ms.map(time::Duration::from_millis),
}
}
fn read_into_buffer(&mut self) -> Result<(), Error> {
if self.eof {
return Ok(());
}
while let Ok(from_channel) = self.reader.try_recv() {
match from_channel {
Ok(PipedChar::Char(c)) => self.buffer.push(c as char),
Ok(PipedChar::EOF) => self.eof = true,
Err(PipeError::IO(ref err)) => {
self.eof = err.raw_os_error() == Some(5);
}
}
}
Ok(())
}
pub fn read_until(&mut self, needle: &ReadUntil) -> Result<(String, String), Error> {
let start = time::Instant::now();
loop {
self.read_into_buffer()?;
if let Some(tuple_pos) = find(needle, &self.buffer, self.eof) {
let first = self.buffer.drain(..tuple_pos.0).collect();
let second = self.buffer.drain(..tuple_pos.1 - tuple_pos.0).collect();
return Ok((first, second));
}
if self.eof {
return Err(Error::EOF {
expected: needle.to_string(),
got: self.buffer.clone(),
exit_code: None,
});
}
if let Some(timeout) = self.timeout {
if start.elapsed() > timeout {
return Err(Error::Timeout {
expected: needle.to_string(),
got: self.buffer.clone(),
timeout,
});
}
}
thread::sleep(time::Duration::from_millis(100));
}
}
pub fn try_read(&mut self) -> Option<char> {
let _ = self.read_into_buffer();
let first = self.buffer.chars().next()?;
self.buffer.drain(..first.len_utf8());
Some(first)
}
}
#[non_exhaustive]
pub enum ReadUntil {
String(String),
Regex(Regex),
NBytes(usize),
EOF,
Any(Vec<ReadUntil>),
}
impl fmt::Display for ReadUntil {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ReadUntil::String(s) if s == "\n" => write!(f, "\\n (newline)"),
ReadUntil::String(s) if s == "\r" => write!(f, "\\r (carriage return)"),
ReadUntil::String(s) => write!(f, "\"{s}\""),
ReadUntil::Regex(r) => write!(f, "Regex: \"{r}\""),
ReadUntil::NBytes(n) => write!(f, "reading {n} bytes"),
ReadUntil::EOF => write!(f, "EOF (End of File)"),
ReadUntil::Any(v) => {
for (i, r) in v.iter().enumerate() {
if i != 0 {
write!(f, ", ")?;
}
write!(f, "{r}")?;
}
Ok(())
}
}
}
}
fn find(needle: &ReadUntil, buffer: &str, eof: bool) -> Option<(usize, usize)> {
match needle {
ReadUntil::String(s) => buffer.find(s).map(|pos| (pos, pos + s.len())),
ReadUntil::Regex(pattern) => pattern.find(buffer).map(|mat| (mat.start(), mat.end())),
ReadUntil::EOF => {
if eof {
Some((0, buffer.len()))
} else {
None
}
}
ReadUntil::NBytes(n) => {
if *n <= buffer.len() {
Some((0, *n))
} else if eof && !buffer.is_empty() {
Some((0, buffer.len()))
} else {
None
}
}
ReadUntil::Any(anys) => anys
.iter()
.filter_map(|any| find(any, buffer, eof))
.min_by(|(start1, end1), (start2, end2)| {
if start1 == start2 {
end1.cmp(end2)
} else {
start1.cmp(start2)
}
}),
}
}
#[derive(Debug)]
enum PipeError {
IO(io::Error),
}
#[derive(Debug)]
#[allow(clippy::upper_case_acronyms)]
enum PipedChar {
Char(u8),
EOF,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_expect_melon() {
let f = io::Cursor::new("a melon\r\n");
let mut r = NBReader::new(f, Options::default());
assert_eq!(
("a melon".to_owned(), "\r\n".to_owned()),
r.read_until(&ReadUntil::String("\r\n".to_owned()))
.expect("cannot read line")
);
match r.read_until(&ReadUntil::NBytes(10)) {
Ok(_) => panic!(),
Err(Error::EOF { .. }) => {}
Err(_) => panic!(),
}
}
#[test]
fn test_regex() {
let f = io::Cursor::new("2014-03-15");
let mut r = NBReader::new(f, Options::default());
let re = Regex::new(r"^\d{4}-\d{2}-\d{2}$").unwrap();
assert_eq!(
("".to_owned(), "2014-03-15".to_owned()),
r.read_until(&ReadUntil::Regex(re))
.expect("regex doesn't match")
);
}
#[test]
fn test_regex2() {
let f = io::Cursor::new("2014-03-15");
let mut r = NBReader::new(f, Options::default());
let re = Regex::new(r"-\d{2}-").unwrap();
assert_eq!(
("2014".to_owned(), "-03-".to_owned()),
r.read_until(&ReadUntil::Regex(re))
.expect("regex doesn't match")
);
}
#[test]
fn test_nbytes() {
let f = io::Cursor::new("abcdef");
let mut r = NBReader::new(f, Options::default());
assert_eq!(
("".to_owned(), "ab".to_owned()),
r.read_until(&ReadUntil::NBytes(2)).expect("2 bytes")
);
assert_eq!(
("".to_owned(), "cde".to_owned()),
r.read_until(&ReadUntil::NBytes(3)).expect("3 bytes")
);
assert_eq!(
("".to_owned(), "f".to_owned()),
r.read_until(&ReadUntil::NBytes(4)).expect("4 bytes")
);
}
#[test]
fn test_any_with_multiple_possible_matches() {
let f = io::Cursor::new("zero one two three four five");
let mut r = NBReader::new(f, Options::default());
let result = r
.read_until(&ReadUntil::Any(vec![
ReadUntil::String("two".to_owned()),
ReadUntil::String("one".to_owned()),
]))
.expect("finding string");
assert_eq!(("zero ".to_owned(), "one".to_owned()), result);
}
#[test]
fn test_any_with_same_start_different_length() {
let f = io::Cursor::new("hi hello");
let mut r = NBReader::new(f, Options::default());
let result = r
.read_until(&ReadUntil::Any(vec![
ReadUntil::String("hello".to_owned()),
ReadUntil::String("hell".to_owned()),
]))
.expect("finding string");
assert_eq!(("hi ".to_owned(), "hell".to_owned()), result);
}
#[test]
fn test_eof() {
let f = io::Cursor::new("lorem ipsum dolor sit amet");
let mut r = NBReader::new(f, Options::default());
r.read_until(&ReadUntil::NBytes(2)).expect("2 bytes");
assert_eq!(
("".to_owned(), "rem ipsum dolor sit amet".to_owned()),
r.read_until(&ReadUntil::EOF).expect("reading until EOF")
);
}
#[test]
fn test_skip_partial_ansi_code() {
let f = io::Cursor::new("\x1b[31;1;4mHello\x1b[1");
let mut r = NBReader::new(f, Options::new().strip_ansi_escape_codes(true));
let bytes = r
.read_until(&ReadUntil::String("Hello".to_owned()))
.unwrap();
assert_eq!(bytes, ("".to_owned(), "Hello".to_owned()));
assert_eq!(None, r.try_read());
}
#[test]
fn test_skip_ansi_codes() {
let f = io::Cursor::new("\x1b[31;1;4mHello\x1b[0m");
let mut r = NBReader::new(f, Options::new().strip_ansi_escape_codes(true));
let bytes = r
.read_until(&ReadUntil::String("Hello".to_owned()))
.unwrap();
assert_eq!(bytes, ("".to_owned(), "Hello".to_owned()));
assert_eq!(None, r.try_read());
}
#[test]
fn test_try_read() {
let f = io::Cursor::new("lorem");
let mut r = NBReader::new(f, Options::default());
let bytes = r.read_until(&ReadUntil::NBytes(4)).unwrap();
assert!(bytes.0.is_empty());
assert_eq!(bytes.1, "lore");
assert_eq!(Some('m'), r.try_read());
assert_eq!(None, r.try_read());
assert_eq!(None, r.try_read());
assert_eq!(None, r.try_read());
assert_eq!(None, r.try_read());
}
#[test]
fn test_try_read_multibyte() {
let f = io::Cursor::new("");
let mut r = NBReader::new(f, Options::default());
r.buffer.push_str("\u{c3}\u{83}");
assert_eq!(Some('\u{c3}'), r.try_read());
assert_eq!(Some('\u{83}'), r.try_read());
assert_eq!(None, r.try_read());
}
}