#![deny(missing_docs)]
extern crate byteorder;
extern crate rand;
extern crate zmq;
use std::io::Cursor;
use std::io::Error as IoError;
use std::io::Write;
use zmq::Socket;
use rand::prelude::*;
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
const CURRENT_REQUEST_VERSION: u32 = 1;
const MAX_REQUEST_DATA_SIZE: u32 = 32;
const CITRA_PORT: u32 = 45987;
#[derive(Copy, Clone)]
pub enum RequestType {
ReadMemory,
WriteMemory,
}
impl RequestType {
fn get_id(self) -> u32 {
match self {
RequestType::ReadMemory => 1,
RequestType::WriteMemory => 2,
}
}
}
fn generate_header(request_type: RequestType, data_size: u32) -> ([u8; 4 * 4], u32) {
let mut buf = [0 as u8; 4 * 4];
let request_id = random::<u32>();
{
let request_type = request_type.get_id();
let mut cursor: Cursor<&mut [u8]> = Cursor::new(&mut buf);
cursor
.write_u32::<LittleEndian>(CURRENT_REQUEST_VERSION)
.expect("Failed to write request version");
cursor
.write_u32::<LittleEndian>(request_id)
.expect("Failed to write request ID");
cursor
.write_u32::<LittleEndian>(request_type)
.expect("Failed to write request type");
cursor
.write_u32::<LittleEndian>(data_size)
.expect("Failed to write request size");
}
(buf, request_id)
}
fn read_and_validate_header(
raw_reply: &[u8],
expected_id: u32,
expected_type: RequestType,
) -> Result<&[u8], String> {
if raw_reply.len() < 4 * 4 {
return Err(format!(
"Payload is smaller than minimum (got {}, expected at least {})",
raw_reply.len(),
4 * 4
));
}
let mut cursor = Cursor::new(raw_reply);
let expected_type = expected_type.get_id();
let reply_version = translate_io_error(cursor.read_u32::<LittleEndian>())?;
let reply_id = translate_io_error(cursor.read_u32::<LittleEndian>())?;
let reply_type = translate_io_error(cursor.read_u32::<LittleEndian>())?;
let reply_data_size = translate_io_error(cursor.read_u32::<LittleEndian>())?;
if reply_version != CURRENT_REQUEST_VERSION {
return Err(format!(
"Bad request version (got {}, expected {})",
reply_version, CURRENT_REQUEST_VERSION
));
}
if reply_id != expected_id {
return Err(format!(
"Bad request ID (got {}, expected {})",
reply_id, expected_id
));
}
if reply_type != expected_type {
return Err(format!(
"Bad request type (got {}, expected {})",
reply_type, expected_type
));
}
if reply_data_size != (raw_reply.len() - 4 * 4) as u32 {
return Err(format!(
"Bad request size (got {}, expected {})",
reply_data_size,
raw_reply.len() - 4 * 4
));
}
Ok(&raw_reply[4 * 4..])
}
fn translate_zmq_error<T>(payload: Result<T, zmq::Error>) -> Result<T, String> {
payload.map_err(|x| format!("ZeroMQ error: {:?}", x))
}
fn translate_io_error<T>(payload: Result<T, IoError>) -> Result<T, String> {
payload.map_err(|x| format!("I/O error: {:?}", x))
}
pub struct CitraConnection {
socket: Socket,
}
impl CitraConnection {
fn make_request(&self, request_kind: RequestType, data: &[u8]) -> Result<Vec<u8>, String> {
let (request, request_id) = generate_header(request_kind, data.len() as _);
let mut outgoing_buffer = Vec::with_capacity(request.len() + data.len());
outgoing_buffer.extend_from_slice(&request);
outgoing_buffer.extend_from_slice(data);
translate_zmq_error(self.socket.send(&outgoing_buffer, 0))?;
let req_reply = translate_zmq_error(self.socket.recv_bytes(0))?;
let data = read_and_validate_header(&req_reply, request_id, request_kind)?;
Ok(data.to_vec())
}
pub fn read_memory(
&self,
mut read_address: u32,
mut read_size: u32,
) -> Result<Vec<u8>, String> {
let mut result = Vec::with_capacity(read_size as _);
while read_size > 0 {
let temp_read_size = if read_size > MAX_REQUEST_DATA_SIZE {
MAX_REQUEST_DATA_SIZE
} else {
read_size
};
let mut request_data = [0 as u8; 2 * 4];
{
let mut cursor: Cursor<&mut [u8]> = Cursor::new(&mut request_data);
cursor
.write_u32::<LittleEndian>(read_address)
.expect("Failed to write read address");
cursor
.write_u32::<LittleEndian>(temp_read_size)
.expect("Failed to write read size");
}
let data = self.make_request(RequestType::ReadMemory, &request_data)?;
result.extend_from_slice(&data);
read_size -= temp_read_size;
read_address += temp_read_size;
}
Ok(result)
}
pub fn write_memory(&self, mut write_address: u32, mut data: &[u8]) -> Result<(), String> {
while !data.is_empty() {
let temp_write_size = if data.len() as u32 > MAX_REQUEST_DATA_SIZE {
MAX_REQUEST_DATA_SIZE
} else {
data.len() as u32
};
let mut request_data = Vec::with_capacity(2 * 4 + temp_write_size as usize);
{
let mut cursor = Cursor::new(&mut request_data);
cursor
.write_u32::<LittleEndian>(write_address)
.expect("Failed to write write address");
cursor
.write_u32::<LittleEndian>(temp_write_size)
.expect("Failed to write write size");
cursor
.write_all(&data[0..temp_write_size as usize])
.expect("Failed to write write data");
}
let incoming_data = self.make_request(RequestType::WriteMemory, &request_data)?;
if !incoming_data.is_empty() {
return Err(format!(
"Unexpected response payload of {} bytes",
incoming_data.len()
));
}
data = &data[temp_write_size as usize..];
write_address += temp_write_size;
}
Ok(())
}
pub fn connect() -> Result<Self, String> {
let ctx = zmq::Context::new();
let socket = translate_zmq_error(ctx.socket(zmq::REQ))?;
translate_zmq_error(socket.connect(&format!("tcp://127.0.0.1:{}", CITRA_PORT)))?;
Ok(CitraConnection { socket })
}
}
#[cfg(test)]
mod tests {
use CitraConnection;
#[test]
fn read_memory() {
let connection = CitraConnection::connect().expect("Got error while connecting");
let memory = connection
.read_memory(0x100000, 4)
.expect("Failed to read memory");
assert_eq!(memory.len(), 4);
}
#[test]
fn overwrite_memory() {
let connection = CitraConnection::connect().expect("Got error while connecting");
let memory_slice = [0xff as u8; 4];
let ptr = 0x0010_0000;
connection
.write_memory(ptr, &memory_slice)
.expect("Failed to write memory");
let memory = connection
.read_memory(ptr, memory_slice.len() as _)
.expect("Failed to read memory");
assert_eq!(&memory_slice, memory.as_slice());
}
}