use crate::client::Session;
use crate::error::{Error, Result};
use crate::parse::parse_idle;
use crate::types::UnsolicitedResponse;
use crate::Connection;
#[cfg(feature = "native-tls")]
use native_tls::TlsStream;
#[cfg(feature = "rustls-tls")]
use rustls_connector::TlsStream as RustlsStream;
use std::io::{self, Read, Write};
use std::net::TcpStream;
use std::ops::DerefMut;
use std::time::Duration;
#[derive(Debug)]
pub struct Handle<'a, T: Read + Write> {
session: &'a mut Session<T>,
timeout: Duration,
keepalive: bool,
done: bool,
}
#[derive(Debug, PartialEq, Eq)]
pub enum WaitOutcome {
TimedOut,
MailboxChanged,
}
pub fn stop_on_any(_response: UnsolicitedResponse) -> bool {
false
}
pub trait SetReadTimeout {
fn set_read_timeout(&mut self, timeout: Option<Duration>) -> Result<()>;
}
impl<'a, T: Read + Write + 'a> Handle<'a, T> {
pub(crate) fn make(session: &'a mut Session<T>) -> Self {
Handle {
session,
timeout: Duration::from_secs(29 * 60),
keepalive: true,
done: false,
}
}
fn init(&mut self) -> Result<()> {
self.session.run_command("IDLE")?;
let mut v = Vec::new();
self.session.readline(&mut v)?;
if v.starts_with(b"+") {
self.done = false;
return Ok(());
}
self.session.read_response_onto(&mut v)?;
unreachable!();
}
fn terminate(&mut self) -> Result<()> {
if !self.done {
self.done = true;
self.session.write_line(b"DONE")?;
self.session.read_response().map(|_| ())
} else {
Ok(())
}
}
fn wait_inner<F>(&mut self, reconnect: bool, mut callback: F) -> Result<WaitOutcome>
where
F: FnMut(UnsolicitedResponse) -> bool,
{
let mut v = Vec::new();
let result = loop {
match self.session.readline(&mut v) {
Err(Error::Io(ref e))
if e.kind() == io::ErrorKind::TimedOut
|| e.kind() == io::ErrorKind::WouldBlock =>
{
break Ok(WaitOutcome::TimedOut);
}
Ok(_len) => {
if v.eq_ignore_ascii_case(b"* OK Still here\r\n") {
v.clear();
continue;
}
match parse_idle(&v) {
(_rest, Some(Err(r))) => break Err(r),
(rest, Some(Ok(response))) => {
if !callback(response) {
break Ok(WaitOutcome::MailboxChanged);
}
debug_assert!(
rest.is_empty(),
"Unexpected partial parse: input: {:?}, output: {:?}",
v,
rest,
);
if rest.is_empty() {
v.clear();
} else {
let used = v.len() - rest.len();
v.drain(0..used);
}
}
(_rest, None) => {}
}
}
Err(r) => break Err(r),
};
};
match (reconnect, result) {
(true, Ok(WaitOutcome::TimedOut)) => {
self.terminate()?;
self.init()?;
self.wait_inner(reconnect, callback)
}
(_, result) => result,
}
}
}
impl<'a, T: SetReadTimeout + Read + Write + 'a> Handle<'a, T> {
pub fn timeout(&mut self, interval: Duration) -> &mut Self {
self.timeout = interval;
self
}
pub fn keepalive(&mut self, keepalive: bool) -> &mut Self {
self.keepalive = keepalive;
self
}
pub fn wait_while<F>(&mut self, callback: F) -> Result<WaitOutcome>
where
F: FnMut(UnsolicitedResponse) -> bool,
{
self.init()?;
self.session
.stream
.get_mut()
.set_read_timeout(Some(self.timeout))?;
let res = self.wait_inner(self.keepalive, callback);
let _ = self.session.stream.get_mut().set_read_timeout(None).is_ok();
res
}
}
impl<'a, T: Read + Write + 'a> Drop for Handle<'a, T> {
fn drop(&mut self) {
let _ = self.terminate().is_ok();
}
}
impl SetReadTimeout for Connection {
fn set_read_timeout(&mut self, timeout: Option<Duration>) -> Result<()> {
self.deref_mut().set_read_timeout(timeout)
}
}
impl SetReadTimeout for TcpStream {
fn set_read_timeout(&mut self, timeout: Option<Duration>) -> Result<()> {
TcpStream::set_read_timeout(self, timeout).map_err(Error::Io)
}
}
#[cfg(feature = "native-tls")]
impl SetReadTimeout for TlsStream<TcpStream> {
fn set_read_timeout(&mut self, timeout: Option<Duration>) -> Result<()> {
self.get_ref().set_read_timeout(timeout).map_err(Error::Io)
}
}
#[cfg(feature = "rustls-tls")]
impl SetReadTimeout for RustlsStream<TcpStream> {
fn set_read_timeout(&mut self, timeout: Option<Duration>) -> Result<()> {
self.get_ref().set_read_timeout(timeout).map_err(Error::Io)
}
}