Skip to main content

harn_vm/llm/
model_test.rs

1use std::time::Instant;
2
3use serde::Serialize;
4
5use super::api::{
6    vm_call_llm_full_streaming, LlmApiMode, LlmCallOptions, LlmRoutePolicy, OutputFormat,
7    ThinkingConfig,
8};
9use crate::value::{VmError, VmValue};
10
11const SMOKE_TEST_MAX_TOKENS: i64 = 32;
12
13#[derive(Clone, Debug, PartialEq, Eq)]
14pub struct ModelSmokeTestOptions {
15    pub model: String,
16    pub provider: Option<String>,
17    pub prompt: String,
18}
19
20#[derive(Clone, Debug, PartialEq, Serialize)]
21pub struct ModelSmokeTestResult {
22    pub model_id: String,
23    pub provider: String,
24    pub latency_ms: u64,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub first_token_ms: Option<u64>,
27    pub input_tokens: i64,
28    pub output_tokens: i64,
29    pub estimated_cost_usd: f64,
30}
31
32pub async fn run_model_smoke_test(
33    options: ModelSmokeTestOptions,
34) -> Result<ModelSmokeTestResult, String> {
35    super::provider::register_default_providers();
36
37    let resolved = crate::llm_config::resolve_model_info(&options.model);
38    let model_id = resolved.id;
39    let provider = options
40        .provider
41        .as_deref()
42        .map(str::trim)
43        .filter(|provider| !provider.is_empty())
44        .map(str::to_string)
45        .unwrap_or(resolved.provider);
46    let api_key = super::helpers::resolve_api_key(&provider).map_err(vm_error_message)?;
47
48    if let Some(def) = crate::llm_config::provider_config(&provider) {
49        if super::supports_model_readiness_probe(&def) {
50            let readiness = super::readiness::probe_provider_readiness_with_options(
51                &provider,
52                super::readiness::ProviderReadinessOptions {
53                    requested_model: Some(&model_id),
54                    base_url_override: None,
55                    api_key_override: Some(&api_key),
56                },
57            )
58            .await;
59            if matches!(
60                readiness.status,
61                super::readiness::ReadinessStatus::ModelMissing
62                    | super::readiness::ReadinessStatus::InvalidUrl
63            ) {
64                return Err(readiness.message);
65            }
66        }
67    }
68
69    let opts = LlmCallOptions {
70        provider: provider.clone(),
71        model: model_id.clone(),
72        api_key,
73        api_mode: LlmApiMode::ChatCompletions,
74        route_policy: LlmRoutePolicy::Manual,
75        fallback_chain: Vec::new(),
76        route_fallbacks: Vec::new(),
77        routing_decision: None,
78        routing_policy: None,
79        region: None,
80        session_id: None,
81        reminders: None,
82        reminder_lifecycle: Vec::new(),
83        messages: vec![serde_json::json!({
84            "role": "user",
85            "content": options.prompt,
86        })],
87        system: None,
88        transcript_summary: None,
89        max_tokens: SMOKE_TEST_MAX_TOKENS,
90        temperature: None,
91        top_p: None,
92        top_k: None,
93        logprobs: false,
94        top_logprobs: None,
95        stop: None,
96        seed: None,
97        frequency_penalty: None,
98        presence_penalty: None,
99        fast: false,
100        output_format: OutputFormat::Text,
101        response_format: None,
102        json_schema: None,
103        output_schema: None,
104        output_validation: None,
105        schema_stream_abort: false,
106        thinking: ThinkingConfig::Disabled,
107        anthropic_beta_features: Vec::new(),
108        vision: false,
109        tools: None,
110        native_tools: None,
111        provider_tools: Vec::new(),
112        tool_choice: None,
113        tool_search: None,
114        cache: false,
115        timeout: None,
116        idle_timeout: None,
117        stream: true,
118        provider_overrides: None,
119        previous_response_id: None,
120        store: None,
121        background: None,
122        truncation: None,
123        compact: None,
124        include: None,
125        max_tool_calls: None,
126        budget: None,
127        prefill: None,
128        structural_experiment: None,
129        applied_structural_experiment: None,
130    };
131
132    let (delta_tx, mut delta_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
133    let started = Instant::now();
134    let first_delta = tokio::spawn(async move { delta_rx.recv().await.map(|_| started.elapsed()) });
135    let result = vm_call_llm_full_streaming(&opts, delta_tx)
136        .await
137        .map_err(vm_error_message);
138    let latency_ms = duration_ms(started.elapsed());
139    let first_token_ms = first_delta.await.ok().flatten().map(duration_ms);
140    let result = result?;
141
142    Ok(ModelSmokeTestResult {
143        model_id: result.model.clone(),
144        provider: result.provider.clone(),
145        latency_ms,
146        first_token_ms,
147        input_tokens: result.input_tokens,
148        output_tokens: result.output_tokens,
149        estimated_cost_usd: super::calculate_cost_for_provider(
150            &result.provider,
151            &result.model,
152            result.input_tokens,
153            result.output_tokens,
154        ),
155    })
156}
157
158fn duration_ms(duration: std::time::Duration) -> u64 {
159    u64::try_from(duration.as_millis()).unwrap_or(u64::MAX)
160}
161
162fn vm_error_message(error: VmError) -> String {
163    match error {
164        VmError::CategorizedError { message, .. } => message,
165        VmError::Thrown(VmValue::String(message)) => message.to_string(),
166        VmError::Thrown(VmValue::Dict(dict)) => dict
167            .get("message")
168            .map(VmValue::display)
169            .unwrap_or_else(|| VmError::Thrown(VmValue::Dict(dict)).to_string()),
170        other => other.to_string(),
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::{run_model_smoke_test, ModelSmokeTestOptions};
177
178    #[tokio::test]
179    async fn mock_provider_smoke_test_reports_timing_tokens_and_cost() {
180        crate::llm::reset_llm_state();
181        let result = run_model_smoke_test(ModelSmokeTestOptions {
182            model: "mock".to_string(),
183            provider: Some("mock".to_string()),
184            prompt: "ping".to_string(),
185        })
186        .await
187        .expect("mock provider smoke test should not require network");
188
189        assert_eq!(result.model_id, "mock");
190        assert_eq!(result.provider, "mock");
191        assert_eq!(result.input_tokens, 4);
192        assert_eq!(result.output_tokens, 30);
193        assert_eq!(result.estimated_cost_usd, 0.0);
194        assert!(result.first_token_ms.is_some());
195    }
196}