channels_console/
lib.rs

1use tokio::sync::mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender};
2use tokio::sync::oneshot;
3
4use crossbeam_channel::{unbounded, Sender as CbSender};
5use prettytable::{Cell, Row, Table};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::{Arc, OnceLock, RwLock};
9use std::time::Instant;
10use tiny_http::{Response, Server};
11
12mod wrappers;
13use wrappers::{wrap_channel, wrap_oneshot, wrap_unbounded};
14
15/// Type of a channel.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ChannelType {
18    Bounded(usize),
19    Unbounded,
20    Oneshot,
21}
22
23impl std::fmt::Display for ChannelType {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            ChannelType::Bounded(size) => write!(f, "bounded[{}]", size),
27            ChannelType::Unbounded => write!(f, "unbounded"),
28            ChannelType::Oneshot => write!(f, "oneshot"),
29        }
30    }
31}
32
33impl Serialize for ChannelType {
34    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
35    where
36        S: serde::Serializer,
37    {
38        serializer.serialize_str(&self.to_string())
39    }
40}
41
42impl<'de> Deserialize<'de> for ChannelType {
43    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
44    where
45        D: serde::Deserializer<'de>,
46    {
47        let s = String::deserialize(deserializer)?;
48
49        match s.as_str() {
50            "unbounded" => Ok(ChannelType::Unbounded),
51            "oneshot" => Ok(ChannelType::Oneshot),
52            _ => {
53                // try: bounded[123]
54                if let Some(inner) = s.strip_prefix("bounded[").and_then(|x| x.strip_suffix(']')) {
55                    let size = inner
56                        .parse()
57                        .map_err(|_| serde::de::Error::custom("invalid bounded size"))?;
58                    Ok(ChannelType::Bounded(size))
59                } else {
60                    Err(serde::de::Error::custom("invalid channel type"))
61                }
62            }
63        }
64    }
65}
66
67/// Format of the output produced by ChannelsGuard on drop.
68#[derive(Clone, Copy, Debug, Default)]
69pub enum Format {
70    #[default]
71    Table,
72    Json,
73    JsonPretty,
74}
75
76/// State of a instrumented channel.
77#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
78pub enum ChannelState {
79    #[default]
80    Active,
81    Closed,
82    Full,
83    Notified,
84}
85
86impl std::fmt::Display for ChannelState {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        write!(f, "{}", self.as_str())
89    }
90}
91
92impl ChannelState {
93    pub fn as_str(&self) -> &'static str {
94        match self {
95            ChannelState::Active => "active",
96            ChannelState::Closed => "closed",
97            ChannelState::Full => "full",
98            ChannelState::Notified => "notified",
99        }
100    }
101}
102
103impl Serialize for ChannelState {
104    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
105    where
106        S: serde::Serializer,
107    {
108        serializer.serialize_str(self.as_str())
109    }
110}
111
112impl<'de> Deserialize<'de> for ChannelState {
113    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
114    where
115        D: serde::Deserializer<'de>,
116    {
117        let s = String::deserialize(deserializer)?;
118        match s.as_str() {
119            "active" => Ok(ChannelState::Active),
120            "closed" => Ok(ChannelState::Closed),
121            "full" => Ok(ChannelState::Full),
122            "notified" => Ok(ChannelState::Notified),
123            _ => Err(serde::de::Error::custom("invalid channel state")),
124        }
125    }
126}
127
128/// Statistics for a single instrumented channel.
129#[derive(Debug, Clone)]
130pub(crate) struct ChannelStats {
131    /// ID of the channel (full path used as HashMap key).
132    pub(crate) id: &'static str,
133    /// Optional user label; if None, display derives from `id`.
134    pub(crate) label: Option<&'static str>,
135    /// Type of channel.
136    pub(crate) channel_type: ChannelType,
137    /// Current state of the channel.
138    pub(crate) state: ChannelState,
139    /// Number of messages sent through this channel.
140    pub(crate) sent_count: u64,
141    /// Number of messages received from this channel.
142    pub(crate) received_count: u64,
143    /// Type name of messages in this channel.
144    pub(crate) type_name: &'static str,
145    /// Size in bytes of the message type.
146    pub(crate) type_size: usize,
147}
148
149impl ChannelStats {
150    pub fn queued(&self) -> u64 {
151        self.sent_count.saturating_sub(self.received_count)
152    }
153
154    /// Calculate total bytes sent through this channel.
155    pub fn total_bytes(&self) -> u64 {
156        self.sent_count * self.type_size as u64
157    }
158
159    /// Calculate bytes currently queued in this channel.
160    pub fn queued_bytes(&self) -> u64 {
161        self.queued() * self.type_size as u64
162    }
163}
164
165/// Serializable version of channel statistics for JSON responses.
166#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct SerializableChannelStats {
168    /// ID of the channel.
169    pub id: String,
170    /// Optional user label; if None, display derives from `id`.
171    pub label: String,
172    /// Type of channel (includes capacity for bounded channels).
173    pub channel_type: ChannelType,
174    /// Current state of the channel.
175    pub state: ChannelState,
176    /// Number of messages sent through this channel.
177    pub sent_count: u64,
178    /// Number of messages received from this channel.
179    pub received_count: u64,
180    /// Current queue size (sent - received).
181    pub queued: u64,
182    /// Type name of messages in this channel.
183    pub type_name: String,
184    /// Size in bytes of the message type.
185    pub type_size: usize,
186    /// Total bytes sent through this channel.
187    pub total_bytes: u64,
188    /// Bytes currently queued in this channel.
189    pub queued_bytes: u64,
190}
191
192impl From<&ChannelStats> for SerializableChannelStats {
193    fn from(stats: &ChannelStats) -> Self {
194        let label = resolve_label(stats.id, stats.label);
195        Self {
196            id: stats.id.to_string(),
197            label,
198            channel_type: stats.channel_type,
199            state: stats.state,
200            sent_count: stats.sent_count,
201            received_count: stats.received_count,
202            queued: stats.queued(),
203            type_name: stats.type_name.to_string(),
204            type_size: stats.type_size,
205            total_bytes: stats.total_bytes(),
206            queued_bytes: stats.queued_bytes(),
207        }
208    }
209}
210
211impl ChannelStats {
212    fn new(
213        id: &'static str,
214        label: Option<&'static str>,
215        channel_type: ChannelType,
216        type_name: &'static str,
217        type_size: usize,
218    ) -> Self {
219        Self {
220            id,
221            label,
222            channel_type,
223            state: ChannelState::default(),
224            sent_count: 0,
225            received_count: 0,
226            type_name,
227            type_size,
228        }
229    }
230
231    /// Update the channel state based on sent/received counts.
232    /// Sets state to Full if sent > received, otherwise Active (unless explicitly closed).
233    fn update_state(&mut self) {
234        if self.state == ChannelState::Closed || self.state == ChannelState::Notified {
235            return;
236        }
237
238        if self.sent_count > self.received_count {
239            self.state = ChannelState::Full;
240        } else {
241            self.state = ChannelState::Active;
242        }
243    }
244}
245
246/// Events sent to the background statistics collection thread.
247#[derive(Debug)]
248pub(crate) enum StatsEvent {
249    Created {
250        id: &'static str,
251        display_label: Option<&'static str>,
252        channel_type: ChannelType,
253        type_name: &'static str,
254        type_size: usize,
255    },
256    MessageSent {
257        id: &'static str,
258    },
259    MessageReceived {
260        id: &'static str,
261    },
262    Closed {
263        id: &'static str,
264    },
265    Notified {
266        id: &'static str,
267    },
268}
269
270type StatsState = (
271    CbSender<StatsEvent>,
272    Arc<RwLock<HashMap<&'static str, ChannelStats>>>,
273);
274
275/// Global state for statistics collection.
276static STATS_STATE: OnceLock<StatsState> = OnceLock::new();
277
278/// Initialize the statistics collection system (called on first instrumented channel).
279/// Returns a reference to the global state.
280fn init_stats_state() -> &'static StatsState {
281    STATS_STATE.get_or_init(|| {
282        let (tx, rx) = unbounded::<StatsEvent>();
283        let stats_map = Arc::new(RwLock::new(HashMap::<&'static str, ChannelStats>::new()));
284        let stats_map_clone = Arc::clone(&stats_map);
285
286        std::thread::Builder::new()
287            .name("channel-stats-collector".into())
288            .spawn(move || {
289                while let Ok(event) = rx.recv() {
290                    let mut stats = stats_map_clone.write().unwrap();
291                    match event {
292                        StatsEvent::Created {
293                            id: key,
294                            display_label,
295                            channel_type,
296                            type_name,
297                            type_size,
298                        } => {
299                            stats.insert(
300                                key,
301                                ChannelStats::new(
302                                    key,
303                                    display_label,
304                                    channel_type,
305                                    type_name,
306                                    type_size,
307                                ),
308                            );
309                        }
310                        StatsEvent::MessageSent { id } => {
311                            if let Some(channel_stats) = stats.get_mut(id) {
312                                channel_stats.sent_count += 1;
313                                channel_stats.update_state();
314                            }
315                        }
316                        StatsEvent::MessageReceived { id } => {
317                            if let Some(channel_stats) = stats.get_mut(id) {
318                                channel_stats.received_count += 1;
319                                channel_stats.update_state();
320                            }
321                        }
322                        StatsEvent::Closed { id } => {
323                            if let Some(channel_stats) = stats.get_mut(id) {
324                                channel_stats.state = ChannelState::Closed;
325                            }
326                        }
327                        StatsEvent::Notified { id } => {
328                            if let Some(channel_stats) = stats.get_mut(id) {
329                                channel_stats.state = ChannelState::Notified;
330                            }
331                        }
332                    }
333                }
334            })
335            .expect("Failed to spawn channel-stats-collector thread");
336
337        // Spawn the metrics HTTP server in the background
338        // Check environment variable for custom port, default to 6770
339        let port = std::env::var("channels_console_METRICS_PORT")
340            .ok()
341            .and_then(|p| p.parse::<u16>().ok())
342            .unwrap_or(6770);
343        let addr = format!("127.0.0.1:{}", port);
344
345        std::thread::spawn(move || {
346            start_metrics_server(&addr);
347        });
348
349        (tx, stats_map)
350    })
351}
352
353fn resolve_label(id: &'static str, provided: Option<&'static str>) -> String {
354    if let Some(l) = provided {
355        return l.to_string();
356    }
357    if let Some(pos) = id.rfind(':') {
358        let (path, line_part) = id.split_at(pos);
359        let line = &line_part[1..];
360        format!("{}:{}", extract_filename(path), line)
361    } else {
362        extract_filename(id)
363    }
364}
365
366fn extract_filename(path: &str) -> String {
367    let components: Vec<&str> = path.split('/').collect();
368    if components.len() >= 2 {
369        format!(
370            "{}/{}",
371            components[components.len() - 2],
372            components[components.len() - 1]
373        )
374    } else {
375        path.to_string()
376    }
377}
378
379/// Format bytes into human-readable units (B, KB, MB, GB, TB).
380pub fn format_bytes(bytes: u64) -> String {
381    if bytes == 0 {
382        return "0 B".to_string();
383    }
384
385    const UNITS: &[&str] = &["B", "KB", "MB", "GB", "TB"];
386    let mut size = bytes as f64;
387    let mut unit_idx = 0;
388
389    while size >= 1024.0 && unit_idx < UNITS.len() - 1 {
390        size /= 1024.0;
391        unit_idx += 1;
392    }
393
394    if unit_idx == 0 {
395        format!("{} {}", bytes, UNITS[unit_idx])
396    } else {
397        format!("{:.1} {}", size, UNITS[unit_idx])
398    }
399}
400
401/// Trait for instrumenting channels.
402///
403/// This trait is not intended for direct use. Use the `instrument!` macro instead.
404#[doc(hidden)]
405pub trait Instrument {
406    type Output;
407    fn instrument(self, channel_id: &'static str, label: Option<&'static str>) -> Self::Output;
408}
409
410impl<T: Send + 'static> Instrument for (Sender<T>, Receiver<T>) {
411    type Output = (Sender<T>, Receiver<T>);
412    fn instrument(self, channel_id: &'static str, label: Option<&'static str>) -> Self::Output {
413        wrap_channel(self, channel_id, label)
414    }
415}
416
417impl<T: Send + 'static> Instrument for (UnboundedSender<T>, UnboundedReceiver<T>) {
418    type Output = (UnboundedSender<T>, UnboundedReceiver<T>);
419    fn instrument(self, channel_id: &'static str, label: Option<&'static str>) -> Self::Output {
420        wrap_unbounded(self, channel_id, label)
421    }
422}
423
424impl<T: Send + 'static> Instrument for (oneshot::Sender<T>, oneshot::Receiver<T>) {
425    type Output = (oneshot::Sender<T>, oneshot::Receiver<T>);
426    fn instrument(self, channel_id: &'static str, label: Option<&'static str>) -> Self::Output {
427        wrap_oneshot(self, channel_id, label)
428    }
429}
430
431/// Instrument a channel creation to wrap it with debugging proxies.
432/// Currently only supports bounded, unbounded and oneshot channels.
433///
434/// # Examples
435///
436/// ```
437/// use tokio::sync::mpsc;
438/// use channels_console::instrument;
439///
440/// #[tokio::main]
441/// async fn main() {
442///
443///    // Create channels normally
444///    let (tx, rx) = mpsc::channel::<String>(100);
445///
446///    // Instrument them only when the feature is enabled
447///    #[cfg(feature = "channels-console")]
448///    let (tx, rx) = channels_console::instrument!((tx, rx));
449///
450///    // The channel works exactly the same way
451///    tx.send("Hello".to_string()).await.unwrap();
452/// }
453/// ```
454///
455/// By default, channels are labeled with their file location and line number (e.g., `src/worker.rs:25`). You can provide custom labels for easier identification:
456///
457/// ```rust,no_run
458/// use tokio::sync::mpsc;
459/// use channels_console::instrument;
460/// let (tx, rx) = mpsc::channel::<String>(10);
461/// #[cfg(feature = "channels-console")]
462/// let (tx, rx) = channels_console::instrument!((tx, rx), label = "task-queue");
463/// ```
464///
465#[macro_export]
466macro_rules! instrument {
467    ($expr:expr) => {{
468        const CHANNEL_ID: &'static str = concat!(file!(), ":", line!());
469        $crate::Instrument::instrument($expr, CHANNEL_ID, None)
470    }};
471
472    ($expr:expr, label = $label:literal) => {{
473        const CHANNEL_ID: &'static str = concat!(file!(), ":", line!());
474        $crate::Instrument::instrument($expr, CHANNEL_ID, Some($label))
475    }};
476}
477
478fn get_channel_stats() -> HashMap<&'static str, ChannelStats> {
479    if let Some((_, stats_map)) = STATS_STATE.get() {
480        stats_map.read().unwrap().clone()
481    } else {
482        HashMap::new()
483    }
484}
485
486fn get_serializable_stats() -> Vec<SerializableChannelStats> {
487    let mut stats: Vec<SerializableChannelStats> = get_channel_stats()
488        .values()
489        .map(SerializableChannelStats::from)
490        .collect();
491
492    stats.sort_by(|a, b| a.id.cmp(&b.id));
493    stats
494}
495
496fn start_metrics_server(addr: &str) {
497    let server = match Server::http(addr) {
498        Ok(s) => s,
499        Err(e) => {
500            panic!("Failed to bind metrics server to {}: {}. Customize the port using the channels_console_METRICS_PORT environment variable.", addr, e);
501        }
502    };
503
504    println!("Channel metrics server listening on http://{}", addr);
505
506    for request in server.incoming_requests() {
507        if request.url() == "/metrics" {
508            let stats = get_serializable_stats();
509            match serde_json::to_string(&stats) {
510                Ok(json) => {
511                    let response = Response::from_string(json).with_header(
512                        tiny_http::Header::from_bytes(
513                            &b"Content-Type"[..],
514                            &b"application/json"[..],
515                        )
516                        .unwrap(),
517                    );
518                    let _ = request.respond(response);
519                }
520                Err(e) => {
521                    eprintln!("Failed to serialize metrics: {}", e);
522                    let response = Response::from_string(format!("Internal server error: {}", e))
523                        .with_status_code(500);
524                    let _ = request.respond(response);
525                }
526            }
527        } else {
528            let response = Response::from_string("Not found").with_status_code(404);
529            let _ = request.respond(response);
530        }
531    }
532}
533
534/// Builder for creating a ChannelsGuard with custom configuration.
535///
536/// # Examples
537///
538/// ```no_run
539/// use channels_console::{ChannelsGuardBuilder, Format};
540///
541/// let _guard = ChannelsGuardBuilder::new()
542///     .format(Format::JsonPretty)
543///     .build();
544/// // Statistics will be printed as pretty JSON when _guard is dropped
545/// ```
546pub struct ChannelsGuardBuilder {
547    format: Format,
548}
549
550impl ChannelsGuardBuilder {
551    /// Create a new channels guard builder.
552    pub fn new() -> Self {
553        Self {
554            format: Format::default(),
555        }
556    }
557
558    /// Set the output format for statistics.
559    ///
560    /// # Examples
561    ///
562    /// ```no_run
563    /// use channels_console::{ChannelsGuardBuilder, Format};
564    ///
565    /// let _guard = ChannelsGuardBuilder::new()
566    ///     .format(Format::Json)
567    ///     .build();
568    /// ```
569    pub fn format(mut self, format: Format) -> Self {
570        self.format = format;
571        self
572    }
573
574    /// Build and return the ChannelsGuard.
575    /// Statistics will be printed when the guard is dropped.
576    pub fn build(self) -> ChannelsGuard {
577        ChannelsGuard {
578            start_time: Instant::now(),
579            format: self.format,
580        }
581    }
582}
583
584impl Default for ChannelsGuardBuilder {
585    fn default() -> Self {
586        Self::new()
587    }
588}
589
590/// Guard for channel statistics collection.
591/// When dropped, prints a summary of all instrumented channels and their statistics.
592///
593/// Use `ChannelsGuardBuilder` to create a guard with custom configuration.
594///
595/// # Examples
596///
597/// ```no_run
598/// use channels_console::ChannelsGuard;
599///
600/// let _guard = ChannelsGuard::new();
601/// // Your code with instrumented channels here
602/// // Statistics will be printed when _guard is dropped
603/// ```
604pub struct ChannelsGuard {
605    start_time: Instant,
606    format: Format,
607}
608
609impl ChannelsGuard {
610    /// Create a new channels guard with default settings (table format).
611    /// Statistics will be printed when this guard is dropped.
612    ///
613    /// For custom configuration, use `ChannelsGuardBuilder::new()` instead.
614    pub fn new() -> Self {
615        Self {
616            start_time: Instant::now(),
617            format: Format::default(),
618        }
619    }
620
621    /// Set the output format for statistics.
622    /// This is a convenience method for backward compatibility.
623    ///
624    /// # Examples
625    ///
626    /// ```no_run
627    /// use channels_console::{ChannelsGuard, Format};
628    ///
629    /// let _guard = ChannelsGuard::new().format(Format::Json);
630    /// ```
631    pub fn format(mut self, format: Format) -> Self {
632        self.format = format;
633        self
634    }
635}
636
637impl Default for ChannelsGuard {
638    fn default() -> Self {
639        Self::new()
640    }
641}
642
643impl Drop for ChannelsGuard {
644    fn drop(&mut self) {
645        let elapsed = self.start_time.elapsed();
646        let stats = get_channel_stats();
647
648        if stats.is_empty() {
649            println!("\nNo instrumented channels found.");
650            return;
651        }
652
653        match self.format {
654            Format::Table => {
655                let mut table = Table::new();
656
657                table.add_row(Row::new(vec![
658                    Cell::new("Channel"),
659                    Cell::new("Type"),
660                    Cell::new("State"),
661                    Cell::new("Sent"),
662                    Cell::new("Mem"),
663                    Cell::new("Received"),
664                    Cell::new("Queued"),
665                    Cell::new("Mem"),
666                ]));
667
668                let mut sorted_stats: Vec<_> = stats.into_iter().collect();
669                sorted_stats.sort_by(|a, b| {
670                    let la = resolve_label(a.1.id, a.1.label);
671                    let lb = resolve_label(b.1.id, b.1.label);
672                    la.cmp(&lb)
673                });
674
675                for (_key, channel_stats) in sorted_stats {
676                    let label = resolve_label(channel_stats.id, channel_stats.label);
677                    table.add_row(Row::new(vec![
678                        Cell::new(&label),
679                        Cell::new(&channel_stats.channel_type.to_string()),
680                        Cell::new(channel_stats.state.as_str()),
681                        Cell::new(&channel_stats.sent_count.to_string()),
682                        Cell::new(&format_bytes(channel_stats.total_bytes())),
683                        Cell::new(&channel_stats.received_count.to_string()),
684                        Cell::new(&channel_stats.queued().to_string()),
685                        Cell::new(&format_bytes(channel_stats.queued_bytes())),
686                    ]));
687                }
688
689                println!(
690                    "\n=== Channel Statistics (runtime: {:.2}s) ===",
691                    elapsed.as_secs_f64()
692                );
693                table.printstd();
694            }
695            Format::Json => {
696                let serializable_stats = get_serializable_stats();
697                match serde_json::to_string(&serializable_stats) {
698                    Ok(json) => println!("{}", json),
699                    Err(e) => eprintln!("Failed to serialize statistics to JSON: {}", e),
700                }
701            }
702            Format::JsonPretty => {
703                let serializable_stats = get_serializable_stats();
704                match serde_json::to_string_pretty(&serializable_stats) {
705                    Ok(json) => println!("{}", json),
706                    Err(e) => eprintln!("Failed to serialize statistics to pretty JSON: {}", e),
707                }
708            }
709        }
710    }
711}