1use std::time::Instant;
2
3use serde::Serialize;
4
5use super::api::{
6 vm_call_llm_full_streaming, LlmCallOptions, LlmRoutePolicy, OutputFormat, ThinkingConfig,
7};
8use crate::value::{VmError, VmValue};
9
10const SMOKE_TEST_MAX_TOKENS: i64 = 32;
11
12#[derive(Clone, Debug, PartialEq, Eq)]
13pub struct ModelSmokeTestOptions {
14 pub model: String,
15 pub provider: Option<String>,
16 pub prompt: String,
17}
18
19#[derive(Clone, Debug, PartialEq, Serialize)]
20pub struct ModelSmokeTestResult {
21 pub model_id: String,
22 pub provider: String,
23 pub latency_ms: u64,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub first_token_ms: Option<u64>,
26 pub input_tokens: i64,
27 pub output_tokens: i64,
28 pub estimated_cost_usd: f64,
29}
30
31pub async fn run_model_smoke_test(
32 options: ModelSmokeTestOptions,
33) -> Result<ModelSmokeTestResult, String> {
34 super::provider::register_default_providers();
35
36 let resolved = crate::llm_config::resolve_model_info(&options.model);
37 let model_id = resolved.id;
38 let provider = options
39 .provider
40 .as_deref()
41 .map(str::trim)
42 .filter(|provider| !provider.is_empty())
43 .map(str::to_string)
44 .unwrap_or(resolved.provider);
45 let api_key = super::helpers::resolve_api_key(&provider).map_err(vm_error_message)?;
46
47 if let Some(def) = crate::llm_config::provider_config(&provider) {
48 if super::supports_model_readiness_probe(&def) {
49 let readiness =
50 super::probe_openai_compatible_model(&provider, &model_id, &api_key).await;
51 if readiness.category == "model_missing" || readiness.category == "invalid_url" {
52 return Err(readiness.message);
53 }
54 }
55 }
56
57 let opts = LlmCallOptions {
58 provider: provider.clone(),
59 model: model_id.clone(),
60 api_key,
61 route_policy: LlmRoutePolicy::Manual,
62 fallback_chain: Vec::new(),
63 route_fallbacks: Vec::new(),
64 routing_decision: None,
65 routing_policy: None,
66 session_id: None,
67 messages: vec![serde_json::json!({
68 "role": "user",
69 "content": options.prompt,
70 })],
71 system: None,
72 transcript_summary: None,
73 max_tokens: SMOKE_TEST_MAX_TOKENS,
74 temperature: None,
75 top_p: None,
76 top_k: None,
77 logprobs: false,
78 top_logprobs: None,
79 stop: None,
80 seed: None,
81 frequency_penalty: None,
82 presence_penalty: None,
83 output_format: OutputFormat::Text,
84 response_format: None,
85 json_schema: None,
86 output_schema: None,
87 output_validation: None,
88 thinking: ThinkingConfig::Disabled,
89 anthropic_beta_features: Vec::new(),
90 vision: false,
91 tools: None,
92 native_tools: None,
93 tool_choice: None,
94 tool_search: None,
95 cache: false,
96 timeout: None,
97 idle_timeout: None,
98 stream: true,
99 provider_overrides: None,
100 budget: None,
101 prefill: None,
102 structural_experiment: None,
103 applied_structural_experiment: None,
104 };
105
106 let (delta_tx, mut delta_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
107 let started = Instant::now();
108 let first_delta = tokio::spawn(async move { delta_rx.recv().await.map(|_| started.elapsed()) });
109 let result = vm_call_llm_full_streaming(&opts, delta_tx)
110 .await
111 .map_err(vm_error_message);
112 let latency_ms = duration_ms(started.elapsed());
113 let first_token_ms = first_delta.await.ok().flatten().map(duration_ms);
114 let result = result?;
115
116 Ok(ModelSmokeTestResult {
117 model_id: result.model.clone(),
118 provider: result.provider.clone(),
119 latency_ms,
120 first_token_ms,
121 input_tokens: result.input_tokens,
122 output_tokens: result.output_tokens,
123 estimated_cost_usd: super::calculate_cost_for_provider(
124 &result.provider,
125 &result.model,
126 result.input_tokens,
127 result.output_tokens,
128 ),
129 })
130}
131
132fn duration_ms(duration: std::time::Duration) -> u64 {
133 u64::try_from(duration.as_millis()).unwrap_or(u64::MAX)
134}
135
136fn vm_error_message(error: VmError) -> String {
137 match error {
138 VmError::CategorizedError { message, .. } => message,
139 VmError::Thrown(VmValue::String(message)) => message.to_string(),
140 VmError::Thrown(VmValue::Dict(dict)) => dict
141 .get("message")
142 .map(VmValue::display)
143 .unwrap_or_else(|| VmError::Thrown(VmValue::Dict(dict)).to_string()),
144 other => other.to_string(),
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::{run_model_smoke_test, ModelSmokeTestOptions};
151
152 #[tokio::test]
153 async fn mock_provider_smoke_test_reports_timing_tokens_and_cost() {
154 crate::llm::reset_llm_state();
155 let result = run_model_smoke_test(ModelSmokeTestOptions {
156 model: "mock".to_string(),
157 provider: Some("mock".to_string()),
158 prompt: "ping".to_string(),
159 })
160 .await
161 .expect("mock provider smoke test should not require network");
162
163 assert_eq!(result.model_id, "mock");
164 assert_eq!(result.provider, "mock");
165 assert_eq!(result.input_tokens, 4);
166 assert_eq!(result.output_tokens, 30);
167 assert_eq!(result.estimated_cost_usd, 0.0);
168 assert!(result.first_token_ms.is_some());
169 }
170}