Skip to main content

harn_vm/llm/
model_test.rs

1use std::time::Instant;
2
3use serde::Serialize;
4
5use super::api::{
6    vm_call_llm_full_streaming, LlmApiMode, LlmCallOptions, LlmRoutePolicy, OutputFormat,
7    ThinkingConfig,
8};
9use crate::value::{VmError, VmValue};
10
11const SMOKE_TEST_MAX_TOKENS: i64 = 32;
12
13#[derive(Clone, Debug, PartialEq, Eq)]
14pub struct ModelSmokeTestOptions {
15    pub model: String,
16    pub provider: Option<String>,
17    pub prompt: String,
18}
19
20#[derive(Clone, Debug, PartialEq, Serialize)]
21pub struct ModelSmokeTestResult {
22    pub model_id: String,
23    pub provider: String,
24    pub latency_ms: u64,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub first_token_ms: Option<u64>,
27    pub input_tokens: i64,
28    pub output_tokens: i64,
29    pub estimated_cost_usd: f64,
30}
31
32pub async fn run_model_smoke_test(
33    options: ModelSmokeTestOptions,
34) -> Result<ModelSmokeTestResult, String> {
35    super::provider::register_default_providers();
36
37    let resolved = crate::llm_config::resolve_model_info(&options.model);
38    let model_id = resolved.id;
39    let provider = options
40        .provider
41        .as_deref()
42        .map(str::trim)
43        .filter(|provider| !provider.is_empty())
44        .map(str::to_string)
45        .unwrap_or(resolved.provider);
46    let api_key = super::helpers::resolve_api_key(&provider).map_err(vm_error_message)?;
47
48    if let Some(def) = crate::llm_config::provider_config(&provider) {
49        if super::supports_model_readiness_probe(&def) {
50            let readiness =
51                super::probe_openai_compatible_model(&provider, &model_id, &api_key).await;
52            if readiness.category == "model_missing" || readiness.category == "invalid_url" {
53                return Err(readiness.message);
54            }
55        }
56    }
57
58    let opts = LlmCallOptions {
59        provider: provider.clone(),
60        model: model_id.clone(),
61        api_key,
62        api_mode: LlmApiMode::ChatCompletions,
63        route_policy: LlmRoutePolicy::Manual,
64        fallback_chain: Vec::new(),
65        route_fallbacks: Vec::new(),
66        routing_decision: None,
67        routing_policy: None,
68        session_id: None,
69        reminders: None,
70        reminder_lifecycle: Vec::new(),
71        messages: vec![serde_json::json!({
72            "role": "user",
73            "content": options.prompt,
74        })],
75        system: None,
76        transcript_summary: None,
77        max_tokens: SMOKE_TEST_MAX_TOKENS,
78        temperature: None,
79        top_p: None,
80        top_k: None,
81        logprobs: false,
82        top_logprobs: None,
83        stop: None,
84        seed: None,
85        frequency_penalty: None,
86        presence_penalty: None,
87        output_format: OutputFormat::Text,
88        response_format: None,
89        json_schema: None,
90        output_schema: None,
91        output_validation: None,
92        schema_stream_abort: false,
93        thinking: ThinkingConfig::Disabled,
94        anthropic_beta_features: Vec::new(),
95        vision: false,
96        tools: None,
97        native_tools: None,
98        provider_tools: Vec::new(),
99        tool_choice: None,
100        tool_search: None,
101        cache: false,
102        timeout: None,
103        idle_timeout: None,
104        stream: true,
105        provider_overrides: None,
106        previous_response_id: None,
107        store: None,
108        background: None,
109        truncation: None,
110        compact: None,
111        include: None,
112        max_tool_calls: None,
113        budget: None,
114        prefill: None,
115        structural_experiment: None,
116        applied_structural_experiment: None,
117    };
118
119    let (delta_tx, mut delta_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
120    let started = Instant::now();
121    let first_delta = tokio::spawn(async move { delta_rx.recv().await.map(|_| started.elapsed()) });
122    let result = vm_call_llm_full_streaming(&opts, delta_tx)
123        .await
124        .map_err(vm_error_message);
125    let latency_ms = duration_ms(started.elapsed());
126    let first_token_ms = first_delta.await.ok().flatten().map(duration_ms);
127    let result = result?;
128
129    Ok(ModelSmokeTestResult {
130        model_id: result.model.clone(),
131        provider: result.provider.clone(),
132        latency_ms,
133        first_token_ms,
134        input_tokens: result.input_tokens,
135        output_tokens: result.output_tokens,
136        estimated_cost_usd: super::calculate_cost_for_provider(
137            &result.provider,
138            &result.model,
139            result.input_tokens,
140            result.output_tokens,
141        ),
142    })
143}
144
145fn duration_ms(duration: std::time::Duration) -> u64 {
146    u64::try_from(duration.as_millis()).unwrap_or(u64::MAX)
147}
148
149fn vm_error_message(error: VmError) -> String {
150    match error {
151        VmError::CategorizedError { message, .. } => message,
152        VmError::Thrown(VmValue::String(message)) => message.to_string(),
153        VmError::Thrown(VmValue::Dict(dict)) => dict
154            .get("message")
155            .map(VmValue::display)
156            .unwrap_or_else(|| VmError::Thrown(VmValue::Dict(dict)).to_string()),
157        other => other.to_string(),
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::{run_model_smoke_test, ModelSmokeTestOptions};
164
165    #[tokio::test]
166    async fn mock_provider_smoke_test_reports_timing_tokens_and_cost() {
167        crate::llm::reset_llm_state();
168        let result = run_model_smoke_test(ModelSmokeTestOptions {
169            model: "mock".to_string(),
170            provider: Some("mock".to_string()),
171            prompt: "ping".to_string(),
172        })
173        .await
174        .expect("mock provider smoke test should not require network");
175
176        assert_eq!(result.model_id, "mock");
177        assert_eq!(result.provider, "mock");
178        assert_eq!(result.input_tokens, 4);
179        assert_eq!(result.output_tokens, 30);
180        assert_eq!(result.estimated_cost_usd, 0.0);
181        assert!(result.first_token_ms.is_some());
182    }
183}