#![deny(missing_docs)]
use std::ops::Deref;
use std::sync::{atomic::AtomicBool, atomic::Ordering, Arc};
use std::thread::{self, JoinHandle};
pub enum LoopState {
Continue,
Break,
}
pub trait Cancellable {
type Error;
fn for_each(&mut self) -> Result<LoopState, Self::Error>;
fn run(&mut self) -> Result<(), Self::Error> {
loop {
match self.for_each() {
Ok(LoopState::Continue) => {}
Ok(LoopState::Break) => break,
Err(e) => return Err(e),
}
}
Ok(())
}
fn spawn(mut self) -> Handle<Self::Error>
where
Self: Send + Sized + 'static,
Self::Error: Send + 'static,
{
let keep_running = Arc::new(AtomicBool::new(true));
let j = {
let keep_running = keep_running.clone();
thread::spawn(move || {
while keep_running.load(Ordering::SeqCst) {
match self.for_each() {
Ok(LoopState::Continue) => {}
Ok(LoopState::Break) => break,
Err(e) => return Err(e),
}
}
Ok(())
})
};
Handle {
canceller: Canceller { keep_running },
executor: j,
}
}
}
pub struct Handle<E> {
canceller: Canceller,
executor: JoinHandle<Result<(), E>>,
}
impl<E> Deref for Handle<E> {
type Target = Canceller;
fn deref(&self) -> &Self::Target {
&self.canceller
}
}
#[derive(Clone)]
pub struct Canceller {
keep_running: Arc<AtomicBool>,
}
impl Canceller {
pub fn cancel(&self) {
self.keep_running.store(false, Ordering::SeqCst);
}
}
impl<E> Handle<E> {
pub fn wait(self) -> Result<(), E> {
match self.executor.join() {
Ok(r) => r,
Err(e) => {
panic!("{:?}", e)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::{io, net, thread};
struct Service(net::TcpListener);
impl Cancellable for Service {
type Error = io::Error;
fn for_each(&mut self) -> Result<LoopState, Self::Error> {
let mut stream = match self.0.accept() {
Ok((stream, _)) => stream,
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {
return Ok(LoopState::Continue)
}
Err(e) => return Err(e),
};
write!(stream, "hello!")?;
Ok(LoopState::Continue)
}
}
impl Service {
fn new() -> Self {
Service(TcpListener::bind("127.0.0.1:0").unwrap())
}
fn port(&self) -> u16 {
self.0.local_addr().unwrap().port()
}
}
fn connect_assert(port: u16) -> Option<io::Error> {
match TcpStream::connect(("127.0.0.1", port)) {
Ok(mut c) => {
let mut r = String::new();
if let Err(e) = c.read_to_string(&mut r) {
return Some(e);
}
assert_eq!(r, "hello!");
None
}
Err(e) => Some(e),
}
}
#[test]
fn it_runs() {
let mut s = Service::new();
let port = s.port();
thread::spawn(move || {
s.run().unwrap();
});
assert!(connect_assert(port).is_none());
assert!(connect_assert(port).is_none());
}
#[test]
fn it_cancels() {
let s = Service::new();
let port = s.port();
let h = s.spawn();
assert!(connect_assert(port).is_none());
assert!(connect_assert(port).is_none());
h.cancel();
let mut succeeded = 0;
while connect_assert(port).is_none() {
succeeded += 1;
assert!(succeeded <= 1);
}
h.wait().unwrap();
}
}