use std::{marker::PhantomData, thread};
use std::sync::mpsc::{Sender, Receiver, channel, SendError, RecvError};
use std::fmt::Debug;
use std::collections::HashMap;
use rand::seq::SliceRandom;
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
pub enum QID { #[default] INIT, STEP(usize), DONE }
pub struct QMsg<Q> { qid:QID, q: Q }
#[derive(Debug)]
pub struct RMsg<R> { pub wid: WID, pub qid:QID, pub r:Option<R> }
#[derive(Debug,Default,PartialEq,Eq,Hash,Clone,Copy)]
pub struct WID { pub n:usize }
pub trait Worker<Q,R,I=()> where R:Debug, Q:Clone {
fn new(_wid:WID)->Self;
fn get_wid(&self)->WID;
fn set_tx(&mut self, _tx:&Sender<RMsg<R>>) {}
fn send_msg(&self, tx:&Sender<RMsg<R>>, qid:QID, r:Option<R>) {
let res = tx.send(RMsg{ wid:self.get_wid(), qid, r });
if res.is_err() { self.on_work_send_err(res.err().unwrap()) }}
fn queue_push(&mut self, _item:I) { panic!("no queue defined"); }
fn queue_pop(&mut self)->Option<I> { None }
fn work_loop(&mut self, wid:WID, rx:&Receiver<Option<QMsg<Q>>>, tx:&Sender<RMsg<R>>) {
self.set_tx(tx);
let msg = self.work_init(wid); self.send_msg(tx, QID::INIT, msg);
loop {
if let Some(item) = self.queue_pop() { self.work_item(item) }
match rx.try_recv() {
Ok(None) => break,
Ok(Some(QMsg{qid, q})) => {
if let QID::STEP(_) = qid {
let msg = self.work_step(&qid, q); self.send_msg(tx, qid, msg); }
else { panic!("Worker {:?} got unexpected qid instead of STEP: {:?}", wid, qid)}}
Err(e) => match e {
std::sync::mpsc::TryRecvError::Empty => {} std::sync::mpsc::TryRecvError::Disconnected => break }}}
let msg = self.work_done(); self.send_msg(tx, QID::DONE, msg); }
fn on_work_send_err(&self, err:SendError<RMsg<R>>) {
println!("failed to send response: {:?}", err.to_string()); }
fn work_item(&mut self, _item:I) { }
fn work_step(&mut self, _qid:&QID, _q:Q)->Option<R> { None }
fn work_init(&mut self, _wid:WID)->Option<R> { None }
fn work_done(&mut self)->Option<R> { None }}
#[derive(Debug)]
pub enum SwarmCmd<Q:Debug,V:Debug> {
Pass,
Halt,
Send(Q),
Batch(Vec<(WID, Q)>),
Panic(String),
Return(V),
Kill(WID)}
#[derive(Debug)]
pub struct Swarm<Q,R,W,I=()> where W:Worker<Q,R,I>, Q:Debug+Clone, R:Debug {
nq: usize,
me: Sender<RMsg<R>>,
rx: Receiver<RMsg<R>>,
qtx: Sender<Q>,
qrx: Receiver<Q>,
whs: HashMap<WID, Sender<Option<QMsg<Q>>>>,
nw: usize,
_w: PhantomData<W>,
_i: PhantomData<I>,
threads: Vec<thread::JoinHandle<()>> }
impl<Q,R,W,I> Default for Swarm<Q,R,W,I> where Q:'static+Send+Debug+Clone, R:'static+Send+Debug, W:Worker<Q, R,I> {
fn default()->Self { Self::new_with_threads(4) }}
impl<Q,R,W,I> Drop for Swarm<Q,R,W,I> where Q:Debug+Clone, R:Debug, W:Worker<Q, R,I> {
fn drop(&mut self) { self.kill_swarm() }}
impl<Q,R,W,I> Swarm<Q,R,W,I> where Q:Debug+Clone, R:Debug, W:Worker<Q, R,I> {
pub fn kill_swarm(&mut self) {
while let Some(&w) = self.whs.keys().take(1).next() { self.kill(w); }
while !self.threads.is_empty() { self.threads.pop().unwrap().join().unwrap() }}
pub fn num_workers(&self)->usize { self.whs.len() }
pub fn kill(&mut self, w:WID) {
if let Some(h) = self.whs.remove(&w) {
if h.send(None).is_err() { panic!("couldn't kill worker") }}
else { panic!("worker was already gone") }}}
impl<Q,R,W,I> Swarm<Q,R,W,I> where Q:'static+Send+Debug+Clone, R:'static+Send+Debug, W:Worker<Q, R, I> {
pub fn new()->Self { Self::default() }
pub fn new_with_threads(n:usize)->Self {
let (tx, rx) = channel();
let (qtx, qrx) = channel();
let mut me = Self { nq: 0, me:tx, rx, qtx, qrx, whs:HashMap::new(), nw:0,
_w:PhantomData, _i:PhantomData, threads:vec![]};
me.start(n); me }
pub fn start(&mut self, num_workers:usize) {
let n = if num_workers==0 { num_cpus::get() } else { num_workers };
for _ in 0..n { self.spawn(); }}
fn spawn(&mut self)->WID {
let wid = WID{ n: self.nw }; self.nw+=1;
let me2 = self.me.clone();
let (wtx, wrx) = channel();
self.threads.push(thread::spawn(move || { W::new(wid).work_loop(wid, &wrx, &me2) }));
self.whs.insert(wid, wtx);
wid }
pub fn add_query(&mut self, q:Q)->QID {
let &wid = self.whs.keys().collect::<Vec<_>>()
.choose(&mut rand::thread_rng()).unwrap();
self.send(*wid, q)}
pub fn send(&mut self, wid:WID, q:Q)->QID {
let qid = QID::STEP(self.nq); self.nq+=1;
let w = self.whs.get(&wid).unwrap_or_else(||
panic!("requested non-existent worker {:?}", wid));
if w.send(Some(QMsg{ qid, q })).is_err() {
panic!("couldn't send message to worker {:?}", wid) }
qid}
pub fn recv(&self)->Result<RMsg<R>, RecvError> { self.rx.recv() }
pub fn send_to_all(&mut self, q:&Q) {
let wids: Vec<WID> = self.whs.keys().cloned().collect();
for wid in wids { self.send(wid, q.clone()); }}
pub fn q_sender(&self)->Sender<Q> { self.qtx.clone() }
pub fn send_to_self(&self, r:R) {
self.me.send(RMsg{ wid:WID::default(), qid:QID::default(), r:Some(r)})
.expect("failed to sent_self"); }
pub fn run<F,V>(&mut self, mut on_msg:F)->Option<V>
where V:Debug, F:FnMut(WID, &QID, Option<R>)->SwarmCmd<Q,V> {
let mut res = None;
loop {
if let Ok(q) = self.qrx.try_recv() { self.add_query(q); }
if let Ok(rmsg) = self.rx.try_recv() {
let RMsg { wid, qid, r } = rmsg;
let cmd = on_msg(wid, &qid, r);
match cmd {
SwarmCmd::Pass => {},
SwarmCmd::Halt => break,
SwarmCmd::Kill(w) => { self.kill(w); if self.whs.is_empty() { break }},
SwarmCmd::Send(q) => { self.send(wid, q); },
SwarmCmd::Batch(wqs) => for (wid, q) in wqs { self.send(wid, q); },
SwarmCmd::Panic(msg) => panic!("{}", msg),
SwarmCmd::Return(v) => { res = Some(v); break }}}}
res}}