use crate::error::{Error, Result};
use crate::process;
use crate::types::VirtAddr;
use nix::unistd::Pid;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WatchpointType {
Write,
ReadWrite,
Execute,
}
impl WatchpointType {
fn dr7_condition(self) -> u64 {
match self {
WatchpointType::Write => 0b01,
WatchpointType::ReadWrite => 0b11,
WatchpointType::Execute => 0b00,
}
}
}
impl std::fmt::Display for WatchpointType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
WatchpointType::Write => write!(f, "write"),
WatchpointType::ReadWrite => write!(f, "rw"),
WatchpointType::Execute => write!(f, "execute"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WatchpointSize {
Byte1,
Byte2,
Byte4,
Byte8,
}
impl WatchpointSize {
fn dr7_length(self) -> u64 {
match self {
WatchpointSize::Byte1 => 0b00,
WatchpointSize::Byte2 => 0b01,
WatchpointSize::Byte4 => 0b11,
WatchpointSize::Byte8 => 0b10,
}
}
pub fn bytes(self) -> usize {
match self {
WatchpointSize::Byte1 => 1,
WatchpointSize::Byte2 => 2,
WatchpointSize::Byte4 => 4,
WatchpointSize::Byte8 => 8,
}
}
pub fn from_bytes(n: usize) -> Option<Self> {
match n {
1 => Some(WatchpointSize::Byte1),
2 => Some(WatchpointSize::Byte2),
4 => Some(WatchpointSize::Byte4),
8 => Some(WatchpointSize::Byte8),
_ => None,
}
}
}
impl std::fmt::Display for WatchpointSize {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.bytes())
}
}
#[derive(Debug, Clone)]
pub struct Watchpoint {
pub id: u32,
pub addr: VirtAddr,
pub wp_type: WatchpointType,
pub size: WatchpointSize,
pub slot: usize,
}
pub struct WatchpointManager {
slots: [Option<Watchpoint>; 4],
next_id: u32,
}
impl Default for WatchpointManager {
fn default() -> Self {
Self {
slots: [None, None, None, None],
next_id: 1,
}
}
}
impl WatchpointManager {
pub fn new() -> Self {
Self::default()
}
pub fn set(
&mut self,
pid: Pid,
addr: VirtAddr,
wp_type: WatchpointType,
size: WatchpointSize,
) -> Result<u32> {
let align = size.bytes() as u64;
if addr.addr() & (align - 1) != 0 {
return Err(Error::Other(format!(
"watchpoint address {:#x} must be aligned to {} bytes",
addr.addr(),
align
)));
}
let slot = self.slots.iter().position(|s| s.is_none()).ok_or_else(|| {
Error::Other("no free hardware debug registers (max 4 watchpoints)".into())
})?;
process::write_debug_reg(pid, slot, addr.addr())?;
let mut dr7 = process::read_debug_reg(pid, 7)?;
let enable_mask = 0b11u64 << (slot * 2);
let config_mask = 0b1111u64 << (slot * 4 + 16);
dr7 &= !(enable_mask | config_mask);
dr7 |= 1u64 << (slot * 2);
dr7 |= wp_type.dr7_condition() << (slot * 4 + 16);
dr7 |= size.dr7_length() << (slot * 4 + 18);
process::write_debug_reg(pid, 7, dr7)?;
let id = self.next_id;
self.next_id += 1;
self.slots[slot] = Some(Watchpoint {
id,
addr,
wp_type,
size,
slot,
});
Ok(id)
}
pub fn remove(&mut self, pid: Pid, id: u32) -> Result<()> {
let slot = self
.slots
.iter()
.position(|s| s.as_ref().map(|w| w.id) == Some(id))
.ok_or_else(|| Error::Other(format!("no watchpoint with id {}", id)))?;
self.clear_slot(pid, slot)?;
self.slots[slot] = None;
Ok(())
}
pub fn remove_at(&mut self, pid: Pid, addr: VirtAddr) -> Result<()> {
let slot = self
.slots
.iter()
.position(|s| s.as_ref().map(|w| w.addr) == Some(addr))
.ok_or_else(|| Error::Other(format!("no watchpoint at {}", addr)))?;
self.clear_slot(pid, slot)?;
self.slots[slot] = None;
Ok(())
}
pub fn get_hit(&self, pid: Pid) -> Result<Option<(usize, VirtAddr)>> {
let dr6 = process::read_debug_reg(pid, 6)?;
for i in 0..4 {
if dr6 & (1 << i) != 0 {
if let Some(wp) = &self.slots[i] {
process::write_debug_reg(pid, 6, 0)?;
return Ok(Some((i, wp.addr)));
}
}
}
process::write_debug_reg(pid, 6, 0)?;
Ok(None)
}
pub fn list(&self) -> Vec<&Watchpoint> {
self.slots.iter().filter_map(|s| s.as_ref()).collect()
}
pub fn get_at_slot(&self, slot: usize) -> Option<&Watchpoint> {
self.slots.get(slot).and_then(|s| s.as_ref())
}
fn clear_slot(&self, pid: Pid, slot: usize) -> Result<()> {
process::write_debug_reg(pid, slot, 0)?;
let mut dr7 = process::read_debug_reg(pid, 7)?;
let enable_mask = 0b11u64 << (slot * 2);
let config_mask = 0b1111u64 << (slot * 4 + 16);
dr7 &= !(enable_mask | config_mask);
process::write_debug_reg(pid, 7, dr7)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn watchpoint_type_dr7_condition() {
assert_eq!(WatchpointType::Write.dr7_condition(), 0b01);
assert_eq!(WatchpointType::ReadWrite.dr7_condition(), 0b11);
assert_eq!(WatchpointType::Execute.dr7_condition(), 0b00);
}
#[test]
fn watchpoint_size_dr7_length() {
assert_eq!(WatchpointSize::Byte1.dr7_length(), 0b00);
assert_eq!(WatchpointSize::Byte2.dr7_length(), 0b01);
assert_eq!(WatchpointSize::Byte4.dr7_length(), 0b11);
assert_eq!(WatchpointSize::Byte8.dr7_length(), 0b10);
}
#[test]
fn watchpoint_size_conversion() {
assert_eq!(WatchpointSize::from_bytes(1), Some(WatchpointSize::Byte1));
assert_eq!(WatchpointSize::from_bytes(2), Some(WatchpointSize::Byte2));
assert_eq!(WatchpointSize::from_bytes(4), Some(WatchpointSize::Byte4));
assert_eq!(WatchpointSize::from_bytes(8), Some(WatchpointSize::Byte8));
assert_eq!(WatchpointSize::from_bytes(3), None);
assert_eq!(WatchpointSize::from_bytes(16), None);
}
#[test]
fn watchpoint_type_display() {
assert_eq!(format!("{}", WatchpointType::Write), "write");
assert_eq!(format!("{}", WatchpointType::ReadWrite), "rw");
assert_eq!(format!("{}", WatchpointType::Execute), "execute");
}
#[test]
fn watchpoint_manager_empty() {
let mgr = WatchpointManager::new();
assert!(mgr.list().is_empty());
assert!(mgr.get_at_slot(0).is_none());
assert!(mgr.get_at_slot(4).is_none());
}
}