use crate::client::Session;
use crate::error::{Error, Result};
#[cfg(feature = "tls")]
use native_tls::TlsStream;
use std::io::{self, Read, Write};
use std::net::TcpStream;
use std::time::Duration;
#[derive(Debug)]
pub struct Handle<'a, T: Read + Write> {
session: &'a mut Session<T>,
keepalive: Duration,
done: bool,
}
#[derive(Debug, PartialEq, Eq)]
pub enum WaitOutcome {
TimedOut,
MailboxChanged,
}
pub trait SetReadTimeout {
fn set_read_timeout(&mut self, timeout: Option<Duration>) -> Result<()>;
}
impl<'a, T: Read + Write + 'a> Handle<'a, T> {
pub(crate) fn make(session: &'a mut Session<T>) -> Result<Self> {
let mut h = Handle {
session,
keepalive: Duration::from_secs(29 * 60),
done: false,
};
h.init()?;
Ok(h)
}
fn init(&mut self) -> Result<()> {
self.session.run_command("IDLE")?;
let mut v = Vec::new();
self.session.readline(&mut v)?;
if v.starts_with(b"+") {
self.done = false;
return Ok(());
}
self.session.read_response_onto(&mut v)?;
unreachable!();
}
fn terminate(&mut self) -> Result<()> {
if !self.done {
self.done = true;
self.session.write_line(b"DONE")?;
self.session.read_response().map(|_| ())
} else {
Ok(())
}
}
fn wait_inner(&mut self, reconnect: bool) -> Result<WaitOutcome> {
let mut v = Vec::new();
loop {
let result = match self.session.readline(&mut v).map(|_| ()) {
Err(Error::Io(ref e))
if e.kind() == io::ErrorKind::TimedOut
|| e.kind() == io::ErrorKind::WouldBlock =>
{
if reconnect {
self.terminate()?;
self.init()?;
return self.wait_inner(reconnect);
}
Ok(WaitOutcome::TimedOut)
}
Ok(()) => Ok(WaitOutcome::MailboxChanged),
Err(r) => Err(r),
}?;
if v.eq_ignore_ascii_case(b"* OK Still here\r\n") {
v.clear();
} else {
break Ok(result);
}
}
}
pub fn wait(mut self) -> Result<()> {
self.wait_inner(true).map(|_| ())
}
}
impl<'a, T: SetReadTimeout + Read + Write + 'a> Handle<'a, T> {
pub fn set_keepalive(&mut self, interval: Duration) {
self.keepalive = interval;
}
pub fn wait_keepalive(self) -> Result<()> {
let keepalive = self.keepalive;
self.timed_wait(keepalive, true).map(|_| ())
}
#[deprecated(note = "use wait_with_timeout instead")]
pub fn wait_timeout(self, timeout: Duration) -> Result<()> {
self.wait_with_timeout(timeout).map(|_| ())
}
pub fn wait_with_timeout(self, timeout: Duration) -> Result<WaitOutcome> {
self.timed_wait(timeout, false)
}
fn timed_wait(mut self, timeout: Duration, reconnect: bool) -> Result<WaitOutcome> {
self.session
.stream
.get_mut()
.set_read_timeout(Some(timeout))?;
let res = self.wait_inner(reconnect);
let _ = self.session.stream.get_mut().set_read_timeout(None).is_ok();
res
}
}
impl<'a, T: Read + Write + 'a> Drop for Handle<'a, T> {
fn drop(&mut self) {
let _ = self.terminate().is_ok();
}
}
impl<'a> SetReadTimeout for TcpStream {
fn set_read_timeout(&mut self, timeout: Option<Duration>) -> Result<()> {
TcpStream::set_read_timeout(self, timeout).map_err(Error::Io)
}
}
#[cfg(feature = "tls")]
impl<'a> SetReadTimeout for TlsStream<TcpStream> {
fn set_read_timeout(&mut self, timeout: Option<Duration>) -> Result<()> {
self.get_ref().set_read_timeout(timeout).map_err(Error::Io)
}
}