harn_vm/llm/
trigger_predicate.rs1use std::cell::RefCell;
2use std::collections::{BTreeMap, HashMap};
3use std::sync::{Mutex, OnceLock};
4use std::time::Duration;
5
6use serde::{Deserialize, Serialize};
7
8use super::api::{LlmRequestPayload, LlmResult};
9use super::cost::calculate_cost;
10
11#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
12pub struct TriggerPredicateBudget {
13 pub max_cost_usd: Option<f64>,
14 pub tokens_max: Option<u64>,
15 pub timeout_ms: Option<u64>,
16}
17
18impl TriggerPredicateBudget {
19 pub fn timeout(&self) -> Option<Duration> {
20 self.timeout_ms.map(Duration::from_millis)
21 }
22}
23
24#[derive(Clone, Debug, Serialize, Deserialize)]
25pub(crate) struct PredicateCacheEntry {
26 pub request_hash: String,
27 pub(crate) result: LlmResult,
28}
29
30#[derive(Clone, Debug, Default)]
31pub struct PredicateEvaluationCapture {
32 pub entries: Vec<PredicateCacheEntry>,
33 pub total_tokens: u64,
34 pub total_cost_usd: f64,
35 pub cached: bool,
36 pub budget_exceeded: bool,
37}
38
39#[derive(Clone, Debug, Default)]
40struct PredicateEvaluationState {
41 budget: TriggerPredicateBudget,
42 replay_cache: HashMap<String, LlmResult>,
43 entries: BTreeMap<String, LlmResult>,
44 total_tokens: u64,
45 total_cost_usd: f64,
46 cached: bool,
47 budget_exceeded: bool,
48}
49
50thread_local! {
51 static ACTIVE_PREDICATE_EVALUATION: RefCell<Option<PredicateEvaluationState>> = const { RefCell::new(None) };
52}
53
54fn request_cache() -> &'static Mutex<HashMap<String, LlmResult>> {
55 static CACHE: OnceLock<Mutex<HashMap<String, LlmResult>>> = OnceLock::new();
56 CACHE.get_or_init(|| Mutex::new(HashMap::new()))
57}
58
59pub(crate) fn reset_trigger_predicate_state() {
60 ACTIVE_PREDICATE_EVALUATION.with(|slot| {
61 *slot.borrow_mut() = None;
62 });
63 if let Ok(mut cache) = request_cache().lock() {
64 cache.clear();
65 }
66}
67
68pub(crate) fn request_hash(request: &LlmRequestPayload) -> String {
69 use std::hash::{Hash, Hasher};
70
71 let canonical = serde_json::json!({
72 "provider": request.provider,
73 "model": request.model,
74 "messages": request.messages,
75 "system": request.system,
76 "max_tokens": request.max_tokens,
77 "temperature": request.temperature,
78 "top_p": request.top_p,
79 "top_k": request.top_k,
80 "stop": request.stop,
81 "seed": request.seed,
82 "frequency_penalty": request.frequency_penalty,
83 "presence_penalty": request.presence_penalty,
84 "response_format": request.response_format,
85 "json_schema": request.json_schema,
86 "thinking": request.thinking,
87 "native_tools": request.native_tools,
88 "tool_choice": request.tool_choice,
89 "cache": request.cache,
90 "timeout": request.timeout,
91 "stream": request.stream,
92 "provider_overrides": request.provider_overrides,
93 "prefill": request.prefill,
94 });
95 let mut hasher = std::collections::hash_map::DefaultHasher::new();
96 serde_json::to_string(&canonical)
97 .unwrap_or_default()
98 .hash(&mut hasher);
99 format!("{:016x}", hasher.finish())
100}
101
102pub(crate) struct PredicateEvaluationGuard;
103
104impl PredicateEvaluationGuard {
105 pub fn finish(self) -> PredicateEvaluationCapture {
106 finish_predicate_evaluation()
107 }
108}
109
110impl Drop for PredicateEvaluationGuard {
111 fn drop(&mut self) {
112 ACTIVE_PREDICATE_EVALUATION.with(|slot| {
113 *slot.borrow_mut() = None;
114 });
115 }
116}
117
118pub(crate) fn start_predicate_evaluation(
119 budget: TriggerPredicateBudget,
120 replay_entries: Vec<PredicateCacheEntry>,
121) -> PredicateEvaluationGuard {
122 ACTIVE_PREDICATE_EVALUATION.with(|slot| {
123 *slot.borrow_mut() = Some(PredicateEvaluationState {
124 budget,
125 replay_cache: replay_entries
126 .into_iter()
127 .map(|entry| (entry.request_hash, entry.result))
128 .collect(),
129 ..Default::default()
130 });
131 });
132 PredicateEvaluationGuard
133}
134
135fn finish_predicate_evaluation() -> PredicateEvaluationCapture {
136 ACTIVE_PREDICATE_EVALUATION.with(|slot| {
137 let Some(state) = slot.borrow_mut().take() else {
138 return PredicateEvaluationCapture::default();
139 };
140 PredicateEvaluationCapture {
141 entries: state
142 .entries
143 .into_iter()
144 .map(|(request_hash, result)| PredicateCacheEntry {
145 request_hash,
146 result,
147 })
148 .collect(),
149 total_tokens: state.total_tokens,
150 total_cost_usd: state.total_cost_usd,
151 cached: state.cached,
152 budget_exceeded: state.budget_exceeded,
153 }
154 })
155}
156
157pub(crate) fn lookup_cached_result(request: &LlmRequestPayload) -> Option<LlmResult> {
158 ACTIVE_PREDICATE_EVALUATION.with(|slot| {
159 let mut borrowed = slot.borrow_mut();
160 let state = borrowed.as_mut()?;
161 if state.budget_exceeded {
162 return None;
163 }
164 let hash = request_hash(request);
165 let cached = state.replay_cache.get(&hash).cloned().or_else(|| {
166 request_cache()
167 .lock()
168 .ok()
169 .and_then(|cache| cache.get(&hash).cloned())
170 });
171 if let Some(result) = cached.clone() {
172 state.cached = true;
173 state.entries.insert(hash, result.clone());
174 return Some(result);
175 }
176 None
177 })
178}
179
180pub(crate) fn note_result(request: &LlmRequestPayload, result: &LlmResult) {
181 ACTIVE_PREDICATE_EVALUATION.with(|slot| {
182 let mut borrowed = slot.borrow_mut();
183 let Some(state) = borrowed.as_mut() else {
184 return;
185 };
186 let hash = request_hash(request);
187 state.entries.insert(hash.clone(), result.clone());
188 if let Ok(mut cache) = request_cache().lock() {
189 cache.insert(hash, result.clone());
190 }
191 let call_tokens = result
192 .input_tokens
193 .saturating_add(result.output_tokens)
194 .max(0) as u64;
195 state.total_tokens = state.total_tokens.saturating_add(call_tokens);
196 state.total_cost_usd +=
197 calculate_cost(&result.model, result.input_tokens, result.output_tokens);
198 if state
199 .budget
200 .tokens_max
201 .is_some_and(|limit| state.total_tokens > limit)
202 {
203 state.budget_exceeded = true;
204 }
205 if state
206 .budget
207 .max_cost_usd
208 .is_some_and(|limit| state.total_cost_usd > limit)
209 {
210 state.budget_exceeded = true;
211 }
212 });
213}