1use super::{ExecutionError, StageExecutor};
36use noether_core::effects::Effect;
37use noether_core::stage::StageId;
38use noether_store::StageStore;
39use serde_json::Value;
40use std::collections::HashMap;
41use std::sync::atomic::{AtomicU64, Ordering};
42use std::sync::Arc;
43
44pub fn build_cost_map(
51 node: &crate::lagrange::CompositionNode,
52 store: &(impl StageStore + ?Sized),
53) -> HashMap<StageId, u64> {
54 let mut map = HashMap::new();
55 collect_costs(node, store, &mut map);
56 map
57}
58
59fn collect_costs(
60 node: &crate::lagrange::CompositionNode,
61 store: &(impl StageStore + ?Sized),
62 map: &mut HashMap<StageId, u64>,
63) {
64 use crate::lagrange::CompositionNode::*;
65 match node {
66 Stage { id, .. } => {
67 if let Ok(Some(stage)) = store.get(id) {
68 let total: u64 = stage
69 .signature
70 .effects
71 .iter()
72 .filter_map(|e| {
73 if let Effect::Cost { cents } = e {
74 Some(*cents)
75 } else {
76 None
77 }
78 })
79 .sum();
80 if total > 0 {
81 map.insert(id.clone(), total);
82 }
83 }
84 }
85 RemoteStage { .. } | Const { .. } => {}
86 Sequential { stages } => {
87 for s in stages {
88 collect_costs(s, store, map);
89 }
90 }
91 Parallel { branches } => {
92 for b in branches.values() {
93 collect_costs(b, store, map);
94 }
95 }
96 Branch {
97 predicate,
98 if_true,
99 if_false,
100 } => {
101 collect_costs(predicate, store, map);
102 collect_costs(if_true, store, map);
103 collect_costs(if_false, store, map);
104 }
105 Fanout { source, targets } => {
106 collect_costs(source, store, map);
107 for t in targets {
108 collect_costs(t, store, map);
109 }
110 }
111 Merge { sources, target } => {
112 for s in sources {
113 collect_costs(s, store, map);
114 }
115 collect_costs(target, store, map);
116 }
117 Retry { stage, .. } => collect_costs(stage, store, map),
118 Let { bindings, body } => {
119 for b in bindings.values() {
120 collect_costs(b, store, map);
121 }
122 collect_costs(body, store, map);
123 }
124 }
125}
126
127pub struct BudgetedExecutor<E: StageExecutor> {
134 inner: E,
135 cost_map: HashMap<StageId, u64>,
137 spent_cents: Arc<AtomicU64>,
139 budget_cents: u64,
141}
142
143impl<E: StageExecutor> BudgetedExecutor<E> {
144 pub fn new(inner: E, cost_map: HashMap<StageId, u64>, budget_cents: u64) -> Self {
151 Self {
152 inner,
153 cost_map,
154 spent_cents: Arc::new(AtomicU64::new(0)),
155 budget_cents,
156 }
157 }
158
159 pub fn spent_cents(&self) -> u64 {
161 self.spent_cents.load(Ordering::Relaxed)
162 }
163}
164
165impl<E: StageExecutor + Sync> StageExecutor for BudgetedExecutor<E> {
166 fn execute(&self, stage_id: &StageId, input: &Value) -> Result<Value, ExecutionError> {
167 let cost = self.cost_map.get(stage_id).copied().unwrap_or(0);
168
169 if cost > 0 {
170 let prev = self.spent_cents.fetch_add(cost, Ordering::SeqCst);
173 let newly_spent = prev + cost;
174
175 if newly_spent > self.budget_cents {
176 self.spent_cents.fetch_sub(cost, Ordering::SeqCst);
178 return Err(ExecutionError::BudgetExceeded {
179 spent_cents: prev,
180 budget_cents: self.budget_cents,
181 });
182 }
183 }
184
185 self.inner.execute(stage_id, input)
186 }
187}
188
189#[cfg(test)]
192mod tests {
193 use super::*;
194 use crate::executor::mock::MockExecutor;
195 use crate::lagrange::CompositionNode;
196 use noether_core::effects::{Effect, EffectSet};
197 use noether_core::stage::{CostEstimate, Stage, StageId, StageLifecycle, StageSignature};
198 use noether_core::types::NType;
199 use noether_store::MemoryStore;
200 use serde_json::json;
201 use std::collections::BTreeSet;
202
203 fn make_costly_stage(id: &str, cents: u64) -> Stage {
204 Stage {
205 id: StageId(id.into()),
206 canonical_id: None,
207 signature: StageSignature {
208 input: NType::Any,
209 output: NType::Any,
210 effects: EffectSet::new([
211 Effect::Cost { cents },
212 Effect::Llm {
213 model: "gpt".into(),
214 },
215 ]),
216 implementation_hash: format!("impl_{id}"),
217 },
218 capabilities: BTreeSet::new(),
219 cost: CostEstimate {
220 time_ms_p50: None,
221 tokens_est: None,
222 memory_mb: None,
223 },
224 description: format!("costly stage {id}"),
225 examples: vec![],
226 lifecycle: StageLifecycle::Active,
227 ed25519_signature: None,
228 signer_public_key: None,
229 implementation_code: None,
230 implementation_language: None,
231 ui_style: None,
232 tags: vec![],
233 aliases: vec![],
234 name: None,
235 }
236 }
237
238 #[test]
239 fn no_cost_stages_pass_through() {
240 let executor = MockExecutor::new().with_output(&StageId("a".into()), json!(1));
241 let budgeted = BudgetedExecutor::new(executor, HashMap::new(), 0);
242 let result = budgeted.execute(&StageId("a".into()), &json!(null));
243 assert_eq!(result.unwrap(), json!(1));
244 assert_eq!(budgeted.spent_cents(), 0);
245 }
246
247 #[test]
248 fn within_budget_executes_and_tracks_cost() {
249 let id = StageId("llm".into());
250 let executor = MockExecutor::new().with_output(&id, json!("ok"));
251 let cost_map = HashMap::from([(id.clone(), 10u64)]);
252 let budgeted = BudgetedExecutor::new(executor, cost_map, 100);
253 assert!(budgeted.execute(&id, &json!(null)).is_ok());
254 assert_eq!(budgeted.spent_cents(), 10);
255 }
256
257 #[test]
258 fn over_budget_returns_error_and_rolls_back() {
259 let id = StageId("expensive".into());
260 let executor = MockExecutor::new().with_output(&id, json!("ok"));
261 let cost_map = HashMap::from([(id.clone(), 50u64)]);
262 let budgeted = BudgetedExecutor::new(executor, cost_map, 49);
263
264 let err = budgeted.execute(&id, &json!(null)).unwrap_err();
265 assert!(
266 matches!(
267 err,
268 ExecutionError::BudgetExceeded {
269 spent_cents: 0,
270 budget_cents: 49
271 }
272 ),
273 "expected BudgetExceeded, got {err:?}"
274 );
275 assert_eq!(budgeted.spent_cents(), 0);
277 }
278
279 #[test]
280 fn second_stage_pushes_over_budget() {
281 let a = StageId("a".into());
282 let b = StageId("b".into());
283 let executor = MockExecutor::new()
284 .with_output(&a, json!(1))
285 .with_output(&b, json!(2));
286 let cost_map = HashMap::from([(a.clone(), 60u64), (b.clone(), 50u64)]);
287 let budgeted = BudgetedExecutor::new(executor, cost_map, 100);
288
289 assert!(budgeted.execute(&a, &json!(null)).is_ok());
291 assert_eq!(budgeted.spent_cents(), 60);
292
293 let err = budgeted.execute(&b, &json!(null)).unwrap_err();
295 assert!(
296 matches!(
297 err,
298 ExecutionError::BudgetExceeded {
299 spent_cents: 60,
300 budget_cents: 100
301 }
302 ),
303 "got {err:?}"
304 );
305 assert_eq!(budgeted.spent_cents(), 60);
307 }
308
309 #[test]
310 fn build_cost_map_extracts_costs_from_store() {
311 let mut store = MemoryStore::new();
312 store.put(make_costly_stage("s1", 25)).unwrap();
313 store.put(make_costly_stage("s2", 75)).unwrap();
314
315 let node = CompositionNode::Sequential {
316 stages: vec![
317 CompositionNode::Stage {
318 id: StageId("s1".into()),
319 config: None,
320 },
321 CompositionNode::Stage {
322 id: StageId("s2".into()),
323 config: None,
324 },
325 ],
326 };
327
328 let map = build_cost_map(&node, &store);
329 assert_eq!(map[&StageId("s1".into())], 25);
330 assert_eq!(map[&StageId("s2".into())], 75);
331 }
332
333 #[test]
334 fn build_cost_map_ignores_free_stages() {
335 let mut store = MemoryStore::new();
336 let free = Stage {
338 id: StageId("free".into()),
339 canonical_id: None,
340 signature: StageSignature {
341 input: NType::Any,
342 output: NType::Any,
343 effects: EffectSet::pure(),
344 implementation_hash: "impl".into(),
345 },
346 capabilities: BTreeSet::new(),
347 cost: CostEstimate {
348 time_ms_p50: None,
349 tokens_est: None,
350 memory_mb: None,
351 },
352 description: "free stage".into(),
353 examples: vec![],
354 lifecycle: StageLifecycle::Active,
355 ed25519_signature: None,
356 signer_public_key: None,
357 implementation_code: None,
358 implementation_language: None,
359 ui_style: None,
360 tags: vec![],
361 aliases: vec![],
362 name: None,
363 };
364 store.put(free).unwrap();
365
366 let node = CompositionNode::Stage {
367 id: StageId("free".into()),
368 config: None,
369 };
370 let map = build_cost_map(&node, &store);
371 assert!(map.is_empty(), "free stage should not appear in cost map");
372 }
373}