Skip to main content

droidrun_adb/
connection.rs

1/// Low-level ADB wire protocol connection.
2///
3/// ADB protocol format:
4///   Request:  [4-char hex length][payload]
5///   Response: "OKAY" | "FAIL"[4-char hex length][error message]
6use tokio::io::{AsyncReadExt, AsyncWriteExt};
7use tokio::net::TcpStream;
8use tracing::trace;
9
10use crate::error::{AdbError, Result};
11
12/// A single TCP connection to the ADB server.
13pub struct AdbConnection {
14    stream: TcpStream,
15}
16
17impl AdbConnection {
18    /// Connect to the ADB server at the given address.
19    pub async fn connect(host: &str, port: u16) -> Result<Self> {
20        let stream = TcpStream::connect((host, port)).await.map_err(|e| {
21            if e.kind() == std::io::ErrorKind::ConnectionRefused {
22                AdbError::ConnectionRefused
23            } else {
24                AdbError::Io(e)
25            }
26        })?;
27        Ok(Self { stream })
28    }
29
30    /// Send a command to the ADB server using the wire protocol.
31    ///
32    /// Format: `{length:04X}{command}`
33    pub async fn send_command(&mut self, cmd: &str) -> Result<()> {
34        let msg = format!("{:04X}{}", cmd.len(), cmd);
35        trace!("ADB send: {msg}");
36        self.stream.write_all(msg.as_bytes()).await?;
37        Ok(())
38    }
39
40    /// Read the OKAY/FAIL status response.
41    pub async fn read_status(&mut self) -> Result<()> {
42        let mut buf = [0u8; 4];
43        self.stream.read_exact(&mut buf).await?;
44        match &buf {
45            b"OKAY" => Ok(()),
46            b"FAIL" => {
47                let msg = self.read_length_prefixed_string().await?;
48                Err(AdbError::ServerFailed(msg))
49            }
50            other => {
51                let s = String::from_utf8_lossy(other).to_string();
52                Err(AdbError::Protocol(format!("expected OKAY/FAIL, got: {s}")))
53            }
54        }
55    }
56
57    /// Send a command and expect OKAY.
58    pub async fn send_and_okay(&mut self, cmd: &str) -> Result<()> {
59        self.send_command(cmd).await?;
60        self.read_status()
61            .await
62            .map_err(|e| AdbError::ServerFailed(format!("command '{cmd}' failed: {e}")))
63    }
64
65    /// Read a length-prefixed string response.
66    ///
67    /// Format: `[4-char hex length][data]`
68    pub async fn read_length_prefixed_string(&mut self) -> Result<String> {
69        let mut len_buf = [0u8; 4];
70        self.stream.read_exact(&mut len_buf).await?;
71        let len_str = std::str::from_utf8(&len_buf)
72            .map_err(|_| AdbError::Protocol("invalid length bytes".into()))?;
73        let len = usize::from_str_radix(len_str, 16)
74            .map_err(|_| AdbError::Protocol(format!("invalid hex length: {len_str}")))?;
75
76        if len == 0 {
77            return Ok(String::new());
78        }
79
80        let mut buf = vec![0u8; len];
81        self.stream.read_exact(&mut buf).await?;
82        Ok(String::from_utf8(buf)?)
83    }
84
85    /// Read all remaining data as a String until the connection closes.
86    pub async fn read_until_close_string(&mut self) -> Result<String> {
87        let bytes = self.read_until_close_bytes().await?;
88        Ok(String::from_utf8(bytes)?)
89    }
90
91    /// Read all remaining data as bytes until the connection closes.
92    pub async fn read_until_close_bytes(&mut self) -> Result<Vec<u8>> {
93        let mut buf = Vec::with_capacity(4096);
94        self.stream.read_to_end(&mut buf).await?;
95        Ok(buf)
96    }
97
98    /// Expose the inner stream for advanced operations (e.g., sync protocol).
99    pub fn into_stream(self) -> TcpStream {
100        self.stream
101    }
102
103    /// Get a mutable reference to the inner stream.
104    pub fn stream_mut(&mut self) -> &mut TcpStream {
105        &mut self.stream
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    #[test]
112    fn test_command_format() {
113        // Verify the format string produces correct output
114        let cmd = "host:version";
115        let msg = format!("{:04X}{}", cmd.len(), cmd);
116        assert_eq!(msg, "000Chost:version");
117    }
118
119    #[test]
120    fn test_short_command_format() {
121        let cmd = "host:devices";
122        let msg = format!("{:04X}{}", cmd.len(), cmd);
123        assert_eq!(msg, "000Chost:devices");
124    }
125}