1use binrw::{BinRead, BinWrite};
2use std::collections::HashMap;
3use std::io;
4use std::io::Cursor;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
8use tokio::sync::{Mutex, oneshot};
9use tokio::time::timeout;
10
11use crate::{
12 AllClassesReply, Command, CommandPacketHeader, IdSizesReply, JdwpIdSizes, ReplyPacketHeader,
13 VersionReply, result,
14};
15
16pub struct JdwpClient<T> {
17 writer: Arc<Mutex<WriteHalf<T>>>,
18 pending_requests: Arc<Mutex<HashMap<u32, oneshot::Sender<ReplyPacket>>>>,
19 packet_id: Arc<Mutex<u32>>,
20 _reader_handle: tokio::task::JoinHandle<()>,
21 sizes: Option<JdwpIdSizes>,
22 timeout_duration: Duration,
23}
24
25struct ReplyPacket {
26 header: ReplyPacketHeader,
27 data: Vec<u8>,
28}
29
30impl<T> JdwpClient<T>
31where
32 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
33{
34 pub async fn new(mut stream: T) -> result::Result<Self> {
35 Self::do_handshake(&mut stream).await?;
36
37 let (reader, writer) = tokio::io::split(stream);
38
39 let pending_requests = Arc::new(Mutex::new(HashMap::new()));
40 let writer_arc = Arc::new(Mutex::new(writer));
41 let packet_id = Arc::new(Mutex::new(0));
42
43 let pending_clone = pending_requests.clone();
45 let reader_handle = tokio::spawn(async move {
46 Self::reader_loop(reader, pending_clone).await;
47 });
48
49 let mut client = JdwpClient {
50 writer: writer_arc,
51 pending_requests,
52 packet_id,
53 _reader_handle: reader_handle,
54 sizes: None,
55 timeout_duration: Duration::from_secs(5),
56 };
57 client.get_id_sizes().await?;
58 Ok(client)
59 }
60
61 async fn reader_loop(
62 mut reader: ReadHalf<T>,
63 pending_requests: Arc<Mutex<HashMap<u32, oneshot::Sender<ReplyPacket>>>>,
64 ) {
65 loop {
66 match Self::read_reply_packet(&mut reader).await {
68 Ok(reply_packet) => {
69 let mut pending = pending_requests.lock().await;
70 if let Some(sender) = pending.remove(&reply_packet.header.id) {
71 let _ = sender.send(reply_packet);
72 }
73 }
74 Err(e) => {
75 eprintln!("Reader task error: {:?}", e);
76 let mut pending = pending_requests.lock().await;
78 for (_, sender) in pending.drain() {
79 let _ = sender.send(ReplyPacket {
80 header: ReplyPacketHeader::default(),
81 data: Vec::new(),
82 });
83 }
84 break;
85 }
86 }
87 }
88 }
89
90 async fn read_reply_packet(reader: &mut ReadHalf<T>) -> result::Result<ReplyPacket> {
91 let mut header_buffer = vec![0u8; ReplyPacketHeader::get_length()];
93 reader.read_exact(&mut header_buffer).await?;
94
95 let mut cursor = Cursor::new(&header_buffer);
96 let header =
97 ReplyPacketHeader::read_be(&mut cursor).map_err(|e| result::Error::ParsingError {
98 message: format!("Parsing error: {:?}", e),
99 })?;
100
101 let data_length = header.length as usize - ReplyPacketHeader::get_length();
103 let mut data = vec![0u8; data_length];
104 reader.read_exact(&mut data).await?;
105
106 Ok(ReplyPacket { header, data })
107 }
108
109 async fn write_request(
110 writer: &mut WriteHalf<T>,
111 header: &CommandPacketHeader,
112 data: &[u8],
113 ) -> result::Result<()> {
114 let mut header_buffer = Vec::with_capacity(CommandPacketHeader::get_length());
116 let mut cursor = Cursor::new(&mut header_buffer);
117 header
118 .write_be(&mut cursor)
119 .map_err(|e| result::Error::ParsingError {
120 message: format!("Serialization error: {:?}", e),
121 })?;
122
123 writer.write_all(&header_buffer).await?;
124 writer.write_all(data).await?;
125 writer.flush().await?;
126
127 Ok(())
128 }
129
130 async fn next_packet_id(&self) -> u32 {
131 let mut id = self.packet_id.lock().await;
132 *id = id.wrapping_add(1);
133 *id
134 }
135
136 async fn send_request_with_timeout(
137 &self,
138 command: Command,
139 data: Vec<u8>,
140 timeout_duration: Duration,
141 ) -> result::Result<ReplyPacket> {
142 let id = self.next_packet_id().await;
143 let (tx, rx) = oneshot::channel();
144
145 {
147 let mut pending = self.pending_requests.lock().await;
148 pending.insert(id, tx);
149 }
150
151 let header = CommandPacketHeader {
153 length: CommandPacketHeader::get_length() as u32 + data.len() as u32,
154 id,
155 flags: 0,
156 command,
157 };
158
159 {
161 let mut writer = self.writer.lock().await;
162 Self::write_request(&mut *writer, &header, &data).await?;
163 }
164
165 match timeout(timeout_duration, rx).await {
167 Ok(Ok(reply)) => Ok(reply),
168 Ok(Err(_)) => Err(result::Error::IoError(io::Error::new(
169 io::ErrorKind::Other,
170 "Reply channel closed",
171 ))),
172 Err(_) => {
173 let mut pending = self.pending_requests.lock().await;
175 pending.remove(&id);
176 Err(result::Error::IoError(io::Error::new(
177 io::ErrorKind::TimedOut,
178 "Request timed out",
179 )))
180 }
181 }
182 }
183
184 async fn send_bodyless<TReply: for<'a> BinRead<Args<'a> = ()>>(
185 &self,
186 cmd: Command,
187 timeout_duration: Duration,
188 ) -> result::Result<TReply>
189 where
190 for<'a> <TReply as BinRead>::Args<'a>: Default,
191 {
192 let reply_packet = self
193 .send_request_with_timeout(cmd, Vec::new(), timeout_duration)
194 .await?;
195
196 let mut cursor = Cursor::new(&reply_packet.data);
197 let reply = TReply::read_be(&mut cursor).map_err(|e| result::Error::ParsingError {
198 message: format!("Binary parsing error: {:?}", e),
199 })?;
200
201 Ok(reply)
202 }
203
204 async fn send_bodyless_variable<TReply: for<'a> BinRead<Args<'a> = JdwpIdSizes>>(
205 &self,
206 cmd: Command,
207 timeout_duration: Duration,
208 ) -> result::Result<TReply> {
209 let reply_packet = self
210 .send_request_with_timeout(cmd, Vec::new(), timeout_duration)
211 .await?;
212
213 let mut cursor = Cursor::new(&reply_packet.data);
214 let reply = TReply::read_be_args(
215 &mut cursor,
216 self.sizes.ok_or(result::Error::IdSizesUnknown)?,
217 )
218 .map_err(|e| result::Error::ParsingError {
219 message: format!("Binary parsing error: {:?}", e),
220 })?;
221
222 Ok(reply)
223 }
224
225 async fn do_handshake(stream: &mut T) -> result::Result<()> {
226 const HANDSHAKE_STR: &str = "JDWP-Handshake";
227
228 let handshake_bytes = HANDSHAKE_STR.as_bytes();
229 stream.write_all(handshake_bytes).await?;
230 stream.flush().await?;
231
232 let mut buffer = [0u8; 14];
233 stream.read_exact(&mut buffer).await?;
234
235 let received = std::str::from_utf8(&buffer)
236 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid UTF-8"))?;
237
238 if received != HANDSHAKE_STR {
239 return Err(result::Error::ParsingError {
240 message: format!(
241 "Invalid handshake: expected '{}', got '{}'",
242 HANDSHAKE_STR, received
243 ),
244 });
245 }
246
247 Ok(())
248 }
249 async fn get_id_sizes(&mut self) -> result::Result<()> {
250 let sizes = self.vm_get_id_sizes().await?;
251 let field_id: u8 = sizes
252 .field_id_size
253 .try_into()
254 .map_err(|_| result::Error::IdSizesTruncated)?;
255 let method_id: u8 = sizes
256 .method_id_size
257 .try_into()
258 .map_err(|_| result::Error::IdSizesTruncated)?;
259 let object_id: u8 = sizes
260 .object_id_size
261 .try_into()
262 .map_err(|_| result::Error::IdSizesTruncated)?;
263 let ref_id: u8 = sizes
264 .reference_type_id_size
265 .try_into()
266 .map_err(|_| result::Error::IdSizesTruncated)?;
267 let frame_id: u8 = sizes
268 .frame_id_size
269 .try_into()
270 .map_err(|_| result::Error::IdSizesTruncated)?;
271 self.sizes = Some(JdwpIdSizes {
272 field_id_size: field_id,
273 method_id_size: method_id,
274 object_id_size: object_id,
275 reference_type_id_size: ref_id,
276 frame_id_size: frame_id,
277 });
278 Ok(())
279 }
280
281 pub async fn vm_get_version(&self) -> result::Result<VersionReply> {
282 self.send_bodyless(Command::VirtualMachineVersion, self.timeout_duration)
283 .await
284 }
285
286 pub async fn vm_get_all_classes(&self) -> result::Result<AllClassesReply> {
287 self.send_bodyless_variable(Command::VirtualMachineAllClasses, self.timeout_duration)
288 .await
289 }
290
291 pub async fn vm_get_id_sizes(&self) -> result::Result<IdSizesReply> {
292 self.send_bodyless(Command::VirtualMachineIDSizes, self.timeout_duration)
293 .await
294 }
295}