use num_traits::Float;
#[derive(Debug, Clone)]
pub struct PowerOptConfig {
pub max_optimized_exponent: i32,
pub unsafe_optimizations: bool,
pub optimize_negative_exponents: bool,
}
impl Default for PowerOptConfig {
fn default() -> Self {
Self {
max_optimized_exponent: 10,
unsafe_optimizations: false,
optimize_negative_exponents: true,
}
}
}
pub fn try_convert_to_integer<T: Float>(value: T, tolerance: Option<f64>) -> Option<i32> {
let float_val = value.to_f64().unwrap_or(0.0);
let tol = tolerance.unwrap_or(1e-12);
if float_val.fract().abs() < tol && float_val.abs() <= 100.0 {
Some(float_val.round() as i32)
} else {
None
}
}
#[must_use]
pub fn generate_integer_power_string(
base_expr: &str,
exponent: i32,
config: &PowerOptConfig,
) -> String {
if exponent.abs() > config.max_optimized_exponent {
return format!("{base_expr}.powi({exponent})");
}
match exponent {
0 => "1.0".to_string(),
1 => base_expr.to_string(),
-1 if config.optimize_negative_exponents => format!("1.0 / {base_expr}"),
2 => format!("{base_expr} * {base_expr}"),
-2 if config.optimize_negative_exponents => {
format!("1.0 / ({base_expr} * {base_expr})")
}
3 => format!("{base_expr} * {base_expr} * {base_expr}"),
4 => {
if config.unsafe_optimizations {
format!("{{ let temp = {base_expr} * {base_expr}; temp * temp }}")
} else {
format!("{base_expr} * {base_expr} * {base_expr} * {base_expr}")
}
}
5 => format!("{{ let temp = {base_expr} * {base_expr}; temp * temp * {base_expr} }}"),
6 => format!("{{ let temp = {base_expr} * {base_expr} * {base_expr}; temp * temp }}"),
8 => format!(
"{{ let temp2 = {base_expr} * {base_expr}; let temp4 = temp2 * temp2; temp4 * temp4 }}"
),
10 => format!(
"{{ let temp5 = {base_expr} * {base_expr} * {base_expr} * {base_expr} * {base_expr}; temp5 * temp5 }}"
),
exp if exp < 0 && config.optimize_negative_exponents => {
format!(
"1.0 / ({})",
generate_integer_power_string(base_expr, -exp, config)
)
}
_ => format!("{base_expr}.powi({exponent})"),
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PowerStrategy {
MultiplicationChain,
RepeatedSquaring,
Generic,
}
#[must_use]
pub fn determine_power_strategy(exponent: i32, config: &PowerOptConfig) -> PowerStrategy {
let abs_exp = exponent.abs();
if abs_exp <= 6 {
PowerStrategy::MultiplicationChain
} else if abs_exp <= config.max_optimized_exponent && (abs_exp & (abs_exp - 1)) == 0 {
PowerStrategy::RepeatedSquaring
} else {
PowerStrategy::Generic
}
}
#[must_use]
pub fn generate_repeated_squaring_string(base_expr: &str, exponent: i32) -> String {
let abs_exp = exponent.abs();
if abs_exp == 1 {
return base_expr.to_string();
}
let mut result = format!("{{ let mut temp = {base_expr}; ");
let mut current_power = 1;
while current_power < abs_exp {
result.push_str("temp = temp * temp; ");
current_power *= 2;
}
result.push_str("temp }");
if exponent < 0 {
format!("1.0 / ({result})")
} else {
result
}
}
#[derive(Debug, Clone)]
pub struct PowerOptimizationInfo {
pub strategy: PowerStrategy,
pub optimized: bool,
pub exponent: i32,
pub performance_gain: f64,
}
#[must_use]
pub fn analyze_power_optimization(exponent: i32, config: &PowerOptConfig) -> PowerOptimizationInfo {
let strategy = determine_power_strategy(exponent, config);
let abs_exp = exponent.abs();
let (optimized, performance_gain) = match strategy {
PowerStrategy::MultiplicationChain => (abs_exp <= 6, 2.0 - (f64::from(abs_exp) * 0.1)),
PowerStrategy::RepeatedSquaring => (abs_exp <= config.max_optimized_exponent, 1.5),
PowerStrategy::Generic => (false, 1.0),
};
PowerOptimizationInfo {
strategy,
optimized,
exponent,
performance_gain: performance_gain.max(1.0),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_try_convert_to_integer() {
assert_eq!(try_convert_to_integer(2.0_f64, None), Some(2));
assert_eq!(try_convert_to_integer(2.1_f64, None), None);
assert_eq!(try_convert_to_integer(-3.0_f64, None), Some(-3));
assert_eq!(try_convert_to_integer(0.0_f64, None), Some(0));
}
#[test]
fn test_generate_integer_power_string() {
let config = PowerOptConfig::default();
assert_eq!(generate_integer_power_string("x", 0, &config), "1.0");
assert_eq!(generate_integer_power_string("x", 1, &config), "x");
assert_eq!(generate_integer_power_string("x", 2, &config), "x * x");
assert_eq!(generate_integer_power_string("x", -1, &config), "1.0 / x");
let unsafe_config = PowerOptConfig {
unsafe_optimizations: true,
..Default::default()
};
let result = generate_integer_power_string("x", 4, &unsafe_config);
assert!(result.contains("let temp"));
}
#[test]
fn test_determine_power_strategy() {
let config = PowerOptConfig::default();
assert_eq!(
determine_power_strategy(2, &config),
PowerStrategy::MultiplicationChain
);
assert_eq!(
determine_power_strategy(8, &config),
PowerStrategy::RepeatedSquaring
);
assert_eq!(
determine_power_strategy(15, &config),
PowerStrategy::Generic
);
}
#[test]
fn test_analyze_power_optimization() {
let config = PowerOptConfig::default();
let info = analyze_power_optimization(2, &config);
assert!(info.optimized);
assert_eq!(info.strategy, PowerStrategy::MultiplicationChain);
assert!(info.performance_gain > 1.0);
let info = analyze_power_optimization(100, &config);
assert!(!info.optimized);
assert_eq!(info.strategy, PowerStrategy::Generic);
}
#[test]
fn test_repeated_squaring() {
let result = generate_repeated_squaring_string("x", 8);
assert!(result.contains("temp = temp * temp"));
let result = generate_repeated_squaring_string("x", -8);
assert!(result.starts_with("1.0 / ("));
}
}