ryo-mutations 0.1.0

[experimental] Code transformation primitives for Rust source code
Documentation
//! BoolSimplifyMutation: Simplify boolean comparisons
//!
//! Transforms:
//! - `x == true` → `x`
//! - `x == false` → `!x`
//! - `true == x` → `x`
//! - `false == x` → `!x`
//! - `x != true` → `!x`
//! - `x != false` → `x`
//!
//! Corresponds to Clippy lint: `clippy::bool_comparison`

use ryo_source::pure::{PureBlock, PureExpr, PureStmt};
use ryo_symbol::SymbolId;

use crate::Mutation;

/// Simplify boolean comparisons to idiomatic Rust
///
/// # Example
///
/// ```rust,ignore
/// use ryo_mutations::idiom::BoolSimplifyMutation;
///
/// let mutation = BoolSimplifyMutation::new();
/// // Transforms: if x == true { ... }
/// // Into:       if x { ... }
/// ```
#[derive(Debug, Clone, Default)]
pub struct BoolSimplifyMutation {
    /// Target function SymbolId. If None, applies to all functions.
    pub target_fn: Option<SymbolId>,
}

impl BoolSimplifyMutation {
    pub fn new() -> Self {
        Self::default()
    }

    /// Only apply in a specific function
    pub fn in_function(mut self, id: SymbolId) -> Self {
        self.target_fn = Some(id);
        self
    }

    /// Check if expression is a boolean literal
    fn is_bool_literal(expr: &PureExpr) -> Option<bool> {
        match expr {
            PureExpr::Lit(lit) => match lit.as_str() {
                "true" => Some(true),
                "false" => Some(false),
                _ => None,
            },
            PureExpr::Path(path) => match path.as_str() {
                "true" => Some(true),
                "false" => Some(false),
                _ => None,
            },
            _ => None,
        }
    }

    /// Transform an expression, returns (new_expr, changes_count)
    fn transform_expr(&self, expr: &mut PureExpr) -> usize {
        let mut changes = 0;

        // Check for comparison patterns: x == true, x == false, etc.
        if let PureExpr::Binary { op, left, right } = expr {
            let is_eq = op == "==";
            let is_neq = op == "!=";

            if is_eq || is_neq {
                // Check left side for bool literal
                if let Some(bool_val) = Self::is_bool_literal(left) {
                    // Transform: true == x, false == x, true != x, false != x
                    let other = std::mem::replace(
                        right.as_mut(),
                        PureExpr::Path("__placeholder".to_string()),
                    );
                    *expr = Self::simplify(other, bool_val, is_eq);
                    return 1;
                }

                // Check right side for bool literal
                if let Some(bool_val) = Self::is_bool_literal(right) {
                    // Transform: x == true, x == false, x != true, x != false
                    let other = std::mem::replace(
                        left.as_mut(),
                        PureExpr::Path("__placeholder".to_string()),
                    );
                    *expr = Self::simplify(other, bool_val, is_eq);
                    return 1;
                }
            }
        }

        // Recursively transform sub-expressions
        match expr {
            PureExpr::Binary { left, right, .. } => {
                changes += self.transform_expr(left);
                changes += self.transform_expr(right);
            }
            PureExpr::Unary { expr: inner, .. } => {
                changes += self.transform_expr(inner);
            }
            PureExpr::Call { func, args } => {
                changes += self.transform_expr(func);
                for arg in args {
                    changes += self.transform_expr(arg);
                }
            }
            PureExpr::MethodCall { receiver, args, .. } => {
                changes += self.transform_expr(receiver);
                for arg in args {
                    changes += self.transform_expr(arg);
                }
            }
            PureExpr::Field { expr: inner, .. } => {
                changes += self.transform_expr(inner);
            }
            PureExpr::Index { expr: inner, index } => {
                changes += self.transform_expr(inner);
                changes += self.transform_expr(index);
            }
            PureExpr::Block { block, .. } => {
                changes += self.transform_block(block);
            }
            PureExpr::If {
                cond,
                then_branch,
                else_branch,
            } => {
                changes += self.transform_expr(cond);
                changes += self.transform_block(then_branch);
                if let Some(else_expr) = else_branch {
                    changes += self.transform_expr(else_expr);
                }
            }
            PureExpr::Match { expr: e, arms } => {
                changes += self.transform_expr(e);
                for arm in arms {
                    changes += self.transform_expr(&mut arm.body);
                }
            }
            PureExpr::Loop { body: block, .. } | PureExpr::While { body: block, .. } => {
                changes += self.transform_block(block);
            }
            PureExpr::For {
                expr: iter_expr,
                body,
                ..
            } => {
                changes += self.transform_expr(iter_expr);
                changes += self.transform_block(body);
            }
            PureExpr::Closure { body, .. } => {
                changes += self.transform_expr(body);
            }
            PureExpr::Tuple(exprs) | PureExpr::Array(exprs) => {
                for e in exprs {
                    changes += self.transform_expr(e);
                }
            }
            PureExpr::Struct { fields, .. } => {
                for (_, e) in fields {
                    changes += self.transform_expr(e);
                }
            }
            PureExpr::Ref { expr: inner, .. } => {
                changes += self.transform_expr(inner);
            }
            PureExpr::Return(Some(inner)) => {
                changes += self.transform_expr(inner);
            }
            PureExpr::Try(inner) | PureExpr::Await(inner) => {
                changes += self.transform_expr(inner);
            }
            _ => {}
        }

        changes
    }

    /// Simplify a boolean comparison
    fn simplify(expr: PureExpr, bool_val: bool, is_eq: bool) -> PureExpr {
        // x == true => x
        // x == false => !x
        // x != true => !x
        // x != false => x
        let need_negate = (is_eq && !bool_val) || (!is_eq && bool_val);

        if need_negate {
            PureExpr::Unary {
                op: "!".to_string(),
                expr: Box::new(expr),
            }
        } else {
            expr
        }
    }

    pub fn transform_block(&self, block: &mut PureBlock) -> usize {
        let mut changes = 0;
        for stmt in &mut block.stmts {
            changes += self.transform_stmt(stmt);
        }
        changes
    }

    fn transform_stmt(&self, stmt: &mut PureStmt) -> usize {
        match stmt {
            PureStmt::Local { init: Some(e), .. } => self.transform_expr(e),
            PureStmt::Semi(e) | PureStmt::Expr(e) => self.transform_expr(e),
            _ => 0,
        }
    }
}

impl Mutation for BoolSimplifyMutation {
    fn describe(&self) -> String {
        "Simplify boolean comparisons (x == true → x)".to_string()
    }

    fn mutation_type(&self) -> &'static str {
        "BoolSimplify"
    }

    fn box_clone(&self) -> Box<dyn Mutation> {
        Box::new(self.clone())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_is_bool_literal() {
        assert_eq!(
            BoolSimplifyMutation::is_bool_literal(&PureExpr::Lit("true".to_string())),
            Some(true)
        );
        assert_eq!(
            BoolSimplifyMutation::is_bool_literal(&PureExpr::Lit("false".to_string())),
            Some(false)
        );
        assert_eq!(
            BoolSimplifyMutation::is_bool_literal(&PureExpr::Path("true".to_string())),
            Some(true)
        );
        assert_eq!(
            BoolSimplifyMutation::is_bool_literal(&PureExpr::Lit("42".to_string())),
            None
        );
    }

    #[test]
    fn test_simplify_eq_true() {
        // x == true => x
        let expr = PureExpr::Path("x".to_string());
        let result = BoolSimplifyMutation::simplify(expr, true, true);
        assert!(matches!(result, PureExpr::Path(s) if s == "x"));
    }

    #[test]
    fn test_simplify_eq_false() {
        // x == false => !x
        let expr = PureExpr::Path("x".to_string());
        let result = BoolSimplifyMutation::simplify(expr, false, true);
        assert!(matches!(result, PureExpr::Unary { op, .. } if op == "!"));
    }

    #[test]
    fn test_simplify_neq_true() {
        // x != true => !x
        let expr = PureExpr::Path("x".to_string());
        let result = BoolSimplifyMutation::simplify(expr, true, false);
        assert!(matches!(result, PureExpr::Unary { op, .. } if op == "!"));
    }

    #[test]
    fn test_simplify_neq_false() {
        // x != false => x
        let expr = PureExpr::Path("x".to_string());
        let result = BoolSimplifyMutation::simplify(expr, false, false);
        assert!(matches!(result, PureExpr::Path(s) if s == "x"));
    }
}