adb_client_tokio/
connection.rs

1use futures::sink::SinkExt;
2use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
3use tokio::net::{TcpStream, UnixStream};
4use tokio_stream::StreamExt;
5
6use std::pin::Pin;
7use std::str;
8use std::task::{Context, Poll};
9use tokio_util::bytes::{Buf, BytesMut};
10use tokio_util::codec::{Decoder, Encoder, FramedRead, FramedWrite};
11
12use crate::util::{AdbError, Result};
13
14const ADB_REQUEST_HEADER_LENGTH: usize = 4;
15const ADB_RESPONSE_STATUS_LENGTH: usize = 4;
16const ADB_RESPONSE_HEADER_LENGTH: usize = 8;
17
18pub(crate) const MAX_MESSAGE_SIZE: usize = 8 * 1024 * 1024;
19
20#[derive(Debug)]
21pub(crate) struct AdbRequest {
22    payload: Vec<u8>,
23}
24
25impl AdbRequest {
26    pub(crate) fn new(cmd: &str) -> AdbRequest {
27        AdbRequest {
28            payload: cmd.as_bytes().to_vec(),
29        }
30    }
31}
32
33#[derive(Debug, PartialEq)]
34pub(crate) enum AdbResponse {
35    OKAY { message: String },
36    FAIL { message: String },
37}
38
39#[derive(Debug)]
40pub(crate) enum AdbResponseDecoderImpl {
41    Status,
42    StatusLengthPayload,
43    StatusPayloadNewline,
44}
45
46#[derive(Debug)]
47pub(crate) struct AdbResponseDecoder {
48    pub(crate) decoder_impl: AdbResponseDecoderImpl,
49}
50
51impl AdbResponseDecoder {
52    pub(crate) fn new() -> AdbResponseDecoder {
53        AdbResponseDecoder {
54            decoder_impl: AdbResponseDecoderImpl::StatusLengthPayload,
55        }
56    }
57}
58
59impl AdbResponseDecoder {
60    fn decode_status(&mut self, src: &mut BytesMut) -> Result<Option<AdbResponse>> {
61        if src.len() < ADB_RESPONSE_STATUS_LENGTH {
62            // Not enough data to read length marker.
63            return Ok(None);
64        }
65
66        let status = src[0..4].to_vec();
67        let status = str::from_utf8(&status)?;
68
69        match status {
70            "OKAY" => {
71                src.advance(ADB_RESPONSE_STATUS_LENGTH);
72                Ok(Some(AdbResponse::OKAY {
73                    message: "".to_string(),
74                }))
75            },
76            "FAIL" => {
77                if src.len() < ADB_RESPONSE_HEADER_LENGTH {
78                    return Ok(None);
79                }
80
81                let length: [u8; 4] = [src[4], src[5], src[6], src[7]];
82                let length: usize = usize::from_str_radix(std::str::from_utf8(&length)?, 16)?;
83
84                if src.len() < ADB_RESPONSE_HEADER_LENGTH + length {
85                    return Ok(None);
86                }
87
88                let message: Vec<u8> = src[ADB_RESPONSE_HEADER_LENGTH..ADB_RESPONSE_HEADER_LENGTH + length].to_vec();
89                let message = String::from_utf8_lossy(&message).to_string();
90                src.advance(ADB_RESPONSE_HEADER_LENGTH + length);
91
92                Ok(Some(AdbResponse::FAIL { message }))
93            }
94            _ => Err(AdbError::UnknownResponseStatus(status.into())),
95        }
96    }
97
98    fn decode_status_and_payload(&mut self, src: &mut BytesMut) -> Result<Option<AdbResponse>> {
99        if src.len() < ADB_RESPONSE_HEADER_LENGTH {
100            // Not enough data to read length marker.
101            return Ok(None);
102        }
103
104        // Read length marker.
105        let length: [u8; 4] = [src[4], src[5], src[6], src[7]];
106        let length: usize = usize::from_str_radix(std::str::from_utf8(&length)?, 16)?;
107
108        // Check that the length is not too large to avoid a denial of
109        // service attack where the server runs out of memory.
110        if length > MAX_MESSAGE_SIZE {
111            return Err(AdbError::IOError(std::io::Error::new(
112                std::io::ErrorKind::InvalidData,
113                format!("Frame of length {} is too large.", length),
114            )));
115        }
116
117        if src.len() < ADB_RESPONSE_HEADER_LENGTH + length {
118            // The full string has not yet arrived.
119            //
120            // We reserve more space in the buffer. This is not strictly
121            // necessary, but is a good idea performance-wise.
122            src.reserve(ADB_RESPONSE_HEADER_LENGTH + length - src.len());
123
124            // We inform the Framed that we need more bytes to form the next
125            // frame.
126            return Ok(None);
127        }
128
129        // Use advance to modify src such that it no longer contains
130        // this frame.
131        let status = src[0..4].to_vec();
132        let status = str::from_utf8(&status)?;
133        let payload: Vec<u8> =
134            src[ADB_RESPONSE_HEADER_LENGTH..ADB_RESPONSE_HEADER_LENGTH + length].to_vec();
135        let message = String::from_utf8_lossy(&payload).to_string();
136        src.advance(ADB_RESPONSE_HEADER_LENGTH + length);
137
138        // Read string from src
139        match status {
140            "OKAY" => Ok(Some(AdbResponse::OKAY { message })),
141            "FAIL" => Ok(Some(AdbResponse::FAIL { message })),
142            _ => Err(AdbError::UnknownResponseStatus(status.into())),
143        }
144    }
145
146    fn decode_status_and_read_until_new_line(
147        &mut self,
148        src: &mut BytesMut,
149    ) -> Result<Option<AdbResponse>> {
150        if src.len() < ADB_RESPONSE_STATUS_LENGTH {
151            // Not enough data to read length marker.
152            return Ok(None);
153        }
154
155        // Check that the length is not too large to avoid a denial of
156        // service attack where the server runs out of memory.
157        if src.len() > MAX_MESSAGE_SIZE {
158            return Err(AdbError::IOError(std::io::Error::new(
159                std::io::ErrorKind::InvalidData,
160                format!("Frame of length {} is too large.", src.len()),
161            )));
162        }
163
164        let status = src[0..ADB_RESPONSE_STATUS_LENGTH].to_vec();
165        let status = str::from_utf8(&status)?;
166
167        let newline_offset = src[ADB_RESPONSE_STATUS_LENGTH..src.len()]
168            .iter()
169            .position(|b| *b == b'\n');
170
171        match newline_offset {
172            Some(offset) => {
173                let message =
174                    src[ADB_RESPONSE_STATUS_LENGTH..ADB_RESPONSE_STATUS_LENGTH + offset].to_vec();
175                let message = String::from_utf8_lossy(&message).to_string();
176                src.advance(ADB_RESPONSE_STATUS_LENGTH + offset + 1);
177
178                match status {
179                    "OKAY" => Ok(Some(AdbResponse::OKAY { message })),
180                    "FAIL" => Ok(Some(AdbResponse::FAIL { message })),
181                    _ => Err(AdbError::UnknownResponseStatus(status.into())),
182                }
183            }
184            None => Ok(None),
185        }
186    }
187}
188
189impl Decoder for AdbResponseDecoder {
190    type Item = AdbResponse;
191    type Error = AdbError;
192
193    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
194        if src.len() == 0
195        {
196            return Ok(None);
197        }
198
199        let response = match self.decoder_impl {
200            AdbResponseDecoderImpl::Status => self.decode_status(src),
201            AdbResponseDecoderImpl::StatusLengthPayload => self.decode_status_and_payload(src),
202            AdbResponseDecoderImpl::StatusPayloadNewline => {
203                self.decode_status_and_read_until_new_line(src)
204            }
205        };
206
207        println!("decofing:\n{}", pretty_hex::pretty_hex(&src));
208
209        response
210    }
211}
212
213#[derive(Debug)]
214pub(crate) struct AdbRequestEncoder {}
215
216impl AdbRequestEncoder {
217    pub(crate) fn new() -> AdbRequestEncoder {
218        AdbRequestEncoder {}
219    }
220}
221
222impl Encoder<AdbRequest> for AdbRequestEncoder {
223    type Error = AdbError;
224
225    fn encode(&mut self, msg: AdbRequest, dst: &mut BytesMut) -> Result<()> {
226        // Don't send a string if it is longer than the other end will
227        // accept.
228        let length = msg.payload.len();
229        if length > MAX_MESSAGE_SIZE {
230            return Err(AdbError::IOError(std::io::Error::new(
231                std::io::ErrorKind::InvalidData,
232                format!("Frame of length {} is too large.", length),
233            )));
234        }
235
236        // Reserve space in the buffer.
237        dst.reserve(ADB_REQUEST_HEADER_LENGTH + length);
238
239        let length_hex = format!("{:04x}", length);
240
241        // Write the length and string to the buffer.
242        dst.extend_from_slice(&length_hex.as_bytes());
243        dst.extend_from_slice(&msg.payload);
244
245        println!("sending {}", pretty_hex::pretty_hex(&dst));
246
247        Ok(())
248    }
249}
250
251/// AdbClientStream represents a stream that can be used to communicate with an ADB server.
252#[derive(Debug)]
253pub enum AdbClientStream {
254    /// A TCP stream
255    Tcp(TcpStream),
256    /// A Unix stream
257    Unix(UnixStream),
258}
259
260impl AsyncRead for AdbClientStream {
261    fn poll_read(
262        self: std::pin::Pin<&mut Self>,
263        cx: &mut std::task::Context<'_>,
264        buf: &mut ReadBuf<'_>,
265    ) -> Poll<std::io::Result<()>> {
266        match self.get_mut() {
267            AdbClientStream::Tcp(s) => Pin::new(s).poll_read(cx, buf),
268            AdbClientStream::Unix(s) => Pin::new(s).poll_read(cx, buf),
269        }
270    }
271}
272
273impl AsyncWrite for AdbClientStream {
274    fn poll_write(
275        self: Pin<&mut Self>,
276        cx: &mut Context<'_>,
277        buf: &[u8],
278    ) -> Poll<std::result::Result<usize, std::io::Error>> {
279        match self.get_mut() {
280            AdbClientStream::Tcp(s) => Pin::new(s).poll_write(cx, buf),
281            AdbClientStream::Unix(s) => Pin::new(s).poll_write(cx, buf),
282        }
283    }
284
285    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::result::Result<(), std::io::Error>> {
286        match self.get_mut() {
287            AdbClientStream::Tcp(s) => Pin::new(s).poll_flush(cx),
288            AdbClientStream::Unix(s) => Pin::new(s).poll_flush(cx),
289        }
290    }
291
292    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::result::Result<(), std::io::Error>> {
293        match self.get_mut() {
294            AdbClientStream::Tcp(s) => Pin::new(s).poll_shutdown(cx),
295            AdbClientStream::Unix(s) => Pin::new(s).poll_shutdown(cx),
296        }
297    }
298}
299
300/// AdbClientStreamOwnedReadHalf represents the read half of an AdbClientStream.
301#[derive(Debug)]
302pub enum AdbClientStreamOwnedReadHalf {
303    /// A TCP read half
304    Tcp(tokio::net::tcp::OwnedReadHalf),
305    /// A Unix read half
306    Unix(tokio::net::unix::OwnedReadHalf),
307}
308
309/// AdbClientStreamOwnedWriteHalf represents the write half of an AdbClientStream.
310#[derive(Debug)]
311pub enum AdbClientStreamOwnedWriteHalf {
312    /// A TCP write half
313    Tcp(tokio::net::tcp::OwnedWriteHalf),
314    /// A Unix write half
315    Unix(tokio::net::unix::OwnedWriteHalf),
316}
317
318impl AdbClientStream {
319
320    /// Split the stream into a read half and a write half.
321    pub fn into_split(self) -> (AdbClientStreamOwnedReadHalf, AdbClientStreamOwnedWriteHalf) {
322        match self {
323            AdbClientStream::Tcp(s) => {
324                let (r, w) = s.into_split();
325                (AdbClientStreamOwnedReadHalf::Tcp(r), AdbClientStreamOwnedWriteHalf::Tcp(w))
326            }
327            AdbClientStream::Unix(s) => {
328                let (r, w) = s.into_split();
329                (AdbClientStreamOwnedReadHalf::Unix(r), AdbClientStreamOwnedWriteHalf::Unix(w))
330            }
331        }
332    }
333}
334
335impl AdbClientStreamOwnedReadHalf {
336    /// Reunite the read half with the write half to recreate the original stream.
337    pub fn reunite(self, w: AdbClientStreamOwnedWriteHalf) -> Result<AdbClientStream> {
338        match self {
339            AdbClientStreamOwnedReadHalf::Tcp(r) => {
340                let w = match w {
341                    AdbClientStreamOwnedWriteHalf::Tcp(w) => w,
342                    _ => panic!("Invalid write half"),
343                };
344                Ok(AdbClientStream::Tcp(r.reunite(w).unwrap()))
345            }
346            AdbClientStreamOwnedReadHalf::Unix(r) => {
347                let w = match w {
348                    AdbClientStreamOwnedWriteHalf::Unix(w) => w,
349                    _ => panic!("Invalid write half"),
350                };
351                Ok(AdbClientStream::Unix(r.reunite(w).unwrap()))
352            }
353        }
354
355    }
356}
357
358impl AsyncRead for AdbClientStreamOwnedReadHalf {
359    fn poll_read(
360        self: std::pin::Pin<&mut Self>,
361        cx: &mut std::task::Context<'_>,
362        buf: &mut ReadBuf<'_>,
363    ) -> Poll<std::io::Result<()>> {
364        match self.get_mut() {
365            AdbClientStreamOwnedReadHalf::Tcp(s) => Pin::new(s).poll_read(cx, buf),
366            AdbClientStreamOwnedReadHalf::Unix(s) => Pin::new(s).poll_read(cx, buf),
367        }
368    }
369}
370
371impl AsyncWrite for AdbClientStreamOwnedWriteHalf {
372    #[inline]
373    fn poll_write(
374        self: Pin<&mut Self>,
375        cx: &mut Context<'_>,
376        buf: &[u8],
377    ) -> Poll<std::io::Result<usize>> {
378        match self.get_mut() {
379            AdbClientStreamOwnedWriteHalf::Unix(x) => Pin::new(x).poll_write(cx, buf),
380            AdbClientStreamOwnedWriteHalf::Tcp(x) => Pin::new(x).poll_write(cx, buf),
381        }
382    }
383
384    #[inline]
385    fn poll_write_vectored(
386        self: Pin<&mut Self>,
387        cx: &mut Context<'_>,
388        bufs: &[std::io::IoSlice<'_>],
389    ) -> Poll<std::io::Result<usize>> {
390        match self.get_mut() {
391            AdbClientStreamOwnedWriteHalf::Unix(x) => Pin::new(x).poll_write_vectored(cx, bufs),
392            AdbClientStreamOwnedWriteHalf::Tcp(x) => Pin::new(x).poll_write_vectored(cx, bufs),
393        }
394    }
395
396    #[inline]
397    fn is_write_vectored(&self) -> bool {
398        match self {
399            AdbClientStreamOwnedWriteHalf::Unix(x) => x.is_write_vectored(),
400            AdbClientStreamOwnedWriteHalf::Tcp(x) => x.is_write_vectored(),
401        }
402    }
403
404    #[inline]
405    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
406        match self.get_mut() {
407            AdbClientStreamOwnedWriteHalf::Unix(x) => Pin::new(x).poll_flush(cx),
408            AdbClientStreamOwnedWriteHalf::Tcp(x) => Pin::new(x).poll_flush(cx),
409        }
410    }
411
412    #[inline]
413    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
414        match self.get_mut() {
415            AdbClientStreamOwnedWriteHalf::Unix(x) => Pin::new(x).poll_shutdown(cx),
416            AdbClientStreamOwnedWriteHalf::Tcp(x) => Pin::new(x).poll_shutdown(cx),
417        }
418    }
419}
420
421#[derive(Debug)]
422pub(crate)  struct AdbClientConnection
423{
424    pub(crate) reader: FramedRead<AdbClientStreamOwnedReadHalf, AdbResponseDecoder>,
425    pub(crate) writer: FramedWrite<AdbClientStreamOwnedWriteHalf, AdbRequestEncoder>,
426}
427
428impl<'a> AdbClientConnection
429{
430    pub(crate) fn new(socket: AdbClientStream) -> AdbClientConnection
431    {
432        let (r, w) = socket.into_split();
433
434        let reader = FramedRead::new(r, AdbResponseDecoder::new());
435        let writer = FramedWrite::new(w, AdbRequestEncoder::new());
436
437        return AdbClientConnection { reader, writer };
438    }
439
440    pub(crate) async fn send(&mut self, request: AdbRequest) -> Result<()> {
441        self.writer.send(request).await
442    }
443
444    pub(crate) async fn next(&mut self) -> Result<String> {
445        match self.reader.next().await {
446            Some(Ok(AdbResponse::OKAY { message })) => Ok(message),
447            Some(Ok(AdbResponse::FAIL { message })) => Err(AdbError::FailedResponseStatus(message)),
448            Some(Err(e)) => Err(e),
449            None => Err(AdbError::FailedResponseStatus("No response".into())),
450        }
451    }
452}