use std::fmt::Debug;
use super::SelectionLogic;
use crate::exploration::map::{ExplorationMap, GraphMap, MapNodeId, MapState};
use crate::exploration::mutation::ActionExtractor;
use crate::learn::{LearnedProvider, LearningQuery};
use crate::online_stats::SwarmStats;
#[derive(Debug, Clone, Default)]
pub struct Greedy {
weight: f64,
}
impl Greedy {
pub fn new() -> Self {
Self { weight: 1.0 }
}
pub fn with_weight(weight: f64) -> Self {
Self { weight }
}
pub fn compute_score(
&self,
action: &str,
target: Option<&str>,
provider: &dyn LearnedProvider,
) -> f64 {
self.compute_score_with_context(action, target, provider, None, None)
}
pub fn compute_score_with_context(
&self,
action: &str,
target: Option<&str>,
provider: &dyn LearnedProvider,
prev_action: Option<&str>,
prev_prev_action: Option<&str>,
) -> f64 {
let bonus = provider
.query(LearningQuery::confidence_with_context(
action,
target,
prev_action,
prev_prev_action,
))
.score();
0.5 + bonus * self.weight
}
}
impl<N, E, S> SelectionLogic<N, E, S> for Greedy
where
N: Debug + Clone + ActionExtractor,
E: Debug + Clone,
S: MapState,
{
fn next(
&self,
map: &GraphMap<N, E, S>,
stats: &SwarmStats,
provider: &dyn LearnedProvider,
) -> Option<MapNodeId> {
self.select(map, 1, stats, provider).into_iter().next()
}
fn select(
&self,
map: &GraphMap<N, E, S>,
count: usize,
_stats: &SwarmStats,
provider: &dyn LearnedProvider,
) -> Vec<MapNodeId> {
let frontiers = map.frontiers();
if frontiers.is_empty() || count == 0 {
return vec![];
}
let mut scored: Vec<_> = frontiers
.iter()
.filter_map(|&id| {
map.get(id).map(|node| {
let (action, target) = node.data.extract();
let parent_id = map.parent(id);
let prev_action = parent_id.and_then(|pid| {
map.get(pid)
.map(|parent_node| parent_node.data.action_name().to_string())
});
let prev_prev_action =
parent_id.and_then(|pid| map.parent(pid)).and_then(|ppid| {
map.get(ppid).map(|grandparent_node| {
grandparent_node.data.action_name().to_string()
})
});
let score = self.compute_score_with_context(
action,
target,
provider,
prev_action.as_deref(),
prev_prev_action.as_deref(),
);
(id, score)
})
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.into_iter().take(count).map(|(id, _)| id).collect()
}
fn score(
&self,
action: &str,
target: Option<&str>,
_stats: &SwarmStats,
provider: &dyn LearnedProvider,
) -> f64 {
self.compute_score(action, target, provider)
}
fn name(&self) -> &str {
"Greedy"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::learn::NullProvider;
#[test]
fn test_greedy_default_score() {
let greedy = Greedy::new();
let provider = NullProvider;
assert_eq!(greedy.compute_score("unknown", None, &provider), 0.5);
}
#[test]
fn test_greedy_with_null_provider() {
let greedy = Greedy::new();
let provider = NullProvider;
assert_eq!(greedy.compute_score("grep", None, &provider), 0.5);
assert_eq!(greedy.compute_score("glob", Some("svc1"), &provider), 0.5);
assert_eq!(greedy.compute_score("other", None, &provider), 0.5);
}
#[test]
fn test_greedy_with_weight() {
let greedy = Greedy::with_weight(2.0);
let provider = NullProvider;
assert_eq!(greedy.compute_score("action", None, &provider), 0.5);
}
}