Skip to main content

awaken_runtime/policies/
plugin.rs

1use std::sync::Arc;
2
3use crate::plugins::{Plugin, PluginDescriptor, PluginRegistrar};
4use awaken_contract::StateError;
5use awaken_contract::model::Phase;
6
7use super::hook::StopConditionHook;
8use super::policy::{MaxRoundsPolicy, StopPolicy};
9use super::state::StopConditionStatsKey;
10
11/// Plugin that evaluates stop policies after each inference step.
12pub struct StopConditionPlugin {
13    policies: Vec<Arc<dyn StopPolicy>>,
14}
15
16impl StopConditionPlugin {
17    pub fn new(policies: Vec<Arc<dyn StopPolicy>>) -> Self {
18        Self { policies }
19    }
20}
21
22impl Plugin for StopConditionPlugin {
23    fn descriptor(&self) -> PluginDescriptor {
24        PluginDescriptor {
25            name: "stop-condition",
26        }
27    }
28
29    fn register(&self, registrar: &mut PluginRegistrar) -> Result<(), StateError> {
30        registrar.register_key::<StopConditionStatsKey>(crate::state::StateKeyOptions::default())?;
31        registrar.register_phase_hook(
32            "stop-condition",
33            Phase::AfterInference,
34            StopConditionHook {
35                policies: self.policies.clone(),
36            },
37        )
38    }
39}
40
41/// Convenience plugin that terminates the run after a maximum number of steps.
42///
43/// Wraps `StopConditionPlugin` with a single `MaxRoundsPolicy`.
44pub struct MaxRoundsPlugin {
45    max_rounds: usize,
46}
47
48impl MaxRoundsPlugin {
49    pub fn new(max_rounds: usize) -> Self {
50        Self { max_rounds }
51    }
52}
53
54impl Plugin for MaxRoundsPlugin {
55    fn descriptor(&self) -> PluginDescriptor {
56        PluginDescriptor {
57            name: "stop-condition:max-rounds",
58        }
59    }
60
61    fn register(&self, registrar: &mut PluginRegistrar) -> Result<(), StateError> {
62        // Delegate to StopConditionPlugin internals
63        let policies: Vec<Arc<dyn StopPolicy>> =
64            vec![Arc::new(MaxRoundsPolicy::new(self.max_rounds))];
65        registrar.register_key::<StopConditionStatsKey>(crate::state::StateKeyOptions::default())?;
66        registrar.register_phase_hook(
67            "stop-condition:max-rounds",
68            Phase::AfterInference,
69            StopConditionHook { policies },
70        )
71    }
72}