use std::sync::{mpsc, Arc, Mutex};
use std::thread::{self, JoinHandle};
pub type Job = Box<dyn FnOnce() + Send + 'static>;
pub struct ThreadPool {
sender: mpsc::Sender<ThreadPoolMessage>,
workers: Vec<Worker>,
joined: bool,
}
impl ThreadPool {
pub fn new(size: usize) -> Self {
assert!(size > 0);
let (sender, receiver) = mpsc::channel();
let receiver = Arc::new(Mutex::new(receiver));
let mut workers = Vec::new();
for _ in 0..size {
workers.push(Worker::new(Arc::clone(&receiver)));
}
Self {
sender,
workers,
joined: false,
}
}
pub fn queue<F>(&self, job: F)
where
F: FnOnce() + Send + 'static,
{
self.sender
.send(ThreadPoolMessage::NewJob(Box::new(job)))
.unwrap();
}
pub fn join(&mut self) {
for _ in &self.workers {
self.sender.send(ThreadPoolMessage::Terminate).unwrap();
}
for worker in &mut self.workers {
if let Some(thread) = worker.thread.take() {
thread.join().unwrap();
}
}
self.joined = true;
}
}
impl Drop for ThreadPool {
fn drop(&mut self) {
if !self.joined {
self.join();
}
}
}
enum ThreadPoolMessage {
NewJob(Job),
Terminate,
}
struct Worker {
thread: Option<JoinHandle<()>>,
}
impl Worker {
fn new(receiver: Arc<Mutex<mpsc::Receiver<ThreadPoolMessage>>>) -> Self {
use ThreadPoolMessage::*;
let thread = thread::spawn(move || loop {
let message = receiver.lock().unwrap().recv().unwrap();
match message {
NewJob(job) => job(),
Terminate => break,
}
});
Self {
thread: Some(thread),
}
}
}
#[cfg(test)]
mod tests {
use crate::*;
use std::io::prelude::*;
use std::net::{TcpListener, TcpStream};
use std::thread;
use std::time::Duration;
#[test]
fn it_works() {
let mut pool = ThreadPool::new(2);
let listener = TcpListener::bind("127.0.0.1:7878").unwrap();
for stream in listener.incoming() {
let stream = stream.unwrap();
pool.queue(|| handle_stream(stream));
}
pool.join();
fn handle_stream(mut stream: TcpStream) {
let mut buf = [0; 1024];
stream.read(&mut buf).unwrap();
if buf.starts_with(b"GET /sleep HTTP/1.1\r\n") {
println!("sleeping");
thread::sleep(Duration::from_secs(5));
}
let message = "Hello world!";
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}",
message.len(),
message
);
stream.write(response.as_bytes()).unwrap();
stream.flush().unwrap();
}
}
}