use std::cmp::Reverse;
use super::{AccessKind, CriticalSectionV2, LockSuggestion, LockTrackerV2, LockType};
pub struct LockGranularityAnalyzerV2<'a> {
tracker: &'a LockTrackerV2,
}
impl<'a> LockGranularityAnalyzerV2<'a> {
pub fn new(tracker: &'a LockTrackerV2) -> Self {
Self { tracker }
}
pub fn analyze(&self) -> Vec<LockSuggestion> {
let mut suggestions = Vec::new();
for cs in self.tracker.critical_sections() {
suggestions.extend(self.analyze_critical_section(cs));
}
suggestions.sort_by_key(|b| Reverse(b.severity()));
suggestions
}
fn analyze_critical_section(&self, cs: &CriticalSectionV2) -> Vec<LockSuggestion> {
let mut suggestions = Vec::new();
if cs.contains_await {
suggestions.push(LockSuggestion::LockAcrossAwait {
guard_name: cs.acquisition.guard_name.clone(),
lock_line: cs.acquisition.line,
await_line: cs.start_line, });
}
suggestions.extend(self.check_atomic_opportunity(cs));
if let Some(split) = self.check_split_opportunity(cs) {
suggestions.push(split);
}
if let Some(rwlock) = self.check_rwlock_opportunity(cs) {
suggestions.push(rwlock);
}
if let Some(reduce) = self.check_scope_reduction(cs) {
suggestions.push(reduce);
}
suggestions
}
fn check_atomic_opportunity(&self, cs: &CriticalSectionV2) -> Vec<LockSuggestion> {
let mut suggestions = Vec::new();
let unique_fields = cs.unique_fields();
if unique_fields.len() == 1 {
let field = unique_fields[0];
let access_kind = cs.field_access_kind(field);
let suggested_type = self.suggest_atomic_type(field, access_kind);
if let Some(atomic_type) = suggested_type {
suggestions.push(LockSuggestion::UseAtomic {
field: field.to_string(),
current_type: None,
suggested_type: atomic_type,
line: cs.start_line,
});
}
}
suggestions
}
fn suggest_atomic_type(&self, field: &str, _access_kind: Option<AccessKind>) -> Option<String> {
let field_lower = field.to_lowercase();
if field_lower.contains("count")
|| field_lower.contains("counter")
|| field_lower.contains("num")
|| field_lower.contains("total")
|| field_lower.contains("size")
|| field_lower.contains("len")
{
return Some("AtomicUsize".to_string());
}
if field_lower.contains("flag")
|| field_lower.contains("enabled")
|| field_lower.contains("active")
|| field_lower.contains("ready")
|| field_lower.contains("done")
|| field_lower.contains("is_")
{
return Some("AtomicBool".to_string());
}
if field_lower.contains("id")
|| field_lower.contains("index")
|| field_lower.contains("seq")
{
return Some("AtomicU64".to_string());
}
None
}
fn check_split_opportunity(&self, cs: &CriticalSectionV2) -> Option<LockSuggestion> {
let unique_fields = cs.unique_fields();
if unique_fields.len() < 2 {
return None;
}
let mut suggested_splits = Vec::new();
for field in unique_fields {
let access_kind = cs.field_access_kind(field);
let wrapper = match access_kind {
Some(AccessKind::Read) => "Arc<RwLock<_>>".to_string(),
Some(AccessKind::Write) | Some(AccessKind::ReadWrite) => {
"Arc<Mutex<_>>".to_string()
}
None => continue,
};
suggested_splits.push((field.to_string(), wrapper));
}
if suggested_splits.len() >= 2 {
Some(LockSuggestion::SplitLock {
lock_name: cs.acquisition.lock_name.clone(),
suggested_splits,
line: cs.acquisition.line,
})
} else {
None
}
}
fn check_rwlock_opportunity(&self, cs: &CriticalSectionV2) -> Option<LockSuggestion> {
if !matches!(
cs.acquisition.lock_type,
LockType::Mutex | LockType::ParkingLotMutex | LockType::TokioMutex
) {
return None;
}
let mut read_count = 0;
let mut write_count = 0;
for access in &cs.field_accesses {
match access.access_kind {
AccessKind::Read => read_count += 1,
AccessKind::Write => write_count += 1,
AccessKind::ReadWrite => {
read_count += 1;
write_count += 1;
}
}
}
if read_count > write_count * 2 && read_count >= 3 {
Some(LockSuggestion::UseRwLock {
lock_name: cs.acquisition.lock_name.clone(),
read_count,
write_count,
line: cs.acquisition.line,
})
} else {
None
}
}
fn check_scope_reduction(&self, cs: &CriticalSectionV2) -> Option<LockSuggestion> {
let end_line = cs.end_line?;
let span = end_line.saturating_sub(cs.start_line);
if span > 5 && cs.contains_expensive_ops {
let first_access = cs.field_accesses.iter().map(|a| a.line).min()?;
let last_access = cs.field_accesses.iter().map(|a| a.line).max()?;
if first_access > cs.start_line + 2 || end_line > last_access + 2 {
return Some(LockSuggestion::ReduceScope {
guard_name: cs.acquisition.guard_name.clone(),
current_span: (cs.start_line, end_line),
suggested_span: (first_access.saturating_sub(1), last_access + 1),
reason: "lock held across non-critical operations".to_string(),
});
}
}
None
}
pub fn stats(&self) -> LockStatsV2 {
let sections = self.tracker.critical_sections();
let mut mutex_count = 0;
let mut rwlock_count = 0;
let mut refcell_count = 0;
let mut total_field_accesses = 0;
let mut max_cs_span = 0u32;
for cs in sections {
match cs.acquisition.lock_type {
LockType::Mutex | LockType::ParkingLotMutex | LockType::TokioMutex => {
mutex_count += 1
}
LockType::RwLockRead
| LockType::RwLockWrite
| LockType::ParkingLotRwLock
| LockType::TokioRwLock => rwlock_count += 1,
LockType::RefCell | LockType::RefCellMut => refcell_count += 1,
}
total_field_accesses += cs.field_accesses.len();
if let Some(span) = cs.span() {
max_cs_span = max_cs_span.max(span);
}
}
LockStatsV2 {
total_locks: sections.len(),
mutex_count,
rwlock_count,
refcell_count,
total_field_accesses,
max_cs_span,
}
}
pub fn tracker(&self) -> &LockTrackerV2 {
self.tracker
}
}
#[derive(Debug, Clone, Default)]
pub struct LockStatsV2 {
pub total_locks: usize,
pub mutex_count: usize,
pub rwlock_count: usize,
pub refcell_count: usize,
pub total_field_accesses: usize,
pub max_cs_span: u32,
}
#[cfg(test)]
mod tests {
use super::super::{LockAcquisitionV2, VarSymbolMapping};
use super::*;
use crate::symbol::SymbolId;
use crate::VarId;
use slotmap::SlotMap;
struct TestVars {
symbols: SlotMap<SymbolId, &'static str>,
mapping: VarSymbolMapping,
}
impl TestVars {
fn new() -> Self {
Self {
symbols: SlotMap::with_key(),
mapping: VarSymbolMapping::new(),
}
}
fn var(&mut self, name: &'static str) -> VarId {
let sym = self.symbols.insert(name);
self.mapping.register(sym)
}
}
#[test]
fn test_atomic_suggestion_counter() {
let mut tracker = LockTrackerV2::new();
let mut vars = TestVars::new();
let lock_var = vars.var("lock");
let guard_var = vars.var("guard");
tracker.acquire(LockAcquisitionV2::new(
lock_var,
guard_var,
LockType::Mutex,
10,
"mutex",
"guard",
));
tracker.record_field_access(guard_var, "counter", AccessKind::Write, 11);
tracker.release(guard_var, 15);
let analyzer = LockGranularityAnalyzerV2::new(&tracker);
let suggestions = analyzer.analyze();
assert!(suggestions
.iter()
.any(|s| matches!(s, LockSuggestion::UseAtomic { field, .. } if field == "counter")));
}
#[test]
fn test_rwlock_suggestion() {
let mut tracker = LockTrackerV2::new();
let mut vars = TestVars::new();
let lock_var = vars.var("lock");
let guard_var = vars.var("guard");
tracker.acquire(LockAcquisitionV2::new(
lock_var,
guard_var,
LockType::Mutex,
10,
"cache",
"guard",
));
tracker.record_field_access(guard_var, "data", AccessKind::Read, 11);
tracker.record_field_access(guard_var, "data", AccessKind::Read, 12);
tracker.record_field_access(guard_var, "data", AccessKind::Read, 13);
tracker.record_field_access(guard_var, "data", AccessKind::Read, 14);
tracker.record_field_access(guard_var, "data", AccessKind::Write, 15);
tracker.release(guard_var, 20);
let analyzer = LockGranularityAnalyzerV2::new(&tracker);
let suggestions = analyzer.analyze();
assert!(suggestions.iter().any(
|s| matches!(s, LockSuggestion::UseRwLock { read_count, write_count, .. }
if *read_count == 4 && *write_count == 1)
));
}
#[test]
fn test_split_lock_suggestion() {
let mut tracker = LockTrackerV2::new();
let mut vars = TestVars::new();
let lock_var = vars.var("lock");
let guard_var = vars.var("guard");
tracker.acquire(LockAcquisitionV2::new(
lock_var,
guard_var,
LockType::Mutex,
10,
"state",
"guard",
));
tracker.record_field_access(guard_var, "counter", AccessKind::Write, 11);
tracker.record_field_access(guard_var, "name", AccessKind::Read, 12);
tracker.record_field_access(guard_var, "config", AccessKind::Read, 13);
tracker.release(guard_var, 20);
let analyzer = LockGranularityAnalyzerV2::new(&tracker);
let suggestions = analyzer.analyze();
assert!(suggestions.iter().any(
|s| matches!(s, LockSuggestion::SplitLock { suggested_splits, .. }
if suggested_splits.len() == 3)
));
}
#[test]
fn test_lock_stats() {
let mut tracker = LockTrackerV2::new();
let mut vars = TestVars::new();
let lock1 = vars.var("lock1");
let guard1 = vars.var("guard1");
let lock2 = vars.var("lock2");
let guard2 = vars.var("guard2");
tracker.acquire(LockAcquisitionV2::new(
lock1,
guard1,
LockType::Mutex,
10,
"m1",
"g1",
));
tracker.record_field_access(guard1, "field1", AccessKind::Write, 11);
tracker.release(guard1, 15);
tracker.acquire(LockAcquisitionV2::new(
lock2,
guard2,
LockType::RwLockRead,
20,
"r1",
"g2",
));
tracker.record_field_access(guard2, "field2", AccessKind::Read, 21);
tracker.release(guard2, 25);
let analyzer = LockGranularityAnalyzerV2::new(&tracker);
let stats = analyzer.stats();
assert_eq!(stats.total_locks, 2);
assert_eq!(stats.mutex_count, 1);
assert_eq!(stats.rwlock_count, 1);
assert_eq!(stats.total_field_accesses, 2);
}
}