use super::policies::Policies;
use super::state::PolicyState;
use super::types::{AggregationStrategy, Policy};
use crate::context::ResourceContext;
use std::collections::HashMap;
#[derive(Debug, Clone, Default)]
pub struct PolicyService;
impl PolicyService {
pub fn new() -> Self {
Self
}
pub fn aggregate_effects(
policies: &[&Policy],
strategies: &HashMap<String, AggregationStrategy>,
default_strategy: AggregationStrategy,
) -> HashMap<String, f32> {
let mut aggregated = HashMap::new();
for policy in policies {
for (key, value) in &policy.effects {
let strategy = strategies.get(key).copied().unwrap_or(default_strategy);
let current = aggregated
.get(key)
.copied()
.unwrap_or_else(|| strategy.initial_value());
let new_value = strategy.aggregate(current, *value);
aggregated.insert(key.clone(), new_value);
}
}
aggregated
}
pub fn get_effect(
effects: &HashMap<String, f32>,
effect_name: &str,
strategies: &HashMap<String, AggregationStrategy>,
default_strategy: AggregationStrategy,
) -> f32 {
if let Some(value) = effects.get(effect_name) {
return *value;
}
let strategy = strategies
.get(effect_name)
.copied()
.unwrap_or(default_strategy);
strategy.initial_value()
}
pub fn apply_effect(base_value: f32, effect_value: f32, strategy: AggregationStrategy) -> f32 {
strategy.aggregate(base_value, effect_value)
}
pub fn combine_values(values: &[f32], strategy: AggregationStrategy) -> f32 {
values.iter().fold(strategy.initial_value(), |acc, value| {
strategy.aggregate(acc, *value)
})
}
pub async fn get_active_effect(effect_name: &str, resources: &ResourceContext) -> f32 {
let state = match resources.get::<PolicyState>().await {
Some(s) => s,
None => return 1.0, };
let policies = match resources.get::<Policies>().await {
Some(p) => p,
None => return 1.0,
};
if let Some(active_id) = state.active_policy_id() {
if let Some(policy) = policies.get(active_id) {
return policy.effects.get(effect_name).copied().unwrap_or(1.0);
}
}
1.0 }
pub async fn get_active_effects(resources: &ResourceContext) -> HashMap<String, f32> {
let state = match resources.get::<PolicyState>().await {
Some(s) => s,
None => return HashMap::new(),
};
let policies_resource = match resources.get::<Policies>().await {
Some(p) => p,
None => return HashMap::new(),
};
let active_policies: Vec<&Policy> = if let Some(active_id) = state.active_policy_id() {
if let Some(policy) = policies_resource.get(active_id) {
vec![policy]
} else {
Vec::new()
}
} else {
Vec::new()
};
if active_policies.is_empty() {
return HashMap::new();
}
Self::aggregate_effects(
&active_policies,
&HashMap::new(),
AggregationStrategy::Multiply,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_policy(id: &str, effects: Vec<(&str, f32)>) -> Policy {
let mut policy = Policy::new(id, id, "Test policy");
for (key, value) in effects {
policy = policy.add_effect(key, value);
}
policy
}
#[test]
fn test_aggregate_effects_multiply() {
let policy1 = create_test_policy("p1", vec![("income_multiplier", 1.2)]);
let policy2 = create_test_policy("p2", vec![("income_multiplier", 1.1)]);
let policies = vec![&policy1, &policy2];
let mut strategies = HashMap::new();
strategies.insert("income_multiplier".into(), AggregationStrategy::Multiply);
let effects =
PolicyService::aggregate_effects(&policies, &strategies, AggregationStrategy::Multiply);
assert_eq!(effects.get("income_multiplier"), Some(&1.32)); }
#[test]
fn test_aggregate_effects_add() {
let policy1 = create_test_policy("p1", vec![("attack_bonus", 10.0)]);
let policy2 = create_test_policy("p2", vec![("attack_bonus", 5.0)]);
let policies = vec![&policy1, &policy2];
let mut strategies = HashMap::new();
strategies.insert("attack_bonus".into(), AggregationStrategy::Add);
let effects =
PolicyService::aggregate_effects(&policies, &strategies, AggregationStrategy::Multiply);
assert_eq!(effects.get("attack_bonus"), Some(&15.0)); }
#[test]
fn test_aggregate_effects_min() {
let policy1 = create_test_policy("p1", vec![("build_cost", 0.9)]);
let policy2 = create_test_policy("p2", vec![("build_cost", 0.8)]);
let policies = vec![&policy1, &policy2];
let mut strategies = HashMap::new();
strategies.insert("build_cost".into(), AggregationStrategy::Min);
let effects =
PolicyService::aggregate_effects(&policies, &strategies, AggregationStrategy::Multiply);
assert_eq!(effects.get("build_cost"), Some(&0.8)); }
#[test]
fn test_aggregate_effects_max() {
let policy1 = create_test_policy("p1", vec![("max_speed", 1.2)]);
let policy2 = create_test_policy("p2", vec![("max_speed", 1.1)]);
let policies = vec![&policy1, &policy2];
let mut strategies = HashMap::new();
strategies.insert("max_speed".into(), AggregationStrategy::Max);
let effects =
PolicyService::aggregate_effects(&policies, &strategies, AggregationStrategy::Multiply);
assert_eq!(effects.get("max_speed"), Some(&1.2)); }
#[test]
fn test_aggregate_effects_mixed_strategies() {
let policy1 = create_test_policy(
"p1",
vec![
("income_multiplier", 1.2),
("attack_bonus", 10.0),
("build_cost", 0.9),
],
);
let policy2 = create_test_policy(
"p2",
vec![
("income_multiplier", 1.1),
("attack_bonus", 5.0),
("build_cost", 0.8),
],
);
let policies = vec![&policy1, &policy2];
let mut strategies = HashMap::new();
strategies.insert("income_multiplier".into(), AggregationStrategy::Multiply);
strategies.insert("attack_bonus".into(), AggregationStrategy::Add);
strategies.insert("build_cost".into(), AggregationStrategy::Min);
let effects =
PolicyService::aggregate_effects(&policies, &strategies, AggregationStrategy::Multiply);
assert_eq!(effects.get("income_multiplier"), Some(&1.32)); assert_eq!(effects.get("attack_bonus"), Some(&15.0)); assert_eq!(effects.get("build_cost"), Some(&0.8)); }
#[test]
fn test_aggregate_effects_empty_policies() {
let policies: Vec<&Policy> = vec![];
let strategies = HashMap::new();
let effects =
PolicyService::aggregate_effects(&policies, &strategies, AggregationStrategy::Multiply);
assert!(effects.is_empty());
}
#[test]
fn test_get_effect_existing() {
let mut effects = HashMap::new();
effects.insert("income_multiplier".into(), 1.5);
let strategies = HashMap::new();
let value = PolicyService::get_effect(
&effects,
"income_multiplier",
&strategies,
AggregationStrategy::Multiply,
);
assert_eq!(value, 1.5);
}
#[test]
fn test_get_effect_fallback_multiply() {
let effects = HashMap::new();
let strategies = HashMap::new();
let value = PolicyService::get_effect(
&effects,
"income_multiplier",
&strategies,
AggregationStrategy::Multiply,
);
assert_eq!(value, 1.0); }
#[test]
fn test_get_effect_fallback_add() {
let effects = HashMap::new();
let mut strategies = HashMap::new();
strategies.insert("attack_bonus".into(), AggregationStrategy::Add);
let value = PolicyService::get_effect(
&effects,
"attack_bonus",
&strategies,
AggregationStrategy::Multiply,
);
assert_eq!(value, 0.0); }
#[test]
fn test_apply_effect_multiply() {
let result = PolicyService::apply_effect(100.0, 1.2, AggregationStrategy::Multiply);
assert!((result - 120.0).abs() < 0.001);
}
#[test]
fn test_apply_effect_add() {
let result = PolicyService::apply_effect(50.0, 10.0, AggregationStrategy::Add);
assert_eq!(result, 60.0);
}
#[test]
fn test_apply_effect_min() {
let result = PolicyService::apply_effect(0.9, 0.8, AggregationStrategy::Min);
assert_eq!(result, 0.8);
}
#[test]
fn test_apply_effect_max() {
let result = PolicyService::apply_effect(1.2, 1.5, AggregationStrategy::Max);
assert_eq!(result, 1.5);
}
#[test]
fn test_combine_values_multiply() {
let values = vec![1.2, 1.1, 1.15];
let result = PolicyService::combine_values(&values, AggregationStrategy::Multiply);
assert!((result - 1.518).abs() < 0.001); }
#[test]
fn test_combine_values_add() {
let values = vec![10.0, 5.0, 3.0];
let result = PolicyService::combine_values(&values, AggregationStrategy::Add);
assert_eq!(result, 18.0);
}
#[test]
fn test_combine_values_min() {
let values = vec![0.9, 0.8, 0.85];
let result = PolicyService::combine_values(&values, AggregationStrategy::Min);
assert_eq!(result, 0.8);
}
#[test]
fn test_combine_values_max() {
let values = vec![1.2, 1.5, 1.3];
let result = PolicyService::combine_values(&values, AggregationStrategy::Max);
assert_eq!(result, 1.5);
}
#[test]
fn test_combine_values_empty() {
let values: Vec<f32> = vec![];
let result = PolicyService::combine_values(&values, AggregationStrategy::Multiply);
assert_eq!(result, 1.0);
let result = PolicyService::combine_values(&values, AggregationStrategy::Add);
assert_eq!(result, 0.0);
}
}