use super::var_id::VarId;
use crate::symbol::SymbolId;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum LockType {
Mutex,
RwLockRead,
RwLockWrite,
RefCell,
RefCellMut,
ParkingLotMutex,
ParkingLotRwLock,
TokioMutex,
TokioRwLock,
}
impl LockType {
pub fn is_read_only(&self) -> bool {
matches!(self, LockType::RwLockRead | LockType::RefCell)
}
pub fn is_exclusive(&self) -> bool {
!self.is_read_only()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AccessKind {
Read,
Write,
ReadWrite,
}
impl AccessKind {
pub fn merge(self, other: Self) -> Self {
match (self, other) {
(AccessKind::Read, AccessKind::Read) => AccessKind::Read,
(AccessKind::Write, AccessKind::Write) => AccessKind::Write,
_ => AccessKind::ReadWrite,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LockSuggestion {
UseAtomic {
field: String,
current_type: Option<String>,
suggested_type: String,
line: u32,
},
SplitLock {
lock_name: String,
suggested_splits: Vec<(String, String)>,
line: u32,
},
ReduceScope {
guard_name: String,
current_span: (u32, u32),
suggested_span: (u32, u32),
reason: String,
},
UseRwLock {
lock_name: String,
read_count: usize,
write_count: usize,
line: u32,
},
LockAcrossAwait {
guard_name: String,
lock_line: u32,
await_line: u32,
},
RemoveLock {
lock_name: String,
reason: String,
line: u32,
},
}
impl LockSuggestion {
pub fn severity(&self) -> u8 {
match self {
LockSuggestion::LockAcrossAwait { .. } => 10, LockSuggestion::UseAtomic { .. } => 8, LockSuggestion::SplitLock { .. } => 6, LockSuggestion::ReduceScope { .. } => 5, LockSuggestion::UseRwLock { .. } => 4, LockSuggestion::RemoveLock { .. } => 3, }
}
pub fn description(&self) -> String {
match self {
LockSuggestion::UseAtomic {
field,
suggested_type,
..
} => {
format!("Consider using {} for field '{}'", suggested_type, field)
}
LockSuggestion::SplitLock {
lock_name,
suggested_splits,
..
} => {
let fields: Vec<_> = suggested_splits.iter().map(|(f, _)| f.as_str()).collect();
format!(
"Consider splitting lock '{}' for fields: {:?}",
lock_name, fields
)
}
LockSuggestion::ReduceScope {
guard_name, reason, ..
} => {
format!("Reduce scope of guard '{}': {}", guard_name, reason)
}
LockSuggestion::UseRwLock {
lock_name,
read_count,
write_count,
..
} => {
format!(
"Consider RwLock for '{}' ({} reads, {} writes)",
lock_name, read_count, write_count
)
}
LockSuggestion::LockAcrossAwait { guard_name, .. } => {
format!("Lock '{}' is held across await point", guard_name)
}
LockSuggestion::RemoveLock {
lock_name, reason, ..
} => {
format!("Lock '{}' may be unnecessary: {}", lock_name, reason)
}
}
}
pub fn short_description(&self) -> String {
match self {
LockSuggestion::UseAtomic {
field,
suggested_type,
..
} => {
format!("Use {} for field '{}'", suggested_type, field)
}
LockSuggestion::SplitLock {
lock_name,
suggested_splits,
..
} => {
format!(
"Split '{}' into {} separate locks",
lock_name,
suggested_splits.len()
)
}
LockSuggestion::ReduceScope {
guard_name, reason, ..
} => {
format!("Reduce scope of '{}': {}", guard_name, reason)
}
LockSuggestion::UseRwLock {
lock_name,
read_count,
write_count,
..
} => {
format!(
"Use RwLock for '{}' ({} reads, {} writes)",
lock_name, read_count, write_count
)
}
LockSuggestion::LockAcrossAwait {
guard_name,
lock_line,
await_line,
..
} => {
format!(
"Lock '{}' held across await (lines {}-{})",
guard_name, lock_line, await_line
)
}
LockSuggestion::RemoveLock {
lock_name, reason, ..
} => {
format!("Remove unnecessary '{}': {}", lock_name, reason)
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LockAcquisitionV2 {
pub lock_var: VarId,
pub guard_var: VarId,
pub lock_type: LockType,
pub line: u32,
pub is_try: bool,
pub lock_name: String,
pub guard_name: String,
pub owner_fn: Option<SymbolId>,
}
impl LockAcquisitionV2 {
pub fn new(
lock_var: VarId,
guard_var: VarId,
lock_type: LockType,
line: u32,
lock_name: impl Into<String>,
guard_name: impl Into<String>,
) -> Self {
Self {
lock_var,
guard_var,
lock_type,
line,
is_try: false,
lock_name: lock_name.into(),
guard_name: guard_name.into(),
owner_fn: None,
}
}
pub fn with_owner_fn(mut self, owner: SymbolId) -> Self {
self.owner_fn = Some(owner);
self
}
pub fn with_try(mut self) -> Self {
self.is_try = true;
self
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FieldAccessV2 {
pub field_name: String,
pub access_kind: AccessKind,
pub line: u32,
}
impl FieldAccessV2 {
pub fn new(field_name: impl Into<String>, access_kind: AccessKind, line: u32) -> Self {
Self {
field_name: field_name.into(),
access_kind,
line,
}
}
}
#[derive(Debug, Clone)]
pub struct CriticalSectionV2 {
pub acquisition: LockAcquisitionV2,
pub start_line: u32,
pub end_line: Option<u32>,
pub field_accesses: Vec<FieldAccessV2>,
pub contains_expensive_ops: bool,
pub contains_await: bool,
}
impl CriticalSectionV2 {
pub fn new(acquisition: LockAcquisitionV2) -> Self {
let start_line = acquisition.line;
Self {
acquisition,
start_line,
end_line: None,
field_accesses: Vec::new(),
contains_expensive_ops: false,
contains_await: false,
}
}
pub fn end_at(&mut self, line: u32) {
self.end_line = Some(line);
}
pub fn add_field_access(&mut self, access: FieldAccessV2) {
self.field_accesses.push(access);
}
pub fn mark_expensive(&mut self) {
self.contains_expensive_ops = true;
}
pub fn mark_await(&mut self) {
self.contains_await = true;
}
pub fn unique_fields(&self) -> Vec<&str> {
let mut fields: Vec<&str> = self
.field_accesses
.iter()
.map(|a| a.field_name.as_str())
.collect();
fields.sort();
fields.dedup();
fields
}
pub fn field_access_kind(&self, field: &str) -> Option<AccessKind> {
self.field_accesses
.iter()
.filter(|a| a.field_name == field)
.map(|a| a.access_kind)
.reduce(|a, b| a.merge(b))
}
pub fn is_read_only(&self) -> bool {
self.field_accesses
.iter()
.all(|a| a.access_kind == AccessKind::Read)
}
pub fn span(&self) -> Option<u32> {
self.end_line.map(|end| end.saturating_sub(self.start_line))
}
}
#[derive(Debug, Clone, Default)]
pub struct LockTrackerV2 {
acquisitions: Vec<LockAcquisitionV2>,
active_sections: HashMap<VarId, CriticalSectionV2>,
completed_sections: Vec<CriticalSectionV2>,
}
impl LockTrackerV2 {
pub fn new() -> Self {
Self::default()
}
pub fn acquire(&mut self, acquisition: LockAcquisitionV2) {
let guard_var = acquisition.guard_var;
self.acquisitions.push(acquisition.clone());
self.active_sections
.insert(guard_var, CriticalSectionV2::new(acquisition));
}
pub fn record_field_access(
&mut self,
guard_var: VarId,
field_name: &str,
access_kind: AccessKind,
line: u32,
) {
if let Some(cs) = self.active_sections.get_mut(&guard_var) {
cs.add_field_access(FieldAccessV2::new(field_name, access_kind, line));
}
}
pub fn mark_expensive(&mut self, guard_var: VarId) {
if let Some(cs) = self.active_sections.get_mut(&guard_var) {
cs.mark_expensive();
}
}
pub fn mark_await(&mut self, guard_var: VarId, _await_line: u32) {
if let Some(cs) = self.active_sections.get_mut(&guard_var) {
cs.mark_await();
}
}
pub fn release(&mut self, guard_var: VarId, line: u32) {
if let Some(mut cs) = self.active_sections.remove(&guard_var) {
cs.end_at(line);
self.completed_sections.push(cs);
}
}
pub fn critical_sections(&self) -> &[CriticalSectionV2] {
&self.completed_sections
}
pub fn acquisitions(&self) -> &[LockAcquisitionV2] {
&self.acquisitions
}
pub fn acquisitions_by_owner(&self, owner: SymbolId) -> Vec<&LockAcquisitionV2> {
self.acquisitions
.iter()
.filter(|a| a.owner_fn == Some(owner))
.collect()
}
pub fn active_sections(&self) -> impl Iterator<Item = &CriticalSectionV2> {
self.active_sections.values()
}
pub fn is_active(&self, guard_var: VarId) -> bool {
self.active_sections.contains_key(&guard_var)
}
pub fn clear(&mut self) {
self.acquisitions.clear();
self.active_sections.clear();
self.completed_sections.clear();
}
pub fn completed_count(&self) -> usize {
self.completed_sections.len()
}
pub fn active_count(&self) -> usize {
self.active_sections.len()
}
pub fn flush_active_sections(&mut self) {
let active = std::mem::take(&mut self.active_sections);
for (_, cs) in active {
self.completed_sections.push(cs);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::symbol::SymbolId;
use slotmap::SlotMap;
struct TestVars {
symbols: SlotMap<SymbolId, &'static str>,
mapping: super::super::var_id::VarSymbolMapping,
}
impl TestVars {
fn new() -> Self {
Self {
symbols: SlotMap::with_key(),
mapping: super::super::var_id::VarSymbolMapping::new(),
}
}
fn var(&mut self, name: &'static str) -> VarId {
let sym = self.symbols.insert(name);
self.mapping.register(sym)
}
}
#[test]
fn test_lock_acquisition_v2() {
let mut vars = TestVars::new();
let lock = vars.var("lock");
let guard = vars.var("guard");
let acq = LockAcquisitionV2::new(lock, guard, LockType::Mutex, 10, "mutex", "guard");
assert_eq!(acq.lock_var, lock);
assert_eq!(acq.guard_var, guard);
assert_eq!(acq.lock_type, LockType::Mutex);
assert_eq!(acq.line, 10);
assert!(!acq.is_try);
let try_acq = acq.with_try();
assert!(try_acq.is_try);
}
#[test]
fn test_critical_section_field_tracking() {
let mut vars = TestVars::new();
let lock = vars.var("lock");
let guard = vars.var("guard");
let acq = LockAcquisitionV2::new(lock, guard, LockType::Mutex, 10, "mutex", "guard");
let mut cs = CriticalSectionV2::new(acq);
cs.add_field_access(FieldAccessV2::new("counter", AccessKind::Read, 11));
cs.add_field_access(FieldAccessV2::new("counter", AccessKind::Write, 12));
cs.add_field_access(FieldAccessV2::new("name", AccessKind::Read, 13));
let fields = cs.unique_fields();
assert_eq!(fields.len(), 2);
assert_eq!(cs.field_access_kind("counter"), Some(AccessKind::ReadWrite));
assert_eq!(cs.field_access_kind("name"), Some(AccessKind::Read));
}
#[test]
fn test_critical_section_is_read_only() {
let mut vars = TestVars::new();
let lock = vars.var("lock");
let guard = vars.var("guard");
let acq = LockAcquisitionV2::new(lock, guard, LockType::Mutex, 10, "mutex", "guard");
let mut cs = CriticalSectionV2::new(acq);
cs.add_field_access(FieldAccessV2::new("a", AccessKind::Read, 11));
cs.add_field_access(FieldAccessV2::new("b", AccessKind::Read, 12));
assert!(cs.is_read_only());
cs.add_field_access(FieldAccessV2::new("c", AccessKind::Write, 13));
assert!(!cs.is_read_only());
}
#[test]
fn test_lock_tracker_lifecycle() {
let mut tracker = LockTrackerV2::new();
let mut vars = TestVars::new();
let lock = vars.var("lock");
let guard = vars.var("guard");
let acq = LockAcquisitionV2::new(lock, guard, LockType::Mutex, 10, "mutex", "guard");
tracker.acquire(acq);
assert!(tracker.is_active(guard));
assert_eq!(tracker.acquisitions().len(), 1);
assert_eq!(tracker.active_count(), 1);
tracker.record_field_access(guard, "counter", AccessKind::Write, 11);
tracker.release(guard, 15);
assert!(!tracker.is_active(guard));
assert_eq!(tracker.critical_sections().len(), 1);
assert_eq!(tracker.completed_count(), 1);
assert_eq!(tracker.active_count(), 0);
let cs = &tracker.critical_sections()[0];
assert_eq!(cs.start_line, 10);
assert_eq!(cs.end_line, Some(15));
assert_eq!(cs.field_accesses.len(), 1);
}
#[test]
fn test_lock_tracker_mark_expensive_await() {
let mut tracker = LockTrackerV2::new();
let mut vars = TestVars::new();
let lock = vars.var("lock");
let guard = vars.var("guard");
let acq = LockAcquisitionV2::new(lock, guard, LockType::TokioMutex, 10, "mutex", "guard");
tracker.acquire(acq);
tracker.mark_expensive(guard);
tracker.mark_await(guard, 12);
tracker.release(guard, 15);
let cs = &tracker.critical_sections()[0];
assert!(cs.contains_expensive_ops);
assert!(cs.contains_await);
}
#[test]
fn test_lock_tracker_clear() {
let mut tracker = LockTrackerV2::new();
let mut vars = TestVars::new();
let lock = vars.var("lock");
let guard = vars.var("guard");
let acq = LockAcquisitionV2::new(lock, guard, LockType::Mutex, 10, "mutex", "guard");
tracker.acquire(acq);
tracker.release(guard, 15);
assert_eq!(tracker.completed_count(), 1);
tracker.clear();
assert_eq!(tracker.completed_count(), 0);
assert_eq!(tracker.acquisitions().len(), 0);
}
#[test]
fn test_flush_active_sections() {
let mut tracker = LockTrackerV2::new();
let mut vars = TestVars::new();
let lock1 = vars.var("lock1");
let guard1 = vars.var("guard1");
let lock2 = vars.var("lock2");
let guard2 = vars.var("guard2");
let acq1 = LockAcquisitionV2::new(lock1, guard1, LockType::Mutex, 0, "m1", "g1");
let acq2 = LockAcquisitionV2::new(lock2, guard2, LockType::RwLockRead, 0, "r1", "g2");
tracker.acquire(acq1);
tracker.acquire(acq2);
assert_eq!(tracker.active_count(), 2);
assert_eq!(tracker.completed_count(), 0);
tracker.flush_active_sections();
assert_eq!(tracker.active_count(), 0);
assert_eq!(tracker.completed_count(), 2);
assert!(tracker.critical_sections()[0].end_line.is_none());
assert!(tracker.critical_sections()[1].end_line.is_none());
}
#[test]
fn test_flush_preserves_acquisitions() {
let mut tracker = LockTrackerV2::new();
let mut vars = TestVars::new();
let lock = vars.var("lock");
let guard = vars.var("guard");
let acq = LockAcquisitionV2::new(lock, guard, LockType::Mutex, 0, "mutex", "guard");
tracker.acquire(acq);
tracker.flush_active_sections();
assert_eq!(tracker.acquisitions().len(), 1);
assert_eq!(tracker.completed_count(), 1);
}
}