use std::sync::Arc;
use std::sync::mpsc::{self, Receiver, RecvError, SendError, Sender, TryRecvError};
use super::super::edge::kind::EdgeKind;
use super::super::file::FileId;
use super::super::node::{NodeId, NodeKind};
#[derive(Debug, Clone)]
pub enum GraphUpdate {
AddNode {
kind: NodeKind,
name: String,
file: FileId,
},
RemoveNode {
node: NodeId,
},
AddEdge {
source: NodeId,
target: NodeId,
kind: EdgeKind,
file: FileId,
},
RemoveEdge {
source: NodeId,
target: NodeId,
kind: EdgeKind,
file: FileId,
},
ClearFile {
file: FileId,
},
TriggerCompaction,
Shutdown,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ChannelError {
Disconnected,
Full,
}
impl std::fmt::Display for ChannelError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Disconnected => write!(f, "channel disconnected"),
Self::Full => write!(f, "channel full"),
}
}
}
impl std::error::Error for ChannelError {}
impl<T> From<SendError<T>> for ChannelError {
fn from(_: SendError<T>) -> Self {
Self::Disconnected
}
}
impl From<RecvError> for ChannelError {
fn from(_: RecvError) -> Self {
Self::Disconnected
}
}
#[derive(Debug, Clone)]
pub struct UpdateChannel {
sender: Sender<GraphUpdate>,
updates_sent: Arc<std::sync::atomic::AtomicU64>,
}
impl UpdateChannel {
#[must_use]
pub fn new() -> (Self, UpdateReceiver) {
let (sender, receiver) = mpsc::channel();
let updates_sent = Arc::new(std::sync::atomic::AtomicU64::new(0));
let updates_received = Arc::new(std::sync::atomic::AtomicU64::new(0));
(
Self {
sender,
updates_sent: Arc::clone(&updates_sent),
},
UpdateReceiver {
receiver,
updates_received,
},
)
}
#[must_use]
pub fn bounded(capacity: usize) -> (Self, UpdateReceiver) {
let (sender, receiver) = mpsc::sync_channel(capacity);
let updates_sent = Arc::new(std::sync::atomic::AtomicU64::new(0));
let updates_received = Arc::new(std::sync::atomic::AtomicU64::new(0));
(
Self {
sender: {
let (tx, rx) = mpsc::channel();
std::thread::spawn(move || {
while let Ok(update) = rx.recv() {
if sender.send(update).is_err() {
break;
}
}
});
tx
},
updates_sent: Arc::clone(&updates_sent),
},
UpdateReceiver {
receiver,
updates_received,
},
)
}
pub fn send(&self, update: GraphUpdate) -> Result<(), ChannelError> {
self.sender.send(update)?;
self.updates_sent
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Ok(())
}
#[must_use]
pub fn updates_sent(&self) -> u64 {
self.updates_sent.load(std::sync::atomic::Ordering::Relaxed)
}
}
impl Default for UpdateChannel {
fn default() -> Self {
Self::new().0
}
}
#[derive(Debug)]
pub struct UpdateReceiver {
receiver: Receiver<GraphUpdate>,
updates_received: Arc<std::sync::atomic::AtomicU64>,
}
impl UpdateReceiver {
pub fn recv(&self) -> Result<GraphUpdate, ChannelError> {
let update = self.receiver.recv()?;
self.updates_received
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Ok(update)
}
pub fn try_recv(&self) -> Result<Option<GraphUpdate>, ChannelError> {
match self.receiver.try_recv() {
Ok(update) => {
self.updates_received
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Ok(Some(update))
}
Err(TryRecvError::Empty) => Ok(None),
Err(TryRecvError::Disconnected) => Err(ChannelError::Disconnected),
}
}
pub fn iter(&self) -> impl Iterator<Item = GraphUpdate> + '_ {
std::iter::from_fn(|| self.recv().ok())
}
#[must_use]
pub fn updates_received(&self) -> u64 {
self.updates_received
.load(std::sync::atomic::Ordering::Relaxed)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ChannelStats {
pub sent: u64,
pub received: u64,
}
impl ChannelStats {
#[must_use]
pub fn in_flight(&self) -> u64 {
self.sent.saturating_sub(self.received)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
#[test]
fn test_channel_new() {
let (sender, _receiver) = UpdateChannel::new();
assert_eq!(sender.updates_sent(), 0);
}
#[test]
fn test_channel_default() {
let sender: UpdateChannel = UpdateChannel::default();
assert_eq!(sender.updates_sent(), 0);
}
#[test]
fn test_send_receive() {
let (sender, receiver) = UpdateChannel::new();
let file = FileId::new(1);
sender
.send(GraphUpdate::ClearFile { file })
.expect("send failed");
let update = receiver.recv().expect("recv failed");
match update {
GraphUpdate::ClearFile { file: f } => assert_eq!(f, file),
_ => panic!("wrong update type"),
}
}
#[test]
fn test_updates_serialized() {
let (sender, receiver) = UpdateChannel::new();
let sender_clone = sender.clone();
let handle1 = thread::spawn(move || {
for i in 0..100 {
let file = FileId::new(i);
sender.send(GraphUpdate::ClearFile { file }).unwrap();
}
});
let handle2 = thread::spawn(move || {
for i in 100..200 {
let file = FileId::new(i);
sender_clone.send(GraphUpdate::ClearFile { file }).unwrap();
}
});
handle1.join().unwrap();
handle2.join().unwrap();
let mut received_updates = Vec::new();
while let Ok(Some(update)) = receiver.try_recv() {
received_updates.push(update);
}
assert_eq!(received_updates.len(), 200);
assert_eq!(receiver.updates_received(), 200);
}
#[test]
fn test_channel_ordering() {
let (sender, receiver) = UpdateChannel::new();
for i in 0..100 {
let file = FileId::new(i);
sender.send(GraphUpdate::ClearFile { file }).unwrap();
}
for i in 0..100 {
let update = receiver.recv().unwrap();
match update {
GraphUpdate::ClearFile { file } => {
assert_eq!(file.index(), i);
}
_ => panic!("wrong update type"),
}
}
}
#[test]
fn test_try_recv_empty() {
let (sender, receiver) = UpdateChannel::new();
let _sender = sender;
assert!(receiver.try_recv().unwrap().is_none());
}
#[test]
fn test_try_recv_available() {
let (sender, receiver) = UpdateChannel::new();
let file = FileId::new(42);
sender.send(GraphUpdate::ClearFile { file }).unwrap();
let result = receiver.try_recv().unwrap();
assert!(result.is_some());
}
#[test]
fn test_disconnected_sender() {
let (sender, receiver) = UpdateChannel::new();
drop(sender);
let result = receiver.recv();
assert!(matches!(result, Err(ChannelError::Disconnected)));
}
#[test]
fn test_disconnected_receiver() {
let (sender, receiver) = UpdateChannel::new();
drop(receiver);
let file = FileId::new(1);
let result = sender.send(GraphUpdate::ClearFile { file });
assert!(matches!(result, Err(ChannelError::Disconnected)));
}
#[test]
fn test_update_kinds() {
let (sender, receiver) = UpdateChannel::new();
sender
.send(GraphUpdate::AddNode {
kind: NodeKind::Function,
name: "test".to_string(),
file: FileId::new(1),
})
.unwrap();
sender
.send(GraphUpdate::RemoveNode {
node: NodeId::new(1, 0),
})
.unwrap();
sender
.send(GraphUpdate::AddEdge {
source: NodeId::new(1, 0),
target: NodeId::new(2, 0),
kind: EdgeKind::Calls {
argument_count: 0,
is_async: false,
},
file: FileId::new(1),
})
.unwrap();
sender
.send(GraphUpdate::RemoveEdge {
source: NodeId::new(1, 0),
target: NodeId::new(2, 0),
kind: EdgeKind::Calls {
argument_count: 0,
is_async: false,
},
file: FileId::new(1),
})
.unwrap();
sender
.send(GraphUpdate::ClearFile {
file: FileId::new(1),
})
.unwrap();
sender.send(GraphUpdate::TriggerCompaction).unwrap();
sender.send(GraphUpdate::Shutdown).unwrap();
assert_eq!(sender.updates_sent(), 7);
let mut count = 0;
while receiver.try_recv().unwrap().is_some() {
count += 1;
}
assert_eq!(count, 7);
}
#[test]
fn test_channel_stats() {
let stats = ChannelStats {
sent: 100,
received: 75,
};
assert_eq!(stats.in_flight(), 25);
}
#[test]
fn test_channel_stats_saturating() {
let stats = ChannelStats {
sent: 50,
received: 75,
};
assert_eq!(stats.in_flight(), 0);
}
#[test]
fn test_iter() {
let (sender, receiver) = UpdateChannel::new();
for i in 0..10 {
sender
.send(GraphUpdate::ClearFile {
file: FileId::new(i),
})
.unwrap();
}
drop(sender);
let updates: Vec<_> = receiver.iter().collect();
assert_eq!(updates.len(), 10);
}
#[test]
fn test_clone_sender() {
let (sender, receiver) = UpdateChannel::new();
let sender2 = sender.clone();
sender
.send(GraphUpdate::ClearFile {
file: FileId::new(1),
})
.unwrap();
sender2
.send(GraphUpdate::ClearFile {
file: FileId::new(2),
})
.unwrap();
assert_eq!(sender.updates_sent(), 2);
assert_eq!(sender2.updates_sent(), 2);
let mut count = 0;
while receiver.try_recv().unwrap().is_some() {
count += 1;
}
assert_eq!(count, 2);
}
#[test]
fn test_channel_error_display() {
let err = ChannelError::Disconnected;
assert_eq!(format!("{err}"), "channel disconnected");
let err = ChannelError::Full;
assert_eq!(format!("{err}"), "channel full");
}
#[test]
fn test_concurrent_send_receive() {
let (sender, receiver) = UpdateChannel::new();
let sender_clone = sender.clone();
let producer = thread::spawn(move || {
for i in 0..1000 {
sender
.send(GraphUpdate::ClearFile {
file: FileId::new(i),
})
.unwrap();
}
});
let producer2 = thread::spawn(move || {
for i in 1000..2000 {
sender_clone
.send(GraphUpdate::ClearFile {
file: FileId::new(i),
})
.unwrap();
}
});
let consumer = thread::spawn(move || {
let mut count = 0u64;
loop {
match receiver.try_recv() {
Ok(Some(_)) => count += 1,
Ok(None) => {
if count >= 2000 {
break;
}
thread::sleep(Duration::from_micros(10));
}
Err(ChannelError::Disconnected) => break,
Err(_) => {}
}
}
count
});
producer.join().unwrap();
producer2.join().unwrap();
let received_count = consumer.join().unwrap();
assert_eq!(received_count, 2000);
}
}