use crate::time::Duration;
use crate::TaskRunner;
use std::cmp::Ordering as CmpOrdering;
use std::rc::{Rc, Weak as WeakRc};
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use crate::network::{Latency, NetworkMessage, Process, ProcessId};
pub struct Link<Message: NetworkMessage> {
queue1: Rc<LinkQueue<Message>>,
queue2: Rc<LinkQueue<Message>>,
task_runner: Rc<TaskRunner>,
active_queues: AtomicU32,
}
impl<Message: NetworkMessage> Link<Message> {
pub(super) fn new(
latency: Latency,
process1: WeakRc<Process<Message>>,
process2: WeakRc<Process<Message>>,
task_runner: Rc<TaskRunner>,
) -> Self {
let queue1 = Rc::new(LinkQueue::new(latency, process1.clone(), process2.clone()));
let queue2 = Rc::new(LinkQueue::new(latency, process2, process1));
let active_queues = AtomicU32::new(0);
Self {
queue1,
queue2,
task_runner,
active_queues,
}
}
pub fn get_processes(&self) -> (Rc<Process<Message>>, Rc<Process<Message>>) {
let process1 = self.queue1.get_source();
let process2 = self.queue1.get_destination();
match process1.get_identifier().cmp(&process2.get_identifier()) {
CmpOrdering::Less => (process1, process2),
CmpOrdering::Greater => (process2, process1),
CmpOrdering::Equal => panic!("Invalid state: src and dst process are the same"),
}
}
pub fn send(self_ptr: &Rc<Link<Message>>, source: ProcessId, message: Message) {
if self_ptr.queue1.get_source().get_identifier() == source {
LinkQueue::send(self_ptr.queue1.clone(), self_ptr.clone(), message);
} else if self_ptr.queue2.get_source().get_identifier() == source {
LinkQueue::send(self_ptr.queue2.clone(), self_ptr.clone(), message);
} else {
panic!("Invalid state");
}
}
pub fn num_total_messages(&self) -> u64 {
self.queue1.total_message_count.load(Ordering::SeqCst)
+ self.queue2.total_message_count.load(Ordering::SeqCst)
}
}
struct LinkQueue<Message: NetworkMessage> {
latency: Duration,
source: WeakRc<Process<Message>>,
dest: WeakRc<Process<Message>>,
current_message_count: AtomicU32,
total_message_count: AtomicU64,
}
impl<Message: NetworkMessage> LinkQueue<Message> {
fn new(
latency: Latency,
source: WeakRc<Process<Message>>,
dest: WeakRc<Process<Message>>,
) -> Self {
let current_message_count = AtomicU32::new(0);
let total_message_count = AtomicU64::new(0);
Self {
latency,
total_message_count,
source,
dest,
current_message_count,
}
}
fn send(
self_ptr: Rc<LinkQueue<Message>>,
link: Rc<Link<Message>>,
message: Message,
) -> (bool, Duration) {
let task_runner = link.task_runner.clone();
let latency = self_ptr.latency;
let was_empty = {
self_ptr.total_message_count.fetch_add(1, Ordering::SeqCst);
let prev = self_ptr
.current_message_count
.fetch_add(1, Ordering::SeqCst);
prev == 0
};
if was_empty {
link.active_queues.fetch_add(1, Ordering::SeqCst);
}
task_runner.spawn(async move {
let notify_delivery_fn = {
let self_ptr = self_ptr.clone();
let link = link.clone();
Box::new(move || {
let prev = self_ptr
.current_message_count
.fetch_sub(1, Ordering::SeqCst);
assert!(prev > 0);
if prev == 1 {
link.active_queues.fetch_sub(1, Ordering::SeqCst);
}
})
};
let dst = self_ptr.get_destination();
dst.deliver_message(
self_ptr.get_source().get_identifier(),
message,
notify_delivery_fn,
);
});
(was_empty, latency)
}
fn get_source(&self) -> Rc<Process<Message>> {
self.source.upgrade().unwrap()
}
fn get_destination(&self) -> Rc<Process<Message>> {
self.dest.upgrade().unwrap()
}
}