Skip to main content

harn_vm/llm/
trigger_predicate.rs

1use std::cell::RefCell;
2use std::collections::{BTreeMap, HashMap};
3use std::sync::{Mutex, OnceLock};
4use std::time::Duration;
5
6use serde::{Deserialize, Serialize};
7
8use super::api::{LlmRequestPayload, LlmResult};
9use super::cost::calculate_cost;
10
11#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
12pub struct TriggerPredicateBudget {
13    pub max_cost_usd: Option<f64>,
14    pub tokens_max: Option<u64>,
15    pub timeout_ms: Option<u64>,
16}
17
18impl TriggerPredicateBudget {
19    pub fn timeout(&self) -> Option<Duration> {
20        self.timeout_ms.map(Duration::from_millis)
21    }
22}
23
24#[derive(Clone, Debug, Serialize, Deserialize)]
25pub(crate) struct PredicateCacheEntry {
26    pub request_hash: String,
27    pub(crate) result: LlmResult,
28}
29
30#[derive(Clone, Debug, Default)]
31pub struct PredicateEvaluationCapture {
32    pub entries: Vec<PredicateCacheEntry>,
33    pub total_tokens: u64,
34    pub total_cost_usd: f64,
35    pub cached: bool,
36    pub budget_exceeded: bool,
37}
38
39#[derive(Clone, Debug, Default)]
40struct PredicateEvaluationState {
41    budget: TriggerPredicateBudget,
42    replay_cache: HashMap<String, LlmResult>,
43    entries: BTreeMap<String, LlmResult>,
44    total_tokens: u64,
45    total_cost_usd: f64,
46    cached: bool,
47    budget_exceeded: bool,
48}
49
50thread_local! {
51    static ACTIVE_PREDICATE_EVALUATION: RefCell<Option<PredicateEvaluationState>> = const { RefCell::new(None) };
52}
53
54fn request_cache() -> &'static Mutex<HashMap<String, LlmResult>> {
55    static CACHE: OnceLock<Mutex<HashMap<String, LlmResult>>> = OnceLock::new();
56    CACHE.get_or_init(|| Mutex::new(HashMap::new()))
57}
58
59pub(crate) fn reset_trigger_predicate_state() {
60    ACTIVE_PREDICATE_EVALUATION.with(|slot| {
61        *slot.borrow_mut() = None;
62    });
63    if let Ok(mut cache) = request_cache().lock() {
64        cache.clear();
65    }
66}
67
68pub(crate) fn request_hash(request: &LlmRequestPayload) -> String {
69    use std::hash::{Hash, Hasher};
70
71    let canonical = serde_json::json!({
72        "provider": request.provider,
73        "model": request.model,
74        "messages": request.messages,
75        "system": request.system,
76        "max_tokens": request.max_tokens,
77        "temperature": request.temperature,
78        "top_p": request.top_p,
79        "top_k": request.top_k,
80        "stop": request.stop,
81        "seed": request.seed,
82        "frequency_penalty": request.frequency_penalty,
83        "presence_penalty": request.presence_penalty,
84        "response_format": request.response_format,
85        "json_schema": request.json_schema,
86        "thinking": request.thinking,
87        "native_tools": request.native_tools,
88        "tool_choice": request.tool_choice,
89        "cache": request.cache,
90        "timeout": request.timeout,
91        "stream": request.stream,
92        "provider_overrides": request.provider_overrides,
93        "prefill": request.prefill,
94    });
95    let mut hasher = std::collections::hash_map::DefaultHasher::new();
96    serde_json::to_string(&canonical)
97        .unwrap_or_default()
98        .hash(&mut hasher);
99    format!("{:016x}", hasher.finish())
100}
101
102pub(crate) struct PredicateEvaluationGuard;
103
104impl PredicateEvaluationGuard {
105    pub fn finish(self) -> PredicateEvaluationCapture {
106        finish_predicate_evaluation()
107    }
108}
109
110impl Drop for PredicateEvaluationGuard {
111    fn drop(&mut self) {
112        ACTIVE_PREDICATE_EVALUATION.with(|slot| {
113            *slot.borrow_mut() = None;
114        });
115    }
116}
117
118pub(crate) fn start_predicate_evaluation(
119    budget: TriggerPredicateBudget,
120    replay_entries: Vec<PredicateCacheEntry>,
121) -> PredicateEvaluationGuard {
122    ACTIVE_PREDICATE_EVALUATION.with(|slot| {
123        *slot.borrow_mut() = Some(PredicateEvaluationState {
124            budget,
125            replay_cache: replay_entries
126                .into_iter()
127                .map(|entry| (entry.request_hash, entry.result))
128                .collect(),
129            ..Default::default()
130        });
131    });
132    PredicateEvaluationGuard
133}
134
135fn finish_predicate_evaluation() -> PredicateEvaluationCapture {
136    ACTIVE_PREDICATE_EVALUATION.with(|slot| {
137        let Some(state) = slot.borrow_mut().take() else {
138            return PredicateEvaluationCapture::default();
139        };
140        PredicateEvaluationCapture {
141            entries: state
142                .entries
143                .into_iter()
144                .map(|(request_hash, result)| PredicateCacheEntry {
145                    request_hash,
146                    result,
147                })
148                .collect(),
149            total_tokens: state.total_tokens,
150            total_cost_usd: state.total_cost_usd,
151            cached: state.cached,
152            budget_exceeded: state.budget_exceeded,
153        }
154    })
155}
156
157pub(crate) fn lookup_cached_result(request: &LlmRequestPayload) -> Option<LlmResult> {
158    ACTIVE_PREDICATE_EVALUATION.with(|slot| {
159        let mut borrowed = slot.borrow_mut();
160        let state = borrowed.as_mut()?;
161        if state.budget_exceeded {
162            return None;
163        }
164        let hash = request_hash(request);
165        let cached = state.replay_cache.get(&hash).cloned().or_else(|| {
166            request_cache()
167                .lock()
168                .ok()
169                .and_then(|cache| cache.get(&hash).cloned())
170        });
171        if let Some(result) = cached.clone() {
172            state.cached = true;
173            state.entries.insert(hash, result.clone());
174            return Some(result);
175        }
176        None
177    })
178}
179
180pub(crate) fn note_result(request: &LlmRequestPayload, result: &LlmResult) {
181    ACTIVE_PREDICATE_EVALUATION.with(|slot| {
182        let mut borrowed = slot.borrow_mut();
183        let Some(state) = borrowed.as_mut() else {
184            return;
185        };
186        let hash = request_hash(request);
187        state.entries.insert(hash.clone(), result.clone());
188        if let Ok(mut cache) = request_cache().lock() {
189            cache.insert(hash, result.clone());
190        }
191        let call_tokens = result
192            .input_tokens
193            .saturating_add(result.output_tokens)
194            .max(0) as u64;
195        state.total_tokens = state.total_tokens.saturating_add(call_tokens);
196        state.total_cost_usd +=
197            calculate_cost(&result.model, result.input_tokens, result.output_tokens);
198        if state
199            .budget
200            .tokens_max
201            .is_some_and(|limit| state.total_tokens > limit)
202        {
203            state.budget_exceeded = true;
204        }
205        if state
206            .budget
207            .max_cost_usd
208            .is_some_and(|limit| state.total_cost_usd > limit)
209        {
210            state.budget_exceeded = true;
211        }
212    });
213}