Skip to main content

harn_vm/llm/
model_test.rs

1use std::time::Instant;
2
3use serde::Serialize;
4
5use super::api::{
6    vm_call_llm_full_streaming, LlmApiMode, LlmCallOptions, LlmRoutePolicy, OutputFormat,
7    ThinkingConfig,
8};
9use crate::value::{VmError, VmValue};
10
11const SMOKE_TEST_MAX_TOKENS: i64 = 32;
12
13#[derive(Clone, Debug, PartialEq, Eq)]
14pub struct ModelSmokeTestOptions {
15    pub model: String,
16    pub provider: Option<String>,
17    pub prompt: String,
18}
19
20#[derive(Clone, Debug, PartialEq, Serialize)]
21pub struct ModelSmokeTestResult {
22    pub model_id: String,
23    pub provider: String,
24    pub latency_ms: u64,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub first_token_ms: Option<u64>,
27    pub input_tokens: i64,
28    pub output_tokens: i64,
29    pub estimated_cost_usd: f64,
30}
31
32pub async fn run_model_smoke_test(
33    options: ModelSmokeTestOptions,
34) -> Result<ModelSmokeTestResult, String> {
35    super::provider::register_default_providers();
36
37    let resolved = crate::llm_config::resolve_model_info(&options.model);
38    let model_id = resolved.id;
39    let provider = options
40        .provider
41        .as_deref()
42        .map(str::trim)
43        .filter(|provider| !provider.is_empty())
44        .map(str::to_string)
45        .unwrap_or(resolved.provider);
46    let api_key = super::helpers::resolve_api_key(&provider).map_err(vm_error_message)?;
47
48    if let Some(def) = crate::llm_config::provider_config(&provider) {
49        if super::supports_model_readiness_probe(&def) {
50            let readiness = super::readiness::probe_provider_readiness_with_options(
51                &provider,
52                super::readiness::ProviderReadinessOptions {
53                    requested_model: Some(&model_id),
54                    base_url_override: None,
55                    api_key_override: Some(&api_key),
56                },
57            )
58            .await;
59            if matches!(
60                readiness.status,
61                super::readiness::ReadinessStatus::ModelMissing
62                    | super::readiness::ReadinessStatus::InvalidUrl
63            ) {
64                return Err(readiness.message);
65            }
66        }
67    }
68
69    let opts = LlmCallOptions {
70        provider: provider.clone(),
71        model: model_id.clone(),
72        api_key,
73        api_mode: LlmApiMode::ChatCompletions,
74        route_policy: LlmRoutePolicy::Manual,
75        fallback_chain: Vec::new(),
76        route_fallbacks: Vec::new(),
77        routing_decision: None,
78        routing_policy: None,
79        region: None,
80        session_id: None,
81        dispatch_provenance: None,
82        reminders: None,
83        reminder_lifecycle: Vec::new(),
84        messages: vec![serde_json::json!({
85            "role": "user",
86            "content": options.prompt,
87        })],
88        system: None,
89        transcript_summary: None,
90        max_tokens: SMOKE_TEST_MAX_TOKENS,
91        temperature: None,
92        top_p: None,
93        top_k: None,
94        logprobs: false,
95        top_logprobs: None,
96        stop: None,
97        seed: None,
98        frequency_penalty: None,
99        presence_penalty: None,
100        fast: false,
101        output_format: OutputFormat::Text,
102        response_format: None,
103        json_schema: None,
104        output_schema: None,
105        output_validation: None,
106        schema_stream_abort: false,
107        thinking: ThinkingConfig::Disabled,
108        anthropic_beta_features: Vec::new(),
109        vision: false,
110        tools: None,
111        native_tools: None,
112        provider_tools: Vec::new(),
113        tool_choice: None,
114        tool_search: None,
115        cache: false,
116        timeout: None,
117        idle_timeout: None,
118        stream: true,
119        provider_overrides: None,
120        previous_response_id: None,
121        store: None,
122        background: None,
123        truncation: None,
124        compact: None,
125        include: None,
126        max_tool_calls: None,
127        budget: None,
128        prefill: None,
129        structural_experiment: None,
130        applied_structural_experiment: None,
131    };
132
133    let (delta_tx, mut delta_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
134    let started = Instant::now();
135    let first_delta = tokio::spawn(async move { delta_rx.recv().await.map(|_| started.elapsed()) });
136    let result = vm_call_llm_full_streaming(&opts, delta_tx)
137        .await
138        .map_err(vm_error_message);
139    let latency_ms = duration_ms(started.elapsed());
140    let first_token_ms = first_delta.await.ok().flatten().map(duration_ms);
141    let result = result?;
142
143    Ok(ModelSmokeTestResult {
144        model_id: result.model.clone(),
145        provider: result.provider.clone(),
146        latency_ms,
147        first_token_ms,
148        input_tokens: result.input_tokens,
149        output_tokens: result.output_tokens,
150        estimated_cost_usd: super::calculate_cost_for_provider(
151            &result.provider,
152            &result.model,
153            result.input_tokens,
154            result.output_tokens,
155        ),
156    })
157}
158
159fn duration_ms(duration: std::time::Duration) -> u64 {
160    u64::try_from(duration.as_millis()).unwrap_or(u64::MAX)
161}
162
163fn vm_error_message(error: VmError) -> String {
164    match error {
165        VmError::CategorizedError { message, .. } => message,
166        VmError::Thrown(VmValue::String(message)) => message.to_string(),
167        VmError::Thrown(VmValue::Dict(dict)) => dict
168            .get("message")
169            .map(VmValue::display)
170            .unwrap_or_else(|| VmError::Thrown(VmValue::Dict(dict)).to_string()),
171        other => other.to_string(),
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::{run_model_smoke_test, ModelSmokeTestOptions};
178
179    #[tokio::test]
180    async fn mock_provider_smoke_test_reports_timing_tokens_and_cost() {
181        crate::llm::reset_llm_state();
182        let result = run_model_smoke_test(ModelSmokeTestOptions {
183            model: "mock".to_string(),
184            provider: Some("mock".to_string()),
185            prompt: "ping".to_string(),
186        })
187        .await
188        .expect("mock provider smoke test should not require network");
189
190        assert_eq!(result.model_id, "mock");
191        assert_eq!(result.provider, "mock");
192        assert_eq!(result.input_tokens, 4);
193        assert_eq!(result.output_tokens, 30);
194        assert_eq!(result.estimated_cost_usd, 0.0);
195        assert!(result.first_token_ms.is_some());
196    }
197}