use std::io::{Error, ErrorKind, Read, Result, Seek, SeekFrom, Write};
use super::MemoryView;
use crate::types::{umem, Address};
pub struct MemoryCursor<T> {
mem: T,
address: Address,
}
impl<T: MemoryView> MemoryCursor<T> {
pub fn new(mem: T) -> Self {
Self {
mem,
address: Address::NULL,
}
}
pub fn at(mem: T, address: Address) -> Self {
Self { mem, address }
}
pub fn into_inner(self) -> T {
self.mem
}
pub fn get_ref(&self) -> &T {
&self.mem
}
pub fn get_mut(&mut self) -> &mut T {
&mut self.mem
}
pub fn address(&self) -> Address {
self.address
}
pub fn set_address(&mut self, address: Address) {
self.address = address;
}
}
impl<T: MemoryView> Read for MemoryCursor<T> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
self.mem
.read_raw_into(self.address, buf)
.map_err(|err| Error::new(ErrorKind::UnexpectedEof, err))?;
self.address = (self.address.to_umem() + buf.len() as umem).into();
Ok(buf.len())
}
}
impl<T: MemoryView> Write for MemoryCursor<T> {
fn write(&mut self, buf: &[u8]) -> Result<usize> {
self.mem
.write_raw(self.address, buf)
.map_err(|err| Error::new(ErrorKind::UnexpectedEof, err))?;
self.address = (self.address.to_umem() + buf.len() as umem).into();
Ok(buf.len())
}
fn flush(&mut self) -> Result<()> {
Ok(())
}
}
impl<T: MemoryView> Seek for MemoryCursor<T> {
fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
let target_pos = match pos {
SeekFrom::Start(offs) => offs,
SeekFrom::End(offs) => self
.mem
.metadata()
.max_address
.to_umem()
.wrapping_add(1)
.wrapping_add(offs as umem) as u64,
SeekFrom::Current(offs) => self.address.to_umem().wrapping_add(offs as umem) as u64,
};
self.address = target_pos.into();
Ok(target_pos)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::architecture::x86::{x64, X86VirtualTranslate};
use crate::dummy::{DummyMemory, DummyOs};
use crate::mem::{DirectTranslate, PhysicalMemory, VirtualDma};
use crate::types::{mem, size};
fn dummy_phys_mem() -> DummyMemory {
DummyMemory::new(size::mb(1))
}
#[test]
fn physical_seek() {
let mut phys_mem = dummy_phys_mem();
let mut cursor = MemoryCursor::new(phys_mem.phys_view());
assert_eq!(cursor.stream_position().unwrap(), 0);
assert_eq!(cursor.seek(SeekFrom::Current(1024)).unwrap(), 1024);
assert_eq!(cursor.seek(SeekFrom::Current(1024)).unwrap(), 2048);
assert_eq!(cursor.seek(SeekFrom::Current(-1024)).unwrap(), 1024);
assert_eq!(cursor.seek(SeekFrom::Start(512)).unwrap(), 512);
assert_eq!(
cursor.seek(SeekFrom::End(-512)).unwrap(),
mem::mb(1) as u64 - 512
);
}
#[test]
fn physical_read_write() {
let mut phys_mem = dummy_phys_mem();
let mut cursor = MemoryCursor::new(phys_mem.phys_view());
let write_buf = [0xAu8, 0xB, 0xC, 0xD];
assert_eq!(cursor.write(&write_buf).unwrap(), 4); assert_eq!(cursor.stream_position().unwrap(), 4);
let mut read_buf = [0u8; 4];
assert!(cursor.rewind().is_ok()); assert_eq!(cursor.read(&mut read_buf).unwrap(), 4); assert_eq!(read_buf, write_buf); }
#[test]
fn physical_read_write_seek() {
let mut phys_mem = dummy_phys_mem();
let mut cursor = MemoryCursor::new(phys_mem.phys_view());
assert_eq!(cursor.seek(SeekFrom::Start(512)).unwrap(), 512);
let write_buf = [0xAu8, 0xB, 0xC, 0xD];
assert_eq!(cursor.write(&write_buf).unwrap(), 4); assert_eq!(cursor.stream_position().unwrap(), 512 + 4);
let mut read_buf = [0u8; 4];
assert_eq!(cursor.seek(SeekFrom::Start(512)).unwrap(), 512); assert_eq!(cursor.read(&mut read_buf).unwrap(), 4); assert_eq!(read_buf, write_buf); }
fn dummy_virt_mem() -> (
VirtualDma<DummyMemory, DirectTranslate, X86VirtualTranslate>,
Address,
) {
let phys_mem = DummyMemory::new(size::mb(1));
let mut os = DummyOs::new(phys_mem);
let (dtb, virt_base) = os.alloc_dtb(size::mb(1), &[]);
let phys_mem = os.into_inner();
let translator = x64::new_translator(dtb);
(VirtualDma::new(phys_mem, x64::ARCH, translator), virt_base)
}
#[test]
fn virtual_seek() {
let (virt_mem, _) = dummy_virt_mem();
let mut cursor = MemoryCursor::new(virt_mem);
assert_eq!(cursor.stream_position().unwrap(), 0);
assert_eq!(cursor.seek(SeekFrom::Current(1024)).unwrap(), 1024);
assert_eq!(cursor.seek(SeekFrom::Current(1024)).unwrap(), 2048);
assert_eq!(cursor.seek(SeekFrom::Current(-1024)).unwrap(), 1024);
assert_eq!(cursor.seek(SeekFrom::Start(512)).unwrap(), 512);
}
#[test]
fn virtual_read_write() {
let (virt_mem, virt_base) = dummy_virt_mem();
let mut cursor = MemoryCursor::new(virt_mem);
let write_buf = [0xAu8, 0xB, 0xC, 0xD];
assert_eq!(
cursor
.seek(SeekFrom::Start(virt_base.to_umem() as u64))
.unwrap(),
virt_base.to_umem() as u64
);
assert_eq!(cursor.write(&write_buf).unwrap(), 4); assert_eq!(
cursor.stream_position().unwrap(),
virt_base.to_umem() as u64 + 4
);
let mut read_buf = [0u8; 4];
assert_eq!(
cursor
.seek(SeekFrom::Start(virt_base.to_umem() as u64))
.unwrap(),
virt_base.to_umem() as u64
); assert_eq!(cursor.read(&mut read_buf).unwrap(), 4); assert_eq!(read_buf, write_buf); }
#[test]
fn virtual_read_write_seek() {
let (virt_mem, virt_base) = dummy_virt_mem();
let mut cursor = MemoryCursor::new(virt_mem);
assert_eq!(
cursor
.seek(SeekFrom::Start(virt_base.to_umem() as u64 + 512))
.unwrap(),
virt_base.to_umem() as u64 + 512
);
let write_buf = [0xAu8, 0xB, 0xC, 0xD];
assert_eq!(cursor.write(&write_buf).unwrap(), 4); assert_eq!(
cursor.stream_position().unwrap(),
virt_base.to_umem() as u64 + 512 + 4
);
let mut read_buf = [0u8; 4];
assert_eq!(
cursor
.seek(SeekFrom::Start(virt_base.to_umem() as u64 + 512))
.unwrap(),
virt_base.to_umem() as u64 + 512
); assert_eq!(cursor.read(&mut read_buf).unwrap(), 4); assert_eq!(read_buf, write_buf); }
}