Skip to main content

hyperstack_interpreter/
scheduler.rs

1use crate::ast::{ComparisonOp, ResolverCondition, UrlTemplatePart};
2use crate::vm::ScheduledCallback;
3use percent_encoding::{utf8_percent_encode, AsciiSet, NON_ALPHANUMERIC};
4use serde_json::Value;
5use std::collections::{BTreeMap, HashMap, HashSet};
6
7pub const MAX_RETRIES: u32 = 100;
8
9type DedupKey = (String, String, String);
10
11pub struct SlotScheduler {
12    callbacks: BTreeMap<u64, Vec<ScheduledCallback>>,
13    registered: HashSet<DedupKey>,
14    /// Reverse index: dedup_key → slot, for O(1) targeted removal in register().
15    slot_index: HashMap<DedupKey, u64>,
16}
17
18impl Default for SlotScheduler {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl SlotScheduler {
25    pub fn new() -> Self {
26        Self {
27            callbacks: BTreeMap::new(),
28            registered: HashSet::new(),
29            slot_index: HashMap::new(),
30        }
31    }
32
33    pub fn register(&mut self, target_slot: u64, callback: ScheduledCallback) {
34        let dedup_key = Self::dedup_key(&callback);
35        if let Some(old_slot) = self.slot_index.remove(&dedup_key) {
36            if let Some(cbs) = self.callbacks.get_mut(&old_slot) {
37                cbs.retain(|cb| Self::dedup_key(cb) != dedup_key);
38                if cbs.is_empty() {
39                    self.callbacks.remove(&old_slot);
40                }
41            }
42        }
43        self.registered.insert(dedup_key.clone());
44        self.slot_index.insert(dedup_key, target_slot);
45        self.callbacks
46            .entry(target_slot)
47            .or_default()
48            .push(callback);
49    }
50
51    pub fn take_due(&mut self, current_slot: u64) -> Vec<ScheduledCallback> {
52        let future = self.callbacks.split_off(&current_slot.saturating_add(1));
53        let due = std::mem::replace(&mut self.callbacks, future);
54
55        let mut result = Vec::new();
56        for (_slot, callbacks) in due {
57            for cb in callbacks {
58                let dedup_key = Self::dedup_key(&cb);
59                self.registered.remove(&dedup_key);
60                self.slot_index.remove(&dedup_key);
61                result.push(cb);
62            }
63        }
64        result
65    }
66
67    pub fn re_register(&mut self, callback: ScheduledCallback, next_slot: u64) {
68        self.register(next_slot, callback);
69    }
70
71    pub fn pending_count(&self) -> usize {
72        self.callbacks.values().map(|v| v.len()).sum()
73    }
74
75    fn dedup_key(cb: &ScheduledCallback) -> DedupKey {
76        let resolver_key = serde_json::to_string(&cb.resolver).unwrap_or_default();
77        let condition_key = cb
78            .condition
79            .as_ref()
80            .map(|c| serde_json::to_string(c).unwrap_or_default())
81            .unwrap_or_default();
82        let pk_key = cb.primary_key.to_string();
83        (
84            cb.entity_name.clone(),
85            pk_key,
86            format!("{}:{}", resolver_key, condition_key),
87        )
88    }
89}
90
91pub fn evaluate_condition(condition: &ResolverCondition, state: &Value) -> bool {
92    let field_val = get_value_at_path(state, &condition.field_path).unwrap_or(Value::Null);
93    evaluate_comparison(&field_val, &condition.op, &condition.value)
94}
95
96/// NON_ALPHANUMERIC minus RFC 3986 unreserved chars (`-`, `.`, `_`, `~`) that are safe in URLs.
97/// NOTE: This uses path-segment encoding for all field references, including those in query
98/// parameter positions. Query strings permit additional chars (`!`, `$`, `'`, `(`, `)`, `+`, etc.)
99/// that will be over-encoded. This is safe for the current numeric/base58 use-cases but may need
100/// a path-vs-query split if general-purpose URL templates are needed.
101const URL_SEGMENT_SET: &AsciiSet = &NON_ALPHANUMERIC
102    .remove(b'-')
103    .remove(b'.')
104    .remove(b'_')
105    .remove(b'~');
106
107pub fn build_url_from_template(template: &[UrlTemplatePart], state: &Value) -> Option<String> {
108    let mut url = String::new();
109    for part in template {
110        match part {
111            UrlTemplatePart::Literal(s) => url.push_str(s),
112            UrlTemplatePart::FieldRef(path) => {
113                let val = get_value_at_path(state, path)?;
114                if val.is_null() {
115                    return None;
116                }
117                let raw = match val.as_str() {
118                    Some(s) => s.to_string(),
119                    None => val.to_string().trim_matches('"').to_string(),
120                };
121                let encoded = utf8_percent_encode(&raw, URL_SEGMENT_SET).to_string();
122                url.push_str(&encoded);
123            }
124        }
125    }
126    Some(url)
127}
128
129pub fn get_value_at_path(value: &Value, path: &str) -> Option<Value> {
130    let mut current = value;
131    for segment in path.split('.') {
132        current = current.get(segment)?;
133    }
134    Some(current.clone())
135}
136
137/// Mirrors the VM's i64 → u64 → f64 comparison cascade to avoid divergent
138/// condition evaluation between registration time and execution time.
139fn evaluate_comparison(field_value: &Value, op: &ComparisonOp, condition_value: &Value) -> bool {
140    match op {
141        ComparisonOp::Equal => field_value == condition_value,
142        ComparisonOp::NotEqual => field_value != condition_value,
143        ComparisonOp::GreaterThan => compare_numeric(
144            field_value,
145            condition_value,
146            |a, b| a > b,
147            |a, b| a > b,
148            |a, b| a > b,
149        ),
150        ComparisonOp::GreaterThanOrEqual => compare_numeric(
151            field_value,
152            condition_value,
153            |a, b| a >= b,
154            |a, b| a >= b,
155            |a, b| a >= b,
156        ),
157        ComparisonOp::LessThan => compare_numeric(
158            field_value,
159            condition_value,
160            |a, b| a < b,
161            |a, b| a < b,
162            |a, b| a < b,
163        ),
164        ComparisonOp::LessThanOrEqual => compare_numeric(
165            field_value,
166            condition_value,
167            |a, b| a <= b,
168            |a, b| a <= b,
169            |a, b| a <= b,
170        ),
171    }
172}
173
174fn compare_numeric(
175    a: &Value,
176    b: &Value,
177    cmp_i64: fn(i64, i64) -> bool,
178    cmp_u64: fn(u64, u64) -> bool,
179    cmp_f64: fn(f64, f64) -> bool,
180) -> bool {
181    match (a.as_i64(), b.as_i64()) {
182        (Some(a), Some(b)) => cmp_i64(a, b),
183        _ => match (a.as_u64(), b.as_u64()) {
184            (Some(a), Some(b)) => cmp_u64(a, b),
185            _ => match (a.as_f64(), b.as_f64()) {
186                (Some(a), Some(b)) => cmp_f64(a, b),
187                _ => false,
188            },
189        },
190    }
191}