use std::collections::{HashMap, HashSet};
pub trait Scoreable: Send + Sync {
fn id(&self) -> &str;
fn score_proxy(&self) -> f32;
fn category(&self) -> &str;
fn prerequisites(&self) -> &[String] {
&[]
}
fn metadata(&self) -> &HashMap<String, String>;
}
#[derive(Debug, Clone)]
pub struct Item {
pub id: String,
pub score_proxy: f32,
pub category: String,
pub prerequisites: Vec<String>,
pub metadata: HashMap<String, String>,
}
impl Item {
pub fn new(id: impl Into<String>, score_proxy: f32, category: impl Into<String>) -> Self {
Self {
id: id.into(),
score_proxy: score_proxy.clamp(0.0, 1.0),
category: category.into(),
prerequisites: Vec::new(),
metadata: HashMap::new(),
}
}
pub fn with_prereqs(mut self, prereqs: Vec<String>) -> Self {
self.prerequisites = prereqs;
self
}
pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
self.metadata = metadata;
self
}
}
impl Scoreable for Item {
fn id(&self) -> &str {
&self.id
}
fn score_proxy(&self) -> f32 {
self.score_proxy
}
fn category(&self) -> &str {
&self.category
}
fn prerequisites(&self) -> &[String] {
&self.prerequisites
}
fn metadata(&self) -> &HashMap<String, String> {
&self.metadata
}
}
pub struct ItemRegistry {
items: HashMap<String, Item>,
}
impl ItemRegistry {
pub fn new() -> Self {
Self {
items: HashMap::new(),
}
}
pub fn register(&mut self, items: Vec<Item>) -> Result<(), crate::error::AriaError> {
for item in items {
self.items.insert(item.id.clone(), item);
}
self.validate_prereq_graph()
}
pub fn eligible<'a>(&'a self, resolved_set: &HashSet<String>) -> Vec<&'a Item> {
self.items
.values()
.filter(|item| {
item.prerequisites
.iter()
.all(|prereq| resolved_set.contains(prereq))
})
.collect()
}
pub fn get(&self, id: &str) -> Option<&Item> {
self.items.get(id)
}
pub fn len(&self) -> usize {
self.items.len()
}
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
fn validate_prereq_graph(&self) -> Result<(), crate::error::AriaError> {
let mut visited: HashSet<&str> = HashSet::new();
let mut in_stack: HashSet<&str> = HashSet::new();
for id in self.items.keys() {
if !visited.contains(id.as_str()) {
self.dfs(id, &mut visited, &mut in_stack)?;
}
}
Ok(())
}
fn dfs<'a>(
&'a self,
id: &'a str,
visited: &mut HashSet<&'a str>,
in_stack: &mut HashSet<&'a str>,
) -> Result<(), crate::error::AriaError> {
visited.insert(id);
in_stack.insert(id);
if let Some(item) = self.items.get(id) {
for prereq in &item.prerequisites {
if !visited.contains(prereq.as_str()) {
self.dfs(prereq, visited, in_stack)?;
} else if in_stack.contains(prereq.as_str()) {
return Err(crate::error::AriaError::CyclicPrerequisite(prereq.clone()));
}
}
}
in_stack.remove(id);
Ok(())
}
}
impl Default for ItemRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cycle_detection() {
let mut registry = ItemRegistry::new();
let items = vec![
Item::new("a", 0.3, "cat").with_prereqs(vec!["b".into()]),
Item::new("b", 0.5, "cat").with_prereqs(vec!["a".into()]),
];
assert!(registry.register(items).is_err());
}
#[test]
fn eligible_filters_unsatisfied_prereqs() {
let mut registry = ItemRegistry::new();
registry
.register(vec![
Item::new("a", 0.3, "cat"),
Item::new("b", 0.5, "cat").with_prereqs(vec!["a".into()]),
])
.unwrap();
let resolved: HashSet<String> = HashSet::new();
let eligible = registry.eligible(&resolved);
assert_eq!(eligible.len(), 1);
assert_eq!(eligible[0].id(), "a");
}
#[test]
fn eligible_after_prereq_resolved() {
let mut registry = ItemRegistry::new();
registry
.register(vec![
Item::new("a", 0.3, "cat"),
Item::new("b", 0.5, "cat").with_prereqs(vec!["a".into()]),
])
.unwrap();
let mut resolved: HashSet<String> = HashSet::new();
resolved.insert("a".into());
let eligible = registry.eligible(&resolved);
assert_eq!(eligible.len(), 2);
}
}