harn-vm 0.8.1

Async bytecode virtual machine for the Harn programming language
Documentation
use std::cell::RefCell;
use std::collections::{BTreeMap, HashMap};
use std::sync::{Mutex, OnceLock};
use std::time::Duration;

use serde::{Deserialize, Serialize};

use super::api::{LlmRequestPayload, LlmResult};
use super::cost::calculate_cost;

#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub struct TriggerPredicateBudget {
    pub max_cost_usd: Option<f64>,
    pub tokens_max: Option<u64>,
    pub timeout_ms: Option<u64>,
}

impl TriggerPredicateBudget {
    pub fn timeout(&self) -> Option<Duration> {
        self.timeout_ms.map(Duration::from_millis)
    }
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub(crate) struct PredicateCacheEntry {
    pub request_hash: String,
    pub(crate) result: LlmResult,
}

#[derive(Clone, Debug, Default)]
pub struct PredicateEvaluationCapture {
    pub entries: Vec<PredicateCacheEntry>,
    pub total_tokens: u64,
    pub total_cost_usd: f64,
    pub cached: bool,
    pub budget_exceeded: bool,
}

#[derive(Clone, Debug, Default)]
struct PredicateEvaluationState {
    budget: TriggerPredicateBudget,
    replay_cache: HashMap<String, LlmResult>,
    entries: BTreeMap<String, LlmResult>,
    total_tokens: u64,
    total_cost_usd: f64,
    cached: bool,
    budget_exceeded: bool,
}

thread_local! {
    static ACTIVE_PREDICATE_EVALUATION: RefCell<Option<PredicateEvaluationState>> = const { RefCell::new(None) };
}

fn request_cache() -> &'static Mutex<HashMap<String, LlmResult>> {
    static CACHE: OnceLock<Mutex<HashMap<String, LlmResult>>> = OnceLock::new();
    CACHE.get_or_init(|| Mutex::new(HashMap::new()))
}

pub(crate) fn reset_trigger_predicate_state() {
    ACTIVE_PREDICATE_EVALUATION.with(|slot| {
        *slot.borrow_mut() = None;
    });
    if let Ok(mut cache) = request_cache().lock() {
        cache.clear();
    }
}

pub(crate) fn request_hash(request: &LlmRequestPayload) -> String {
    use std::hash::{Hash, Hasher};

    let canonical = serde_json::json!({
        "provider": request.provider,
        "model": request.model,
        "messages": request.messages,
        "system": request.system,
        "max_tokens": request.max_tokens,
        "temperature": request.temperature,
        "top_p": request.top_p,
        "top_k": request.top_k,
        "stop": request.stop,
        "seed": request.seed,
        "frequency_penalty": request.frequency_penalty,
        "presence_penalty": request.presence_penalty,
        "output_format": request.output_format,
        "response_format": request.response_format,
        "json_schema": request.json_schema,
        "thinking": request.thinking,
        "anthropic_beta_features": request.anthropic_beta_features,
        "native_tools": request.native_tools,
        "tool_choice": request.tool_choice,
        "cache": request.cache,
        "timeout": request.timeout,
        "stream": request.stream,
        "provider_overrides": request.provider_overrides,
        "prefill": request.prefill,
    });
    let mut hasher = std::collections::hash_map::DefaultHasher::new();
    serde_json::to_string(&canonical)
        .unwrap_or_default()
        .hash(&mut hasher);
    format!("{:016x}", hasher.finish())
}

pub(crate) struct PredicateEvaluationGuard;

impl PredicateEvaluationGuard {
    pub fn finish(self) -> PredicateEvaluationCapture {
        finish_predicate_evaluation()
    }
}

impl Drop for PredicateEvaluationGuard {
    fn drop(&mut self) {
        ACTIVE_PREDICATE_EVALUATION.with(|slot| {
            *slot.borrow_mut() = None;
        });
    }
}

pub(crate) fn start_predicate_evaluation(
    budget: TriggerPredicateBudget,
    replay_entries: Vec<PredicateCacheEntry>,
) -> PredicateEvaluationGuard {
    ACTIVE_PREDICATE_EVALUATION.with(|slot| {
        *slot.borrow_mut() = Some(PredicateEvaluationState {
            budget,
            replay_cache: replay_entries
                .into_iter()
                .map(|entry| (entry.request_hash, entry.result))
                .collect(),
            ..Default::default()
        });
    });
    PredicateEvaluationGuard
}

fn finish_predicate_evaluation() -> PredicateEvaluationCapture {
    ACTIVE_PREDICATE_EVALUATION.with(|slot| {
        let Some(state) = slot.borrow_mut().take() else {
            return PredicateEvaluationCapture::default();
        };
        PredicateEvaluationCapture {
            entries: state
                .entries
                .into_iter()
                .map(|(request_hash, result)| PredicateCacheEntry {
                    request_hash,
                    result,
                })
                .collect(),
            total_tokens: state.total_tokens,
            total_cost_usd: state.total_cost_usd,
            cached: state.cached,
            budget_exceeded: state.budget_exceeded,
        }
    })
}

pub(crate) fn lookup_cached_result(request: &LlmRequestPayload) -> Option<LlmResult> {
    ACTIVE_PREDICATE_EVALUATION.with(|slot| {
        let mut borrowed = slot.borrow_mut();
        let state = borrowed.as_mut()?;
        if state.budget_exceeded {
            return None;
        }
        let hash = request_hash(request);
        let cached = state.replay_cache.get(&hash).cloned().or_else(|| {
            request_cache()
                .lock()
                .ok()
                .and_then(|cache| cache.get(&hash).cloned())
        });
        if let Some(result) = cached.clone() {
            state.cached = true;
            state.entries.insert(hash, result.clone());
            return Some(result);
        }
        None
    })
}

pub(crate) fn note_result(request: &LlmRequestPayload, result: &LlmResult) {
    ACTIVE_PREDICATE_EVALUATION.with(|slot| {
        let mut borrowed = slot.borrow_mut();
        let Some(state) = borrowed.as_mut() else {
            return;
        };
        let hash = request_hash(request);
        state.entries.insert(hash.clone(), result.clone());
        if let Ok(mut cache) = request_cache().lock() {
            cache.insert(hash, result.clone());
        }
        let call_tokens = result
            .input_tokens
            .saturating_add(result.output_tokens)
            .max(0) as u64;
        state.total_tokens = state.total_tokens.saturating_add(call_tokens);
        state.total_cost_usd +=
            calculate_cost(&result.model, result.input_tokens, result.output_tokens);
        if state
            .budget
            .tokens_max
            .is_some_and(|limit| state.total_tokens > limit)
        {
            state.budget_exceeded = true;
        }
        if state
            .budget
            .max_cost_usd
            .is_some_and(|limit| state.total_cost_usd > limit)
        {
            state.budget_exceeded = true;
        }
    });
}