1use super::{ExecutionError, StageExecutor};
36use noether_core::effects::Effect;
37use noether_core::stage::StageId;
38use noether_store::StageStore;
39use serde_json::Value;
40use std::collections::HashMap;
41use std::sync::atomic::{AtomicU64, Ordering};
42use std::sync::Arc;
43
44pub fn build_cost_map(
51 node: &crate::lagrange::CompositionNode,
52 store: &(impl StageStore + ?Sized),
53) -> HashMap<StageId, u64> {
54 let mut map = HashMap::new();
55 collect_costs(node, store, &mut map);
56 map
57}
58
59fn collect_costs(
60 node: &crate::lagrange::CompositionNode,
61 store: &(impl StageStore + ?Sized),
62 map: &mut HashMap<StageId, u64>,
63) {
64 use crate::lagrange::CompositionNode::*;
65 match node {
66 Stage { id, .. } => {
67 if let Ok(Some(stage)) = store.get(id) {
68 let total: u64 = stage
69 .signature
70 .effects
71 .iter()
72 .filter_map(|e| {
73 if let Effect::Cost { cents } = e {
74 Some(*cents)
75 } else {
76 None
77 }
78 })
79 .sum();
80 if total > 0 {
81 map.insert(id.clone(), total);
82 }
83 }
84 }
85 RemoteStage { .. } | Const { .. } => {}
86 Sequential { stages } => {
87 for s in stages {
88 collect_costs(s, store, map);
89 }
90 }
91 Parallel { branches } => {
92 for b in branches.values() {
93 collect_costs(b, store, map);
94 }
95 }
96 Branch {
97 predicate,
98 if_true,
99 if_false,
100 } => {
101 collect_costs(predicate, store, map);
102 collect_costs(if_true, store, map);
103 collect_costs(if_false, store, map);
104 }
105 Fanout { source, targets } => {
106 collect_costs(source, store, map);
107 for t in targets {
108 collect_costs(t, store, map);
109 }
110 }
111 Merge { sources, target } => {
112 for s in sources {
113 collect_costs(s, store, map);
114 }
115 collect_costs(target, store, map);
116 }
117 Retry { stage, .. } => collect_costs(stage, store, map),
118 Let { bindings, body } => {
119 for b in bindings.values() {
120 collect_costs(b, store, map);
121 }
122 collect_costs(body, store, map);
123 }
124 }
125}
126
127pub struct BudgetedExecutor<E: StageExecutor> {
134 inner: E,
135 cost_map: HashMap<StageId, u64>,
137 spent_cents: Arc<AtomicU64>,
139 budget_cents: u64,
141}
142
143impl<E: StageExecutor> BudgetedExecutor<E> {
144 pub fn new(inner: E, cost_map: HashMap<StageId, u64>, budget_cents: u64) -> Self {
151 Self {
152 inner,
153 cost_map,
154 spent_cents: Arc::new(AtomicU64::new(0)),
155 budget_cents,
156 }
157 }
158
159 pub fn spent_cents(&self) -> u64 {
161 self.spent_cents.load(Ordering::Relaxed)
162 }
163}
164
165impl<E: StageExecutor + Sync> StageExecutor for BudgetedExecutor<E> {
166 fn execute(&self, stage_id: &StageId, input: &Value) -> Result<Value, ExecutionError> {
167 let cost = self.cost_map.get(stage_id).copied().unwrap_or(0);
168
169 if cost > 0 {
170 let prev = self.spent_cents.fetch_add(cost, Ordering::SeqCst);
173 let newly_spent = prev + cost;
174
175 if newly_spent > self.budget_cents {
176 self.spent_cents.fetch_sub(cost, Ordering::SeqCst);
178 return Err(ExecutionError::BudgetExceeded {
179 spent_cents: prev,
180 budget_cents: self.budget_cents,
181 });
182 }
183 }
184
185 self.inner.execute(stage_id, input)
186 }
187}
188
189#[cfg(test)]
192mod tests {
193 use super::*;
194 use crate::executor::mock::MockExecutor;
195 use crate::lagrange::{CompositionNode, Pinning};
196 use noether_core::effects::{Effect, EffectSet};
197 use noether_core::stage::{CostEstimate, Stage, StageId, StageLifecycle, StageSignature};
198 use noether_core::types::NType;
199 use noether_store::MemoryStore;
200 use serde_json::json;
201 use std::collections::BTreeSet;
202
203 fn make_costly_stage(id: &str, cents: u64) -> Stage {
204 Stage {
205 id: StageId(id.into()),
206 signature_id: None,
207 signature: StageSignature {
208 input: NType::Any,
209 output: NType::Any,
210 effects: EffectSet::new([
211 Effect::Cost { cents },
212 Effect::Llm {
213 model: "gpt".into(),
214 },
215 ]),
216 implementation_hash: format!("impl_{id}"),
217 },
218 capabilities: BTreeSet::new(),
219 cost: CostEstimate {
220 time_ms_p50: None,
221 tokens_est: None,
222 memory_mb: None,
223 },
224 description: format!("costly stage {id}"),
225 examples: vec![],
226 lifecycle: StageLifecycle::Active,
227 ed25519_signature: None,
228 signer_public_key: None,
229 implementation_code: None,
230 implementation_language: None,
231 ui_style: None,
232 tags: vec![],
233 aliases: vec![],
234 name: None,
235 properties: Vec::new(),
236 }
237 }
238
239 #[test]
240 fn no_cost_stages_pass_through() {
241 let executor = MockExecutor::new().with_output(&StageId("a".into()), json!(1));
242 let budgeted = BudgetedExecutor::new(executor, HashMap::new(), 0);
243 let result = budgeted.execute(&StageId("a".into()), &json!(null));
244 assert_eq!(result.unwrap(), json!(1));
245 assert_eq!(budgeted.spent_cents(), 0);
246 }
247
248 #[test]
249 fn within_budget_executes_and_tracks_cost() {
250 let id = StageId("llm".into());
251 let executor = MockExecutor::new().with_output(&id, json!("ok"));
252 let cost_map = HashMap::from([(id.clone(), 10u64)]);
253 let budgeted = BudgetedExecutor::new(executor, cost_map, 100);
254 assert!(budgeted.execute(&id, &json!(null)).is_ok());
255 assert_eq!(budgeted.spent_cents(), 10);
256 }
257
258 #[test]
259 fn over_budget_returns_error_and_rolls_back() {
260 let id = StageId("expensive".into());
261 let executor = MockExecutor::new().with_output(&id, json!("ok"));
262 let cost_map = HashMap::from([(id.clone(), 50u64)]);
263 let budgeted = BudgetedExecutor::new(executor, cost_map, 49);
264
265 let err = budgeted.execute(&id, &json!(null)).unwrap_err();
266 assert!(
267 matches!(
268 err,
269 ExecutionError::BudgetExceeded {
270 spent_cents: 0,
271 budget_cents: 49
272 }
273 ),
274 "expected BudgetExceeded, got {err:?}"
275 );
276 assert_eq!(budgeted.spent_cents(), 0);
278 }
279
280 #[test]
281 fn second_stage_pushes_over_budget() {
282 let a = StageId("a".into());
283 let b = StageId("b".into());
284 let executor = MockExecutor::new()
285 .with_output(&a, json!(1))
286 .with_output(&b, json!(2));
287 let cost_map = HashMap::from([(a.clone(), 60u64), (b.clone(), 50u64)]);
288 let budgeted = BudgetedExecutor::new(executor, cost_map, 100);
289
290 assert!(budgeted.execute(&a, &json!(null)).is_ok());
292 assert_eq!(budgeted.spent_cents(), 60);
293
294 let err = budgeted.execute(&b, &json!(null)).unwrap_err();
296 assert!(
297 matches!(
298 err,
299 ExecutionError::BudgetExceeded {
300 spent_cents: 60,
301 budget_cents: 100
302 }
303 ),
304 "got {err:?}"
305 );
306 assert_eq!(budgeted.spent_cents(), 60);
308 }
309
310 #[test]
311 fn build_cost_map_extracts_costs_from_store() {
312 let mut store = MemoryStore::new();
313 store.put(make_costly_stage("s1", 25)).unwrap();
314 store.put(make_costly_stage("s2", 75)).unwrap();
315
316 let node = CompositionNode::Sequential {
317 stages: vec![
318 CompositionNode::Stage {
319 id: StageId("s1".into()),
320 pinning: Pinning::Signature,
321 config: None,
322 },
323 CompositionNode::Stage {
324 id: StageId("s2".into()),
325 pinning: Pinning::Signature,
326 config: None,
327 },
328 ],
329 };
330
331 let map = build_cost_map(&node, &store);
332 assert_eq!(map[&StageId("s1".into())], 25);
333 assert_eq!(map[&StageId("s2".into())], 75);
334 }
335
336 #[test]
337 fn build_cost_map_ignores_free_stages() {
338 let mut store = MemoryStore::new();
339 let free = Stage {
341 id: StageId("free".into()),
342 signature_id: None,
343 signature: StageSignature {
344 input: NType::Any,
345 output: NType::Any,
346 effects: EffectSet::pure(),
347 implementation_hash: "impl".into(),
348 },
349 capabilities: BTreeSet::new(),
350 cost: CostEstimate {
351 time_ms_p50: None,
352 tokens_est: None,
353 memory_mb: None,
354 },
355 description: "free stage".into(),
356 examples: vec![],
357 lifecycle: StageLifecycle::Active,
358 ed25519_signature: None,
359 signer_public_key: None,
360 implementation_code: None,
361 implementation_language: None,
362 ui_style: None,
363 tags: vec![],
364 aliases: vec![],
365 name: None,
366 properties: Vec::new(),
367 };
368 store.put(free).unwrap();
369
370 let node = CompositionNode::Stage {
371 id: StageId("free".into()),
372 pinning: Pinning::Signature,
373 config: None,
374 };
375 let map = build_cost_map(&node, &store);
376 assert!(map.is_empty(), "free stage should not appear in cost map");
377 }
378}