1#![deny(missing_docs)]
6
7extern crate byteorder;
8extern crate rand;
9extern crate zmq;
10
11use std::io::Cursor;
12use std::io::Error as IoError;
13use std::io::Write;
14
15use zmq::Socket;
16
17use rand::prelude::*;
18
19use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
20
21const CURRENT_REQUEST_VERSION: u32 = 1;
23const MAX_REQUEST_DATA_SIZE: u32 = 32;
25
26const CITRA_PORT: u32 = 45987;
28
29#[derive(Copy, Clone)]
31pub enum RequestType {
32 ReadMemory,
34 WriteMemory,
36}
37
38impl RequestType {
39 fn get_id(self) -> u32 {
41 match self {
42 RequestType::ReadMemory => 1,
43 RequestType::WriteMemory => 2,
44 }
45 }
46}
47
48fn generate_header(request_type: RequestType, data_size: u32) -> ([u8; 4 * 4], u32) {
55 let mut buf = [0 as u8; 4 * 4];
56
57 let request_id = random::<u32>();
58
59 {
60 let request_type = request_type.get_id();
61
62 let mut cursor: Cursor<&mut [u8]> = Cursor::new(&mut buf);
63 cursor
64 .write_u32::<LittleEndian>(CURRENT_REQUEST_VERSION)
65 .expect("Failed to write request version");
66 cursor
67 .write_u32::<LittleEndian>(request_id)
68 .expect("Failed to write request ID");
69 cursor
70 .write_u32::<LittleEndian>(request_type)
71 .expect("Failed to write request type");
72 cursor
73 .write_u32::<LittleEndian>(data_size)
74 .expect("Failed to write request size");
75 }
76
77 (buf, request_id)
78}
79
80fn read_and_validate_header(
88 raw_reply: &[u8],
89 expected_id: u32,
90 expected_type: RequestType,
91) -> Result<&[u8], String> {
92 if raw_reply.len() < 4 * 4 {
93 return Err(format!(
94 "Payload is smaller than minimum (got {}, expected at least {})",
95 raw_reply.len(),
96 4 * 4
97 ));
98 }
99
100 let mut cursor = Cursor::new(raw_reply);
101
102 let expected_type = expected_type.get_id();
103
104 let reply_version = translate_io_error(cursor.read_u32::<LittleEndian>())?;
105 let reply_id = translate_io_error(cursor.read_u32::<LittleEndian>())?;
106 let reply_type = translate_io_error(cursor.read_u32::<LittleEndian>())?;
107 let reply_data_size = translate_io_error(cursor.read_u32::<LittleEndian>())?;
108
109 if reply_version != CURRENT_REQUEST_VERSION {
110 return Err(format!(
111 "Bad request version (got {}, expected {})",
112 reply_version, CURRENT_REQUEST_VERSION
113 ));
114 }
115
116 if reply_id != expected_id {
117 return Err(format!(
118 "Bad request ID (got {}, expected {})",
119 reply_id, expected_id
120 ));
121 }
122
123 if reply_type != expected_type {
124 return Err(format!(
125 "Bad request type (got {}, expected {})",
126 reply_type, expected_type
127 ));
128 }
129
130 if reply_data_size != (raw_reply.len() - 4 * 4) as u32 {
131 return Err(format!(
132 "Bad request size (got {}, expected {})",
133 reply_data_size,
134 raw_reply.len() - 4 * 4
135 ));
136 }
137
138 Ok(&raw_reply[4 * 4..])
139}
140
141fn translate_zmq_error<T>(payload: Result<T, zmq::Error>) -> Result<T, String> {
143 payload.map_err(|x| format!("ZeroMQ error: {:?}", x))
144}
145
146fn translate_io_error<T>(payload: Result<T, IoError>) -> Result<T, String> {
148 payload.map_err(|x| format!("I/O error: {:?}", x))
149}
150
151pub struct CitraConnection {
153 socket: Socket,
154}
155
156impl CitraConnection {
157 fn make_request(&self, request_kind: RequestType, data: &[u8]) -> Result<Vec<u8>, String> {
159 let (request, request_id) = generate_header(request_kind, data.len() as _);
160
161 let mut outgoing_buffer = Vec::with_capacity(request.len() + data.len());
162 outgoing_buffer.extend_from_slice(&request);
163 outgoing_buffer.extend_from_slice(data);
164
165 translate_zmq_error(self.socket.send(&outgoing_buffer, 0))?;
166
167 let req_reply = translate_zmq_error(self.socket.recv_bytes(0))?;
168
169 let data = read_and_validate_header(&req_reply, request_id, request_kind)?;
170
171 Ok(data.to_vec())
172 }
173
174 pub fn read_memory(
190 &self,
191 mut read_address: u32,
192 mut read_size: u32,
193 ) -> Result<Vec<u8>, String> {
194 let mut result = Vec::with_capacity(read_size as _);
195
196 while read_size > 0 {
197 let temp_read_size = if read_size > MAX_REQUEST_DATA_SIZE {
198 MAX_REQUEST_DATA_SIZE
199 } else {
200 read_size
201 };
202
203 let mut request_data = [0 as u8; 2 * 4];
204
205 {
206 let mut cursor: Cursor<&mut [u8]> = Cursor::new(&mut request_data);
207
208 cursor
209 .write_u32::<LittleEndian>(read_address)
210 .expect("Failed to write read address");
211 cursor
212 .write_u32::<LittleEndian>(temp_read_size)
213 .expect("Failed to write read size");
214 }
215
216 let data = self.make_request(RequestType::ReadMemory, &request_data)?;
217 result.extend_from_slice(&data);
218
219 read_size -= temp_read_size;
220 read_address += temp_read_size;
221 }
222
223 Ok(result)
224 }
225
226 pub fn write_memory(&self, mut write_address: u32, mut data: &[u8]) -> Result<(), String> {
242 while !data.is_empty() {
243 let temp_write_size = if data.len() as u32 > MAX_REQUEST_DATA_SIZE {
244 MAX_REQUEST_DATA_SIZE
245 } else {
246 data.len() as u32
247 };
248
249 let mut request_data = Vec::with_capacity(2 * 4 + temp_write_size as usize);
250
251 {
252 let mut cursor = Cursor::new(&mut request_data);
253
254 cursor
255 .write_u32::<LittleEndian>(write_address)
256 .expect("Failed to write write address");
257 cursor
258 .write_u32::<LittleEndian>(temp_write_size)
259 .expect("Failed to write write size");
260 cursor
261 .write_all(&data[0..temp_write_size as usize])
262 .expect("Failed to write write data");
263 }
264
265 let incoming_data = self.make_request(RequestType::WriteMemory, &request_data)?;
266
267 if !incoming_data.is_empty() {
268 return Err(format!(
269 "Unexpected response payload of {} bytes",
270 incoming_data.len()
271 ));
272 }
273
274 data = &data[temp_write_size as usize..];
275 write_address += temp_write_size;
276 }
277
278 Ok(())
279 }
280
281 pub fn connect() -> Result<Self, String> {
283 let ctx = zmq::Context::new();
284
285 let socket = translate_zmq_error(ctx.socket(zmq::REQ))?;
286
287 translate_zmq_error(socket.connect(&format!("tcp://127.0.0.1:{}", CITRA_PORT)))?;
288
289 Ok(CitraConnection { socket })
290 }
291}
292
293#[cfg(test)]
295mod tests {
296 use CitraConnection;
297
298 #[test]
299 fn read_memory() {
300 let connection = CitraConnection::connect().expect("Got error while connecting");
301
302 let memory = connection
303 .read_memory(0x100000, 4)
304 .expect("Failed to read memory");
305
306 assert_eq!(memory.len(), 4);
307 }
308
309 #[test]
310 fn overwrite_memory() {
311 let connection = CitraConnection::connect().expect("Got error while connecting");
312
313 let memory_slice = [0xff as u8; 4];
314 let ptr = 0x0010_0000;
315
316 connection
317 .write_memory(ptr, &memory_slice)
318 .expect("Failed to write memory");
319
320 let memory = connection
321 .read_memory(ptr, memory_slice.len() as _)
322 .expect("Failed to read memory");
323
324 assert_eq!(&memory_slice, memory.as_slice());
325 }
326}