use tokio::sync::mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender};
use tokio::sync::oneshot;
use crossbeam_channel::{unbounded, Sender as CbSender};
use prettytable::{Cell, Row, Table};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, OnceLock, RwLock};
use std::time::Instant;
use tiny_http::{Response, Server};
mod wrappers;
use wrappers::{wrap_channel, wrap_oneshot, wrap_unbounded};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ChannelType {
Bounded(usize),
Unbounded,
Oneshot,
}
impl std::fmt::Display for ChannelType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ChannelType::Bounded(size) => write!(f, "bounded[{}]", size),
ChannelType::Unbounded => write!(f, "unbounded"),
ChannelType::Oneshot => write!(f, "oneshot"),
}
}
}
impl Serialize for ChannelType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
impl<'de> Deserialize<'de> for ChannelType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
match s.as_str() {
"unbounded" => Ok(ChannelType::Unbounded),
"oneshot" => Ok(ChannelType::Oneshot),
_ => {
if let Some(inner) = s.strip_prefix("bounded[").and_then(|x| x.strip_suffix(']')) {
let size = inner
.parse()
.map_err(|_| serde::de::Error::custom("invalid bounded size"))?;
Ok(ChannelType::Bounded(size))
} else {
Err(serde::de::Error::custom("invalid channel type"))
}
}
}
}
}
#[derive(Clone, Copy, Debug, Default)]
pub enum Format {
#[default]
Table,
Json,
JsonPretty,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ChannelState {
#[default]
Active,
Closed,
Full,
Notified,
}
impl std::fmt::Display for ChannelState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
impl ChannelState {
pub fn as_str(&self) -> &'static str {
match self {
ChannelState::Active => "active",
ChannelState::Closed => "closed",
ChannelState::Full => "full",
ChannelState::Notified => "notified",
}
}
}
impl Serialize for ChannelState {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(self.as_str())
}
}
impl<'de> Deserialize<'de> for ChannelState {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
match s.as_str() {
"active" => Ok(ChannelState::Active),
"closed" => Ok(ChannelState::Closed),
"full" => Ok(ChannelState::Full),
"notified" => Ok(ChannelState::Notified),
_ => Err(serde::de::Error::custom("invalid channel state")),
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct ChannelStats {
pub(crate) id: &'static str,
pub(crate) label: Option<&'static str>,
pub(crate) channel_type: ChannelType,
pub(crate) state: ChannelState,
pub(crate) sent_count: u64,
pub(crate) received_count: u64,
pub(crate) type_name: &'static str,
pub(crate) type_size: usize,
}
impl ChannelStats {
pub fn queued(&self) -> u64 {
self.sent_count.saturating_sub(self.received_count)
}
pub fn total_bytes(&self) -> u64 {
self.sent_count * self.type_size as u64
}
pub fn queued_bytes(&self) -> u64 {
self.queued() * self.type_size as u64
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SerializableChannelStats {
pub id: String,
pub label: String,
pub channel_type: ChannelType,
pub state: ChannelState,
pub sent_count: u64,
pub received_count: u64,
pub queued: u64,
pub type_name: String,
pub type_size: usize,
pub total_bytes: u64,
pub queued_bytes: u64,
}
impl From<&ChannelStats> for SerializableChannelStats {
fn from(stats: &ChannelStats) -> Self {
let label = resolve_label(stats.id, stats.label);
Self {
id: stats.id.to_string(),
label,
channel_type: stats.channel_type,
state: stats.state,
sent_count: stats.sent_count,
received_count: stats.received_count,
queued: stats.queued(),
type_name: stats.type_name.to_string(),
type_size: stats.type_size,
total_bytes: stats.total_bytes(),
queued_bytes: stats.queued_bytes(),
}
}
}
impl ChannelStats {
fn new(
id: &'static str,
label: Option<&'static str>,
channel_type: ChannelType,
type_name: &'static str,
type_size: usize,
) -> Self {
Self {
id,
label,
channel_type,
state: ChannelState::default(),
sent_count: 0,
received_count: 0,
type_name,
type_size,
}
}
fn update_state(&mut self) {
if self.state == ChannelState::Closed || self.state == ChannelState::Notified {
return;
}
if self.sent_count > self.received_count {
self.state = ChannelState::Full;
} else {
self.state = ChannelState::Active;
}
}
}
#[derive(Debug)]
pub(crate) enum StatsEvent {
Created {
id: &'static str,
display_label: Option<&'static str>,
channel_type: ChannelType,
type_name: &'static str,
type_size: usize,
},
MessageSent {
id: &'static str,
},
MessageReceived {
id: &'static str,
},
Closed {
id: &'static str,
},
Notified {
id: &'static str,
},
}
type StatsState = (
CbSender<StatsEvent>,
Arc<RwLock<HashMap<&'static str, ChannelStats>>>,
);
static STATS_STATE: OnceLock<StatsState> = OnceLock::new();
fn init_stats_state() -> &'static StatsState {
STATS_STATE.get_or_init(|| {
let (tx, rx) = unbounded::<StatsEvent>();
let stats_map = Arc::new(RwLock::new(HashMap::<&'static str, ChannelStats>::new()));
let stats_map_clone = Arc::clone(&stats_map);
std::thread::Builder::new()
.name("channel-stats-collector".into())
.spawn(move || {
while let Ok(event) = rx.recv() {
let mut stats = stats_map_clone.write().unwrap();
match event {
StatsEvent::Created {
id: key,
display_label,
channel_type,
type_name,
type_size,
} => {
stats.insert(
key,
ChannelStats::new(
key,
display_label,
channel_type,
type_name,
type_size,
),
);
}
StatsEvent::MessageSent { id } => {
if let Some(channel_stats) = stats.get_mut(id) {
channel_stats.sent_count += 1;
channel_stats.update_state();
}
}
StatsEvent::MessageReceived { id } => {
if let Some(channel_stats) = stats.get_mut(id) {
channel_stats.received_count += 1;
channel_stats.update_state();
}
}
StatsEvent::Closed { id } => {
if let Some(channel_stats) = stats.get_mut(id) {
channel_stats.state = ChannelState::Closed;
}
}
StatsEvent::Notified { id } => {
if let Some(channel_stats) = stats.get_mut(id) {
channel_stats.state = ChannelState::Notified;
}
}
}
}
})
.expect("Failed to spawn channel-stats-collector thread");
let port = std::env::var("channels_console_METRICS_PORT")
.ok()
.and_then(|p| p.parse::<u16>().ok())
.unwrap_or(6770);
let addr = format!("127.0.0.1:{}", port);
std::thread::spawn(move || {
start_metrics_server(&addr);
});
(tx, stats_map)
})
}
fn resolve_label(id: &'static str, provided: Option<&'static str>) -> String {
if let Some(l) = provided {
return l.to_string();
}
if let Some(pos) = id.rfind(':') {
let (path, line_part) = id.split_at(pos);
let line = &line_part[1..];
format!("{}:{}", extract_filename(path), line)
} else {
extract_filename(id)
}
}
fn extract_filename(path: &str) -> String {
let components: Vec<&str> = path.split('/').collect();
if components.len() >= 2 {
format!(
"{}/{}",
components[components.len() - 2],
components[components.len() - 1]
)
} else {
path.to_string()
}
}
pub fn format_bytes(bytes: u64) -> String {
if bytes == 0 {
return "0 B".to_string();
}
const UNITS: &[&str] = &["B", "KB", "MB", "GB", "TB"];
let mut size = bytes as f64;
let mut unit_idx = 0;
while size >= 1024.0 && unit_idx < UNITS.len() - 1 {
size /= 1024.0;
unit_idx += 1;
}
if unit_idx == 0 {
format!("{} {}", bytes, UNITS[unit_idx])
} else {
format!("{:.1} {}", size, UNITS[unit_idx])
}
}
#[doc(hidden)]
pub trait Instrument {
type Output;
fn instrument(self, channel_id: &'static str, label: Option<&'static str>) -> Self::Output;
}
impl<T: Send + 'static> Instrument for (Sender<T>, Receiver<T>) {
type Output = (Sender<T>, Receiver<T>);
fn instrument(self, channel_id: &'static str, label: Option<&'static str>) -> Self::Output {
wrap_channel(self, channel_id, label)
}
}
impl<T: Send + 'static> Instrument for (UnboundedSender<T>, UnboundedReceiver<T>) {
type Output = (UnboundedSender<T>, UnboundedReceiver<T>);
fn instrument(self, channel_id: &'static str, label: Option<&'static str>) -> Self::Output {
wrap_unbounded(self, channel_id, label)
}
}
impl<T: Send + 'static> Instrument for (oneshot::Sender<T>, oneshot::Receiver<T>) {
type Output = (oneshot::Sender<T>, oneshot::Receiver<T>);
fn instrument(self, channel_id: &'static str, label: Option<&'static str>) -> Self::Output {
wrap_oneshot(self, channel_id, label)
}
}
#[macro_export]
macro_rules! instrument {
($expr:expr) => {{
const CHANNEL_ID: &'static str = concat!(file!(), ":", line!());
$crate::Instrument::instrument($expr, CHANNEL_ID, None)
}};
($expr:expr, label = $label:literal) => {{
const CHANNEL_ID: &'static str = concat!(file!(), ":", line!());
$crate::Instrument::instrument($expr, CHANNEL_ID, Some($label))
}};
}
fn get_channel_stats() -> HashMap<&'static str, ChannelStats> {
if let Some((_, stats_map)) = STATS_STATE.get() {
stats_map.read().unwrap().clone()
} else {
HashMap::new()
}
}
fn get_serializable_stats() -> Vec<SerializableChannelStats> {
let mut stats: Vec<SerializableChannelStats> = get_channel_stats()
.values()
.map(SerializableChannelStats::from)
.collect();
stats.sort_by(|a, b| a.id.cmp(&b.id));
stats
}
fn start_metrics_server(addr: &str) {
let server = match Server::http(addr) {
Ok(s) => s,
Err(e) => {
panic!("Failed to bind metrics server to {}: {}. Customize the port using the channels_console_METRICS_PORT environment variable.", addr, e);
}
};
println!("Channel metrics server listening on http://{}", addr);
for request in server.incoming_requests() {
if request.url() == "/metrics" {
let stats = get_serializable_stats();
match serde_json::to_string(&stats) {
Ok(json) => {
let response = Response::from_string(json).with_header(
tiny_http::Header::from_bytes(
&b"Content-Type"[..],
&b"application/json"[..],
)
.unwrap(),
);
let _ = request.respond(response);
}
Err(e) => {
eprintln!("Failed to serialize metrics: {}", e);
let response = Response::from_string(format!("Internal server error: {}", e))
.with_status_code(500);
let _ = request.respond(response);
}
}
} else {
let response = Response::from_string("Not found").with_status_code(404);
let _ = request.respond(response);
}
}
}
pub struct ChannelsGuardBuilder {
format: Format,
}
impl ChannelsGuardBuilder {
pub fn new() -> Self {
Self {
format: Format::default(),
}
}
pub fn format(mut self, format: Format) -> Self {
self.format = format;
self
}
pub fn build(self) -> ChannelsGuard {
ChannelsGuard {
start_time: Instant::now(),
format: self.format,
}
}
}
impl Default for ChannelsGuardBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct ChannelsGuard {
start_time: Instant,
format: Format,
}
impl ChannelsGuard {
pub fn new() -> Self {
Self {
start_time: Instant::now(),
format: Format::default(),
}
}
pub fn format(mut self, format: Format) -> Self {
self.format = format;
self
}
}
impl Default for ChannelsGuard {
fn default() -> Self {
Self::new()
}
}
impl Drop for ChannelsGuard {
fn drop(&mut self) {
let elapsed = self.start_time.elapsed();
let stats = get_channel_stats();
if stats.is_empty() {
println!("\nNo instrumented channels found.");
return;
}
match self.format {
Format::Table => {
let mut table = Table::new();
table.add_row(Row::new(vec![
Cell::new("Channel"),
Cell::new("Type"),
Cell::new("State"),
Cell::new("Sent"),
Cell::new("Mem"),
Cell::new("Received"),
Cell::new("Queued"),
Cell::new("Mem"),
]));
let mut sorted_stats: Vec<_> = stats.into_iter().collect();
sorted_stats.sort_by(|a, b| {
let la = resolve_label(a.1.id, a.1.label);
let lb = resolve_label(b.1.id, b.1.label);
la.cmp(&lb)
});
for (_key, channel_stats) in sorted_stats {
let label = resolve_label(channel_stats.id, channel_stats.label);
table.add_row(Row::new(vec![
Cell::new(&label),
Cell::new(&channel_stats.channel_type.to_string()),
Cell::new(channel_stats.state.as_str()),
Cell::new(&channel_stats.sent_count.to_string()),
Cell::new(&format_bytes(channel_stats.total_bytes())),
Cell::new(&channel_stats.received_count.to_string()),
Cell::new(&channel_stats.queued().to_string()),
Cell::new(&format_bytes(channel_stats.queued_bytes())),
]));
}
println!(
"\n=== Channel Statistics (runtime: {:.2}s) ===",
elapsed.as_secs_f64()
);
table.printstd();
}
Format::Json => {
let serializable_stats = get_serializable_stats();
match serde_json::to_string(&serializable_stats) {
Ok(json) => println!("{}", json),
Err(e) => eprintln!("Failed to serialize statistics to JSON: {}", e),
}
}
Format::JsonPretty => {
let serializable_stats = get_serializable_stats();
match serde_json::to_string_pretty(&serializable_stats) {
Ok(json) => println!("{}", json),
Err(e) => eprintln!("Failed to serialize statistics to pretty JSON: {}", e),
}
}
}
}
}