use std::collections::HashSet;
use std::path::PathBuf;
use std::time::{Duration, Instant};
use ryo_analysis::SymbolId;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct GoalId(pub u64);
impl GoalId {
pub fn new(id: u64) -> Self {
Self(id)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct WaveId(pub u64);
impl WaveId {
pub fn new(id: u64) -> Self {
Self(id)
}
}
#[derive(Debug, Clone, Default)]
pub struct AcChanges {
pub added: Vec<SymbolId>,
pub modified: Vec<SymbolId>,
pub removed: Vec<SymbolId>,
}
impl AcChanges {
pub fn new() -> Self {
Self::default()
}
pub fn is_empty(&self) -> bool {
self.added.is_empty() && self.modified.is_empty() && self.removed.is_empty()
}
pub fn affected_symbols(&self) -> impl Iterator<Item = &SymbolId> {
self.added.iter().chain(self.modified.iter())
}
pub fn len(&self) -> usize {
self.added.len() + self.modified.len() + self.removed.len()
}
pub fn merge(&mut self, other: AcChanges) {
self.added.extend(other.added);
self.modified.extend(other.modified);
self.removed.extend(other.removed);
self.added.sort();
self.added.dedup();
self.modified.sort();
self.modified.dedup();
self.removed.sort();
self.removed.dedup();
}
}
#[derive(Debug, Clone)]
pub enum SuggestTrigger {
GoalCompleted { goal_id: GoalId, changes: AcChanges },
WaveCompleted {
wave_id: WaveId,
goal_count: usize,
changes: AcChanges,
},
Manual,
Periodic { elapsed: Duration },
FileChanged { paths: Vec<PathBuf> },
}
impl SuggestTrigger {
pub fn kind(&self) -> TriggerKind {
match self {
Self::GoalCompleted { .. } => TriggerKind::GoalCompleted,
Self::WaveCompleted { .. } => TriggerKind::WaveCompleted,
Self::Manual => TriggerKind::Manual,
Self::Periodic { .. } => TriggerKind::Periodic,
Self::FileChanged { .. } => TriggerKind::FileChanged,
}
}
pub fn changes(&self) -> Option<&AcChanges> {
match self {
Self::GoalCompleted { changes, .. } => Some(changes),
Self::WaveCompleted { changes, .. } => Some(changes),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum TriggerKind {
GoalCompleted,
WaveCompleted,
Manual,
Periodic,
FileChanged,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum EvalGranularity {
#[default]
PerGoal,
PerWave,
EveryNGoals(usize),
ManualOnly,
}
#[derive(Debug, Clone)]
pub struct SuggestStrategy {
pub granularity: EvalGranularity,
pub min_interval: Duration,
pub max_pending_changes: usize,
pub enabled_triggers: HashSet<TriggerKind>,
}
impl Default for SuggestStrategy {
fn default() -> Self {
Self {
granularity: EvalGranularity::PerGoal,
min_interval: Duration::from_millis(100),
max_pending_changes: 50,
enabled_triggers: [
TriggerKind::GoalCompleted,
TriggerKind::WaveCompleted,
TriggerKind::Manual,
]
.into_iter()
.collect(),
}
}
}
impl SuggestStrategy {
pub fn interactive() -> Self {
Self::default()
}
pub fn high_perf() -> Self {
Self {
granularity: EvalGranularity::PerWave,
min_interval: Duration::from_millis(500),
max_pending_changes: 200,
enabled_triggers: [TriggerKind::WaveCompleted, TriggerKind::Manual]
.into_iter()
.collect(),
}
}
pub fn batch() -> Self {
Self {
granularity: EvalGranularity::ManualOnly,
min_interval: Duration::from_secs(1),
max_pending_changes: 1000,
enabled_triggers: [TriggerKind::Manual].into_iter().collect(),
}
}
pub fn should_evaluate(&self, trigger: &SuggestTrigger, pending: &PendingChanges) -> bool {
let kind = trigger.kind();
if !self.enabled_triggers.contains(&kind) {
return false;
}
if pending.last_eval.elapsed() < self.min_interval {
if pending.changes.len() < self.max_pending_changes {
return false;
}
}
match self.granularity {
EvalGranularity::PerGoal => {
matches!(trigger, SuggestTrigger::GoalCompleted { .. })
|| matches!(trigger, SuggestTrigger::Manual)
}
EvalGranularity::PerWave => {
matches!(trigger, SuggestTrigger::WaveCompleted { .. })
|| matches!(trigger, SuggestTrigger::Manual)
}
EvalGranularity::EveryNGoals(n) => {
pending.goal_count >= n || matches!(trigger, SuggestTrigger::Manual)
}
EvalGranularity::ManualOnly => matches!(trigger, SuggestTrigger::Manual),
}
}
pub fn with_granularity(mut self, granularity: EvalGranularity) -> Self {
self.granularity = granularity;
self
}
pub fn with_min_interval(mut self, interval: Duration) -> Self {
self.min_interval = interval;
self
}
pub fn enable_trigger(mut self, kind: TriggerKind) -> Self {
self.enabled_triggers.insert(kind);
self
}
pub fn disable_trigger(mut self, kind: TriggerKind) -> Self {
self.enabled_triggers.remove(&kind);
self
}
}
#[derive(Debug)]
pub struct PendingChanges {
pub goal_count: usize,
pub changes: AcChanges,
pub last_eval: Instant,
}
impl Default for PendingChanges {
fn default() -> Self {
Self::new()
}
}
impl PendingChanges {
pub fn new() -> Self {
Self {
goal_count: 0,
changes: AcChanges::default(),
last_eval: Instant::now(),
}
}
pub fn record_goal(&mut self, changes: AcChanges) {
self.goal_count += 1;
self.changes.merge(changes);
}
pub fn take(&mut self) -> (usize, AcChanges) {
let goal_count = self.goal_count;
let changes = std::mem::take(&mut self.changes);
self.goal_count = 0;
self.last_eval = Instant::now();
(goal_count, changes)
}
pub fn reset(&mut self) {
self.goal_count = 0;
self.changes = AcChanges::default();
self.last_eval = Instant::now();
}
pub fn has_pending(&self) -> bool {
self.goal_count > 0 || !self.changes.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ac_changes_merge() {
let mut c1 = AcChanges {
added: vec![
SymbolId::parse("1v1").unwrap(),
SymbolId::parse("2v1").unwrap(),
],
modified: vec![SymbolId::parse("3v1").unwrap()],
removed: vec![],
};
let c2 = AcChanges {
added: vec![
SymbolId::parse("2v1").unwrap(),
SymbolId::parse("4v1").unwrap(),
],
modified: vec![
SymbolId::parse("3v1").unwrap(),
SymbolId::parse("5v1").unwrap(),
],
removed: vec![SymbolId::parse("6v1").unwrap()],
};
c1.merge(c2);
assert_eq!(c1.added.len(), 3); assert_eq!(c1.modified.len(), 2); assert_eq!(c1.removed.len(), 1); }
#[test]
fn test_strategy_per_goal() {
let strategy = SuggestStrategy::interactive().with_min_interval(Duration::ZERO);
let pending = PendingChanges::new();
let trigger = SuggestTrigger::GoalCompleted {
goal_id: GoalId::new(1),
changes: AcChanges::default(),
};
assert!(strategy.should_evaluate(&trigger, &pending));
}
#[test]
fn test_strategy_per_wave() {
let strategy = SuggestStrategy::high_perf().with_min_interval(Duration::ZERO);
let pending = PendingChanges::new();
let goal_trigger = SuggestTrigger::GoalCompleted {
goal_id: GoalId::new(1),
changes: AcChanges::default(),
};
assert!(!strategy.should_evaluate(&goal_trigger, &pending));
let wave_trigger = SuggestTrigger::WaveCompleted {
wave_id: WaveId::new(1),
goal_count: 5,
changes: AcChanges::default(),
};
assert!(strategy.should_evaluate(&wave_trigger, &pending));
}
#[test]
fn test_strategy_manual_only() {
let strategy = SuggestStrategy::batch().with_min_interval(Duration::ZERO);
let pending = PendingChanges::new();
let goal_trigger = SuggestTrigger::GoalCompleted {
goal_id: GoalId::new(1),
changes: AcChanges::default(),
};
assert!(!strategy.should_evaluate(&goal_trigger, &pending));
assert!(strategy.should_evaluate(&SuggestTrigger::Manual, &pending));
}
#[test]
fn test_pending_changes() {
let mut pending = PendingChanges::new();
assert!(!pending.has_pending());
pending.record_goal(AcChanges {
added: vec![SymbolId::parse("1v1").unwrap()],
..Default::default()
});
assert!(pending.has_pending());
assert_eq!(pending.goal_count, 1);
let (count, changes) = pending.take();
assert_eq!(count, 1);
assert_eq!(changes.added.len(), 1);
assert!(!pending.has_pending());
}
}