harn_vm/llm/
trigger_predicate.rs1use std::cell::RefCell;
2use std::collections::{BTreeMap, HashMap};
3use std::sync::{Mutex, OnceLock};
4use std::time::Duration;
5
6use serde::{Deserialize, Serialize};
7
8use super::api::{LlmRequestPayload, LlmResult};
9use super::cost::calculate_cost;
10
11#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
12pub struct TriggerPredicateBudget {
13 pub max_cost_usd: Option<f64>,
14 pub tokens_max: Option<u64>,
15 pub timeout_ms: Option<u64>,
16}
17
18impl TriggerPredicateBudget {
19 pub fn timeout(&self) -> Option<Duration> {
20 self.timeout_ms.map(Duration::from_millis)
21 }
22}
23
24#[derive(Clone, Debug, Serialize, Deserialize)]
25pub(crate) struct PredicateCacheEntry {
26 pub request_hash: String,
27 pub(crate) result: LlmResult,
28}
29
30#[derive(Clone, Debug, Default)]
31pub struct PredicateEvaluationCapture {
32 pub entries: Vec<PredicateCacheEntry>,
33 pub total_tokens: u64,
34 pub total_cost_usd: f64,
35 pub cached: bool,
36 pub budget_exceeded: bool,
37}
38
39#[derive(Clone, Debug, Default)]
40struct PredicateEvaluationState {
41 budget: TriggerPredicateBudget,
42 replay_cache: HashMap<String, LlmResult>,
43 entries: BTreeMap<String, LlmResult>,
44 total_tokens: u64,
45 total_cost_usd: f64,
46 cached: bool,
47 budget_exceeded: bool,
48}
49
50thread_local! {
51 static ACTIVE_PREDICATE_EVALUATION: RefCell<Option<PredicateEvaluationState>> = const { RefCell::new(None) };
52}
53
54fn request_cache() -> &'static Mutex<HashMap<String, LlmResult>> {
55 static CACHE: OnceLock<Mutex<HashMap<String, LlmResult>>> = OnceLock::new();
56 CACHE.get_or_init(|| Mutex::new(HashMap::new()))
57}
58
59pub(crate) fn reset_trigger_predicate_state() {
60 ACTIVE_PREDICATE_EVALUATION.with(|slot| {
61 *slot.borrow_mut() = None;
62 });
63 if let Ok(mut cache) = request_cache().lock() {
64 cache.clear();
65 }
66}
67
68pub(crate) fn request_hash(request: &LlmRequestPayload) -> String {
69 use std::hash::{Hash, Hasher};
70
71 let canonical = serde_json::json!({
72 "provider": request.provider,
73 "model": request.model,
74 "messages": request.messages,
75 "system": request.system,
76 "max_tokens": request.max_tokens,
77 "temperature": request.temperature,
78 "top_p": request.top_p,
79 "top_k": request.top_k,
80 "stop": request.stop,
81 "seed": request.seed,
82 "frequency_penalty": request.frequency_penalty,
83 "presence_penalty": request.presence_penalty,
84 "output_format": request.output_format,
85 "response_format": request.response_format,
86 "json_schema": request.json_schema,
87 "thinking": request.thinking,
88 "anthropic_beta_features": request.anthropic_beta_features,
89 "native_tools": request.native_tools,
90 "tool_choice": request.tool_choice,
91 "cache": request.cache,
92 "timeout": request.timeout,
93 "stream": request.stream,
94 "provider_overrides": request.provider_overrides,
95 "prefill": request.prefill,
96 });
97 let mut hasher = std::collections::hash_map::DefaultHasher::new();
98 serde_json::to_string(&canonical)
99 .unwrap_or_default()
100 .hash(&mut hasher);
101 format!("{:016x}", hasher.finish())
102}
103
104pub(crate) struct PredicateEvaluationGuard;
105
106impl PredicateEvaluationGuard {
107 pub fn finish(self) -> PredicateEvaluationCapture {
108 finish_predicate_evaluation()
109 }
110}
111
112impl Drop for PredicateEvaluationGuard {
113 fn drop(&mut self) {
114 ACTIVE_PREDICATE_EVALUATION.with(|slot| {
115 *slot.borrow_mut() = None;
116 });
117 }
118}
119
120pub(crate) fn start_predicate_evaluation(
121 budget: TriggerPredicateBudget,
122 replay_entries: Vec<PredicateCacheEntry>,
123) -> PredicateEvaluationGuard {
124 ACTIVE_PREDICATE_EVALUATION.with(|slot| {
125 *slot.borrow_mut() = Some(PredicateEvaluationState {
126 budget,
127 replay_cache: replay_entries
128 .into_iter()
129 .map(|entry| (entry.request_hash, entry.result))
130 .collect(),
131 ..Default::default()
132 });
133 });
134 PredicateEvaluationGuard
135}
136
137fn finish_predicate_evaluation() -> PredicateEvaluationCapture {
138 ACTIVE_PREDICATE_EVALUATION.with(|slot| {
139 let Some(state) = slot.borrow_mut().take() else {
140 return PredicateEvaluationCapture::default();
141 };
142 PredicateEvaluationCapture {
143 entries: state
144 .entries
145 .into_iter()
146 .map(|(request_hash, result)| PredicateCacheEntry {
147 request_hash,
148 result,
149 })
150 .collect(),
151 total_tokens: state.total_tokens,
152 total_cost_usd: state.total_cost_usd,
153 cached: state.cached,
154 budget_exceeded: state.budget_exceeded,
155 }
156 })
157}
158
159pub(crate) fn lookup_cached_result(request: &LlmRequestPayload) -> Option<LlmResult> {
160 ACTIVE_PREDICATE_EVALUATION.with(|slot| {
161 let mut borrowed = slot.borrow_mut();
162 let state = borrowed.as_mut()?;
163 if state.budget_exceeded {
164 return None;
165 }
166 let hash = request_hash(request);
167 let cached = state.replay_cache.get(&hash).cloned().or_else(|| {
168 request_cache()
169 .lock()
170 .ok()
171 .and_then(|cache| cache.get(&hash).cloned())
172 });
173 if let Some(result) = cached.clone() {
174 state.cached = true;
175 state.entries.insert(hash, result.clone());
176 return Some(result);
177 }
178 None
179 })
180}
181
182pub(crate) fn note_result(request: &LlmRequestPayload, result: &LlmResult) {
183 ACTIVE_PREDICATE_EVALUATION.with(|slot| {
184 let mut borrowed = slot.borrow_mut();
185 let Some(state) = borrowed.as_mut() else {
186 return;
187 };
188 let hash = request_hash(request);
189 state.entries.insert(hash.clone(), result.clone());
190 if let Ok(mut cache) = request_cache().lock() {
191 cache.insert(hash, result.clone());
192 }
193 let call_tokens = result
194 .input_tokens
195 .saturating_add(result.output_tokens)
196 .max(0) as u64;
197 state.total_tokens = state.total_tokens.saturating_add(call_tokens);
198 state.total_cost_usd +=
199 calculate_cost(&result.model, result.input_tokens, result.output_tokens);
200 if state
201 .budget
202 .tokens_max
203 .is_some_and(|limit| state.total_tokens > limit)
204 {
205 state.budget_exceeded = true;
206 }
207 if state
208 .budget
209 .max_cost_usd
210 .is_some_and(|limit| state.total_cost_usd > limit)
211 {
212 state.budget_exceeded = true;
213 }
214 });
215}