awaken_runtime/policies/
plugin.rs1use std::sync::Arc;
2
3use crate::plugins::{Plugin, PluginDescriptor, PluginRegistrar};
4use awaken_contract::StateError;
5use awaken_contract::model::Phase;
6
7use super::hook::StopConditionHook;
8use super::policy::{MaxRoundsPolicy, StopPolicy};
9use super::state::StopConditionStatsKey;
10
11pub struct StopConditionPlugin {
13 policies: Vec<Arc<dyn StopPolicy>>,
14}
15
16impl StopConditionPlugin {
17 pub fn new(policies: Vec<Arc<dyn StopPolicy>>) -> Self {
18 Self { policies }
19 }
20}
21
22impl Plugin for StopConditionPlugin {
23 fn descriptor(&self) -> PluginDescriptor {
24 PluginDescriptor {
25 name: "stop-condition",
26 }
27 }
28
29 fn register(&self, registrar: &mut PluginRegistrar) -> Result<(), StateError> {
30 registrar.register_key::<StopConditionStatsKey>(crate::state::StateKeyOptions::default())?;
31 registrar.register_phase_hook(
32 "stop-condition",
33 Phase::AfterInference,
34 StopConditionHook {
35 policies: self.policies.clone(),
36 },
37 )
38 }
39}
40
41pub struct MaxRoundsPlugin {
45 max_rounds: usize,
46}
47
48impl MaxRoundsPlugin {
49 pub fn new(max_rounds: usize) -> Self {
50 Self { max_rounds }
51 }
52}
53
54impl Plugin for MaxRoundsPlugin {
55 fn descriptor(&self) -> PluginDescriptor {
56 PluginDescriptor {
57 name: "stop-condition:max-rounds",
58 }
59 }
60
61 fn register(&self, registrar: &mut PluginRegistrar) -> Result<(), StateError> {
62 let policies: Vec<Arc<dyn StopPolicy>> =
64 vec![Arc::new(MaxRoundsPolicy::new(self.max_rounds))];
65 registrar.register_key::<StopConditionStatsKey>(crate::state::StateKeyOptions::default())?;
66 registrar.register_phase_hook(
67 "stop-condition:max-rounds",
68 Phase::AfterInference,
69 StopConditionHook { policies },
70 )
71 }
72}