1use futures::sink::SinkExt;
2use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
3use tokio::net::{TcpStream, UnixStream};
4use tokio_stream::StreamExt;
5
6use std::pin::Pin;
7use std::str;
8use std::task::{Context, Poll};
9use tokio_util::bytes::{Buf, BytesMut};
10use tokio_util::codec::{Decoder, Encoder, FramedRead, FramedWrite};
11
12use crate::util::{AdbError, Result};
13
14const ADB_REQUEST_HEADER_LENGTH: usize = 4;
15const ADB_RESPONSE_STATUS_LENGTH: usize = 4;
16const ADB_RESPONSE_HEADER_LENGTH: usize = 8;
17
18pub(crate) const MAX_MESSAGE_SIZE: usize = 8 * 1024 * 1024;
19
20#[derive(Debug)]
21pub(crate) struct AdbRequest {
22 payload: Vec<u8>,
23}
24
25impl AdbRequest {
26 pub(crate) fn new(cmd: &str) -> AdbRequest {
27 AdbRequest {
28 payload: cmd.as_bytes().to_vec(),
29 }
30 }
31}
32
33#[derive(Debug, PartialEq)]
34pub(crate) enum AdbResponse {
35 OKAY { message: String },
36 FAIL { message: String },
37}
38
39#[derive(Debug)]
40pub(crate) enum AdbResponseDecoderImpl {
41 Status,
42 StatusLengthPayload,
43 StatusPayloadNewline,
44}
45
46#[derive(Debug)]
47pub(crate) struct AdbResponseDecoder {
48 pub(crate) decoder_impl: AdbResponseDecoderImpl,
49}
50
51impl AdbResponseDecoder {
52 pub(crate) fn new() -> AdbResponseDecoder {
53 AdbResponseDecoder {
54 decoder_impl: AdbResponseDecoderImpl::StatusLengthPayload,
55 }
56 }
57}
58
59impl AdbResponseDecoder {
60 fn decode_status(&mut self, src: &mut BytesMut) -> Result<Option<AdbResponse>> {
61 if src.len() < ADB_RESPONSE_STATUS_LENGTH {
62 return Ok(None);
64 }
65
66 let status = src[0..4].to_vec();
67 let status = str::from_utf8(&status)?;
68
69 match status {
70 "OKAY" => {
71 src.advance(ADB_RESPONSE_STATUS_LENGTH);
72 Ok(Some(AdbResponse::OKAY {
73 message: "".to_string(),
74 }))
75 },
76 "FAIL" => {
77 if src.len() < ADB_RESPONSE_HEADER_LENGTH {
78 return Ok(None);
79 }
80
81 let length: [u8; 4] = [src[4], src[5], src[6], src[7]];
82 let length: usize = usize::from_str_radix(std::str::from_utf8(&length)?, 16)?;
83
84 if src.len() < ADB_RESPONSE_HEADER_LENGTH + length {
85 return Ok(None);
86 }
87
88 let message: Vec<u8> = src[ADB_RESPONSE_HEADER_LENGTH..ADB_RESPONSE_HEADER_LENGTH + length].to_vec();
89 let message = String::from_utf8_lossy(&message).to_string();
90 src.advance(ADB_RESPONSE_HEADER_LENGTH + length);
91
92 Ok(Some(AdbResponse::FAIL { message }))
93 }
94 _ => Err(AdbError::UnknownResponseStatus(status.into())),
95 }
96 }
97
98 fn decode_status_and_payload(&mut self, src: &mut BytesMut) -> Result<Option<AdbResponse>> {
99 if src.len() < ADB_RESPONSE_HEADER_LENGTH {
100 return Ok(None);
102 }
103
104 let length: [u8; 4] = [src[4], src[5], src[6], src[7]];
106 let length: usize = usize::from_str_radix(std::str::from_utf8(&length)?, 16)?;
107
108 if length > MAX_MESSAGE_SIZE {
111 return Err(AdbError::IOError(std::io::Error::new(
112 std::io::ErrorKind::InvalidData,
113 format!("Frame of length {} is too large.", length),
114 )));
115 }
116
117 if src.len() < ADB_RESPONSE_HEADER_LENGTH + length {
118 src.reserve(ADB_RESPONSE_HEADER_LENGTH + length - src.len());
123
124 return Ok(None);
127 }
128
129 let status = src[0..4].to_vec();
132 let status = str::from_utf8(&status)?;
133 let payload: Vec<u8> =
134 src[ADB_RESPONSE_HEADER_LENGTH..ADB_RESPONSE_HEADER_LENGTH + length].to_vec();
135 let message = String::from_utf8_lossy(&payload).to_string();
136 src.advance(ADB_RESPONSE_HEADER_LENGTH + length);
137
138 match status {
140 "OKAY" => Ok(Some(AdbResponse::OKAY { message })),
141 "FAIL" => Ok(Some(AdbResponse::FAIL { message })),
142 _ => Err(AdbError::UnknownResponseStatus(status.into())),
143 }
144 }
145
146 fn decode_status_and_read_until_new_line(
147 &mut self,
148 src: &mut BytesMut,
149 ) -> Result<Option<AdbResponse>> {
150 if src.len() < ADB_RESPONSE_STATUS_LENGTH {
151 return Ok(None);
153 }
154
155 if src.len() > MAX_MESSAGE_SIZE {
158 return Err(AdbError::IOError(std::io::Error::new(
159 std::io::ErrorKind::InvalidData,
160 format!("Frame of length {} is too large.", src.len()),
161 )));
162 }
163
164 let status = src[0..ADB_RESPONSE_STATUS_LENGTH].to_vec();
165 let status = str::from_utf8(&status)?;
166
167 let newline_offset = src[ADB_RESPONSE_STATUS_LENGTH..src.len()]
168 .iter()
169 .position(|b| *b == b'\n');
170
171 match newline_offset {
172 Some(offset) => {
173 let message =
174 src[ADB_RESPONSE_STATUS_LENGTH..ADB_RESPONSE_STATUS_LENGTH + offset].to_vec();
175 let message = String::from_utf8_lossy(&message).to_string();
176 src.advance(ADB_RESPONSE_STATUS_LENGTH + offset + 1);
177
178 match status {
179 "OKAY" => Ok(Some(AdbResponse::OKAY { message })),
180 "FAIL" => Ok(Some(AdbResponse::FAIL { message })),
181 _ => Err(AdbError::UnknownResponseStatus(status.into())),
182 }
183 }
184 None => Ok(None),
185 }
186 }
187}
188
189impl Decoder for AdbResponseDecoder {
190 type Item = AdbResponse;
191 type Error = AdbError;
192
193 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
194 if src.len() == 0
195 {
196 return Ok(None);
197 }
198
199 let response = match self.decoder_impl {
200 AdbResponseDecoderImpl::Status => self.decode_status(src),
201 AdbResponseDecoderImpl::StatusLengthPayload => self.decode_status_and_payload(src),
202 AdbResponseDecoderImpl::StatusPayloadNewline => {
203 self.decode_status_and_read_until_new_line(src)
204 }
205 };
206
207 println!("decofing:\n{}", pretty_hex::pretty_hex(&src));
208
209 response
210 }
211}
212
213#[derive(Debug)]
214pub(crate) struct AdbRequestEncoder {}
215
216impl AdbRequestEncoder {
217 pub(crate) fn new() -> AdbRequestEncoder {
218 AdbRequestEncoder {}
219 }
220}
221
222impl Encoder<AdbRequest> for AdbRequestEncoder {
223 type Error = AdbError;
224
225 fn encode(&mut self, msg: AdbRequest, dst: &mut BytesMut) -> Result<()> {
226 let length = msg.payload.len();
229 if length > MAX_MESSAGE_SIZE {
230 return Err(AdbError::IOError(std::io::Error::new(
231 std::io::ErrorKind::InvalidData,
232 format!("Frame of length {} is too large.", length),
233 )));
234 }
235
236 dst.reserve(ADB_REQUEST_HEADER_LENGTH + length);
238
239 let length_hex = format!("{:04x}", length);
240
241 dst.extend_from_slice(&length_hex.as_bytes());
243 dst.extend_from_slice(&msg.payload);
244
245 println!("sending {}", pretty_hex::pretty_hex(&dst));
246
247 Ok(())
248 }
249}
250
251#[derive(Debug)]
253pub enum AdbClientStream {
254 Tcp(TcpStream),
256 Unix(UnixStream),
258}
259
260impl AsyncRead for AdbClientStream {
261 fn poll_read(
262 self: std::pin::Pin<&mut Self>,
263 cx: &mut std::task::Context<'_>,
264 buf: &mut ReadBuf<'_>,
265 ) -> Poll<std::io::Result<()>> {
266 match self.get_mut() {
267 AdbClientStream::Tcp(s) => Pin::new(s).poll_read(cx, buf),
268 AdbClientStream::Unix(s) => Pin::new(s).poll_read(cx, buf),
269 }
270 }
271}
272
273impl AsyncWrite for AdbClientStream {
274 fn poll_write(
275 self: Pin<&mut Self>,
276 cx: &mut Context<'_>,
277 buf: &[u8],
278 ) -> Poll<std::result::Result<usize, std::io::Error>> {
279 match self.get_mut() {
280 AdbClientStream::Tcp(s) => Pin::new(s).poll_write(cx, buf),
281 AdbClientStream::Unix(s) => Pin::new(s).poll_write(cx, buf),
282 }
283 }
284
285 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::result::Result<(), std::io::Error>> {
286 match self.get_mut() {
287 AdbClientStream::Tcp(s) => Pin::new(s).poll_flush(cx),
288 AdbClientStream::Unix(s) => Pin::new(s).poll_flush(cx),
289 }
290 }
291
292 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::result::Result<(), std::io::Error>> {
293 match self.get_mut() {
294 AdbClientStream::Tcp(s) => Pin::new(s).poll_shutdown(cx),
295 AdbClientStream::Unix(s) => Pin::new(s).poll_shutdown(cx),
296 }
297 }
298}
299
300#[derive(Debug)]
302pub enum AdbClientStreamOwnedReadHalf {
303 Tcp(tokio::net::tcp::OwnedReadHalf),
305 Unix(tokio::net::unix::OwnedReadHalf),
307}
308
309#[derive(Debug)]
311pub enum AdbClientStreamOwnedWriteHalf {
312 Tcp(tokio::net::tcp::OwnedWriteHalf),
314 Unix(tokio::net::unix::OwnedWriteHalf),
316}
317
318impl AdbClientStream {
319
320 pub fn into_split(self) -> (AdbClientStreamOwnedReadHalf, AdbClientStreamOwnedWriteHalf) {
322 match self {
323 AdbClientStream::Tcp(s) => {
324 let (r, w) = s.into_split();
325 (AdbClientStreamOwnedReadHalf::Tcp(r), AdbClientStreamOwnedWriteHalf::Tcp(w))
326 }
327 AdbClientStream::Unix(s) => {
328 let (r, w) = s.into_split();
329 (AdbClientStreamOwnedReadHalf::Unix(r), AdbClientStreamOwnedWriteHalf::Unix(w))
330 }
331 }
332 }
333}
334
335impl AdbClientStreamOwnedReadHalf {
336 pub fn reunite(self, w: AdbClientStreamOwnedWriteHalf) -> Result<AdbClientStream> {
338 match self {
339 AdbClientStreamOwnedReadHalf::Tcp(r) => {
340 let w = match w {
341 AdbClientStreamOwnedWriteHalf::Tcp(w) => w,
342 _ => panic!("Invalid write half"),
343 };
344 Ok(AdbClientStream::Tcp(r.reunite(w).unwrap()))
345 }
346 AdbClientStreamOwnedReadHalf::Unix(r) => {
347 let w = match w {
348 AdbClientStreamOwnedWriteHalf::Unix(w) => w,
349 _ => panic!("Invalid write half"),
350 };
351 Ok(AdbClientStream::Unix(r.reunite(w).unwrap()))
352 }
353 }
354
355 }
356}
357
358impl AsyncRead for AdbClientStreamOwnedReadHalf {
359 fn poll_read(
360 self: std::pin::Pin<&mut Self>,
361 cx: &mut std::task::Context<'_>,
362 buf: &mut ReadBuf<'_>,
363 ) -> Poll<std::io::Result<()>> {
364 match self.get_mut() {
365 AdbClientStreamOwnedReadHalf::Tcp(s) => Pin::new(s).poll_read(cx, buf),
366 AdbClientStreamOwnedReadHalf::Unix(s) => Pin::new(s).poll_read(cx, buf),
367 }
368 }
369}
370
371impl AsyncWrite for AdbClientStreamOwnedWriteHalf {
372 #[inline]
373 fn poll_write(
374 self: Pin<&mut Self>,
375 cx: &mut Context<'_>,
376 buf: &[u8],
377 ) -> Poll<std::io::Result<usize>> {
378 match self.get_mut() {
379 AdbClientStreamOwnedWriteHalf::Unix(x) => Pin::new(x).poll_write(cx, buf),
380 AdbClientStreamOwnedWriteHalf::Tcp(x) => Pin::new(x).poll_write(cx, buf),
381 }
382 }
383
384 #[inline]
385 fn poll_write_vectored(
386 self: Pin<&mut Self>,
387 cx: &mut Context<'_>,
388 bufs: &[std::io::IoSlice<'_>],
389 ) -> Poll<std::io::Result<usize>> {
390 match self.get_mut() {
391 AdbClientStreamOwnedWriteHalf::Unix(x) => Pin::new(x).poll_write_vectored(cx, bufs),
392 AdbClientStreamOwnedWriteHalf::Tcp(x) => Pin::new(x).poll_write_vectored(cx, bufs),
393 }
394 }
395
396 #[inline]
397 fn is_write_vectored(&self) -> bool {
398 match self {
399 AdbClientStreamOwnedWriteHalf::Unix(x) => x.is_write_vectored(),
400 AdbClientStreamOwnedWriteHalf::Tcp(x) => x.is_write_vectored(),
401 }
402 }
403
404 #[inline]
405 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
406 match self.get_mut() {
407 AdbClientStreamOwnedWriteHalf::Unix(x) => Pin::new(x).poll_flush(cx),
408 AdbClientStreamOwnedWriteHalf::Tcp(x) => Pin::new(x).poll_flush(cx),
409 }
410 }
411
412 #[inline]
413 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
414 match self.get_mut() {
415 AdbClientStreamOwnedWriteHalf::Unix(x) => Pin::new(x).poll_shutdown(cx),
416 AdbClientStreamOwnedWriteHalf::Tcp(x) => Pin::new(x).poll_shutdown(cx),
417 }
418 }
419}
420
421#[derive(Debug)]
422pub(crate) struct AdbClientConnection
423{
424 pub(crate) reader: FramedRead<AdbClientStreamOwnedReadHalf, AdbResponseDecoder>,
425 pub(crate) writer: FramedWrite<AdbClientStreamOwnedWriteHalf, AdbRequestEncoder>,
426}
427
428impl<'a> AdbClientConnection
429{
430 pub(crate) fn new(socket: AdbClientStream) -> AdbClientConnection
431 {
432 let (r, w) = socket.into_split();
433
434 let reader = FramedRead::new(r, AdbResponseDecoder::new());
435 let writer = FramedWrite::new(w, AdbRequestEncoder::new());
436
437 return AdbClientConnection { reader, writer };
438 }
439
440 pub(crate) async fn send(&mut self, request: AdbRequest) -> Result<()> {
441 self.writer.send(request).await
442 }
443
444 pub(crate) async fn next(&mut self) -> Result<String> {
445 match self.reader.next().await {
446 Some(Ok(AdbResponse::OKAY { message })) => Ok(message),
447 Some(Ok(AdbResponse::FAIL { message })) => Err(AdbError::FailedResponseStatus(message)),
448 Some(Err(e)) => Err(e),
449 None => Err(AdbError::FailedResponseStatus("No response".into())),
450 }
451 }
452}