proto_tower_util/
debug_io_layer.rs1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4use tokio::io::{simplex, AsyncReadExt, AsyncWriteExt, ReadHalf, SimplexStream, WriteHalf};
5use tokio::task::{JoinError, JoinHandle};
6use tower::{Layer, Service};
7
8const MAX_BUF_SIZE: usize = 1024;
9const IDENTIFIER: &str = "[DEBUG_LAYER]";
10
11#[derive(Clone)]
12pub struct DebugIoService<InnerService>
13where
14 InnerService: Service<(ReadHalf<SimplexStream>, WriteHalf<SimplexStream>), Response = ()> + Clone + Send + 'static,
17{
18 inner: InnerService,
19}
20
21impl<InnerService, Reader, Writer> Service<(Reader, Writer)> for DebugIoService<InnerService>
22where
23 InnerService: Service<(ReadHalf<SimplexStream>, WriteHalf<SimplexStream>), Response = ()> + Clone + Send + 'static,
24 InnerService::Future: Future<Output = Result<InnerService::Response, InnerService::Error>> + Send + 'static,
25 InnerService::Error: Send + 'static,
26 Reader: AsyncReadExt + Send + Unpin + 'static,
28 Writer: AsyncWriteExt + Send + Unpin + 'static,
29{
30 type Response = ();
32 type Error = InnerService::Error;
33 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
34
35 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
36 self.inner.poll_ready(cx).map(|result| result.map_err(|err| err.into()))
37 }
38
39 fn call(&mut self, (mut input_reader, mut input_writer): (Reader, Writer)) -> Self::Future {
40 let mut inner = self.inner.clone();
41 Box::pin(async move {
42 let (read_svc, mut write_this) = simplex(MAX_BUF_SIZE);
44 let (mut read_this, write_svc) = simplex(MAX_BUF_SIZE);
45
46 let task: JoinHandle<Result<InnerService::Response, InnerService::Error>> = tokio::spawn(inner.call((read_svc, write_svc)));
50
51 let mut input_read_buffer = [0u8; 1024];
56 let mut output_read_buffer = [0u8; 1024];
57
58 loop {
59 tokio::select! {
60 result_sz = input_reader.read(&mut input_read_buffer) => {
61 match result_sz {
62 Ok(0) | Err(_) => {
63 eprintln!("{} Failed to read from input reader", IDENTIFIER);
64 break;
65 }
66 Ok(sz) => {
67 let have_read = &input_read_buffer[..sz];
68 eprintln!("{} Read {} bytes\n{}", IDENTIFIER, sz, escape_bytes_hex(have_read));
69 if let Err(e) = write_this.write_all(have_read).await {
70 eprintln!("{} Failed to write to downstream service: {:?}", IDENTIFIER, e);
71 break;
72 }
73 }
74 }
75 }
76 result_sz = read_this.read(&mut output_read_buffer) => {
77 match result_sz {
78 Ok(0) | Err(_) => {
79 eprintln!("{} Failed to read from downstream reader", IDENTIFIER);
80 break;
81 }
82 Ok(sz) => {
83 let have_read = &output_read_buffer[..sz];
84 eprintln!("{} Read {} bytes: {}", IDENTIFIER, sz, escape_bytes_hex(have_read));
85 if let Err(e) = input_writer.write_all(have_read).await {
86 eprintln!("{} Failed to write to upstream service: {:?}", IDENTIFIER, e);
87 break;
88 }
89 }
90 }
91 }
92 }
93 if task.is_finished() {
94 break;
95 }
96 }
97 eprintln!("{} Going into shutdown", IDENTIFIER);
98 drop(input_reader);
99 drop(input_writer);
100 drop(read_this);
101 drop(write_this);
102
103 match task.await {
105 Ok(s) => {
106 eprintln!("{} Task completed", IDENTIFIER);
107 s
108 }
109 Err(e) => {
110 eprintln!("{} Task failed: {:?}", IDENTIFIER, e);
111 let r: Result<(), JoinError> = Err(e);
112 r.unwrap();
113 unreachable!("We just unwrapped a known error");
114 }
115 }
116 })
117 }
118}
119
120#[derive(Default)]
123pub struct DebugIoLayer {}
124
125impl<InnerService> Layer<InnerService> for DebugIoLayer
126where
127 InnerService: Service<(ReadHalf<SimplexStream>, WriteHalf<SimplexStream>), Response = ()> + Clone + Send + 'static,
128 InnerService::Future: Future<Output = Result<InnerService::Response, InnerService::Error>> + Send + 'static,
129 InnerService::Error: Send + 'static,
130{
131 type Service = DebugIoService<InnerService>;
132
133 fn layer(&self, inner: InnerService) -> Self::Service {
134 DebugIoService { inner }
135 }
136}
137
138fn escape_bytes_hex(input: &[u8]) -> String {
139 input
140 .iter()
141 .map(|&b| {
142 if b.is_ascii_graphic() || b == b' ' {
143 (b as char).to_string() } else {
145 format!("\\x{:02x}", b) }
147 })
148 .collect()
149}