citra-scripting 0.1.1

A Rust interface to Citra's scripting interface.
Documentation
//! Basic implementation of the Citra scripting interface for Rust.
//!
//! Based on the Python implementation here:
//! <https://github.com/citra-emu/citra/commit/04dd91be822aa2358e2160370f6082ab81ec4a2b>
#![deny(missing_docs)]

extern crate byteorder;
extern crate rand;
extern crate zmq;

use std::io::Cursor;
use std::io::Error as IoError;
use std::io::Write;

use zmq::Socket;

use rand::prelude::*;

use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};

/// The protocol version.
const CURRENT_REQUEST_VERSION: u32 = 1;
/// Maximum amount of payload data that can be sent in a single request.
const MAX_REQUEST_DATA_SIZE: u32 = 32;

/// The port that the Citra server runs on.
const CITRA_PORT: u32 = 45987;

/// Different request types that can be sent.
#[derive(Copy, Clone)]
pub enum RequestType {
    /// A request to read from a memory region.
    ReadMemory,
    /// A request to write to a memory region.
    WriteMemory,
}

impl RequestType {
    /// Returns the protocol ID for this request type.
    fn get_id(self) -> u32 {
        match self {
            RequestType::ReadMemory => 1,
            RequestType::WriteMemory => 2,
        }
    }
}

/// Generates the outgoing header to be sent to Citra.
///
/// # Params
///
/// *request_type*: The kind of request to generate a header for.
/// *data_size*: the amount of payload (not header) data to be sent.
fn generate_header(request_type: RequestType, data_size: u32) -> ([u8; 4 * 4], u32) {
    let mut buf = [0 as u8; 4 * 4];

    let request_id = random::<u32>();

    {
        let request_type = request_type.get_id();

        let mut cursor: Cursor<&mut [u8]> = Cursor::new(&mut buf);
        cursor
            .write_u32::<LittleEndian>(CURRENT_REQUEST_VERSION)
            .expect("Failed to write request version");
        cursor
            .write_u32::<LittleEndian>(request_id)
            .expect("Failed to write request ID");
        cursor
            .write_u32::<LittleEndian>(request_type)
            .expect("Failed to write request type");
        cursor
            .write_u32::<LittleEndian>(data_size)
            .expect("Failed to write request size");
    }

    (buf, request_id)
}

/// Generates the outgoing header to be sent to Citra.
///
/// # Params
///
/// *raw_reply*: Data just received from a socket.
/// *expected_id*: The request ID for which this payload should satisfy.
/// *expected_type*: The expected type of this incoming payload.
fn read_and_validate_header(
    raw_reply: &[u8],
    expected_id: u32,
    expected_type: RequestType,
) -> Result<&[u8], String> {
    if raw_reply.len() < 4 * 4 {
        return Err(format!(
            "Payload is smaller than minimum (got {}, expected at least {})",
            raw_reply.len(),
            4 * 4
        ));
    }

    let mut cursor = Cursor::new(raw_reply);

    let expected_type = expected_type.get_id();

    let reply_version = translate_io_error(cursor.read_u32::<LittleEndian>())?;
    let reply_id = translate_io_error(cursor.read_u32::<LittleEndian>())?;
    let reply_type = translate_io_error(cursor.read_u32::<LittleEndian>())?;
    let reply_data_size = translate_io_error(cursor.read_u32::<LittleEndian>())?;

    if reply_version != CURRENT_REQUEST_VERSION {
        return Err(format!(
            "Bad request version (got {}, expected {})",
            reply_version, CURRENT_REQUEST_VERSION
        ));
    }

    if reply_id != expected_id {
        return Err(format!(
            "Bad request ID (got {}, expected {})",
            reply_id, expected_id
        ));
    }

    if reply_type != expected_type {
        return Err(format!(
            "Bad request type (got {}, expected {})",
            reply_type, expected_type
        ));
    }

    if reply_data_size != (raw_reply.len() - 4 * 4) as u32 {
        return Err(format!(
            "Bad request size (got {}, expected {})",
            reply_data_size,
            raw_reply.len() - 4 * 4
        ));
    }

    Ok(&raw_reply[4 * 4..])
}

/// Translates a ZMQ error to a generic String one.
fn translate_zmq_error<T>(payload: Result<T, zmq::Error>) -> Result<T, String> {
    payload.map_err(|x| format!("ZeroMQ error: {:?}", x))
}

/// Translates an I/O error to a generic String one.
fn translate_io_error<T>(payload: Result<T, IoError>) -> Result<T, String> {
    payload.map_err(|x| format!("I/O error: {:?}", x))
}

/// The main interface to Citra. Adds a level of abstraction on the ZMQ socket.
pub struct CitraConnection {
    socket: Socket,
}

impl CitraConnection {
    /// Makes a request to Citra, returning the response (if any).
    fn make_request(&self, request_kind: RequestType, data: &[u8]) -> Result<Vec<u8>, String> {
        let (request, request_id) = generate_header(request_kind, data.len() as _);

        let mut outgoing_buffer = Vec::with_capacity(request.len() + data.len());
        outgoing_buffer.extend_from_slice(&request);
        outgoing_buffer.extend_from_slice(data);

        translate_zmq_error(self.socket.send(&outgoing_buffer, 0))?;

        let req_reply = translate_zmq_error(self.socket.recv_bytes(0))?;

        let data = read_and_validate_header(&req_reply, request_id, request_kind)?;

        Ok(data.to_vec())
    }

    /// Reads a region of memory.
    ///
    /// # Params
    ///
    /// *read_address*: The remote memory pointer to read from.
    /// *read_size*: The amount of data to read, in bytes.
    ///
    /// # Example
    ///
    /// ```rust
    /// use citra_scripting::CitraConnection;
    ///
    /// let connection = CitraConnection::connect().unwrap();
    /// connection.read_memory(0x100000, 4);
    /// ```
    pub fn read_memory(
        &self,
        mut read_address: u32,
        mut read_size: u32,
    ) -> Result<Vec<u8>, String> {
        let mut result = Vec::with_capacity(read_size as _);

        while read_size > 0 {
            let temp_read_size = if read_size > MAX_REQUEST_DATA_SIZE {
                MAX_REQUEST_DATA_SIZE
            } else {
                read_size
            };

            let mut request_data = [0 as u8; 2 * 4];

            {
                let mut cursor: Cursor<&mut [u8]> = Cursor::new(&mut request_data);

                cursor
                    .write_u32::<LittleEndian>(read_address)
                    .expect("Failed to write read address");
                cursor
                    .write_u32::<LittleEndian>(temp_read_size)
                    .expect("Failed to write read size");
            }

            let data = self.make_request(RequestType::ReadMemory, &request_data)?;
            result.extend_from_slice(&data);

            read_size -= temp_read_size;
            read_address += temp_read_size;
        }

        Ok(result)
    }

    /// Reads a region of memory.
    ///
    /// # Params
    ///
    /// *write_address*: The remote memory pointer to write to.
    /// *data*: The data to write.
    ///
    /// # Example
    ///
    /// ```rust
    /// use citra_scripting::CitraConnection;
    ///
    /// let connection = CitraConnection::connect().unwrap();
    /// connection.write_memory(0x100000, &[0xff as u8; 4]);
    /// ```
    pub fn write_memory(&self, mut write_address: u32, mut data: &[u8]) -> Result<(), String> {
        while !data.is_empty() {
            let temp_write_size = if data.len() as u32 > MAX_REQUEST_DATA_SIZE {
                MAX_REQUEST_DATA_SIZE
            } else {
                data.len() as u32
            };

            let mut request_data = Vec::with_capacity(2 * 4 + temp_write_size as usize);

            {
                let mut cursor = Cursor::new(&mut request_data);

                cursor
                    .write_u32::<LittleEndian>(write_address)
                    .expect("Failed to write write address");
                cursor
                    .write_u32::<LittleEndian>(temp_write_size)
                    .expect("Failed to write write size");
                cursor
                    .write_all(&data[0..temp_write_size as usize])
                    .expect("Failed to write write data");
            }

            let incoming_data = self.make_request(RequestType::WriteMemory, &request_data)?;

            if !incoming_data.is_empty() {
                return Err(format!(
                    "Unexpected response payload of {} bytes",
                    incoming_data.len()
                ));
            }

            data = &data[temp_write_size as usize..];
            write_address += temp_write_size;
        }

        Ok(())
    }

    /// Connects to the current Citra client, assuming defaults.
    pub fn connect() -> Result<Self, String> {
        let ctx = zmq::Context::new();

        let socket = translate_zmq_error(ctx.socket(zmq::REQ))?;

        translate_zmq_error(socket.connect(&format!("tcp://127.0.0.1:{}", CITRA_PORT)))?;

        Ok(CitraConnection { socket })
    }
}

/// Tests need a active Citra client running.
#[cfg(test)]
mod tests {
    use CitraConnection;

    #[test]
    fn read_memory() {
        let connection = CitraConnection::connect().expect("Got error while connecting");

        let memory = connection
            .read_memory(0x100000, 4)
            .expect("Failed to read memory");

        assert_eq!(memory.len(), 4);
    }

    #[test]
    fn overwrite_memory() {
        let connection = CitraConnection::connect().expect("Got error while connecting");

        let memory_slice = [0xff as u8; 4];
        let ptr = 0x0010_0000;

        connection
            .write_memory(ptr, &memory_slice)
            .expect("Failed to write memory");

        let memory = connection
            .read_memory(ptr, memory_slice.len() as _)
            .expect("Failed to read memory");

        assert_eq!(&memory_slice, memory.as_slice());
    }
}