use crate::context::StatContext;
use crate::error::StatError;
use crate::stat_id::StatId;
use std::collections::HashMap;
pub trait StatTransform: Send + Sync {
fn depends_on(&self) -> Vec<StatId>;
fn apply(
&self,
input: f64,
dependencies: &HashMap<StatId, f64>,
context: &StatContext,
) -> Result<f64, StatError>;
fn description(&self) -> String;
}
#[derive(Debug, Clone)]
pub struct MultiplicativeTransform {
multiplier: f64,
}
impl MultiplicativeTransform {
pub fn new(multiplier: f64) -> Self {
Self { multiplier }
}
}
impl StatTransform for MultiplicativeTransform {
fn depends_on(&self) -> Vec<StatId> {
Vec::new()
}
fn apply(
&self,
input: f64,
_dependencies: &HashMap<StatId, f64>,
_context: &StatContext,
) -> Result<f64, StatError> {
Ok(input * self.multiplier)
}
fn description(&self) -> String {
format!("×{:.2}", self.multiplier)
}
}
#[derive(Debug, Clone)]
pub struct AdditiveTransform {
bonus: f64,
}
impl AdditiveTransform {
pub fn new(bonus: f64) -> Self {
Self { bonus }
}
}
impl StatTransform for AdditiveTransform {
fn depends_on(&self) -> Vec<StatId> {
Vec::new()
}
fn apply(
&self,
input: f64,
_dependencies: &HashMap<StatId, f64>,
_context: &StatContext,
) -> Result<f64, StatError> {
Ok(input + self.bonus)
}
fn description(&self) -> String {
format!("+{:.2}", self.bonus)
}
}
#[derive(Debug, Clone)]
pub struct ClampTransform {
min: f64,
max: f64,
}
impl ClampTransform {
pub fn new(min: f64, max: f64) -> Self {
Self { min, max }
}
}
impl StatTransform for ClampTransform {
fn depends_on(&self) -> Vec<StatId> {
Vec::new()
}
fn apply(
&self,
input: f64,
_dependencies: &HashMap<StatId, f64>,
_context: &StatContext,
) -> Result<f64, StatError> {
Ok(input.clamp(self.min, self.max))
}
fn description(&self) -> String {
format!("clamp({:.2}, {:.2})", self.min, self.max)
}
}
pub struct ConditionalTransform {
condition: Box<dyn Fn(&StatContext) -> bool + Send + Sync>,
transform: Box<dyn StatTransform>,
description: String,
}
impl ConditionalTransform {
pub fn new<F>(
condition: F,
transform: Box<dyn StatTransform>,
description: impl Into<String>,
) -> Self
where
F: Fn(&StatContext) -> bool + Send + Sync + 'static,
{
Self {
condition: Box::new(condition),
transform,
description: description.into(),
}
}
}
impl StatTransform for ConditionalTransform {
fn depends_on(&self) -> Vec<StatId> {
self.transform.depends_on()
}
fn apply(
&self,
input: f64,
dependencies: &HashMap<StatId, f64>,
context: &StatContext,
) -> Result<f64, StatError> {
if (self.condition)(context) {
self.transform.apply(input, dependencies, context)
} else {
Ok(input)
}
}
fn description(&self) -> String {
self.description.clone()
}
}
#[derive(Debug, Clone)]
pub struct ScalingTransform {
dependency: StatId,
scale_factor: f64,
}
impl ScalingTransform {
pub fn new(dependency: StatId, scale_factor: f64) -> Self {
Self {
dependency,
scale_factor,
}
}
}
impl StatTransform for ScalingTransform {
fn depends_on(&self) -> Vec<StatId> {
vec![self.dependency.clone()]
}
fn apply(
&self,
input: f64,
dependencies: &HashMap<StatId, f64>,
_context: &StatContext,
) -> Result<f64, StatError> {
let dep_value = dependencies
.get(&self.dependency)
.ok_or_else(|| StatError::MissingDependency(self.dependency.clone()))?;
Ok(input + (dep_value * self.scale_factor))
}
fn description(&self) -> String {
format!("scale({}, {:.2})", self.dependency, self.scale_factor)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_multiplicative_transform() {
let transform = MultiplicativeTransform::new(1.5);
let context = StatContext::new();
let deps = HashMap::new();
assert_eq!(transform.apply(100.0, &deps, &context).unwrap(), 150.0);
}
#[test]
fn test_additive_transform() {
let transform = AdditiveTransform::new(25.0);
let context = StatContext::new();
let deps = HashMap::new();
assert_eq!(transform.apply(100.0, &deps, &context).unwrap(), 125.0);
}
#[test]
fn test_clamp_transform() {
let transform = ClampTransform::new(0.0, 100.0);
let context = StatContext::new();
let deps = HashMap::new();
assert_eq!(transform.apply(150.0, &deps, &context).unwrap(), 100.0);
assert_eq!(transform.apply(-10.0, &deps, &context).unwrap(), 0.0);
assert_eq!(transform.apply(50.0, &deps, &context).unwrap(), 50.0);
}
#[test]
fn test_scaling_transform() {
let str_id = StatId::from_str("STR");
let transform = ScalingTransform::new(str_id.clone(), 2.0);
let context = StatContext::new();
let mut deps = HashMap::new();
deps.insert(str_id.clone(), 10.0);
assert_eq!(transform.depends_on(), vec![str_id]);
assert_eq!(transform.apply(100.0, &deps, &context).unwrap(), 120.0);
}
#[test]
fn test_scaling_transform_missing_dependency() {
let str_id = StatId::from_str("STR");
let transform = ScalingTransform::new(str_id, 2.0);
let context = StatContext::new();
let deps = HashMap::new();
assert!(transform.apply(100.0, &deps, &context).is_err());
}
#[test]
fn test_conditional_transform() {
let mut context = StatContext::new();
context.set("in_combat", true);
let inner_transform = Box::new(MultiplicativeTransform::new(1.2));
let transform = ConditionalTransform::new(
|ctx| ctx.get::<bool>("in_combat").unwrap_or(false),
inner_transform,
"combat bonus",
);
let deps = HashMap::new();
assert_eq!(transform.apply(100.0, &deps, &context).unwrap(), 120.0);
context.set("in_combat", false);
assert_eq!(transform.apply(100.0, &deps, &context).unwrap(), 100.0);
}
}