Skip to main content

noether_engine/executor/
budget.rs

1//! Runtime cost-budget enforcement for composition execution.
2//!
3//! [`BudgetedExecutor`] wraps any [`StageExecutor`] and tracks actual cost
4//! consumed by each stage using its declared [`Effect::Cost`] effects.
5//! The accounting uses an `AtomicU64` so parallel branches are handled
6//! correctly without a mutex.
7//!
8//! ## Semantics
9//!
10//! Cost is **deducted before** a stage runs.  If adding a stage's declared
11//! cost would push `spent_cents` past `budget_cents`, the call returns
12//! [`ExecutionError::BudgetExceeded`] immediately — the stage is never
13//! invoked.  This is conservative: a stage that fails does not refund its
14//! cost.
15//!
16//! Parallel branches that collectively exceed the budget will each see the
17//! up-to-date atomic counter.  The first branch to cross the limit aborts;
18//! others may proceed transiently if they add their cost in the same
19//! microsecond, but the overall spent total accurately reflects reality.
20//!
21//! ## Usage
22//!
23//! ```no_run
24//! use noether_engine::executor::budget::{BudgetedExecutor, build_cost_map};
25//! use noether_engine::executor::mock::MockExecutor;
26//! use noether_engine::lagrange::CompositionNode;
27//! use noether_store::MemoryStore;
28//!
29//! let store = MemoryStore::new();
30//! let cost_map = build_cost_map(&CompositionNode::Const { value: serde_json::Value::Null }, &store);
31//! let inner = MockExecutor::new();
32//! let budgeted = BudgetedExecutor::new(inner, cost_map, 100 /* cents */);
33//! ```
34
35use super::{ExecutionError, StageExecutor};
36use noether_core::effects::Effect;
37use noether_core::stage::StageId;
38use noether_store::StageStore;
39use serde_json::Value;
40use std::collections::HashMap;
41use std::sync::atomic::{AtomicU64, Ordering};
42use std::sync::Arc;
43
44// ── Cost map ─────────────────────────────────────────────────────────────────
45
46/// Walk a composition graph and build a map of `StageId → declared_cents`.
47///
48/// Only stages that declare at least one `Effect::Cost { cents }` appear in
49/// the map.  Stages not in the store (e.g. RemoteStage) are ignored.
50pub fn build_cost_map(
51    node: &crate::lagrange::CompositionNode,
52    store: &(impl StageStore + ?Sized),
53) -> HashMap<StageId, u64> {
54    let mut map = HashMap::new();
55    collect_costs(node, store, &mut map);
56    map
57}
58
59fn collect_costs(
60    node: &crate::lagrange::CompositionNode,
61    store: &(impl StageStore + ?Sized),
62    map: &mut HashMap<StageId, u64>,
63) {
64    use crate::lagrange::CompositionNode::*;
65    match node {
66        Stage { id, .. } => {
67            if let Ok(Some(stage)) = store.get(id) {
68                let total: u64 = stage
69                    .signature
70                    .effects
71                    .iter()
72                    .filter_map(|e| {
73                        if let Effect::Cost { cents } = e {
74                            Some(*cents)
75                        } else {
76                            None
77                        }
78                    })
79                    .sum();
80                if total > 0 {
81                    map.insert(id.clone(), total);
82                }
83            }
84        }
85        RemoteStage { .. } | Const { .. } => {}
86        Sequential { stages } => {
87            for s in stages {
88                collect_costs(s, store, map);
89            }
90        }
91        Parallel { branches } => {
92            for b in branches.values() {
93                collect_costs(b, store, map);
94            }
95        }
96        Branch {
97            predicate,
98            if_true,
99            if_false,
100        } => {
101            collect_costs(predicate, store, map);
102            collect_costs(if_true, store, map);
103            collect_costs(if_false, store, map);
104        }
105        Fanout { source, targets } => {
106            collect_costs(source, store, map);
107            for t in targets {
108                collect_costs(t, store, map);
109            }
110        }
111        Merge { sources, target } => {
112            for s in sources {
113                collect_costs(s, store, map);
114            }
115            collect_costs(target, store, map);
116        }
117        Retry { stage, .. } => collect_costs(stage, store, map),
118        Let { bindings, body } => {
119            for b in bindings.values() {
120                collect_costs(b, store, map);
121            }
122            collect_costs(body, store, map);
123        }
124    }
125}
126
127// ── BudgetedExecutor ──────────────────────────────────────────────────────────
128
129/// An executor wrapper that enforces a runtime cost budget.
130///
131/// Maintains an `Arc<AtomicU64>` counter of cents spent so that concurrent
132/// parallel branches all see the same running total.
133pub struct BudgetedExecutor<E: StageExecutor> {
134    inner: E,
135    /// Declared cost in cents per stage id.
136    cost_map: HashMap<StageId, u64>,
137    /// Running total shared with all clones / concurrent uses.
138    spent_cents: Arc<AtomicU64>,
139    /// Hard limit in cents.
140    budget_cents: u64,
141}
142
143impl<E: StageExecutor> BudgetedExecutor<E> {
144    /// Create a new budgeted executor wrapping `inner`.
145    ///
146    /// `cost_map` maps stage ids to their declared cost in cents
147    /// (build it with [`build_cost_map`]).
148    /// `budget_cents` is the hard limit; execution aborts when it would
149    /// be exceeded.
150    pub fn new(inner: E, cost_map: HashMap<StageId, u64>, budget_cents: u64) -> Self {
151        Self {
152            inner,
153            cost_map,
154            spent_cents: Arc::new(AtomicU64::new(0)),
155            budget_cents,
156        }
157    }
158
159    /// Return a snapshot of cents spent so far.
160    pub fn spent_cents(&self) -> u64 {
161        self.spent_cents.load(Ordering::Relaxed)
162    }
163}
164
165impl<E: StageExecutor + Sync> StageExecutor for BudgetedExecutor<E> {
166    fn execute(&self, stage_id: &StageId, input: &Value) -> Result<Value, ExecutionError> {
167        let cost = self.cost_map.get(stage_id).copied().unwrap_or(0);
168
169        if cost > 0 {
170            // Atomically reserve the cost before executing.
171            // fetch_add returns the *previous* value, so newly_spent = prev + cost.
172            let prev = self.spent_cents.fetch_add(cost, Ordering::SeqCst);
173            let newly_spent = prev + cost;
174
175            if newly_spent > self.budget_cents {
176                // Roll back: we're not going to run this stage.
177                self.spent_cents.fetch_sub(cost, Ordering::SeqCst);
178                return Err(ExecutionError::BudgetExceeded {
179                    spent_cents: prev,
180                    budget_cents: self.budget_cents,
181                });
182            }
183        }
184
185        self.inner.execute(stage_id, input)
186    }
187}
188
189// ── Tests ─────────────────────────────────────────────────────────────────────
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use crate::executor::mock::MockExecutor;
195    use crate::lagrange::CompositionNode;
196    use noether_core::effects::{Effect, EffectSet};
197    use noether_core::stage::{CostEstimate, Stage, StageId, StageLifecycle, StageSignature};
198    use noether_core::types::NType;
199    use noether_store::MemoryStore;
200    use serde_json::json;
201    use std::collections::BTreeSet;
202
203    fn make_costly_stage(id: &str, cents: u64) -> Stage {
204        Stage {
205            id: StageId(id.into()),
206            canonical_id: None,
207            signature: StageSignature {
208                input: NType::Any,
209                output: NType::Any,
210                effects: EffectSet::new([
211                    Effect::Cost { cents },
212                    Effect::Llm {
213                        model: "gpt".into(),
214                    },
215                ]),
216                implementation_hash: format!("impl_{id}"),
217            },
218            capabilities: BTreeSet::new(),
219            cost: CostEstimate {
220                time_ms_p50: None,
221                tokens_est: None,
222                memory_mb: None,
223            },
224            description: format!("costly stage {id}"),
225            examples: vec![],
226            lifecycle: StageLifecycle::Active,
227            ed25519_signature: None,
228            signer_public_key: None,
229            implementation_code: None,
230            implementation_language: None,
231            ui_style: None,
232            tags: vec![],
233            aliases: vec![],
234            name: None,
235        }
236    }
237
238    #[test]
239    fn no_cost_stages_pass_through() {
240        let executor = MockExecutor::new().with_output(&StageId("a".into()), json!(1));
241        let budgeted = BudgetedExecutor::new(executor, HashMap::new(), 0);
242        let result = budgeted.execute(&StageId("a".into()), &json!(null));
243        assert_eq!(result.unwrap(), json!(1));
244        assert_eq!(budgeted.spent_cents(), 0);
245    }
246
247    #[test]
248    fn within_budget_executes_and_tracks_cost() {
249        let id = StageId("llm".into());
250        let executor = MockExecutor::new().with_output(&id, json!("ok"));
251        let cost_map = HashMap::from([(id.clone(), 10u64)]);
252        let budgeted = BudgetedExecutor::new(executor, cost_map, 100);
253        assert!(budgeted.execute(&id, &json!(null)).is_ok());
254        assert_eq!(budgeted.spent_cents(), 10);
255    }
256
257    #[test]
258    fn over_budget_returns_error_and_rolls_back() {
259        let id = StageId("expensive".into());
260        let executor = MockExecutor::new().with_output(&id, json!("ok"));
261        let cost_map = HashMap::from([(id.clone(), 50u64)]);
262        let budgeted = BudgetedExecutor::new(executor, cost_map, 49);
263
264        let err = budgeted.execute(&id, &json!(null)).unwrap_err();
265        assert!(
266            matches!(
267                err,
268                ExecutionError::BudgetExceeded {
269                    spent_cents: 0,
270                    budget_cents: 49
271                }
272            ),
273            "expected BudgetExceeded, got {err:?}"
274        );
275        // Counter rolled back — no cost was charged.
276        assert_eq!(budgeted.spent_cents(), 0);
277    }
278
279    #[test]
280    fn second_stage_pushes_over_budget() {
281        let a = StageId("a".into());
282        let b = StageId("b".into());
283        let executor = MockExecutor::new()
284            .with_output(&a, json!(1))
285            .with_output(&b, json!(2));
286        let cost_map = HashMap::from([(a.clone(), 60u64), (b.clone(), 50u64)]);
287        let budgeted = BudgetedExecutor::new(executor, cost_map, 100);
288
289        // First call: 60¢ → within 100¢ budget.
290        assert!(budgeted.execute(&a, &json!(null)).is_ok());
291        assert_eq!(budgeted.spent_cents(), 60);
292
293        // Second call: 60 + 50 = 110¢ > 100¢ → abort.
294        let err = budgeted.execute(&b, &json!(null)).unwrap_err();
295        assert!(
296            matches!(
297                err,
298                ExecutionError::BudgetExceeded {
299                    spent_cents: 60,
300                    budget_cents: 100
301                }
302            ),
303            "got {err:?}"
304        );
305        // Rolled back.
306        assert_eq!(budgeted.spent_cents(), 60);
307    }
308
309    #[test]
310    fn build_cost_map_extracts_costs_from_store() {
311        let mut store = MemoryStore::new();
312        store.put(make_costly_stage("s1", 25)).unwrap();
313        store.put(make_costly_stage("s2", 75)).unwrap();
314
315        let node = CompositionNode::Sequential {
316            stages: vec![
317                CompositionNode::Stage {
318                    id: StageId("s1".into()),
319                    config: None,
320                },
321                CompositionNode::Stage {
322                    id: StageId("s2".into()),
323                    config: None,
324                },
325            ],
326        };
327
328        let map = build_cost_map(&node, &store);
329        assert_eq!(map[&StageId("s1".into())], 25);
330        assert_eq!(map[&StageId("s2".into())], 75);
331    }
332
333    #[test]
334    fn build_cost_map_ignores_free_stages() {
335        let mut store = MemoryStore::new();
336        // Stage with no Cost effect.
337        let free = Stage {
338            id: StageId("free".into()),
339            canonical_id: None,
340            signature: StageSignature {
341                input: NType::Any,
342                output: NType::Any,
343                effects: EffectSet::pure(),
344                implementation_hash: "impl".into(),
345            },
346            capabilities: BTreeSet::new(),
347            cost: CostEstimate {
348                time_ms_p50: None,
349                tokens_est: None,
350                memory_mb: None,
351            },
352            description: "free stage".into(),
353            examples: vec![],
354            lifecycle: StageLifecycle::Active,
355            ed25519_signature: None,
356            signer_public_key: None,
357            implementation_code: None,
358            implementation_language: None,
359            ui_style: None,
360            tags: vec![],
361            aliases: vec![],
362            name: None,
363        };
364        store.put(free).unwrap();
365
366        let node = CompositionNode::Stage {
367            id: StageId("free".into()),
368            config: None,
369        };
370        let map = build_cost_map(&node, &store);
371        assert!(map.is_empty(), "free stage should not appear in cost map");
372    }
373}