burn_fusion/stream/execution/
validator.rs1use burn_ir::OperationIr;
2
3use crate::stream::store::{ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger};
4
5#[derive(Debug)]
11pub(crate) struct OperationsValidator<ID> {
12 pub(crate) id: ID,
14 pub(crate) state: ValidatorState,
16}
17
18#[derive(Debug)]
20pub(crate) enum ValidatorState {
21 Found { size: usize },
23 Invalidated,
25 Validating,
27}
28
29pub(crate) trait OperationsStore {
31 type Id: Copy;
33
34 fn get(&self, id: Self::Id) -> &[OperationIr];
36}
37
38impl<ID> OperationsValidator<ID> {
39 pub(crate) fn new(id: ID) -> Self {
41 Self {
42 id,
43 state: ValidatorState::Validating,
44 }
45 }
46
47 pub(crate) fn update<S>(&mut self, added: &OperationIr, added_position: usize, store: &S)
49 where
50 S: OperationsStore<Id = ID>,
51 ID: PartialEq + Copy,
52 {
53 match &self.state {
54 ValidatorState::Found { size: _ } => return,
55 ValidatorState::Invalidated => return,
56 ValidatorState::Validating => {}
57 };
58
59 let item = store.get(self.id);
60 let operation_candidate = match item.get(added_position) {
61 Some(val) => val,
62 None => {
63 self.state = ValidatorState::Invalidated;
64 return;
65 }
66 };
67
68 if operation_candidate != added {
69 self.state = ValidatorState::Invalidated;
70 return;
71 }
72
73 if item.len() == added_position + 1 {
75 self.state = ValidatorState::Found { size: item.len() };
76 }
77 }
78}
79
80#[derive(new)]
82pub(crate) struct TriggerOperationsStore<'a, O> {
83 id: ExecutionPlanId,
84 store: &'a ExecutionPlanStore<O>,
85}
86
87#[derive(Debug)]
89pub(crate) enum TriggerValidator {
90 OnOperations {
91 matching: OperationsValidator<TriggerId>,
92 progress: TriggerProgress,
93 },
94 Always,
95 OnSync,
96}
97
98#[derive(Debug)]
100pub(crate) enum TriggerProgress {
101 NotInit,
103 NumChecked(usize),
105}
106
107pub(crate) type TriggerId = usize;
110
111impl<O> OperationsStore for TriggerOperationsStore<'_, O> {
112 type Id = TriggerId;
113
114 fn get(&self, id: Self::Id) -> &[OperationIr] {
115 match &self.store.get_unchecked(self.id).triggers[id] {
116 ExecutionTrigger::OnOperations(operations) => operations,
117 ExecutionTrigger::OnSync => &[],
118 ExecutionTrigger::Always => &[],
119 }
120 }
121}
122
123#[derive(new)]
126pub(crate) struct ExecutionPlanOperationsStore<'a, O> {
127 store: &'a ExecutionPlanStore<O>,
128}
129
130impl<O> OperationsStore for ExecutionPlanOperationsStore<'_, O> {
131 type Id = ExecutionPlanId;
132
133 fn get(&self, id: Self::Id) -> &[OperationIr] {
134 &self.store.get_unchecked(id).operations
135 }
136}