Skip to main content

ryo_mutations/idiom/
lock.rs

1//! Lock Optimization Mutations
2//!
3//! Performance and safety mutations for lock usage:
4//!
5//! - `UseAtomicMutation`: Replace Mutex with atomic types for simple fields
6//! - `UseRwLockMutation`: Replace Mutex with RwLock for read-heavy access
7//! - `LockScopeMutation`: Detect locks held across await points
8//!
9//! # Note
10//!
11//! These mutations currently only detect opportunities. The actual refactoring
12//! is not yet implemented (TODO).
13
14use ryo_source::pure::{PureFields, PureFile, PureItem, PureType};
15use ryo_symbol::SymbolId;
16
17use super::detect::{Detect, DetectCategory, DetectLocation, DetectOperation, DetectOpportunity};
18use crate::Mutation;
19
20// ============================================================================
21// UseAtomicMutation
22// ============================================================================
23
24/// Suggests replacing Mutex<T> with atomic types for simple counter/flag fields.
25///
26/// Detection: Fields protected by Mutex that could use AtomicUsize, AtomicBool, etc.
27///
28/// # Example
29///
30/// ```rust,ignore
31/// // Before
32/// struct Counter {
33///     count: Mutex<usize>,
34///     is_ready: Mutex<bool>,
35/// }
36///
37/// // After (suggested)
38/// struct Counter {
39///     count: AtomicUsize,
40///     is_ready: AtomicBool,
41/// }
42/// ```
43#[derive(Debug, Clone, Default)]
44pub struct UseAtomicMutation {
45    /// Only apply to specific struct
46    pub target_struct: Option<String>,
47}
48
49impl UseAtomicMutation {
50    pub fn new() -> Self {
51        Self::default()
52    }
53
54    /// Only apply to a specific struct
55    pub fn for_struct(mut self, name: impl Into<String>) -> Self {
56        self.target_struct = Some(name.into());
57        self
58    }
59
60    /// Check if field name suggests atomic usage
61    fn is_atomic_candidate(field_name: &str) -> Option<&'static str> {
62        let lower = field_name.to_lowercase();
63
64        // Counter-like
65        if lower.contains("count")
66            || lower.contains("counter")
67            || lower.contains("num")
68            || lower.contains("total")
69            || lower.contains("size")
70            || lower.contains("len")
71        {
72            return Some("AtomicUsize");
73        }
74
75        // Flag-like
76        if lower.contains("flag")
77            || lower.contains("enabled")
78            || lower.contains("active")
79            || lower.contains("ready")
80            || lower.contains("done")
81            || lower.starts_with("is_")
82        {
83            return Some("AtomicBool");
84        }
85
86        // ID-like
87        if lower.contains("id") || lower.contains("index") || lower.contains("seq") {
88            return Some("AtomicU64");
89        }
90
91        None
92    }
93
94    /// Detect atomic opportunities in a file
95    fn detect_opportunities(&self, file: &PureFile) -> Vec<AtomicOpportunity> {
96        let mut opportunities = Vec::new();
97
98        for item in &file.items {
99            if let PureItem::Struct(s) = item {
100                // Apply target filter
101                if let Some(ref target) = self.target_struct {
102                    if &s.name != target {
103                        continue;
104                    }
105                }
106
107                if let PureFields::Named(fields) = &s.fields {
108                    for field in fields {
109                        let type_str = match &field.ty {
110                            PureType::Path(p) => p.as_str(),
111                            _ => continue,
112                        };
113
114                        if type_str.contains("Mutex<") {
115                            if let Some(atomic_type) = Self::is_atomic_candidate(&field.name) {
116                                opportunities.push(AtomicOpportunity {
117                                    struct_name: s.name.clone(),
118                                    field_name: field.name.clone(),
119                                    suggested_type: atomic_type.to_string(),
120                                });
121                            }
122                        }
123                    }
124                }
125            }
126        }
127
128        opportunities
129    }
130}
131
132#[derive(Debug)]
133struct AtomicOpportunity {
134    struct_name: String,
135    field_name: String,
136    suggested_type: String,
137}
138
139impl Mutation for UseAtomicMutation {
140    fn describe(&self) -> String {
141        "Replace Mutex<T> with atomic types for simple counter/flag fields".to_string()
142    }
143
144    fn mutation_type(&self) -> &'static str {
145        "UseAtomic"
146    }
147
148    fn box_clone(&self) -> Box<dyn Mutation> {
149        Box::new(self.clone())
150    }
151}
152
153impl Detect for UseAtomicMutation {
154    fn detect(&self, file: &PureFile) -> Vec<DetectOpportunity> {
155        self.detect_opportunities(file)
156            .into_iter()
157            .map(|o| {
158                DetectOpportunity::new(
159                    DetectLocation::struct_item(&o.struct_name),
160                    format!(
161                        "Consider using {} for field '{}' instead of Mutex",
162                        o.suggested_type, o.field_name
163                    ),
164                )
165                .with_operations(vec![DetectOperation::Refactor])
166                .with_confidence(0.7)
167                .with_context(format!(
168                    "field:{},suggested:{}",
169                    o.field_name, o.suggested_type
170                ))
171            })
172            .collect()
173    }
174
175    fn category(&self) -> DetectCategory {
176        DetectCategory::Performance
177    }
178
179    fn detect_name(&self) -> &'static str {
180        "UseAtomic"
181    }
182
183    fn detect_description(&self) -> &str {
184        "Replace Mutex<T> with atomic types for simple counter/flag fields"
185    }
186}
187
188// ============================================================================
189// UseRwLockMutation
190// ============================================================================
191
192/// Suggests replacing Mutex with RwLock for read-heavy access patterns.
193///
194/// Detection: Mutex fields with collections (HashMap, Vec, etc.) that are
195/// typically read more often than written.
196///
197/// # Example
198///
199/// ```rust,ignore
200/// // Before
201/// struct Cache {
202///     data: Mutex<HashMap<String, Value>>,
203/// }
204///
205/// // After (suggested)
206/// struct Cache {
207///     data: RwLock<HashMap<String, Value>>,
208/// }
209/// ```
210#[derive(Debug, Clone, Default)]
211pub struct UseRwLockMutation {
212    /// Only apply to specific struct
213    pub target_struct: Option<String>,
214}
215
216impl UseRwLockMutation {
217    pub fn new() -> Self {
218        Self::default()
219    }
220
221    /// Only apply to a specific struct
222    pub fn for_struct(mut self, name: impl Into<String>) -> Self {
223        self.target_struct = Some(name.into());
224        self
225    }
226
227    /// Detect RwLock opportunities in a file
228    fn detect_opportunities(&self, file: &PureFile) -> Vec<RwLockOpportunity> {
229        let mut opportunities = Vec::new();
230
231        for item in &file.items {
232            if let PureItem::Struct(s) = item {
233                // Apply target filter
234                if let Some(ref target) = self.target_struct {
235                    if &s.name != target {
236                        continue;
237                    }
238                }
239
240                if let PureFields::Named(fields) = &s.fields {
241                    for field in fields {
242                        let type_str = match &field.ty {
243                            PureType::Path(p) => p.as_str(),
244                            _ => continue,
245                        };
246
247                        // Check for Mutex<Collection> patterns
248                        if type_str.contains("Mutex<")
249                            && (type_str.contains("HashMap")
250                                || type_str.contains("BTreeMap")
251                                || type_str.contains("Vec<")
252                                || type_str.contains("HashSet")
253                                || field.name.to_lowercase().contains("cache")
254                                || field.name.to_lowercase().contains("registry")
255                                || field.name.to_lowercase().contains("store"))
256                        {
257                            opportunities.push(RwLockOpportunity {
258                                struct_name: s.name.clone(),
259                                field_name: field.name.clone(),
260                            });
261                        }
262                    }
263                }
264            }
265        }
266
267        opportunities
268    }
269}
270
271#[derive(Debug)]
272struct RwLockOpportunity {
273    struct_name: String,
274    field_name: String,
275}
276
277impl Mutation for UseRwLockMutation {
278    fn describe(&self) -> String {
279        "Replace Mutex with RwLock for read-heavy data structures".to_string()
280    }
281
282    fn mutation_type(&self) -> &'static str {
283        "UseRwLock"
284    }
285
286    fn box_clone(&self) -> Box<dyn Mutation> {
287        Box::new(self.clone())
288    }
289}
290
291impl Detect for UseRwLockMutation {
292    fn detect(&self, file: &PureFile) -> Vec<DetectOpportunity> {
293        self.detect_opportunities(file)
294            .into_iter()
295            .map(|o| {
296                DetectOpportunity::new(
297                    DetectLocation::struct_item(&o.struct_name),
298                    format!(
299                        "Consider using RwLock for field '{}' if reads outnumber writes",
300                        o.field_name
301                    ),
302                )
303                .with_operations(vec![DetectOperation::Refactor])
304                .with_confidence(0.5)
305                .with_context(format!("field:{}", o.field_name))
306            })
307            .collect()
308    }
309
310    fn category(&self) -> DetectCategory {
311        DetectCategory::Performance
312    }
313
314    fn detect_name(&self) -> &'static str {
315        "UseRwLock"
316    }
317
318    fn detect_description(&self) -> &str {
319        "Replace Mutex with RwLock for read-heavy data structures"
320    }
321}
322
323// ============================================================================
324// LockScopeMutation
325// ============================================================================
326
327/// Detects locks held across await points or with unnecessarily wide scope.
328///
329/// This is a safety pattern that helps prevent deadlocks and improves
330/// concurrency by reducing lock hold times.
331///
332/// # Example
333///
334/// ```rust,ignore
335/// // Problematic: lock held across await
336/// async fn bad() {
337///     let guard = self.data.lock().unwrap();
338///     some_async_operation().await; // Lock still held!
339///     drop(guard);
340/// }
341///
342/// // Better: release lock before await
343/// async fn good() {
344///     let value = {
345///         let guard = self.data.lock().unwrap();
346///         guard.clone()
347///     }; // Lock released
348///     some_async_operation().await;
349/// }
350/// ```
351#[derive(Debug, Clone, Default)]
352pub struct LockScopeMutation {
353    /// Target function SymbolId. If None, applies to all functions.
354    pub target_fn: Option<SymbolId>,
355}
356
357impl LockScopeMutation {
358    pub fn new() -> Self {
359        Self::default()
360    }
361
362    /// Only apply to a specific function
363    pub fn for_fn(mut self, id: SymbolId) -> Self {
364        self.target_fn = Some(id);
365        self
366    }
367
368    /// Detect lock scope issues in a file
369    fn detect_opportunities(&self, file: &PureFile) -> Vec<LockScopeOpportunity> {
370        let mut opportunities = Vec::new();
371
372        for item in &file.items {
373            if let PureItem::Impl(impl_block) = item {
374                for impl_item in &impl_block.items {
375                    if let ryo_source::pure::PureImplItem::Fn(func) = impl_item {
376                        // Note: target_fn filtering requires SymbolId comparison at executor layer.
377                        // This method is called from executor with pre-filtered functions.
378
379                        if func.is_async {
380                            // Heuristic: check if body contains .lock() and .await
381                            let body_str = format!("{:?}", func.body);
382
383                            if body_str.contains("lock()")
384                                && (body_str.contains(".await") || body_str.contains("await"))
385                            {
386                                opportunities.push(LockScopeOpportunity {
387                                    impl_type: impl_block.self_ty.clone(),
388                                    fn_name: func.name.clone(),
389                                    issue: "lock_across_await".to_string(),
390                                });
391                            }
392                        }
393                    }
394                }
395            }
396        }
397
398        opportunities
399    }
400}
401
402#[derive(Debug)]
403struct LockScopeOpportunity {
404    impl_type: String,
405    fn_name: String,
406    issue: String,
407}
408
409impl Mutation for LockScopeMutation {
410    fn describe(&self) -> String {
411        "Detect locks held across await points or with unnecessarily wide scope".to_string()
412    }
413
414    fn mutation_type(&self) -> &'static str {
415        "LockScope"
416    }
417
418    fn box_clone(&self) -> Box<dyn Mutation> {
419        Box::new(self.clone())
420    }
421}
422
423impl Detect for LockScopeMutation {
424    fn detect(&self, file: &PureFile) -> Vec<DetectOpportunity> {
425        self.detect_opportunities(file)
426            .into_iter()
427            .map(|o| {
428                DetectOpportunity::new(
429                    DetectLocation::fn_item(&o.fn_name),
430                    format!(
431                        "Async method '{}::{}' may hold lock across await point",
432                        o.impl_type, o.fn_name
433                    ),
434                )
435                .with_operations(vec![DetectOperation::Refactor])
436                .with_confidence(0.6)
437                .with_context(o.issue)
438            })
439            .collect()
440    }
441
442    fn category(&self) -> DetectCategory {
443        DetectCategory::Safety
444    }
445
446    fn detect_name(&self) -> &'static str {
447        "LockScope"
448    }
449
450    fn detect_description(&self) -> &str {
451        "Detect locks held across await points or with unnecessarily wide scope"
452    }
453}