harn-vm 0.8.1

Async bytecode virtual machine for the Harn programming language
Documentation
use std::time::Instant;

use serde::Serialize;

use super::api::{
    vm_call_llm_full_streaming, LlmCallOptions, LlmRoutePolicy, OutputFormat, ThinkingConfig,
};
use crate::value::{VmError, VmValue};

const SMOKE_TEST_MAX_TOKENS: i64 = 32;

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ModelSmokeTestOptions {
    pub model: String,
    pub provider: Option<String>,
    pub prompt: String,
}

#[derive(Clone, Debug, PartialEq, Serialize)]
pub struct ModelSmokeTestResult {
    pub model_id: String,
    pub provider: String,
    pub latency_ms: u64,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub first_token_ms: Option<u64>,
    pub input_tokens: i64,
    pub output_tokens: i64,
    pub estimated_cost_usd: f64,
}

pub async fn run_model_smoke_test(
    options: ModelSmokeTestOptions,
) -> Result<ModelSmokeTestResult, String> {
    super::provider::register_default_providers();

    let resolved = crate::llm_config::resolve_model_info(&options.model);
    let model_id = resolved.id;
    let provider = options
        .provider
        .as_deref()
        .map(str::trim)
        .filter(|provider| !provider.is_empty())
        .map(str::to_string)
        .unwrap_or(resolved.provider);
    let api_key = super::helpers::resolve_api_key(&provider).map_err(vm_error_message)?;

    if let Some(def) = crate::llm_config::provider_config(&provider) {
        if super::supports_model_readiness_probe(&def) {
            let readiness =
                super::probe_openai_compatible_model(&provider, &model_id, &api_key).await;
            if readiness.category == "model_missing" || readiness.category == "invalid_url" {
                return Err(readiness.message);
            }
        }
    }

    let opts = LlmCallOptions {
        provider: provider.clone(),
        model: model_id.clone(),
        api_key,
        route_policy: LlmRoutePolicy::Manual,
        fallback_chain: Vec::new(),
        route_fallbacks: Vec::new(),
        routing_decision: None,
        session_id: None,
        messages: vec![serde_json::json!({
            "role": "user",
            "content": options.prompt,
        })],
        system: None,
        transcript_summary: None,
        max_tokens: SMOKE_TEST_MAX_TOKENS,
        temperature: None,
        top_p: None,
        top_k: None,
        logprobs: false,
        top_logprobs: None,
        stop: None,
        seed: None,
        frequency_penalty: None,
        presence_penalty: None,
        output_format: OutputFormat::Text,
        response_format: None,
        json_schema: None,
        output_schema: None,
        output_validation: None,
        thinking: ThinkingConfig::Disabled,
        anthropic_beta_features: Vec::new(),
        vision: false,
        tools: None,
        native_tools: None,
        tool_choice: None,
        tool_search: None,
        cache: false,
        timeout: None,
        idle_timeout: None,
        stream: true,
        provider_overrides: None,
        budget: None,
        prefill: None,
        structural_experiment: None,
        applied_structural_experiment: None,
    };

    let (delta_tx, mut delta_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
    let started = Instant::now();
    let first_delta = tokio::spawn(async move { delta_rx.recv().await.map(|_| started.elapsed()) });
    let result = vm_call_llm_full_streaming(&opts, delta_tx)
        .await
        .map_err(vm_error_message);
    let latency_ms = duration_ms(started.elapsed());
    let first_token_ms = first_delta.await.ok().flatten().map(duration_ms);
    let result = result?;

    Ok(ModelSmokeTestResult {
        model_id: result.model.clone(),
        provider: result.provider.clone(),
        latency_ms,
        first_token_ms,
        input_tokens: result.input_tokens,
        output_tokens: result.output_tokens,
        estimated_cost_usd: super::calculate_cost_for_provider(
            &result.provider,
            &result.model,
            result.input_tokens,
            result.output_tokens,
        ),
    })
}

fn duration_ms(duration: std::time::Duration) -> u64 {
    u64::try_from(duration.as_millis()).unwrap_or(u64::MAX)
}

fn vm_error_message(error: VmError) -> String {
    match error {
        VmError::CategorizedError { message, .. } => message,
        VmError::Thrown(VmValue::String(message)) => message.to_string(),
        VmError::Thrown(VmValue::Dict(dict)) => dict
            .get("message")
            .map(VmValue::display)
            .unwrap_or_else(|| VmError::Thrown(VmValue::Dict(dict)).to_string()),
        other => other.to_string(),
    }
}

#[cfg(test)]
mod tests {
    use super::{run_model_smoke_test, ModelSmokeTestOptions};

    #[tokio::test]
    async fn mock_provider_smoke_test_reports_timing_tokens_and_cost() {
        crate::llm::reset_llm_state();
        let result = run_model_smoke_test(ModelSmokeTestOptions {
            model: "mock".to_string(),
            provider: Some("mock".to_string()),
            prompt: "ping".to_string(),
        })
        .await
        .expect("mock provider smoke test should not require network");

        assert_eq!(result.model_id, "mock");
        assert_eq!(result.provider, "mock");
        assert_eq!(result.input_tokens, 4);
        assert_eq!(result.output_tokens, 30);
        assert_eq!(result.estimated_cost_usd, 0.0);
        assert!(result.first_token_ms.is_some());
    }
}