adb_rs/
client.rs

1use bytes::buf::FromBuf;
2use bytes::{Bytes, BytesMut};
3use crossbeam_channel::{bounded, select, unbounded, Receiver, Sender};
4use std::collections::HashMap;
5use std::io::prelude::*;
6use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
7use std::sync::{Arc, RwLock};
8use std::thread::{self, JoinHandle};
9
10pub use crate::message::Command;
11use crate::message::{Connect, Header};
12use crate::result::*;
13
14#[derive(Debug)]
15pub struct AdbClient {
16  system_identity: String,
17}
18
19impl AdbClient {
20  pub fn new(system_identity: &str) -> Self {
21    AdbClient {
22      system_identity: system_identity.to_string(),
23    }
24  }
25
26  pub fn connect<T>(self, addr: T) -> AdbResult<AdbConnection>
27  where
28    T: ToSocketAddrs,
29  {
30    let addrs: Vec<_> = addr.to_socket_addrs()?.collect();
31
32    debug!("connecting to {:?}...", addrs);
33
34    let mut stream = TcpStream::connect(&addrs as &[SocketAddr])?;
35
36    debug!("connected. sending CNXN...");
37
38    Connect::new(&self.system_identity).encode(&mut stream)?;
39
40    let resp = Header::decode(&mut stream)?;
41    let data = match resp.get_command() {
42      Some(Command::A_CNXN) => resp.decode_data(&mut stream)?,
43      Some(Command::A_AUTH) => {
44        return Err(AdbError::AuthNotSupported);
45      }
46      Some(cmd) => {
47        return Err(AdbError::UnexpectedCommand(cmd));
48      }
49      None => return Err(AdbError::UnknownCommand(resp.command)),
50    };
51
52    let device_id = String::from_utf8_lossy(&data);
53
54    debug!(
55      "handshake ok: device_id = {}, version = 0x{:x}, max_data = 0x{:x}",
56      device_id, resp.arg0, resp.arg1
57    );
58
59    let streams = Arc::new(RwLock::new(HashMap::<u32, StreamContext>::new()));
60
61    let (conn_reader_s, conn_reader_r) = bounded::<ConnectionPacket>(0);
62    let (conn_writer_s, conn_writer_r) = bounded::<ConnectionPacket>(0);
63    let (conn_error_s, conn_error_r) = unbounded();
64
65    let reader_worker = thread::spawn({
66      let mut stream = stream.try_clone()?;
67      let error_s = conn_error_s.clone();
68      move || loop {
69        let res = Header::decode(&mut stream)
70          .and_then(|header| {
71            let mut payload = BytesMut::new();
72            if header.data_length > 0 {
73              payload.resize(header.data_length as usize, 0);
74              stream
75                .read_exact(&mut payload)
76                .map(move |_| ConnectionPacket {
77                  header,
78                  payload: payload.freeze(),
79                })
80                .map_err(Into::into)
81            } else {
82              Ok(ConnectionPacket {
83                header,
84                payload: payload.freeze(),
85              })
86            }
87          })
88          .and_then(|packet| {
89            conn_reader_s
90              .send(packet)
91              .map_err(|_| AdbError::Disconnected)
92          });
93
94        if let Err(err) = res {
95          debug!("AdbConnection: reader_worker exited: {}", err);
96          error_s.send(err).ok();
97          break;
98        }
99      }
100    });
101
102    let writer_worker = thread::spawn({
103      let mut stream = stream.try_clone()?;
104      let streams = streams.clone();
105      let error_s = conn_error_s.clone();
106      move || {
107        let mut closed_local_ids = vec![];
108        let mut conn_dead = false;
109        loop {
110          let packet = conn_writer_r.recv();
111          match packet {
112            Ok(packet) => {
113              let local_id = packet.header.arg0;
114              let locked = streams.read().unwrap();
115              if let Some(ctx) = locked.get(&local_id) {
116                let write = packet
117                  .header
118                  .encode(&mut stream)
119                  .and_then(|_| stream.write_all(&packet.payload).map_err(Into::into));
120                match write {
121                  Ok(_) => {
122                    if let Err(_) = ctx.write_result_s.send(Ok(())) {
123                      closed_local_ids.push(local_id);
124                    }
125                  }
126                  Err(err) => {
127                    if let Err(_) = ctx.write_result_s.send(Err(AdbError::Disconnected)) {
128                      closed_local_ids.push(local_id);
129                    }
130                    conn_dead = true;
131                    error_s.send(err).ok();
132                  }
133                }
134              } else {
135                warn!(
136                  "write packet discarded: cmd = {}, local_id = {}",
137                  packet.header.command, packet.header.arg0
138                );
139              }
140            }
141            Err(_) => {
142              break;
143            }
144          }
145
146          if !closed_local_ids.is_empty() {
147            let mut locked = streams.write().unwrap();
148            for id in &closed_local_ids {
149              debug!("remove stream: local_id = {}", id);
150              locked.remove(&id);
151            }
152            closed_local_ids.clear();
153          }
154
155          if conn_dead {
156            break;
157          }
158        }
159      }
160    });
161
162    let dispatch_worker = thread::spawn({
163      let streams = streams.clone();
164      move || {
165        let mut closed_local_ids = vec![];
166        loop {
167          select! {
168            recv(conn_reader_r) -> packet => {
169              match packet {
170                Ok(packet) => {
171                  let local_id = packet.header.arg1;
172                  let locked = streams.read().unwrap();
173                  match locked.get(&packet.header.arg1) {
174                    Some(ctx) => {
175                      if packet.header.get_command().is_some() {
176                        if let Err(_) = ctx.stream_reader_s.send(packet) {
177                          closed_local_ids.push(local_id);
178                        }
179                      } else {
180                        error!(
181                          "read packet discarded: unknown_cmd = 0x{:x}, local_id = {}",
182                          packet.header.command,
183                          packet.header.arg1
184                        );
185                      }
186                    },
187                    None => {
188                      warn!("read packet discarded: cmd = 0x{:x}, local_id = {}",
189                        packet.header.command,
190                        packet.header.arg1
191                      );
192                    },
193                  }
194                },
195                Err(err) => {
196                  error!("recv conn_reader_r: {}", err);
197                  break
198                },
199              }
200            },
201            recv(conn_error_r) -> err => {
202              match err {
203                Ok(_) => {
204                  break;
205                },
206                Err(recv_err) => {
207                  error!("recv conn_error_r: {}", recv_err);
208                  break
209                },
210              }
211            },
212          }
213
214          if !closed_local_ids.is_empty() {
215            let mut locked = streams.write().unwrap();
216            for id in &closed_local_ids {
217              debug!("remove stream: local_id = {}", id);
218              locked.remove(&id);
219            }
220            closed_local_ids.clear();
221          }
222        }
223        debug!("dispatch worker exited.");
224      }
225    });
226
227    Ok(AdbConnection {
228      system_identity: self.system_identity,
229      device_system_identity: device_id.to_string(),
230      device_version: resp.arg0,
231      device_max_data: resp.arg1,
232      tcp_stream: stream,
233      local_id_counter: 0,
234      workers: vec![reader_worker, writer_worker, dispatch_worker],
235      streams,
236      conn_writer_s,
237    })
238  }
239}
240
241struct ConnectionPacket {
242  header: Header,
243  payload: Bytes,
244}
245
246#[derive(Debug)]
247pub struct AdbStreamPacket {
248  pub command: Command,
249  pub payload: Bytes,
250}
251
252impl AdbStreamPacket {
253  pub fn new_write<T: AsRef<[u8]>>(payload: T) -> Self {
254    let bytes = payload.as_ref();
255    AdbStreamPacket {
256      command: Command::A_WRTE,
257      payload: Bytes::from_buf(bytes),
258    }
259  }
260
261  pub fn check_command(&self, cmd: Command) -> AdbResult<()> {
262    if self.command != cmd {
263      Err(AdbError::UnexpectedCommand(self.command))
264    } else {
265      Ok(())
266    }
267  }
268}
269
270#[derive(Debug, Clone)]
271struct StreamContext {
272  local_id: u32,
273  remote_id: u32,
274  stream_reader_s: Sender<ConnectionPacket>,
275  write_result_s: Sender<AdbResult<()>>,
276}
277
278#[derive(Debug)]
279pub struct AdbConnection {
280  system_identity: String,
281  device_system_identity: String,
282  device_version: u32,
283  device_max_data: u32,
284  tcp_stream: TcpStream,
285  local_id_counter: u32,
286  workers: Vec<JoinHandle<()>>,
287  streams: Arc<RwLock<HashMap<u32, StreamContext>>>,
288  conn_writer_s: Sender<ConnectionPacket>,
289}
290
291impl Drop for AdbConnection {
292  fn drop(&mut self) {
293    use std::net::Shutdown;
294    self.tcp_stream.shutdown(Shutdown::Both).ok();
295    let (conn_writer_s, _) = bounded::<ConnectionPacket>(0);
296    ::std::mem::replace(&mut self.conn_writer_s, conn_writer_s);
297    for w in ::std::mem::replace(&mut self.workers, vec![]) {
298      w.join().ok();
299    }
300  }
301}
302
303impl AdbConnection {
304  pub fn max_data_len(&self) -> usize {
305    self.device_max_data as usize
306  }
307
308  pub fn open_stream(&mut self, destination: &str) -> AdbResult<AdbStream> {
309    use bytes::BufMut;
310
311    self.local_id_counter = self.local_id_counter + 1;
312    let local_id = self.local_id_counter;
313    debug!(
314      "opening stream: local_id = {}, destination = {}...",
315      local_id, destination
316    );
317
318    let (write_result_s, write_result_r) = bounded::<AdbResult<()>>(1);
319    let (stream_reader_s, stream_reader_r) = bounded::<ConnectionPacket>(1);
320
321    let ctx = StreamContext {
322      local_id,
323      remote_id: 0,
324      stream_reader_s,
325      write_result_s,
326    };
327
328    self.streams.write().unwrap().insert(local_id, ctx);
329    debug!("register stream: local_id = {}", local_id);
330
331    let mut dst_bytes = BytesMut::with_capacity(destination.as_bytes().len() + 1);
332    dst_bytes.extend(destination.as_bytes());
333    dst_bytes.put_u8(0);
334    let dst_bytes = dst_bytes.freeze();
335
336    let open_packet = ConnectionPacket {
337      header: Header::new(Command::A_OPEN)
338        .arg0(local_id)
339        .data(&dst_bytes)
340        .finalize(),
341      payload: dst_bytes,
342    };
343
344    self
345      .conn_writer_s
346      .send(open_packet)
347      .map_err(|_| AdbError::Disconnected)?;
348
349    let open_packet = stream_reader_r.recv().map_err(|_| AdbError::Disconnected)?;
350    if open_packet.header.command != Command::A_OKAY as u32 {
351      if let Some(cmd) = open_packet.header.get_command() {
352        return Err(AdbError::UnexpectedCommand(cmd));
353      } else {
354        return Err(AdbError::UnknownCommand(open_packet.header.command));
355      }
356    }
357
358    let local_id = open_packet.header.arg1;
359    let remote_id = open_packet.header.arg0;
360    debug!("stream opened: {} -> {}", local_id, remote_id);
361
362    Ok(AdbStream {
363      local_id,
364      remote_id,
365      stream_reader: stream_reader_r,
366      writer: self.conn_writer_s.clone(),
367      write_result_r,
368    })
369  }
370}
371
372#[derive(Debug)]
373pub struct AdbStream {
374  local_id: u32,
375  remote_id: u32,
376  stream_reader: Receiver<ConnectionPacket>,
377  writer: Sender<ConnectionPacket>,
378  write_result_r: Receiver<AdbResult<()>>,
379}
380
381impl AdbStream {
382  pub fn send(&self, packet: AdbStreamPacket) -> AdbResult<()> {
383    self
384      .writer
385      .send(ConnectionPacket {
386        header: Header::new(packet.command)
387          .arg0(self.local_id)
388          .arg1(self.remote_id)
389          .data(&packet.payload)
390          .finalize(),
391        payload: packet.payload,
392      })
393      .map_err(|_| AdbError::Disconnected)
394      .and_then(|_| {
395        self
396          .write_result_r
397          .recv()
398          .map_err(|_| AdbError::Disconnected)
399          .and_then(|res| res.map(|_| ()))
400      })
401  }
402
403  pub fn recv(&self) -> AdbResult<AdbStreamPacket> {
404    let packet = self
405      .stream_reader
406      .recv()
407      .map_err(|_| AdbError::Disconnected)?;
408
409    Ok(AdbStreamPacket {
410      command: packet
411        .header
412        .get_command()
413        .ok_or_else(|| AdbError::UnknownCommand(packet.header.command))?,
414      payload: packet.payload,
415    })
416  }
417
418  pub fn try_recv(&self) -> AdbResult<Option<AdbStreamPacket>> {
419    use crossbeam_channel::TryRecvError;
420    match self.stream_reader.try_recv() {
421      Ok(packet) => Ok(Some(AdbStreamPacket {
422        command: packet
423          .header
424          .get_command()
425          .ok_or_else(|| AdbError::UnknownCommand(packet.header.command))?,
426        payload: packet.payload,
427      })),
428      Err(TryRecvError::Empty) => Ok(None),
429      Err(TryRecvError::Disconnected) => Err(AdbError::Disconnected),
430    }
431  }
432
433  pub fn send_ok(&self) -> AdbResult<()> {
434    self.send(AdbStreamPacket {
435      command: Command::A_OKAY,
436      payload: Bytes::new(),
437    })
438  }
439
440  pub fn recv_command(&self, cmd: Command) -> AdbResult<AdbStreamPacket> {
441    let packet = self.recv()?;
442    if packet.command != cmd {
443      return Err(AdbError::UnexpectedCommand(packet.command));
444    }
445    Ok(packet)
446  }
447
448  pub fn send_close(&self) -> AdbResult<()> {
449    self.send(AdbStreamPacket {
450      command: Command::A_CLSE,
451      payload: Bytes::new(),
452    })
453  }
454}