1use std::time::Instant;
2
3use serde::Serialize;
4
5use super::api::{
6 vm_call_llm_full_streaming, LlmApiMode, LlmCallOptions, LlmRoutePolicy, OutputFormat,
7 ThinkingConfig,
8};
9use crate::value::{VmError, VmValue};
10
11const SMOKE_TEST_MAX_TOKENS: i64 = 32;
12
13#[derive(Clone, Debug, PartialEq, Eq)]
14pub struct ModelSmokeTestOptions {
15 pub model: String,
16 pub provider: Option<String>,
17 pub prompt: String,
18}
19
20#[derive(Clone, Debug, PartialEq, Serialize)]
21pub struct ModelSmokeTestResult {
22 pub model_id: String,
23 pub provider: String,
24 pub latency_ms: u64,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 pub first_token_ms: Option<u64>,
27 pub input_tokens: i64,
28 pub output_tokens: i64,
29 pub estimated_cost_usd: f64,
30}
31
32pub async fn run_model_smoke_test(
33 options: ModelSmokeTestOptions,
34) -> Result<ModelSmokeTestResult, String> {
35 super::provider::register_default_providers();
36
37 let resolved = crate::llm_config::resolve_model_info(&options.model);
38 let model_id = resolved.id;
39 let provider = options
40 .provider
41 .as_deref()
42 .map(str::trim)
43 .filter(|provider| !provider.is_empty())
44 .map(str::to_string)
45 .unwrap_or(resolved.provider);
46 let api_key = super::helpers::resolve_api_key(&provider).map_err(vm_error_message)?;
47
48 if let Some(def) = crate::llm_config::provider_config(&provider) {
49 if super::supports_model_readiness_probe(&def) {
50 let readiness = super::readiness::probe_provider_readiness_with_options(
51 &provider,
52 super::readiness::ProviderReadinessOptions {
53 requested_model: Some(&model_id),
54 base_url_override: None,
55 api_key_override: Some(&api_key),
56 },
57 )
58 .await;
59 if matches!(
60 readiness.status,
61 super::readiness::ReadinessStatus::ModelMissing
62 | super::readiness::ReadinessStatus::InvalidUrl
63 ) {
64 return Err(readiness.message);
65 }
66 }
67 }
68
69 let opts = LlmCallOptions {
70 provider: provider.clone(),
71 model: model_id.clone(),
72 api_key,
73 api_mode: LlmApiMode::ChatCompletions,
74 route_policy: LlmRoutePolicy::Manual,
75 fallback_chain: Vec::new(),
76 route_fallbacks: Vec::new(),
77 routing_decision: None,
78 routing_policy: None,
79 region: None,
80 session_id: None,
81 dispatch_provenance: None,
82 reminders: None,
83 reminder_lifecycle: Vec::new(),
84 messages: vec![serde_json::json!({
85 "role": "user",
86 "content": options.prompt,
87 })],
88 system: None,
89 transcript_summary: None,
90 max_tokens: SMOKE_TEST_MAX_TOKENS,
91 temperature: None,
92 top_p: None,
93 top_k: None,
94 logprobs: false,
95 top_logprobs: None,
96 stop: None,
97 seed: None,
98 frequency_penalty: None,
99 presence_penalty: None,
100 fast: false,
101 output_format: OutputFormat::Text,
102 response_format: None,
103 json_schema: None,
104 output_schema: None,
105 output_validation: None,
106 schema_stream_abort: false,
107 thinking: ThinkingConfig::Disabled,
108 anthropic_beta_features: Vec::new(),
109 vision: false,
110 tools: None,
111 native_tools: None,
112 provider_tools: Vec::new(),
113 tool_choice: None,
114 tool_search: None,
115 cache: false,
116 timeout: None,
117 idle_timeout: None,
118 stream: true,
119 provider_overrides: None,
120 previous_response_id: None,
121 store: None,
122 background: None,
123 truncation: None,
124 compact: None,
125 include: None,
126 max_tool_calls: None,
127 budget: None,
128 prefill: None,
129 structural_experiment: None,
130 applied_structural_experiment: None,
131 };
132
133 let (delta_tx, mut delta_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
134 let started = Instant::now();
135 let first_delta = tokio::spawn(async move { delta_rx.recv().await.map(|_| started.elapsed()) });
136 let result = vm_call_llm_full_streaming(&opts, delta_tx)
137 .await
138 .map_err(vm_error_message);
139 let latency_ms = duration_ms(started.elapsed());
140 let first_token_ms = first_delta.await.ok().flatten().map(duration_ms);
141 let result = result?;
142
143 Ok(ModelSmokeTestResult {
144 model_id: result.model.clone(),
145 provider: result.provider.clone(),
146 latency_ms,
147 first_token_ms,
148 input_tokens: result.input_tokens,
149 output_tokens: result.output_tokens,
150 estimated_cost_usd: super::calculate_cost_for_provider(
151 &result.provider,
152 &result.model,
153 result.input_tokens,
154 result.output_tokens,
155 ),
156 })
157}
158
159fn duration_ms(duration: std::time::Duration) -> u64 {
160 u64::try_from(duration.as_millis()).unwrap_or(u64::MAX)
161}
162
163fn vm_error_message(error: VmError) -> String {
164 match error {
165 VmError::CategorizedError { message, .. } => message,
166 VmError::Thrown(VmValue::String(message)) => message.to_string(),
167 VmError::Thrown(VmValue::Dict(dict)) => dict
168 .get("message")
169 .map(VmValue::display)
170 .unwrap_or_else(|| VmError::Thrown(VmValue::Dict(dict)).to_string()),
171 other => other.to_string(),
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::{run_model_smoke_test, ModelSmokeTestOptions};
178
179 #[tokio::test]
180 async fn mock_provider_smoke_test_reports_timing_tokens_and_cost() {
181 crate::llm::reset_llm_state();
182 let result = run_model_smoke_test(ModelSmokeTestOptions {
183 model: "mock".to_string(),
184 provider: Some("mock".to_string()),
185 prompt: "ping".to_string(),
186 })
187 .await
188 .expect("mock provider smoke test should not require network");
189
190 assert_eq!(result.model_id, "mock");
191 assert_eq!(result.provider, "mock");
192 assert_eq!(result.input_tokens, 4);
193 assert_eq!(result.output_tokens, 30);
194 assert_eq!(result.estimated_cost_usd, 0.0);
195 assert!(result.first_token_ms.is_some());
196 }
197}