use crate::context::StatContext;
use crate::error::StatError;
use crate::graph::StatGraph;
use crate::resolved::ResolvedStat;
use crate::source::StatSource;
use crate::stat_id::StatId;
use crate::transform::StatTransform;
use std::collections::HashMap;
pub struct StatResolver {
sources: HashMap<StatId, Vec<Box<dyn StatSource>>>,
transforms: HashMap<StatId, Vec<Box<dyn StatTransform>>>,
cache: HashMap<StatId, ResolvedStat>,
}
impl StatResolver {
pub fn new() -> Self {
Self {
sources: HashMap::new(),
transforms: HashMap::new(),
cache: HashMap::new(),
}
}
pub fn register_source(&mut self, stat_id: StatId, source: Box<dyn StatSource>) {
let stat_id_clone = stat_id.clone();
self.sources.entry(stat_id).or_default().push(source);
self.cache.remove(&stat_id_clone);
}
pub fn register_transform(&mut self, stat_id: StatId, transform: Box<dyn StatTransform>) {
let stat_id_clone = stat_id.clone();
self.transforms.entry(stat_id).or_default().push(transform);
self.cache.remove(&stat_id_clone);
}
pub fn resolve(
&mut self,
stat_id: &StatId,
context: &StatContext,
) -> Result<ResolvedStat, StatError> {
if let Some(cached) = self.cache.get(stat_id) {
return Ok(cached.clone());
}
let graph = self.build_graph()?;
let resolution_order = graph.topological_sort()?;
for stat_to_resolve in &resolution_order {
if self.cache.contains_key(stat_to_resolve) {
continue; }
let resolved = self.resolve_stat_internal(stat_to_resolve, context, &graph)?;
self.cache.insert(stat_to_resolve.clone(), resolved);
}
self.cache
.get(stat_id)
.cloned()
.ok_or_else(|| StatError::MissingSource(stat_id.clone()))
}
pub fn resolve_all(
&mut self,
context: &StatContext,
) -> Result<HashMap<StatId, ResolvedStat>, StatError> {
let graph = self.build_graph()?;
let resolution_order = graph.topological_sort()?;
for stat_id in &resolution_order {
if !self.cache.contains_key(stat_id) {
let resolved = self.resolve_stat_internal(stat_id, context, &graph)?;
self.cache.insert(stat_id.clone(), resolved);
}
}
Ok(self.cache.clone())
}
pub fn invalidate(&mut self, stat_id: &StatId) {
self.cache.remove(stat_id);
}
pub fn invalidate_all(&mut self) {
self.cache.clear();
}
pub fn get_breakdown(&self, stat_id: &StatId) -> Option<&ResolvedStat> {
self.cache.get(stat_id)
}
fn build_graph(&self) -> Result<StatGraph, StatError> {
let mut graph = StatGraph::new();
for stat_id in self.sources.keys().chain(self.transforms.keys()) {
graph.add_node(stat_id.clone());
}
for (stat_id, transforms) in &self.transforms {
for transform in transforms {
for dep in transform.depends_on() {
graph.add_edge(stat_id.clone(), dep);
}
}
}
Ok(graph)
}
fn resolve_stat_internal(
&self,
stat_id: &StatId,
context: &StatContext,
_graph: &StatGraph,
) -> Result<ResolvedStat, StatError> {
let mut resolved = ResolvedStat::new(stat_id.clone(), 0.0);
let mut base_value = 0.0;
if let Some(sources) = self.sources.get(stat_id) {
for (idx, source) in sources.iter().enumerate() {
let value = source.get_value(stat_id, context);
base_value += value;
resolved.add_source(format!("Source #{}", idx + 1), value);
}
} else {
resolved.add_source("Default", 0.0);
}
let mut current_value = base_value;
if let Some(transforms) = self.transforms.get(stat_id) {
for transform in transforms {
let mut dependencies = HashMap::new();
for dep_id in transform.depends_on() {
let dep_value = self
.cache
.get(&dep_id)
.map(|r| r.value)
.ok_or_else(|| StatError::MissingDependency(dep_id.clone()))?;
dependencies.insert(dep_id, dep_value);
}
let new_value = transform.apply(current_value, &dependencies, context)?;
resolved.add_transform(transform.description(), new_value);
current_value = new_value;
}
}
resolved.value = current_value;
Ok(resolved)
}
}
impl Default for StatResolver {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::source::ConstantSource;
use crate::transform::{MultiplicativeTransform, ScalingTransform};
#[test]
fn test_resolve_simple_source() {
let mut resolver = StatResolver::new();
let hp_id = StatId::from_str("HP");
resolver.register_source(hp_id.clone(), Box::new(ConstantSource(100.0)));
let context = StatContext::new();
let resolved = resolver.resolve(&hp_id, &context).unwrap();
assert_eq!(resolved.value, 100.0);
assert_eq!(resolved.stat_id, hp_id);
}
#[test]
fn test_resolve_multiple_sources() {
let mut resolver = StatResolver::new();
let hp_id = StatId::from_str("HP");
resolver.register_source(hp_id.clone(), Box::new(ConstantSource(100.0)));
resolver.register_source(hp_id.clone(), Box::new(ConstantSource(50.0)));
let context = StatContext::new();
let resolved = resolver.resolve(&hp_id, &context).unwrap();
assert_eq!(resolved.value, 150.0);
assert_eq!(resolved.sources.len(), 2);
}
#[test]
fn test_resolve_with_transform() {
let mut resolver = StatResolver::new();
let atk_id = StatId::from_str("ATK");
resolver.register_source(atk_id.clone(), Box::new(ConstantSource(100.0)));
resolver.register_transform(atk_id.clone(), Box::new(MultiplicativeTransform::new(1.5)));
let context = StatContext::new();
let resolved = resolver.resolve(&atk_id, &context).unwrap();
assert_eq!(resolved.value, 150.0);
assert_eq!(resolved.transforms.len(), 1);
}
#[test]
fn test_resolve_with_dependency() {
let mut resolver = StatResolver::new();
let str_id = StatId::from_str("STR");
let atk_id = StatId::from_str("ATK");
resolver.register_source(str_id.clone(), Box::new(ConstantSource(10.0)));
resolver.register_source(atk_id.clone(), Box::new(ConstantSource(50.0)));
resolver.register_transform(
atk_id.clone(),
Box::new(ScalingTransform::new(str_id.clone(), 2.0)),
);
let context = StatContext::new();
let resolved = resolver.resolve(&atk_id, &context).unwrap();
assert_eq!(resolved.value, 70.0);
}
#[test]
fn test_resolve_missing_source() {
let mut resolver = StatResolver::new();
let hp_id = StatId::from_str("HP");
let context = StatContext::new();
let _result = resolver.resolve(&hp_id, &context);
}
#[test]
fn test_cache_invalidation() {
let mut resolver = StatResolver::new();
let hp_id = StatId::from_str("HP");
resolver.register_source(hp_id.clone(), Box::new(ConstantSource(100.0)));
let context = StatContext::new();
let resolved1 = resolver.resolve(&hp_id, &context).unwrap();
assert_eq!(resolved1.value, 100.0);
let resolved2 = resolver.resolve(&hp_id, &context).unwrap();
assert_eq!(resolved2.value, 100.0);
resolver.invalidate(&hp_id);
resolver.register_source(hp_id.clone(), Box::new(ConstantSource(50.0)));
let resolved3 = resolver.resolve(&hp_id, &context).unwrap();
assert_eq!(resolved3.value, 150.0); }
#[test]
fn test_cycle_detection() {
let mut resolver = StatResolver::new();
let a_id = StatId::from_str("A");
let b_id = StatId::from_str("B");
resolver.register_source(a_id.clone(), Box::new(ConstantSource(1.0)));
resolver.register_source(b_id.clone(), Box::new(ConstantSource(1.0)));
resolver.register_transform(
a_id.clone(),
Box::new(ScalingTransform::new(b_id.clone(), 1.0)),
);
resolver.register_transform(
b_id.clone(),
Box::new(ScalingTransform::new(a_id.clone(), 1.0)),
);
let context = StatContext::new();
let result = resolver.resolve(&a_id, &context);
assert!(result.is_err());
if let Err(StatError::CycleDetected(_)) = result {
} else {
panic!("Expected CycleDetected error");
}
}
}