Skip to main content

burn_fusion/stream/execution/
validator.rs

1use burn_ir::OperationIr;
2
3use crate::stream::store::{ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger};
4
5/// Compare each operation in the list of operations provided by the [store](OperationsStore)
6/// to verify if the newly added operations match the original list.
7///
8/// It is used by the [policy](crate::stream::execution::Policy) to check each candidate as well
9/// as to verify if a list of operations is optimal to execute based on their triggers.
10#[derive(Debug)]
11pub(crate) struct OperationsValidator<ID> {
12    /// The ID used to retrieve the operation list.
13    pub(crate) id: ID,
14    /// The current [state](MatchingState).
15    pub(crate) state: ValidatorState,
16}
17
18/// The state of the validator.
19#[derive(Debug)]
20pub(crate) enum ValidatorState {
21    /// A matching operation list has been found.
22    Found { size: usize },
23    /// No matching operation list has been found.
24    Invalidated,
25    /// Potentially going to find a matching operation list when more operations are added.
26    Validating,
27}
28
29/// Provides a list of operations based on an Id.
30pub(crate) trait OperationsStore {
31    /// The type used for the identifier.
32    type Id: Copy;
33
34    /// retrieve the list of operations corresponding on the provided id.
35    fn get(&self, id: Self::Id) -> &[OperationIr];
36}
37
38impl<ID> OperationsValidator<ID> {
39    /// Create a new validator.
40    pub(crate) fn new(id: ID) -> Self {
41        Self {
42            id,
43            state: ValidatorState::Validating,
44        }
45    }
46
47    /// Update the state of the validator based on the newly added operation.
48    pub(crate) fn update<S>(&mut self, added: &OperationIr, added_position: usize, store: &S)
49    where
50        S: OperationsStore<Id = ID>,
51        ID: PartialEq + Copy,
52    {
53        match &self.state {
54            ValidatorState::Found { size: _ } => return,
55            ValidatorState::Invalidated => return,
56            ValidatorState::Validating => {}
57        };
58
59        let item = store.get(self.id);
60        let operation_candidate = match item.get(added_position) {
61            Some(val) => val,
62            None => {
63                self.state = ValidatorState::Invalidated;
64                return;
65            }
66        };
67
68        if operation_candidate != added {
69            self.state = ValidatorState::Invalidated;
70            return;
71        }
72
73        // Finished
74        if item.len() == added_position + 1 {
75            self.state = ValidatorState::Found { size: item.len() };
76        }
77    }
78}
79
80/// [Operations store](OperationsStore) used to retrieve the list of operations for a trigger.
81#[derive(new)]
82pub(crate) struct TriggerOperationsStore<'a, O> {
83    id: ExecutionPlanId,
84    store: &'a ExecutionPlanStore<O>,
85}
86
87/// Validates when operations match a trigger.
88#[derive(Debug)]
89pub(crate) enum TriggerValidator {
90    OnOperations {
91        matching: OperationsValidator<TriggerId>,
92        progress: TriggerProgress,
93    },
94    Always,
95    OnSync,
96}
97
98/// The progress made into the trigger validation process.
99#[derive(Debug)]
100pub(crate) enum TriggerProgress {
101    /// When the validation hasn't started.
102    NotInit,
103    /// The number of operations that have been checked.
104    NumChecked(usize),
105}
106
107/// An execution plan can have many triggers, so we use the position in the list to identify a
108/// trigger.
109pub(crate) type TriggerId = usize;
110
111impl<O> OperationsStore for TriggerOperationsStore<'_, O> {
112    type Id = TriggerId;
113
114    fn get(&self, id: Self::Id) -> &[OperationIr] {
115        match &self.store.get_unchecked(self.id).triggers[id] {
116            ExecutionTrigger::OnOperations(operations) => operations,
117            ExecutionTrigger::OnSync => &[],
118            ExecutionTrigger::Always => &[],
119        }
120    }
121}
122
123/// [Operations store](OperationsStore) used to retrieve the list of operations for an
124/// [execution plan](crate::stream::store::ExecutionPlan).
125#[derive(new)]
126pub(crate) struct ExecutionPlanOperationsStore<'a, O> {
127    store: &'a ExecutionPlanStore<O>,
128}
129
130impl<O> OperationsStore for ExecutionPlanOperationsStore<'_, O> {
131    type Id = ExecutionPlanId;
132
133    fn get(&self, id: Self::Id) -> &[OperationIr] {
134        &self.store.get_unchecked(id).operations
135    }
136}