Skip to main content

noether_engine/executor/
budget.rs

1//! Runtime cost-budget enforcement for composition execution.
2//!
3//! [`BudgetedExecutor`] wraps any [`StageExecutor`] and tracks actual cost
4//! consumed by each stage using its declared [`Effect::Cost`] effects.
5//! The accounting uses an `AtomicU64` so parallel branches are handled
6//! correctly without a mutex.
7//!
8//! ## Semantics
9//!
10//! Cost is **deducted before** a stage runs.  If adding a stage's declared
11//! cost would push `spent_cents` past `budget_cents`, the call returns
12//! [`ExecutionError::BudgetExceeded`] immediately — the stage is never
13//! invoked.  This is conservative: a stage that fails does not refund its
14//! cost.
15//!
16//! Parallel branches that collectively exceed the budget will each see the
17//! up-to-date atomic counter.  The first branch to cross the limit aborts;
18//! others may proceed transiently if they add their cost in the same
19//! microsecond, but the overall spent total accurately reflects reality.
20//!
21//! ## Usage
22//!
23//! ```no_run
24//! use noether_engine::executor::budget::{BudgetedExecutor, build_cost_map};
25//! use noether_engine::executor::mock::MockExecutor;
26//! use noether_engine::lagrange::CompositionNode;
27//! use noether_store::MemoryStore;
28//!
29//! let store = MemoryStore::new();
30//! let cost_map = build_cost_map(&CompositionNode::Const { value: serde_json::Value::Null }, &store);
31//! let inner = MockExecutor::new();
32//! let budgeted = BudgetedExecutor::new(inner, cost_map, 100 /* cents */);
33//! ```
34
35use super::{ExecutionError, StageExecutor};
36use noether_core::effects::Effect;
37use noether_core::stage::StageId;
38use noether_store::StageStore;
39use serde_json::Value;
40use std::collections::HashMap;
41use std::sync::atomic::{AtomicU64, Ordering};
42use std::sync::Arc;
43
44// ── Cost map ─────────────────────────────────────────────────────────────────
45
46/// Walk a composition graph and build a map of `StageId → declared_cents`.
47///
48/// Only stages that declare at least one `Effect::Cost { cents }` appear in
49/// the map.  Stages not in the store (e.g. RemoteStage) are ignored.
50pub fn build_cost_map(
51    node: &crate::lagrange::CompositionNode,
52    store: &(impl StageStore + ?Sized),
53) -> HashMap<StageId, u64> {
54    let mut map = HashMap::new();
55    collect_costs(node, store, &mut map);
56    map
57}
58
59fn collect_costs(
60    node: &crate::lagrange::CompositionNode,
61    store: &(impl StageStore + ?Sized),
62    map: &mut HashMap<StageId, u64>,
63) {
64    use crate::lagrange::CompositionNode::*;
65    match node {
66        Stage { id, .. } => {
67            if let Ok(Some(stage)) = store.get(id) {
68                let total: u64 = stage
69                    .signature
70                    .effects
71                    .iter()
72                    .filter_map(|e| {
73                        if let Effect::Cost { cents } = e {
74                            Some(*cents)
75                        } else {
76                            None
77                        }
78                    })
79                    .sum();
80                if total > 0 {
81                    map.insert(id.clone(), total);
82                }
83            }
84        }
85        RemoteStage { .. } | Const { .. } => {}
86        Sequential { stages } => {
87            for s in stages {
88                collect_costs(s, store, map);
89            }
90        }
91        Parallel { branches } => {
92            for b in branches.values() {
93                collect_costs(b, store, map);
94            }
95        }
96        Branch {
97            predicate,
98            if_true,
99            if_false,
100        } => {
101            collect_costs(predicate, store, map);
102            collect_costs(if_true, store, map);
103            collect_costs(if_false, store, map);
104        }
105        Fanout { source, targets } => {
106            collect_costs(source, store, map);
107            for t in targets {
108                collect_costs(t, store, map);
109            }
110        }
111        Merge { sources, target } => {
112            for s in sources {
113                collect_costs(s, store, map);
114            }
115            collect_costs(target, store, map);
116        }
117        Retry { stage, .. } => collect_costs(stage, store, map),
118        Let { bindings, body } => {
119            for b in bindings.values() {
120                collect_costs(b, store, map);
121            }
122            collect_costs(body, store, map);
123        }
124    }
125}
126
127// ── BudgetedExecutor ──────────────────────────────────────────────────────────
128
129/// An executor wrapper that enforces a runtime cost budget.
130///
131/// Maintains an `Arc<AtomicU64>` counter of cents spent so that concurrent
132/// parallel branches all see the same running total.
133pub struct BudgetedExecutor<E: StageExecutor> {
134    inner: E,
135    /// Declared cost in cents per stage id.
136    cost_map: HashMap<StageId, u64>,
137    /// Running total shared with all clones / concurrent uses.
138    spent_cents: Arc<AtomicU64>,
139    /// Hard limit in cents.
140    budget_cents: u64,
141}
142
143impl<E: StageExecutor> BudgetedExecutor<E> {
144    /// Create a new budgeted executor wrapping `inner`.
145    ///
146    /// `cost_map` maps stage ids to their declared cost in cents
147    /// (build it with [`build_cost_map`]).
148    /// `budget_cents` is the hard limit; execution aborts when it would
149    /// be exceeded.
150    pub fn new(inner: E, cost_map: HashMap<StageId, u64>, budget_cents: u64) -> Self {
151        Self {
152            inner,
153            cost_map,
154            spent_cents: Arc::new(AtomicU64::new(0)),
155            budget_cents,
156        }
157    }
158
159    /// Return a snapshot of cents spent so far.
160    pub fn spent_cents(&self) -> u64 {
161        self.spent_cents.load(Ordering::Relaxed)
162    }
163}
164
165impl<E: StageExecutor + Sync> StageExecutor for BudgetedExecutor<E> {
166    fn execute(&self, stage_id: &StageId, input: &Value) -> Result<Value, ExecutionError> {
167        let cost = self.cost_map.get(stage_id).copied().unwrap_or(0);
168
169        if cost > 0 {
170            // Atomically reserve the cost before executing.
171            // fetch_add returns the *previous* value, so newly_spent = prev + cost.
172            let prev = self.spent_cents.fetch_add(cost, Ordering::SeqCst);
173            let newly_spent = prev + cost;
174
175            if newly_spent > self.budget_cents {
176                // Roll back: we're not going to run this stage.
177                self.spent_cents.fetch_sub(cost, Ordering::SeqCst);
178                return Err(ExecutionError::BudgetExceeded {
179                    spent_cents: prev,
180                    budget_cents: self.budget_cents,
181                });
182            }
183        }
184
185        self.inner.execute(stage_id, input)
186    }
187}
188
189// ── Tests ─────────────────────────────────────────────────────────────────────
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use crate::executor::mock::MockExecutor;
195    use crate::lagrange::{CompositionNode, Pinning};
196    use noether_core::effects::{Effect, EffectSet};
197    use noether_core::stage::{CostEstimate, Stage, StageId, StageLifecycle, StageSignature};
198    use noether_core::types::NType;
199    use noether_store::MemoryStore;
200    use serde_json::json;
201    use std::collections::BTreeSet;
202
203    fn make_costly_stage(id: &str, cents: u64) -> Stage {
204        Stage {
205            id: StageId(id.into()),
206            signature_id: None,
207            signature: StageSignature {
208                input: NType::Any,
209                output: NType::Any,
210                effects: EffectSet::new([
211                    Effect::Cost { cents },
212                    Effect::Llm {
213                        model: "gpt".into(),
214                    },
215                ]),
216                implementation_hash: format!("impl_{id}"),
217            },
218            capabilities: BTreeSet::new(),
219            cost: CostEstimate {
220                time_ms_p50: None,
221                tokens_est: None,
222                memory_mb: None,
223            },
224            description: format!("costly stage {id}"),
225            examples: vec![],
226            lifecycle: StageLifecycle::Active,
227            ed25519_signature: None,
228            signer_public_key: None,
229            implementation_code: None,
230            implementation_language: None,
231            ui_style: None,
232            tags: vec![],
233            aliases: vec![],
234            name: None,
235            properties: Vec::new(),
236        }
237    }
238
239    #[test]
240    fn no_cost_stages_pass_through() {
241        let executor = MockExecutor::new().with_output(&StageId("a".into()), json!(1));
242        let budgeted = BudgetedExecutor::new(executor, HashMap::new(), 0);
243        let result = budgeted.execute(&StageId("a".into()), &json!(null));
244        assert_eq!(result.unwrap(), json!(1));
245        assert_eq!(budgeted.spent_cents(), 0);
246    }
247
248    #[test]
249    fn within_budget_executes_and_tracks_cost() {
250        let id = StageId("llm".into());
251        let executor = MockExecutor::new().with_output(&id, json!("ok"));
252        let cost_map = HashMap::from([(id.clone(), 10u64)]);
253        let budgeted = BudgetedExecutor::new(executor, cost_map, 100);
254        assert!(budgeted.execute(&id, &json!(null)).is_ok());
255        assert_eq!(budgeted.spent_cents(), 10);
256    }
257
258    #[test]
259    fn over_budget_returns_error_and_rolls_back() {
260        let id = StageId("expensive".into());
261        let executor = MockExecutor::new().with_output(&id, json!("ok"));
262        let cost_map = HashMap::from([(id.clone(), 50u64)]);
263        let budgeted = BudgetedExecutor::new(executor, cost_map, 49);
264
265        let err = budgeted.execute(&id, &json!(null)).unwrap_err();
266        assert!(
267            matches!(
268                err,
269                ExecutionError::BudgetExceeded {
270                    spent_cents: 0,
271                    budget_cents: 49
272                }
273            ),
274            "expected BudgetExceeded, got {err:?}"
275        );
276        // Counter rolled back — no cost was charged.
277        assert_eq!(budgeted.spent_cents(), 0);
278    }
279
280    #[test]
281    fn second_stage_pushes_over_budget() {
282        let a = StageId("a".into());
283        let b = StageId("b".into());
284        let executor = MockExecutor::new()
285            .with_output(&a, json!(1))
286            .with_output(&b, json!(2));
287        let cost_map = HashMap::from([(a.clone(), 60u64), (b.clone(), 50u64)]);
288        let budgeted = BudgetedExecutor::new(executor, cost_map, 100);
289
290        // First call: 60¢ → within 100¢ budget.
291        assert!(budgeted.execute(&a, &json!(null)).is_ok());
292        assert_eq!(budgeted.spent_cents(), 60);
293
294        // Second call: 60 + 50 = 110¢ > 100¢ → abort.
295        let err = budgeted.execute(&b, &json!(null)).unwrap_err();
296        assert!(
297            matches!(
298                err,
299                ExecutionError::BudgetExceeded {
300                    spent_cents: 60,
301                    budget_cents: 100
302                }
303            ),
304            "got {err:?}"
305        );
306        // Rolled back.
307        assert_eq!(budgeted.spent_cents(), 60);
308    }
309
310    #[test]
311    fn build_cost_map_extracts_costs_from_store() {
312        let mut store = MemoryStore::new();
313        store.put(make_costly_stage("s1", 25)).unwrap();
314        store.put(make_costly_stage("s2", 75)).unwrap();
315
316        let node = CompositionNode::Sequential {
317            stages: vec![
318                CompositionNode::Stage {
319                    id: StageId("s1".into()),
320                    pinning: Pinning::Signature,
321                    config: None,
322                },
323                CompositionNode::Stage {
324                    id: StageId("s2".into()),
325                    pinning: Pinning::Signature,
326                    config: None,
327                },
328            ],
329        };
330
331        let map = build_cost_map(&node, &store);
332        assert_eq!(map[&StageId("s1".into())], 25);
333        assert_eq!(map[&StageId("s2".into())], 75);
334    }
335
336    #[test]
337    fn build_cost_map_ignores_free_stages() {
338        let mut store = MemoryStore::new();
339        // Stage with no Cost effect.
340        let free = Stage {
341            id: StageId("free".into()),
342            signature_id: None,
343            signature: StageSignature {
344                input: NType::Any,
345                output: NType::Any,
346                effects: EffectSet::pure(),
347                implementation_hash: "impl".into(),
348            },
349            capabilities: BTreeSet::new(),
350            cost: CostEstimate {
351                time_ms_p50: None,
352                tokens_est: None,
353                memory_mb: None,
354            },
355            description: "free stage".into(),
356            examples: vec![],
357            lifecycle: StageLifecycle::Active,
358            ed25519_signature: None,
359            signer_public_key: None,
360            implementation_code: None,
361            implementation_language: None,
362            ui_style: None,
363            tags: vec![],
364            aliases: vec![],
365            name: None,
366            properties: Vec::new(),
367        };
368        store.put(free).unwrap();
369
370        let node = CompositionNode::Stage {
371            id: StageId("free".into()),
372            pinning: Pinning::Signature,
373            config: None,
374        };
375        let map = build_cost_map(&node, &store);
376        assert!(map.is_empty(), "free stage should not appear in cost map");
377    }
378}