use dashmap::DashMap;
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use once_cell::sync::Lazy;
use std::{
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
time::Duration,
};
use tokio::{
sync::mpsc,
task,
};
pub trait ChannelInfo: Send + Sync {
fn get_queue_depth(&self) -> (String, usize);
fn get_channel_length(&self) -> usize;
}
struct ChannelMetadata<T> {
id: String,
len: usize,
sender: mpsc::Sender<T>,
}
impl<T: Send + Sync + 'static> ChannelInfo for ChannelMetadata<T> {
fn get_queue_depth(&self) -> (String, usize) {
let capacity = self.sender.capacity();
let qdepth = self.len - capacity;
(self.id.clone(), qdepth)
}
fn get_channel_length(&self) -> usize {
self.len
}
}
static CHANNELS: Lazy<DashMap<String, Arc<dyn ChannelInfo>>> = Lazy::new(DashMap::new);
static INITIALIZED: AtomicBool = AtomicBool::new(false);
pub fn hook_channel<T: Send + Sync + 'static>(sender: mpsc::Sender<T>, id: &str, len: usize) {
let metadata = ChannelMetadata {
id: id.to_string(),
len,
sender,
};
let metadata_arc: Arc<dyn ChannelInfo> = Arc::new(metadata);
CHANNELS.insert(id.to_string(), metadata_arc);
}
pub fn init(interval: u64) {
if INITIALIZED.swap(true, Ordering::SeqCst) {
return;
}
let multi = Arc::new(MultiProgress::new());
let style = ProgressStyle::with_template("{msg}: [{wide_bar}] {pos}/{len}")
.unwrap()
.progress_chars("█▉▊▋▌▍▎▏ ");
let multi_clone = multi.clone();
task::spawn(async move {
let progress_bars = DashMap::new();
loop {
CHANNELS.iter().for_each(|entry| {
let id = entry.key();
if !progress_bars.contains_key(id) {
let pb = multi_clone.add(ProgressBar::new(entry.value().get_channel_length() as u64));
pb.set_style(style.clone());
pb.set_message(format!("Copepod::{}", id));
progress_bars.insert(id.clone(), pb);
}
});
for item in progress_bars.iter_mut() {
let channel_id = item.key();
if let Some(metadata) = CHANNELS.get(channel_id) {
let (_, qdepth) = metadata.value().get_queue_depth();
item.value().set_position(qdepth as u64);
}
}
tokio::time::sleep(Duration::from_millis(interval)).await;
}
});
}
#[tokio::test]
async fn test_hook_channel() {
let size = 10;
let (tx, mut rx) = tokio::sync::mpsc::channel::<usize>(size);
hook_channel(tx.clone(), "test_channel", size);
let metadata = CHANNELS.get("test_channel").expect("Channel not registered");
let (id, qdepth) = metadata.get_queue_depth();
assert_eq!(id, "test_channel");
assert!(qdepth <= size); assert_eq!(metadata.get_channel_length(), size);
tx.send(1).await.unwrap();
let (_, qdepth) = metadata.get_queue_depth();
assert_eq!(qdepth, 1);
rx.recv().await;
let (_, qdepth) = metadata.get_queue_depth();
assert_eq!(qdepth, 0); }