citra_scripting/
lib.rs

1//! Basic implementation of the Citra scripting interface for Rust.
2//!
3//! Based on the Python implementation here:
4//! <https://github.com/citra-emu/citra/commit/04dd91be822aa2358e2160370f6082ab81ec4a2b>
5#![deny(missing_docs)]
6
7extern crate byteorder;
8extern crate rand;
9extern crate zmq;
10
11use std::io::Cursor;
12use std::io::Error as IoError;
13use std::io::Write;
14
15use zmq::Socket;
16
17use rand::prelude::*;
18
19use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
20
21/// The protocol version.
22const CURRENT_REQUEST_VERSION: u32 = 1;
23/// Maximum amount of payload data that can be sent in a single request.
24const MAX_REQUEST_DATA_SIZE: u32 = 32;
25
26/// The port that the Citra server runs on.
27const CITRA_PORT: u32 = 45987;
28
29/// Different request types that can be sent.
30#[derive(Copy, Clone)]
31pub enum RequestType {
32    /// A request to read from a memory region.
33    ReadMemory,
34    /// A request to write to a memory region.
35    WriteMemory,
36}
37
38impl RequestType {
39    /// Returns the protocol ID for this request type.
40    fn get_id(self) -> u32 {
41        match self {
42            RequestType::ReadMemory => 1,
43            RequestType::WriteMemory => 2,
44        }
45    }
46}
47
48/// Generates the outgoing header to be sent to Citra.
49///
50/// # Params
51///
52/// *request_type*: The kind of request to generate a header for.
53/// *data_size*: the amount of payload (not header) data to be sent.
54fn generate_header(request_type: RequestType, data_size: u32) -> ([u8; 4 * 4], u32) {
55    let mut buf = [0 as u8; 4 * 4];
56
57    let request_id = random::<u32>();
58
59    {
60        let request_type = request_type.get_id();
61
62        let mut cursor: Cursor<&mut [u8]> = Cursor::new(&mut buf);
63        cursor
64            .write_u32::<LittleEndian>(CURRENT_REQUEST_VERSION)
65            .expect("Failed to write request version");
66        cursor
67            .write_u32::<LittleEndian>(request_id)
68            .expect("Failed to write request ID");
69        cursor
70            .write_u32::<LittleEndian>(request_type)
71            .expect("Failed to write request type");
72        cursor
73            .write_u32::<LittleEndian>(data_size)
74            .expect("Failed to write request size");
75    }
76
77    (buf, request_id)
78}
79
80/// Generates the outgoing header to be sent to Citra.
81///
82/// # Params
83///
84/// *raw_reply*: Data just received from a socket.
85/// *expected_id*: The request ID for which this payload should satisfy.
86/// *expected_type*: The expected type of this incoming payload.
87fn read_and_validate_header(
88    raw_reply: &[u8],
89    expected_id: u32,
90    expected_type: RequestType,
91) -> Result<&[u8], String> {
92    if raw_reply.len() < 4 * 4 {
93        return Err(format!(
94            "Payload is smaller than minimum (got {}, expected at least {})",
95            raw_reply.len(),
96            4 * 4
97        ));
98    }
99
100    let mut cursor = Cursor::new(raw_reply);
101
102    let expected_type = expected_type.get_id();
103
104    let reply_version = translate_io_error(cursor.read_u32::<LittleEndian>())?;
105    let reply_id = translate_io_error(cursor.read_u32::<LittleEndian>())?;
106    let reply_type = translate_io_error(cursor.read_u32::<LittleEndian>())?;
107    let reply_data_size = translate_io_error(cursor.read_u32::<LittleEndian>())?;
108
109    if reply_version != CURRENT_REQUEST_VERSION {
110        return Err(format!(
111            "Bad request version (got {}, expected {})",
112            reply_version, CURRENT_REQUEST_VERSION
113        ));
114    }
115
116    if reply_id != expected_id {
117        return Err(format!(
118            "Bad request ID (got {}, expected {})",
119            reply_id, expected_id
120        ));
121    }
122
123    if reply_type != expected_type {
124        return Err(format!(
125            "Bad request type (got {}, expected {})",
126            reply_type, expected_type
127        ));
128    }
129
130    if reply_data_size != (raw_reply.len() - 4 * 4) as u32 {
131        return Err(format!(
132            "Bad request size (got {}, expected {})",
133            reply_data_size,
134            raw_reply.len() - 4 * 4
135        ));
136    }
137
138    Ok(&raw_reply[4 * 4..])
139}
140
141/// Translates a ZMQ error to a generic String one.
142fn translate_zmq_error<T>(payload: Result<T, zmq::Error>) -> Result<T, String> {
143    payload.map_err(|x| format!("ZeroMQ error: {:?}", x))
144}
145
146/// Translates an I/O error to a generic String one.
147fn translate_io_error<T>(payload: Result<T, IoError>) -> Result<T, String> {
148    payload.map_err(|x| format!("I/O error: {:?}", x))
149}
150
151/// The main interface to Citra. Adds a level of abstraction on the ZMQ socket.
152pub struct CitraConnection {
153    socket: Socket,
154}
155
156impl CitraConnection {
157    /// Makes a request to Citra, returning the response (if any).
158    fn make_request(&self, request_kind: RequestType, data: &[u8]) -> Result<Vec<u8>, String> {
159        let (request, request_id) = generate_header(request_kind, data.len() as _);
160
161        let mut outgoing_buffer = Vec::with_capacity(request.len() + data.len());
162        outgoing_buffer.extend_from_slice(&request);
163        outgoing_buffer.extend_from_slice(data);
164
165        translate_zmq_error(self.socket.send(&outgoing_buffer, 0))?;
166
167        let req_reply = translate_zmq_error(self.socket.recv_bytes(0))?;
168
169        let data = read_and_validate_header(&req_reply, request_id, request_kind)?;
170
171        Ok(data.to_vec())
172    }
173
174    /// Reads a region of memory.
175    ///
176    /// # Params
177    ///
178    /// *read_address*: The remote memory pointer to read from.
179    /// *read_size*: The amount of data to read, in bytes.
180    ///
181    /// # Example
182    ///
183    /// ```rust
184    /// use citra_scripting::CitraConnection;
185    ///
186    /// let connection = CitraConnection::connect().unwrap();
187    /// connection.read_memory(0x100000, 4);
188    /// ```
189    pub fn read_memory(
190        &self,
191        mut read_address: u32,
192        mut read_size: u32,
193    ) -> Result<Vec<u8>, String> {
194        let mut result = Vec::with_capacity(read_size as _);
195
196        while read_size > 0 {
197            let temp_read_size = if read_size > MAX_REQUEST_DATA_SIZE {
198                MAX_REQUEST_DATA_SIZE
199            } else {
200                read_size
201            };
202
203            let mut request_data = [0 as u8; 2 * 4];
204
205            {
206                let mut cursor: Cursor<&mut [u8]> = Cursor::new(&mut request_data);
207
208                cursor
209                    .write_u32::<LittleEndian>(read_address)
210                    .expect("Failed to write read address");
211                cursor
212                    .write_u32::<LittleEndian>(temp_read_size)
213                    .expect("Failed to write read size");
214            }
215
216            let data = self.make_request(RequestType::ReadMemory, &request_data)?;
217            result.extend_from_slice(&data);
218
219            read_size -= temp_read_size;
220            read_address += temp_read_size;
221        }
222
223        Ok(result)
224    }
225
226    /// Reads a region of memory.
227    ///
228    /// # Params
229    ///
230    /// *write_address*: The remote memory pointer to write to.
231    /// *data*: The data to write.
232    ///
233    /// # Example
234    ///
235    /// ```rust
236    /// use citra_scripting::CitraConnection;
237    ///
238    /// let connection = CitraConnection::connect().unwrap();
239    /// connection.write_memory(0x100000, &[0xff as u8; 4]);
240    /// ```
241    pub fn write_memory(&self, mut write_address: u32, mut data: &[u8]) -> Result<(), String> {
242        while !data.is_empty() {
243            let temp_write_size = if data.len() as u32 > MAX_REQUEST_DATA_SIZE {
244                MAX_REQUEST_DATA_SIZE
245            } else {
246                data.len() as u32
247            };
248
249            let mut request_data = Vec::with_capacity(2 * 4 + temp_write_size as usize);
250
251            {
252                let mut cursor = Cursor::new(&mut request_data);
253
254                cursor
255                    .write_u32::<LittleEndian>(write_address)
256                    .expect("Failed to write write address");
257                cursor
258                    .write_u32::<LittleEndian>(temp_write_size)
259                    .expect("Failed to write write size");
260                cursor
261                    .write_all(&data[0..temp_write_size as usize])
262                    .expect("Failed to write write data");
263            }
264
265            let incoming_data = self.make_request(RequestType::WriteMemory, &request_data)?;
266
267            if !incoming_data.is_empty() {
268                return Err(format!(
269                    "Unexpected response payload of {} bytes",
270                    incoming_data.len()
271                ));
272            }
273
274            data = &data[temp_write_size as usize..];
275            write_address += temp_write_size;
276        }
277
278        Ok(())
279    }
280
281    /// Connects to the current Citra client, assuming defaults.
282    pub fn connect() -> Result<Self, String> {
283        let ctx = zmq::Context::new();
284
285        let socket = translate_zmq_error(ctx.socket(zmq::REQ))?;
286
287        translate_zmq_error(socket.connect(&format!("tcp://127.0.0.1:{}", CITRA_PORT)))?;
288
289        Ok(CitraConnection { socket })
290    }
291}
292
293/// Tests need a active Citra client running.
294#[cfg(test)]
295mod tests {
296    use CitraConnection;
297
298    #[test]
299    fn read_memory() {
300        let connection = CitraConnection::connect().expect("Got error while connecting");
301
302        let memory = connection
303            .read_memory(0x100000, 4)
304            .expect("Failed to read memory");
305
306        assert_eq!(memory.len(), 4);
307    }
308
309    #[test]
310    fn overwrite_memory() {
311        let connection = CitraConnection::connect().expect("Got error while connecting");
312
313        let memory_slice = [0xff as u8; 4];
314        let ptr = 0x0010_0000;
315
316        connection
317            .write_memory(ptr, &memory_slice)
318            .expect("Failed to write memory");
319
320        let memory = connection
321            .read_memory(ptr, memory_slice.len() as _)
322            .expect("Failed to read memory");
323
324        assert_eq!(&memory_slice, memory.as_slice());
325    }
326}