jdwp_client/
client.rs

1use binrw::{BinRead, BinWrite};
2use std::collections::HashMap;
3use std::io;
4use std::io::Cursor;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
8use tokio::sync::{Mutex, oneshot};
9use tokio::time::timeout;
10
11use crate::{
12    AllClassesReply, Command, CommandPacketHeader, IdSizesReply, JdwpIdSizes, ReplyPacketHeader,
13    VersionReply, result,
14};
15
16pub struct JdwpClient<T> {
17    writer: Arc<Mutex<WriteHalf<T>>>,
18    pending_requests: Arc<Mutex<HashMap<u32, oneshot::Sender<ReplyPacket>>>>,
19    packet_id: Arc<Mutex<u32>>,
20    _reader_handle: tokio::task::JoinHandle<()>,
21    sizes: Option<JdwpIdSizes>,
22    timeout_duration: Duration,
23}
24
25struct ReplyPacket {
26    header: ReplyPacketHeader,
27    data: Vec<u8>,
28}
29
30impl<T> JdwpClient<T>
31where
32    T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
33{
34    pub async fn new(mut stream: T) -> result::Result<Self> {
35        Self::do_handshake(&mut stream).await?;
36
37        let (reader, writer) = tokio::io::split(stream);
38
39        let pending_requests = Arc::new(Mutex::new(HashMap::new()));
40        let writer_arc = Arc::new(Mutex::new(writer));
41        let packet_id = Arc::new(Mutex::new(0));
42
43        // Spawn reader task
44        let pending_clone = pending_requests.clone();
45        let reader_handle = tokio::spawn(async move {
46            Self::reader_loop(reader, pending_clone).await;
47        });
48
49        let mut client = JdwpClient {
50            writer: writer_arc,
51            pending_requests,
52            packet_id,
53            _reader_handle: reader_handle,
54            sizes: None,
55            timeout_duration: Duration::from_secs(5),
56        };
57        client.get_id_sizes().await?;
58        Ok(client)
59    }
60
61    async fn reader_loop(
62        mut reader: ReadHalf<T>,
63        pending_requests: Arc<Mutex<HashMap<u32, oneshot::Sender<ReplyPacket>>>>,
64    ) {
65        loop {
66            // TODO: Handle command packets coming from the VM
67            match Self::read_reply_packet(&mut reader).await {
68                Ok(reply_packet) => {
69                    let mut pending = pending_requests.lock().await;
70                    if let Some(sender) = pending.remove(&reply_packet.header.id) {
71                        let _ = sender.send(reply_packet);
72                    }
73                }
74                Err(e) => {
75                    eprintln!("Reader task error: {:?}", e);
76                    // Notify all pending requests about the error
77                    let mut pending = pending_requests.lock().await;
78                    for (_, sender) in pending.drain() {
79                        let _ = sender.send(ReplyPacket {
80                            header: ReplyPacketHeader::default(),
81                            data: Vec::new(),
82                        });
83                    }
84                    break;
85                }
86            }
87        }
88    }
89
90    async fn read_reply_packet(reader: &mut ReadHalf<T>) -> result::Result<ReplyPacket> {
91        // Read header
92        let mut header_buffer = vec![0u8; ReplyPacketHeader::get_length()];
93        reader.read_exact(&mut header_buffer).await?;
94
95        let mut cursor = Cursor::new(&header_buffer);
96        let header =
97            ReplyPacketHeader::read_be(&mut cursor).map_err(|e| result::Error::ParsingError {
98                message: format!("Parsing error: {:?}", e),
99            })?;
100
101        // Read data
102        let data_length = header.length as usize - ReplyPacketHeader::get_length();
103        let mut data = vec![0u8; data_length];
104        reader.read_exact(&mut data).await?;
105
106        Ok(ReplyPacket { header, data })
107    }
108
109    async fn write_request(
110        writer: &mut WriteHalf<T>,
111        header: &CommandPacketHeader,
112        data: &[u8],
113    ) -> result::Result<()> {
114        // Write header
115        let mut header_buffer = Vec::with_capacity(CommandPacketHeader::get_length());
116        let mut cursor = Cursor::new(&mut header_buffer);
117        header
118            .write_be(&mut cursor)
119            .map_err(|e| result::Error::ParsingError {
120                message: format!("Serialization error: {:?}", e),
121            })?;
122
123        writer.write_all(&header_buffer).await?;
124        writer.write_all(data).await?;
125        writer.flush().await?;
126
127        Ok(())
128    }
129
130    async fn next_packet_id(&self) -> u32 {
131        let mut id = self.packet_id.lock().await;
132        *id = id.wrapping_add(1);
133        *id
134    }
135
136    async fn send_request_with_timeout(
137        &self,
138        command: Command,
139        data: Vec<u8>,
140        timeout_duration: Duration,
141    ) -> result::Result<ReplyPacket> {
142        let id = self.next_packet_id().await;
143        let (tx, rx) = oneshot::channel();
144
145        // Register pending request
146        {
147            let mut pending = self.pending_requests.lock().await;
148            pending.insert(id, tx);
149        }
150
151        // Create header
152        let header = CommandPacketHeader {
153            length: CommandPacketHeader::get_length() as u32 + data.len() as u32,
154            id,
155            flags: 0,
156            command,
157        };
158
159        // Send request
160        {
161            let mut writer = self.writer.lock().await;
162            Self::write_request(&mut *writer, &header, &data).await?;
163        }
164
165        // Wait for reply with timeout
166        match timeout(timeout_duration, rx).await {
167            Ok(Ok(reply)) => Ok(reply),
168            Ok(Err(_)) => Err(result::Error::IoError(io::Error::new(
169                io::ErrorKind::Other,
170                "Reply channel closed",
171            ))),
172            Err(_) => {
173                // Timeout - clean up pending request
174                let mut pending = self.pending_requests.lock().await;
175                pending.remove(&id);
176                Err(result::Error::IoError(io::Error::new(
177                    io::ErrorKind::TimedOut,
178                    "Request timed out",
179                )))
180            }
181        }
182    }
183
184    async fn send_bodyless<TReply: for<'a> BinRead<Args<'a> = ()>>(
185        &self,
186        cmd: Command,
187        timeout_duration: Duration,
188    ) -> result::Result<TReply>
189    where
190        for<'a> <TReply as BinRead>::Args<'a>: Default,
191    {
192        let reply_packet = self
193            .send_request_with_timeout(cmd, Vec::new(), timeout_duration)
194            .await?;
195
196        let mut cursor = Cursor::new(&reply_packet.data);
197        let reply = TReply::read_be(&mut cursor).map_err(|e| result::Error::ParsingError {
198            message: format!("Binary parsing error: {:?}", e),
199        })?;
200
201        Ok(reply)
202    }
203
204    async fn send_bodyless_variable<TReply: for<'a> BinRead<Args<'a> = JdwpIdSizes>>(
205        &self,
206        cmd: Command,
207        timeout_duration: Duration,
208    ) -> result::Result<TReply> {
209        let reply_packet = self
210            .send_request_with_timeout(cmd, Vec::new(), timeout_duration)
211            .await?;
212
213        let mut cursor = Cursor::new(&reply_packet.data);
214        let reply = TReply::read_be_args(
215            &mut cursor,
216            self.sizes.ok_or(result::Error::IdSizesUnknown)?,
217        )
218        .map_err(|e| result::Error::ParsingError {
219            message: format!("Binary parsing error: {:?}", e),
220        })?;
221
222        Ok(reply)
223    }
224
225    async fn do_handshake(stream: &mut T) -> result::Result<()> {
226        const HANDSHAKE_STR: &str = "JDWP-Handshake";
227
228        let handshake_bytes = HANDSHAKE_STR.as_bytes();
229        stream.write_all(handshake_bytes).await?;
230        stream.flush().await?;
231
232        let mut buffer = [0u8; 14];
233        stream.read_exact(&mut buffer).await?;
234
235        let received = std::str::from_utf8(&buffer)
236            .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid UTF-8"))?;
237
238        if received != HANDSHAKE_STR {
239            return Err(result::Error::ParsingError {
240                message: format!(
241                    "Invalid handshake: expected '{}', got '{}'",
242                    HANDSHAKE_STR, received
243                ),
244            });
245        }
246
247        Ok(())
248    }
249    async fn get_id_sizes(&mut self) -> result::Result<()> {
250        let sizes = self.vm_get_id_sizes().await?;
251        let field_id: u8 = sizes
252            .field_id_size
253            .try_into()
254            .map_err(|_| result::Error::IdSizesTruncated)?;
255        let method_id: u8 = sizes
256            .method_id_size
257            .try_into()
258            .map_err(|_| result::Error::IdSizesTruncated)?;
259        let object_id: u8 = sizes
260            .object_id_size
261            .try_into()
262            .map_err(|_| result::Error::IdSizesTruncated)?;
263        let ref_id: u8 = sizes
264            .reference_type_id_size
265            .try_into()
266            .map_err(|_| result::Error::IdSizesTruncated)?;
267        let frame_id: u8 = sizes
268            .frame_id_size
269            .try_into()
270            .map_err(|_| result::Error::IdSizesTruncated)?;
271        self.sizes = Some(JdwpIdSizes {
272            field_id_size: field_id,
273            method_id_size: method_id,
274            object_id_size: object_id,
275            reference_type_id_size: ref_id,
276            frame_id_size: frame_id,
277        });
278        Ok(())
279    }
280
281    pub async fn vm_get_version(&self) -> result::Result<VersionReply> {
282        self.send_bodyless(Command::VirtualMachineVersion, self.timeout_duration)
283            .await
284    }
285
286    pub async fn vm_get_all_classes(&self) -> result::Result<AllClassesReply> {
287        self.send_bodyless_variable(Command::VirtualMachineAllClasses, self.timeout_duration)
288            .await
289    }
290
291    pub async fn vm_get_id_sizes(&self) -> result::Result<IdSizesReply> {
292        self.send_bodyless(Command::VirtualMachineIDSizes, self.timeout_duration)
293            .await
294    }
295}