proto_tower_util/
debug_io_layer.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4use tokio::io::{simplex, AsyncReadExt, AsyncWriteExt, ReadHalf, SimplexStream, WriteHalf};
5use tokio::task::{JoinError, JoinHandle};
6use tower::{Layer, Service};
7
8const MAX_BUF_SIZE: usize = 1024;
9const IDENTIFIER: &str = "[DEBUG_LAYER]";
10
11#[derive(Clone)]
12pub struct DebugIoService<InnerService>
13where
14    // We are explicit with the types, since we know the implementation we are providing downstream
15    // However, the downstream service (such as another instance of this layer) can be generic
16    InnerService: Service<(ReadHalf<SimplexStream>, WriteHalf<SimplexStream>), Response = ()> + Clone + Send + 'static,
17{
18    inner: InnerService,
19}
20
21impl<InnerService, Reader, Writer> Service<(Reader, Writer)> for DebugIoService<InnerService>
22where
23    InnerService: Service<(ReadHalf<SimplexStream>, WriteHalf<SimplexStream>), Response = ()> + Clone + Send + 'static,
24    InnerService::Future: Future<Output = Result<InnerService::Response, InnerService::Error>> + Send + 'static,
25    InnerService::Error: Send + 'static,
26    // We need Unpin because otherwise we cannot access the methods of these traits
27    Reader: AsyncReadExt + Send + Unpin + 'static,
28    Writer: AsyncWriteExt + Send + Unpin + 'static,
29{
30    // Since all communication is done via the readers and writers, there isn't really a need for a return type
31    type Response = ();
32    type Error = InnerService::Error;
33    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
34
35    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
36        self.inner.poll_ready(cx).map(|result| result.map_err(|err| err.into()))
37    }
38
39    fn call(&mut self, (mut input_reader, mut input_writer): (Reader, Writer)) -> Self::Future {
40        let mut inner = self.inner.clone();
41        Box::pin(async move {
42            // We create the read/write pairs that we will use to communicate with the downstream service
43            let (read_svc, mut write_this) = simplex(MAX_BUF_SIZE);
44            let (mut read_this, write_svc) = simplex(MAX_BUF_SIZE);
45
46            // Now we spawn the downstream inner service because otherwise we would need to poll it to make it progress
47            // Calling await on it directly would block the current task, preventing us from relaying messages
48            // Because we have so many generics, my IDE isn't prompting with types, so I declared them explicitly here.
49            let task: JoinHandle<Result<InnerService::Response, InnerService::Error>> = tokio::spawn(inner.call((read_svc, write_svc)));
50
51            // Ideally everything below would be a loop, but we won't bother with that
52            // We would need to handle more conditions than we would like for the purpose of the example
53
54            // Read from the layer input
55            let mut input_read_buffer = [0u8; 1024];
56            let mut output_read_buffer = [0u8; 1024];
57
58            loop {
59                tokio::select! {
60                    result_sz = input_reader.read(&mut input_read_buffer) => {
61                        match result_sz {
62                            Ok(0) | Err(_) => {
63                                eprintln!("{} Failed to read from input reader", IDENTIFIER);
64                                break;
65                            }
66                            Ok(sz) => {
67                                let have_read = &input_read_buffer[..sz];
68                                eprintln!("{} Read {} bytes\n{}", IDENTIFIER, sz, escape_bytes_hex(have_read));
69                                if let Err(e) = write_this.write_all(have_read).await {
70                                    eprintln!("{} Failed to write to downstream service: {:?}", IDENTIFIER, e);
71                                    break;
72                                }
73                            }
74                        }
75                    }
76                    result_sz = read_this.read(&mut output_read_buffer) => {
77                        match result_sz {
78                            Ok(0) | Err(_) => {
79                                eprintln!("{} Failed to read from downstream reader", IDENTIFIER);
80                                break;
81                            }
82                            Ok(sz) => {
83                                let have_read = &output_read_buffer[..sz];
84                                eprintln!("{} Read {} bytes: {}", IDENTIFIER, sz, escape_bytes_hex(have_read));
85                                if let Err(e) = input_writer.write_all(have_read).await {
86                                    eprintln!("{} Failed to write to upstream service: {:?}", IDENTIFIER, e);
87                                    break;
88                                }
89                            }
90                        }
91                    }
92                }
93                if task.is_finished() {
94                    break;
95                }
96            }
97            eprintln!("{} Going into shutdown", IDENTIFIER);
98            drop(input_reader);
99            drop(input_writer);
100            drop(read_this);
101            drop(write_this);
102
103            // Let's politely wait for the task to complete in case it has errored
104            match task.await {
105                Ok(s) => {
106                    eprintln!("{} Task completed", IDENTIFIER);
107                    s
108                }
109                Err(e) => {
110                    eprintln!("{} Task failed: {:?}", IDENTIFIER, e);
111                    let r: Result<(), JoinError> = Err(e);
112                    r.unwrap();
113                    unreachable!("We just unwrapped a known error");
114                }
115            }
116        })
117    }
118}
119
120/// I/O Pattern Layer takes a (read, write) (as it would for servers) and will also send down
121/// a (read, write) pair (as you would do for clients)
122#[derive(Default)]
123pub struct DebugIoLayer {}
124
125impl<InnerService> Layer<InnerService> for DebugIoLayer
126where
127    InnerService: Service<(ReadHalf<SimplexStream>, WriteHalf<SimplexStream>), Response = ()> + Clone + Send + 'static,
128    InnerService::Future: Future<Output = Result<InnerService::Response, InnerService::Error>> + Send + 'static,
129    InnerService::Error: Send + 'static,
130{
131    type Service = DebugIoService<InnerService>;
132
133    fn layer(&self, inner: InnerService) -> Self::Service {
134        DebugIoService { inner }
135    }
136}
137
138fn escape_bytes_hex(input: &[u8]) -> String {
139    input
140        .iter()
141        .map(|&b| {
142            if b.is_ascii_graphic() || b == b' ' {
143                (b as char).to_string() // Keep printable ASCII
144            } else {
145                format!("\\x{:02x}", b) // Convert others to hex
146            }
147        })
148        .collect()
149}