1use std::time::Instant;
2
3use serde::Serialize;
4
5use super::api::{
6 vm_call_llm_full_streaming, LlmCallOptions, LlmRoutePolicy, OutputFormat, ThinkingConfig,
7};
8use crate::value::{VmError, VmValue};
9
10const SMOKE_TEST_MAX_TOKENS: i64 = 32;
11
12#[derive(Clone, Debug, PartialEq, Eq)]
13pub struct ModelSmokeTestOptions {
14 pub model: String,
15 pub provider: Option<String>,
16 pub prompt: String,
17}
18
19#[derive(Clone, Debug, PartialEq, Serialize)]
20pub struct ModelSmokeTestResult {
21 pub model_id: String,
22 pub provider: String,
23 pub latency_ms: u64,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub first_token_ms: Option<u64>,
26 pub input_tokens: i64,
27 pub output_tokens: i64,
28 pub estimated_cost_usd: f64,
29}
30
31pub async fn run_model_smoke_test(
32 options: ModelSmokeTestOptions,
33) -> Result<ModelSmokeTestResult, String> {
34 super::provider::register_default_providers();
35
36 let resolved = crate::llm_config::resolve_model_info(&options.model);
37 let model_id = resolved.id;
38 let provider = options
39 .provider
40 .as_deref()
41 .map(str::trim)
42 .filter(|provider| !provider.is_empty())
43 .map(str::to_string)
44 .unwrap_or(resolved.provider);
45 let api_key = super::helpers::resolve_api_key(&provider).map_err(vm_error_message)?;
46
47 if let Some(def) = crate::llm_config::provider_config(&provider) {
48 if super::supports_model_readiness_probe(&def) {
49 let readiness =
50 super::probe_openai_compatible_model(&provider, &model_id, &api_key).await;
51 if readiness.category == "model_missing" || readiness.category == "invalid_url" {
52 return Err(readiness.message);
53 }
54 }
55 }
56
57 let opts = LlmCallOptions {
58 provider: provider.clone(),
59 model: model_id.clone(),
60 api_key,
61 route_policy: LlmRoutePolicy::Manual,
62 fallback_chain: Vec::new(),
63 route_fallbacks: Vec::new(),
64 routing_decision: None,
65 session_id: None,
66 messages: vec![serde_json::json!({
67 "role": "user",
68 "content": options.prompt,
69 })],
70 system: None,
71 transcript_summary: None,
72 max_tokens: SMOKE_TEST_MAX_TOKENS,
73 temperature: None,
74 top_p: None,
75 top_k: None,
76 logprobs: false,
77 top_logprobs: None,
78 stop: None,
79 seed: None,
80 frequency_penalty: None,
81 presence_penalty: None,
82 output_format: OutputFormat::Text,
83 response_format: None,
84 json_schema: None,
85 output_schema: None,
86 output_validation: None,
87 thinking: ThinkingConfig::Disabled,
88 anthropic_beta_features: Vec::new(),
89 vision: false,
90 tools: None,
91 native_tools: None,
92 tool_choice: None,
93 tool_search: None,
94 cache: false,
95 timeout: None,
96 idle_timeout: None,
97 stream: true,
98 provider_overrides: None,
99 budget: None,
100 prefill: None,
101 structural_experiment: None,
102 applied_structural_experiment: None,
103 };
104
105 let (delta_tx, mut delta_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
106 let started = Instant::now();
107 let first_delta = tokio::spawn(async move { delta_rx.recv().await.map(|_| started.elapsed()) });
108 let result = vm_call_llm_full_streaming(&opts, delta_tx)
109 .await
110 .map_err(vm_error_message);
111 let latency_ms = duration_ms(started.elapsed());
112 let first_token_ms = first_delta.await.ok().flatten().map(duration_ms);
113 let result = result?;
114
115 Ok(ModelSmokeTestResult {
116 model_id: result.model.clone(),
117 provider: result.provider.clone(),
118 latency_ms,
119 first_token_ms,
120 input_tokens: result.input_tokens,
121 output_tokens: result.output_tokens,
122 estimated_cost_usd: super::calculate_cost_for_provider(
123 &result.provider,
124 &result.model,
125 result.input_tokens,
126 result.output_tokens,
127 ),
128 })
129}
130
131fn duration_ms(duration: std::time::Duration) -> u64 {
132 u64::try_from(duration.as_millis()).unwrap_or(u64::MAX)
133}
134
135fn vm_error_message(error: VmError) -> String {
136 match error {
137 VmError::CategorizedError { message, .. } => message,
138 VmError::Thrown(VmValue::String(message)) => message.to_string(),
139 VmError::Thrown(VmValue::Dict(dict)) => dict
140 .get("message")
141 .map(VmValue::display)
142 .unwrap_or_else(|| VmError::Thrown(VmValue::Dict(dict)).to_string()),
143 other => other.to_string(),
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::{run_model_smoke_test, ModelSmokeTestOptions};
150
151 #[tokio::test]
152 async fn mock_provider_smoke_test_reports_timing_tokens_and_cost() {
153 crate::llm::reset_llm_state();
154 let result = run_model_smoke_test(ModelSmokeTestOptions {
155 model: "mock".to_string(),
156 provider: Some("mock".to_string()),
157 prompt: "ping".to_string(),
158 })
159 .await
160 .expect("mock provider smoke test should not require network");
161
162 assert_eq!(result.model_id, "mock");
163 assert_eq!(result.provider, "mock");
164 assert_eq!(result.input_tokens, 4);
165 assert_eq!(result.output_tokens, 30);
166 assert_eq!(result.estimated_cost_usd, 0.0);
167 assert!(result.first_token_ms.is_some());
168 }
169}