use std::collections::HashMap;
use corophage::prelude::*;
use tokio::sync::mpsc;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum Loc {
Alice,
Bob,
Carol,
}
impl std::fmt::Display for Loc {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Loc::Alice => write!(f, "Alice"),
Loc::Bob => write!(f, "Bob"),
Loc::Carol => write!(f, "Carol"),
}
}
}
#[effect(String)]
struct Locally {
at: Loc,
value: String,
}
#[effect(String)]
struct Comm {
from: Loc,
to: Loc,
payload: String,
}
#[effectful(Locally, Comm)]
fn pipeline() -> String {
use Loc::*;
macro_rules! locally {
($loc:expr, $value:expr) => {
yield_!(Locally {
at: $loc,
value: $value,
})
};
}
macro_rules! send {
($from:expr => $to:expr, $val:expr) => {
yield_!(Comm {
from: $from,
to: $to,
payload: $val,
})
};
}
let input = locally!(Alice, "the quick brown fox jumps over the lazy dog".into());
let at_bob = send!(Alice => Bob, input);
let count = locally!(Bob, at_bob.split_whitespace().count().to_string());
let at_carol = send!(Bob => Carol, count);
let report = locally!(Carol, format!("=== Report: {at_carol} words ==="));
send!(Carol => Alice, report)
}
struct NodeState {
loc: Loc,
senders: HashMap<Loc, mpsc::Sender<String>>,
receivers: HashMap<Loc, mpsc::Receiver<String>>,
}
impl NodeState {
async fn send_to(&self, dest: Loc, msg: String) {
self.senders[&dest].send(msg).await.expect("channel closed");
}
async fn recv_from(&mut self, source: Loc) -> String {
self.receivers
.get_mut(&source)
.expect("no channel from source")
.recv()
.await
.expect("channel closed")
}
}
async fn handle_locally(state: &mut NodeState, eff: Locally) -> Control<String> {
if eff.at == state.loc {
println!(" [{}] local: {:?}", state.loc, eff.value);
Control::resume(eff.value)
} else {
Control::resume(String::new())
}
}
async fn handle_comm(state: &mut NodeState, eff: Comm) -> Control<String> {
let me = state.loc;
if eff.from == me {
println!(" [{me}] send to {}: {:?}", eff.to, eff.payload);
state.send_to(eff.to, eff.payload).await;
Control::resume(String::new())
} else if eff.to == me {
let value = state.recv_from(eff.from).await;
println!(" [{me}] recv from {}: {:?}", eff.from, value);
Control::resume(value)
} else {
Control::resume(String::new())
}
}
type SenderMap = HashMap<Loc, HashMap<Loc, mpsc::Sender<String>>>;
type ReceiverMap = HashMap<Loc, HashMap<Loc, mpsc::Receiver<String>>>;
fn channel_mesh(locs: &[Loc]) -> (SenderMap, ReceiverMap) {
let mut senders: HashMap<Loc, HashMap<Loc, mpsc::Sender<String>>> = HashMap::new();
let mut receivers: HashMap<Loc, HashMap<Loc, mpsc::Receiver<String>>> = HashMap::new();
for &from in locs {
for &to in locs {
if from != to {
let (tx, rx) = mpsc::channel(1);
senders.entry(from).or_default().insert(to, tx);
receivers.entry(to).or_default().insert(from, rx);
}
}
}
(senders, receivers)
}
fn make_nodes(locs: &[Loc]) -> Vec<NodeState> {
let (mut senders, mut receivers) = channel_mesh(locs);
locs.iter()
.map(|&loc| NodeState {
loc,
senders: senders.remove(&loc).unwrap_or_default(),
receivers: receivers.remove(&loc).unwrap_or_default(),
})
.collect()
}
#[tokio::main(flavor = "current_thread")]
async fn main() {
println!("=== Choreographic Data Pipeline ===\n");
println!("Choreography (global view):\n");
println!(" input <- locally Alice, \"the quick brown fox ...\"");
println!(" at_bob <- send Alice => Bob, input");
println!(" count <- locally Bob, countWords(at_bob)");
println!(" at_carol <- send Bob => Carol, count");
println!(" report <- locally Carol, formatReport(at_carol)");
println!(" result <- send Carol => Alice, report");
println!();
let mut nodes = make_nodes(&[Loc::Alice, Loc::Bob, Loc::Carol]);
let [alice, bob, carol] = nodes.as_mut_slice() else {
unreachable!()
};
println!("Running endpoint projections concurrently:\n");
let (a, b, c) = tokio::join!(
pipeline()
.handle(handle_locally)
.handle(handle_comm)
.run_stateful(alice),
pipeline()
.handle(handle_locally)
.handle(handle_comm)
.run_stateful(bob),
pipeline()
.handle(handle_locally)
.handle(handle_comm)
.run_stateful(carol),
);
println!();
println!("Results (each location ran the same choreography):");
println!(" Alice: {:?} <-- the final report", a.unwrap());
println!(" Bob: {:?} <-- erased (not at Alice)", b.unwrap());
println!(" Carol: {:?} <-- erased (not at Alice)", c.unwrap());
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn full_pipeline_produces_correct_results() {
let mut nodes = make_nodes(&[Loc::Alice, Loc::Bob, Loc::Carol]);
let [alice, bob, carol] = nodes.as_mut_slice() else {
unreachable!()
};
let (a, b, c) = tokio::join!(
pipeline()
.handle(handle_locally)
.handle(handle_comm)
.run_stateful(alice),
pipeline()
.handle(handle_locally)
.handle(handle_comm)
.run_stateful(bob),
pipeline()
.handle(handle_locally)
.handle(handle_comm)
.run_stateful(carol),
);
assert_eq!(a.unwrap(), "=== Report: 9 words ===");
assert_eq!(b.unwrap(), "");
assert_eq!(c.unwrap(), "");
}
#[tokio::test]
async fn mock_epp_for_alice() {
let result = pipeline()
.handle(async |eff: Locally| -> Control<String> {
if eff.at == Loc::Alice {
Control::resume(eff.value)
} else {
Control::resume(String::new())
}
})
.handle(async |eff: Comm| -> Control<String> {
if eff.from == Loc::Alice {
Control::resume(String::new())
} else if eff.to == Loc::Alice {
Control::resume("=== Report: 9 words ===".into())
} else {
Control::resume(String::new())
}
})
.run()
.await
.unwrap();
assert_eq!(result, "=== Report: 9 words ===");
}
#[tokio::test]
async fn mock_epp_for_bob() {
let result = pipeline()
.handle(async |eff: Locally| -> Control<String> {
if eff.at == Loc::Bob {
Control::resume(eff.value)
} else {
Control::resume(String::new())
}
})
.handle(async |eff: Comm| -> Control<String> {
if eff.to == Loc::Bob {
Control::resume("some words here".into())
} else if eff.from == Loc::Bob {
assert_eq!(eff.payload, "3", "Bob should send the word count");
Control::resume(String::new())
} else {
Control::resume(String::new())
}
})
.run()
.await
.unwrap();
assert_eq!(result, "");
}
}