defmt_logger_tcp/
lib.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
//! # A defmt logger that sends logs over TCP.
//!
//! ## Usage
//!
//! ```rust
//! use defmt::info;
//! use std::thread;
//!
//! thread::spawn(|| {
//!     defmt_logger_tcp::init().unwrap();
//! });
//!
//! info!("Hello, world!");
//! ```

use defmt::Encoder;

#[cfg(feature = "std")]
use std::{
    io::{self, Write},
    net::{SocketAddr, TcpListener, TcpStream},
    sync::{
        atomic::{AtomicBool, Ordering},
        Mutex,
    },
    time::Duration,
};

static TAKEN: AtomicBool = AtomicBool::new(false);
static PENDING_STREAMS: Mutex<Vec<(TcpStream, Encoder)>> = Mutex::new(Vec::new());
static STREAMS: Mutex<Vec<(TcpStream, Encoder)>> = Mutex::new(Vec::new());

/// Initialize the logger, and start listening for connections on `localhost:19021`.
pub fn init() -> io::Result<()> {
    let listener = TcpListener::bind("localhost:19021")?;

    for stream in listener.incoming() {
        let stream = stream?;

        // Don't block excessively on writes.
        let timeout = Duration::from_millis(100);
        stream.set_write_timeout(Some(timeout))?;

        let mut streams = PENDING_STREAMS.lock().unwrap();
        streams.push((stream, Encoder::new()));
    }

    Ok(())
}

#[defmt::global_logger]
struct Logger;

unsafe impl defmt::Logger for Logger {
    fn acquire() {
        if TAKEN.load(Ordering::Relaxed) {
            panic!("defmt logger taken reentrantly");
        }

        TAKEN.store(true, Ordering::Relaxed);

        on_all_streams(|stream, encoder| {
            let mut result: io::Result<()> = Ok(());
            encoder.start_frame(|bytes| write_stream(stream, bytes, &mut result));
            result
        });
    }

    unsafe fn release() {
        on_all_streams(|stream, encoder| {
            let mut result: io::Result<()> = Ok(());
            encoder.end_frame(|bytes| write_stream(stream, bytes, &mut result));
            result
        });

        // Move pending streams to active streams.
        STREAMS
            .lock()
            .unwrap()
            .extend(PENDING_STREAMS.lock().unwrap().drain(..));

        TAKEN.store(false, Ordering::Relaxed);
    }

    unsafe fn write(bytes: &[u8]) {
        on_all_streams(|stream, encoder| {
            let mut result: io::Result<()> = Ok(());
            encoder.write(bytes, |bytes| write_stream(stream, bytes, &mut result));
            result
        });
    }

    unsafe fn flush() {
        on_all_streams(|stream, _| stream.flush());
    }
}

fn on_all_streams(op: impl Fn(&mut TcpStream, &mut Encoder) -> io::Result<()>) {
    let mut streams = STREAMS.lock().unwrap();

    let mut streams_to_drop: Vec<SocketAddr> = Vec::new();
    for (stream, encoder) in streams.iter_mut() {
        if op(stream, encoder).is_err() {
            streams_to_drop.push(stream.peer_addr().unwrap());
        }
    }

    for stream in streams_to_drop {
        streams.retain(|(s, _)| s.peer_addr().unwrap() != stream);
    }
}

fn write_stream(stream: &mut TcpStream, bytes: &[u8], result: &mut io::Result<()>) {
    if let Err(e) = stream.write_all(bytes) {
        *result = Err(e);
    }
    *result = Ok(());
}