Skip to main content

harn_vm/llm/
model_test.rs

1use std::time::Instant;
2
3use serde::Serialize;
4
5use super::api::{
6    vm_call_llm_full_streaming, LlmCallOptions, LlmRoutePolicy, OutputFormat, ThinkingConfig,
7};
8use crate::value::{VmError, VmValue};
9
10const SMOKE_TEST_MAX_TOKENS: i64 = 32;
11
12#[derive(Clone, Debug, PartialEq, Eq)]
13pub struct ModelSmokeTestOptions {
14    pub model: String,
15    pub provider: Option<String>,
16    pub prompt: String,
17}
18
19#[derive(Clone, Debug, PartialEq, Serialize)]
20pub struct ModelSmokeTestResult {
21    pub model_id: String,
22    pub provider: String,
23    pub latency_ms: u64,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub first_token_ms: Option<u64>,
26    pub input_tokens: i64,
27    pub output_tokens: i64,
28    pub estimated_cost_usd: f64,
29}
30
31pub async fn run_model_smoke_test(
32    options: ModelSmokeTestOptions,
33) -> Result<ModelSmokeTestResult, String> {
34    super::provider::register_default_providers();
35
36    let resolved = crate::llm_config::resolve_model_info(&options.model);
37    let model_id = resolved.id;
38    let provider = options
39        .provider
40        .as_deref()
41        .map(str::trim)
42        .filter(|provider| !provider.is_empty())
43        .map(str::to_string)
44        .unwrap_or(resolved.provider);
45    let api_key = super::helpers::resolve_api_key(&provider).map_err(vm_error_message)?;
46
47    if let Some(def) = crate::llm_config::provider_config(&provider) {
48        if super::supports_model_readiness_probe(&def) {
49            let readiness =
50                super::probe_openai_compatible_model(&provider, &model_id, &api_key).await;
51            if readiness.category == "model_missing" || readiness.category == "invalid_url" {
52                return Err(readiness.message);
53            }
54        }
55    }
56
57    let opts = LlmCallOptions {
58        provider: provider.clone(),
59        model: model_id.clone(),
60        api_key,
61        route_policy: LlmRoutePolicy::Manual,
62        fallback_chain: Vec::new(),
63        route_fallbacks: Vec::new(),
64        routing_decision: None,
65        routing_policy: None,
66        session_id: None,
67        messages: vec![serde_json::json!({
68            "role": "user",
69            "content": options.prompt,
70        })],
71        system: None,
72        transcript_summary: None,
73        max_tokens: SMOKE_TEST_MAX_TOKENS,
74        temperature: None,
75        top_p: None,
76        top_k: None,
77        logprobs: false,
78        top_logprobs: None,
79        stop: None,
80        seed: None,
81        frequency_penalty: None,
82        presence_penalty: None,
83        output_format: OutputFormat::Text,
84        response_format: None,
85        json_schema: None,
86        output_schema: None,
87        output_validation: None,
88        thinking: ThinkingConfig::Disabled,
89        anthropic_beta_features: Vec::new(),
90        vision: false,
91        tools: None,
92        native_tools: None,
93        tool_choice: None,
94        tool_search: None,
95        cache: false,
96        timeout: None,
97        idle_timeout: None,
98        stream: true,
99        provider_overrides: None,
100        budget: None,
101        prefill: None,
102        structural_experiment: None,
103        applied_structural_experiment: None,
104    };
105
106    let (delta_tx, mut delta_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
107    let started = Instant::now();
108    let first_delta = tokio::spawn(async move { delta_rx.recv().await.map(|_| started.elapsed()) });
109    let result = vm_call_llm_full_streaming(&opts, delta_tx)
110        .await
111        .map_err(vm_error_message);
112    let latency_ms = duration_ms(started.elapsed());
113    let first_token_ms = first_delta.await.ok().flatten().map(duration_ms);
114    let result = result?;
115
116    Ok(ModelSmokeTestResult {
117        model_id: result.model.clone(),
118        provider: result.provider.clone(),
119        latency_ms,
120        first_token_ms,
121        input_tokens: result.input_tokens,
122        output_tokens: result.output_tokens,
123        estimated_cost_usd: super::calculate_cost_for_provider(
124            &result.provider,
125            &result.model,
126            result.input_tokens,
127            result.output_tokens,
128        ),
129    })
130}
131
132fn duration_ms(duration: std::time::Duration) -> u64 {
133    u64::try_from(duration.as_millis()).unwrap_or(u64::MAX)
134}
135
136fn vm_error_message(error: VmError) -> String {
137    match error {
138        VmError::CategorizedError { message, .. } => message,
139        VmError::Thrown(VmValue::String(message)) => message.to_string(),
140        VmError::Thrown(VmValue::Dict(dict)) => dict
141            .get("message")
142            .map(VmValue::display)
143            .unwrap_or_else(|| VmError::Thrown(VmValue::Dict(dict)).to_string()),
144        other => other.to_string(),
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::{run_model_smoke_test, ModelSmokeTestOptions};
151
152    #[tokio::test]
153    async fn mock_provider_smoke_test_reports_timing_tokens_and_cost() {
154        crate::llm::reset_llm_state();
155        let result = run_model_smoke_test(ModelSmokeTestOptions {
156            model: "mock".to_string(),
157            provider: Some("mock".to_string()),
158            prompt: "ping".to_string(),
159        })
160        .await
161        .expect("mock provider smoke test should not require network");
162
163        assert_eq!(result.model_id, "mock");
164        assert_eq!(result.provider, "mock");
165        assert_eq!(result.input_tokens, 4);
166        assert_eq!(result.output_tokens, 30);
167        assert_eq!(result.estimated_cost_usd, 0.0);
168        assert!(result.first_token_ms.is_some());
169    }
170}