Skip to main content

harn_vm/llm/
trigger_predicate.rs

1use std::cell::RefCell;
2use std::collections::{BTreeMap, HashMap};
3use std::sync::{Mutex, OnceLock};
4use std::time::Duration;
5
6use serde::{Deserialize, Serialize};
7
8use super::api::{LlmRequestPayload, LlmResult};
9use super::cost::calculate_cost;
10
11#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
12pub struct TriggerPredicateBudget {
13    pub max_cost_usd: Option<f64>,
14    pub tokens_max: Option<u64>,
15    pub timeout_ms: Option<u64>,
16}
17
18impl TriggerPredicateBudget {
19    pub fn timeout(&self) -> Option<Duration> {
20        self.timeout_ms.map(Duration::from_millis)
21    }
22}
23
24#[derive(Clone, Debug, Serialize, Deserialize)]
25pub(crate) struct PredicateCacheEntry {
26    pub request_hash: String,
27    pub(crate) result: LlmResult,
28}
29
30#[derive(Clone, Debug, Default)]
31pub struct PredicateEvaluationCapture {
32    pub entries: Vec<PredicateCacheEntry>,
33    pub total_tokens: u64,
34    pub total_cost_usd: f64,
35    pub cached: bool,
36    pub budget_exceeded: bool,
37}
38
39#[derive(Clone, Debug, Default)]
40struct PredicateEvaluationState {
41    budget: TriggerPredicateBudget,
42    replay_cache: HashMap<String, LlmResult>,
43    entries: BTreeMap<String, LlmResult>,
44    total_tokens: u64,
45    total_cost_usd: f64,
46    cached: bool,
47    budget_exceeded: bool,
48}
49
50thread_local! {
51    static ACTIVE_PREDICATE_EVALUATION: RefCell<Option<PredicateEvaluationState>> = const { RefCell::new(None) };
52}
53
54fn request_cache() -> &'static Mutex<HashMap<String, LlmResult>> {
55    static CACHE: OnceLock<Mutex<HashMap<String, LlmResult>>> = OnceLock::new();
56    CACHE.get_or_init(|| Mutex::new(HashMap::new()))
57}
58
59pub(crate) fn reset_trigger_predicate_state() {
60    ACTIVE_PREDICATE_EVALUATION.with(|slot| {
61        *slot.borrow_mut() = None;
62    });
63    if let Ok(mut cache) = request_cache().lock() {
64        cache.clear();
65    }
66}
67
68pub(crate) fn request_hash(request: &LlmRequestPayload) -> String {
69    use std::hash::{Hash, Hasher};
70
71    let canonical = serde_json::json!({
72        "provider": request.provider,
73        "model": request.model,
74        "messages": request.messages,
75        "system": request.system,
76        "max_tokens": request.max_tokens,
77        "temperature": request.temperature,
78        "top_p": request.top_p,
79        "top_k": request.top_k,
80        "stop": request.stop,
81        "seed": request.seed,
82        "frequency_penalty": request.frequency_penalty,
83        "presence_penalty": request.presence_penalty,
84        "output_format": request.output_format,
85        "response_format": request.response_format,
86        "json_schema": request.json_schema,
87        "thinking": request.thinking,
88        "anthropic_beta_features": request.anthropic_beta_features,
89        "native_tools": request.native_tools,
90        "tool_choice": request.tool_choice,
91        "cache": request.cache,
92        "timeout": request.timeout,
93        "stream": request.stream,
94        "provider_overrides": request.provider_overrides,
95        "prefill": request.prefill,
96    });
97    let mut hasher = std::collections::hash_map::DefaultHasher::new();
98    serde_json::to_string(&canonical)
99        .unwrap_or_default()
100        .hash(&mut hasher);
101    format!("{:016x}", hasher.finish())
102}
103
104pub(crate) struct PredicateEvaluationGuard;
105
106impl PredicateEvaluationGuard {
107    pub fn finish(self) -> PredicateEvaluationCapture {
108        finish_predicate_evaluation()
109    }
110}
111
112impl Drop for PredicateEvaluationGuard {
113    fn drop(&mut self) {
114        ACTIVE_PREDICATE_EVALUATION.with(|slot| {
115            *slot.borrow_mut() = None;
116        });
117    }
118}
119
120pub(crate) fn start_predicate_evaluation(
121    budget: TriggerPredicateBudget,
122    replay_entries: Vec<PredicateCacheEntry>,
123) -> PredicateEvaluationGuard {
124    ACTIVE_PREDICATE_EVALUATION.with(|slot| {
125        *slot.borrow_mut() = Some(PredicateEvaluationState {
126            budget,
127            replay_cache: replay_entries
128                .into_iter()
129                .map(|entry| (entry.request_hash, entry.result))
130                .collect(),
131            ..Default::default()
132        });
133    });
134    PredicateEvaluationGuard
135}
136
137fn finish_predicate_evaluation() -> PredicateEvaluationCapture {
138    ACTIVE_PREDICATE_EVALUATION.with(|slot| {
139        let Some(state) = slot.borrow_mut().take() else {
140            return PredicateEvaluationCapture::default();
141        };
142        PredicateEvaluationCapture {
143            entries: state
144                .entries
145                .into_iter()
146                .map(|(request_hash, result)| PredicateCacheEntry {
147                    request_hash,
148                    result,
149                })
150                .collect(),
151            total_tokens: state.total_tokens,
152            total_cost_usd: state.total_cost_usd,
153            cached: state.cached,
154            budget_exceeded: state.budget_exceeded,
155        }
156    })
157}
158
159pub(crate) fn lookup_cached_result(request: &LlmRequestPayload) -> Option<LlmResult> {
160    ACTIVE_PREDICATE_EVALUATION.with(|slot| {
161        let mut borrowed = slot.borrow_mut();
162        let state = borrowed.as_mut()?;
163        if state.budget_exceeded {
164            return None;
165        }
166        let hash = request_hash(request);
167        let cached = state.replay_cache.get(&hash).cloned().or_else(|| {
168            request_cache()
169                .lock()
170                .ok()
171                .and_then(|cache| cache.get(&hash).cloned())
172        });
173        if let Some(result) = cached.clone() {
174            state.cached = true;
175            state.entries.insert(hash, result.clone());
176            return Some(result);
177        }
178        None
179    })
180}
181
182pub(crate) fn note_result(request: &LlmRequestPayload, result: &LlmResult) {
183    ACTIVE_PREDICATE_EVALUATION.with(|slot| {
184        let mut borrowed = slot.borrow_mut();
185        let Some(state) = borrowed.as_mut() else {
186            return;
187        };
188        let hash = request_hash(request);
189        state.entries.insert(hash.clone(), result.clone());
190        if let Ok(mut cache) = request_cache().lock() {
191            cache.insert(hash, result.clone());
192        }
193        let call_tokens = result
194            .input_tokens
195            .saturating_add(result.output_tokens)
196            .max(0) as u64;
197        state.total_tokens = state.total_tokens.saturating_add(call_tokens);
198        state.total_cost_usd +=
199            calculate_cost(&result.model, result.input_tokens, result.output_tokens);
200        if state
201            .budget
202            .tokens_max
203            .is_some_and(|limit| state.total_tokens > limit)
204        {
205            state.budget_exceeded = true;
206        }
207        if state
208            .budget
209            .max_cost_usd
210            .is_some_and(|limit| state.total_cost_usd > limit)
211        {
212            state.budget_exceeded = true;
213        }
214    });
215}