use std::collections::HashMap;
use traceforge::msg::Message;
use traceforge::thread::{self, ThreadId};
use traceforge::*;
use log::{info, trace};
mod utils;
#[derive(Clone, Debug, PartialEq)]
enum ActorMsg<S: Send + 'static, T: Message + Send + 'static> {
Init(S),
Terminate,
Msg(T),
}
pub trait Actor<S, T: Message + Send + 'static> {
fn start(&mut self, d: S);
fn handle(&mut self, m: T);
fn stop(&mut self);
}
pub fn actor<S, T>(a: &mut dyn Actor<S, T>)
where
S: Clone + std::fmt::Debug + PartialEq + Send + 'static,
T: Clone + std::fmt::Debug + PartialEq + Send + 'static,
{
let mut started = false;
if let ActorMsg::Init(d) = recv_msg_block::<ActorMsg<S, T>>() {
a.start(d);
started = true;
} else {
traceforge::assume!(false);
}
loop {
let msg: ActorMsg<S, T> = recv_msg_block();
trace!("msg = {:?}", msg);
match msg {
ActorMsg::Init(_d) => {
panic!("Restarting actor?");
}
ActorMsg::Msg(m) => {
if started {
a.handle(m);
} else {
panic!("Message to actor that is not started");
}
}
ActorMsg::Terminate => {
a.stop();
break;
}
}
}
}
pub fn send_msg_to_actor<
S: Clone + std::fmt::Debug + PartialEq + Send + 'static,
M: Clone + std::fmt::Debug + PartialEq + Send + 'static,
>(
tid: ThreadId,
m: M,
) {
trace!("Sending message {:?}", &m);
traceforge::send_msg(tid, ActorMsg::<S, M>::Msg(m));
}
#[derive(Clone, PartialEq, Debug)]
enum CoordinatorMsg {
Request(u32),
Yes(u32),
No(u32),
}
#[derive(Clone, PartialEq, Debug)]
enum ParticipantMsg {
Prepare(ThreadId, u32),
Commit(u32),
Abort(u32),
}
struct Coordinator {
participants: Vec<ThreadId>,
reply_map: HashMap<u32, (u32, u32)>,
}
impl Coordinator {
pub fn default() -> Self {
Self {
participants: Vec::new(),
reply_map: HashMap::new(),
}
}
fn commit(&self, id: u32) {
self.participants
.iter()
.for_each(|tid| send_msg_to_participant(*tid, ParticipantMsg::Commit(id)));
}
fn abort(&self, id: u32) {
self.participants
.iter()
.for_each(|tid| send_msg_to_participant(*tid, ParticipantMsg::Abort(id)));
}
}
impl Actor<Vec<ThreadId>, CoordinatorMsg> for Coordinator {
fn start(&mut self, d: Vec<ThreadId>) {
self.participants = d;
}
fn stop(&mut self) {}
fn handle(&mut self, msg: CoordinatorMsg) {
trace!("C received {:?}", &msg);
match msg {
CoordinatorMsg::Request(reqid) => {
assert!(!self.reply_map.contains_key(&reqid));
self.reply_map.insert(reqid, (0, 0));
self.participants.iter().for_each(|id| {
send_msg_to_participant(
*id,
ParticipantMsg::Prepare(thread::current().id(), reqid),
)
});
}
CoordinatorMsg::Yes(id) => {
assert!(self.reply_map.contains_key(&id));
let (total, yes) = *self.reply_map.get(&id).unwrap();
let _ = self.reply_map.insert(id, (total + 1, yes + 1));
if yes + 1 == self.participants.len() as u32 {
self.commit(id);
self.reply_map.remove(&id);
}
}
CoordinatorMsg::No(id) => {
assert!(self.reply_map.contains_key(&id));
let (total, yes) = *self.reply_map.get(&id).unwrap();
let _ = self.reply_map.insert(id, (total + 1, yes));
if total + 1 == self.participants.len() as u32 {
self.abort(id);
self.reply_map.remove(&id);
}
}
}
}
}
type CoordinatorActorMsg = ActorMsg<Vec<ThreadId>, CoordinatorMsg>;
fn send_msg_to_coordinator(t: ThreadId, m: CoordinatorMsg) {
send_msg_to_actor::<Vec<ThreadId>, CoordinatorMsg>(t, m);
}
fn coordinator() -> thread::JoinHandle<()> {
let mut c = Coordinator::default();
thread::spawn(move || actor(&mut c))
}
struct Participant {
requests: HashMap<u32, bool>,
}
impl Participant {
fn default() -> Self {
Self {
requests: HashMap::new(),
}
}
}
impl Actor<(), ParticipantMsg> for Participant {
fn start(&mut self, _d: ()) {}
fn handle(&mut self, msg: ParticipantMsg) {
trace!("P received {:?}", &msg);
match msg {
ParticipantMsg::Prepare(tid, reqid) => {
let response = traceforge::nondet();
if response {
send_msg_to_coordinator(tid, CoordinatorMsg::Yes(reqid));
} else {
send_msg_to_coordinator(tid, CoordinatorMsg::No(reqid));
}
self.requests.insert(reqid, response);
}
ParticipantMsg::Abort(reqid) => {
assert!(self.requests.contains_key(&reqid));
}
ParticipantMsg::Commit(reqid) => {
assert!(self.requests.contains_key(&reqid));
assert!(self.requests.get(&reqid).unwrap()); }
}
}
fn stop(&mut self) {}
}
type ParticipantActorMsg = ActorMsg<(), ParticipantMsg>;
fn send_msg_to_participant(tid: ThreadId, m: ParticipantMsg) {
send_msg_to_actor::<(), ParticipantMsg>(tid, m);
}
fn participant() -> thread::JoinHandle<()> {
let mut c = Participant::default();
thread::spawn(move || actor(&mut c))
}
fn two_pc(num_ps: u32) {
let mut ps = Vec::new();
let c = coordinator();
for _i in 0..num_ps {
ps.push(participant());
}
traceforge::send_msg(
c.thread().id(),
CoordinatorActorMsg::Init(ps.iter().map(|h| h.thread().id()).collect()),
);
let _: Vec<()> = ps
.iter()
.map(|h| traceforge::send_msg(h.thread().id(), ParticipantActorMsg::Init(())))
.collect();
send_msg_to_actor::<Vec<ThreadId>, _>(c.thread().id(), CoordinatorMsg::Request(1));
send_msg_to_actor::<Vec<ThreadId>, _>(c.thread().id(), CoordinatorMsg::Request(2));
traceforge::send_msg(c.thread().id(), CoordinatorActorMsg::Terminate);
let _ = c.join();
let _: Vec<()> = ps
.iter()
.map(|h| traceforge::send_msg(h.thread().id(), ParticipantActorMsg::Terminate))
.collect();
for h in ps {
let _ = h.join();
}
}
#[cfg(test)]
mod tests {
use traceforge::*;
use super::*;
#[test]
fn two_pc_verify() {
let num_ps: u32 = 2;
let stats = traceforge::verify(
Config::builder()
.with_cons_type(ConsType::FIFO)
.with_progress_report(100)
.with_verbose(1)
.with_trace_out("/tmp/twopc.traces")
.build(),
move || two_pc(num_ps),
);
info!("Stats: {} {}", stats.execs, stats.block);
}
}