1use crate::util::rand_f64;
23use crate::HttpCaller;
24use crate::Router;
25use hyperinfer_core::{types::Provider, ChatRequest, Config};
26use std::sync::{Arc, OnceLock};
27use tokio::sync::{RwLock, Semaphore};
28use tracing::warn;
29
30const MIRROR_CONCURRENCY_LIMIT: usize = 100;
32
33fn mirror_semaphore() -> &'static Arc<Semaphore> {
36 static SEM: OnceLock<Arc<Semaphore>> = OnceLock::new();
37 SEM.get_or_init(|| Arc::new(Semaphore::new(MIRROR_CONCURRENCY_LIMIT)))
38}
39
40#[derive(Debug, Clone)]
42pub struct MirrorConfig {
43 pub model: String,
45 pub sample_rate: f64,
48}
49
50impl MirrorConfig {
51 pub fn new(model: String, sample_rate: f64) -> Self {
52 Self {
53 model,
54 sample_rate: sample_rate.clamp(0.0, 1.0),
55 }
56 }
57}
58
59pub type MirrorHandle = Arc<RwLock<Option<MirrorConfig>>>;
61
62pub fn maybe_mirror(
67 mirror_handle: MirrorHandle,
68 http_caller: Arc<HttpCaller>,
69 router: Arc<Router>,
70 config_snapshot: Arc<Config>,
71 _key: String,
72 mut request: ChatRequest,
73) {
74 let mirror_cfg = match mirror_handle.try_read() {
76 Ok(guard) => match guard.as_ref() {
77 Some(cfg) if cfg.sample_rate > 0.0 => cfg.clone(),
78 _ => return,
79 },
80 Err(_) => return,
81 };
82
83 if mirror_cfg.sample_rate < 1.0 {
85 let roll: f64 = rand_f64();
86 if roll > mirror_cfg.sample_rate {
87 tracing::debug!(
88 "Mirror skipped (sample_rate={:.2}, roll={:.2})",
89 mirror_cfg.sample_rate,
90 roll
91 );
92 return;
93 }
94 }
95
96 request.model = mirror_cfg.model.clone();
98
99 let resolved = router.resolve(&request.model, &config_snapshot);
101 let (model, provider) = match resolved {
102 Some(r) => r,
103 None => {
104 warn!(
105 "Mirror: could not resolve model '{}', skipping",
106 request.model
107 );
108 return;
109 }
110 };
111
112 let api_key = match config_snapshot.api_keys.get(&provider.to_string()) {
114 Some(k) => k.clone(),
115 None => {
116 warn!("Mirror: no API key for provider {:?}", provider);
117 return;
118 }
119 };
120
121 match provider {
122 Provider::OpenAI | Provider::Anthropic => {}
123 _ => {
124 warn!("Mirror: unsupported provider {:?}", provider);
125 return;
126 }
127 }
128
129 let permit = match mirror_semaphore().clone().try_acquire_owned() {
131 Ok(p) => p,
132 Err(_) => {
133 tracing::debug!("Mirror skipped: concurrency limit reached");
134 return;
135 }
136 };
137
138 tokio::spawn(async move {
139 let _permit = permit;
140
141 let result = match provider {
142 Provider::OpenAI => http_caller.call_openai(&model, &api_key, &request).await,
143 Provider::Anthropic => http_caller.call_anthropic(&model, &api_key, &request).await,
144 _ => unreachable!(),
145 };
146
147 match result {
148 Ok(resp) => {
149 let content_len = resp
150 .choices
151 .first()
152 .map(|c| c.message.content.len())
153 .unwrap_or(0);
154 tracing::debug!(
155 mirror_model = %model,
156 input_tokens = resp.usage.input_tokens,
157 output_tokens = resp.usage.output_tokens,
158 content_len,
159 "Mirror response received",
160 );
161 }
162 Err(e) => {
163 warn!("Mirror request failed: {:?}", e);
164 }
165 }
166 });
167}
168
169#[cfg(test)]
172mod tests {
173 use super::*;
174 use hyperinfer_core::types::Config;
175 use std::collections::HashMap;
176
177 fn empty_config() -> Config {
178 Config {
179 api_keys: HashMap::new(),
180 routing_rules: vec![],
181 quotas: HashMap::new(),
182 model_aliases: HashMap::new(),
183 default_provider: None,
184 }
185 }
186
187 #[test]
188 fn test_mirror_config_clone() {
189 let cfg = MirrorConfig {
190 model: "gpt-4o".to_string(),
191 sample_rate: 0.5,
192 };
193 let cloned = cfg.clone();
194 assert_eq!(cloned.model, "gpt-4o");
195 assert!((cloned.sample_rate - 0.5).abs() < 1e-9);
196 }
197
198 #[tokio::test]
199 async fn test_maybe_mirror_disabled_no_panic() {
200 let handle: MirrorHandle = Arc::new(RwLock::new(Some(MirrorConfig {
202 model: "gpt-4o".to_string(),
203 sample_rate: 0.0,
204 })));
205 let http = Arc::new(HttpCaller::new().unwrap());
206 let router = Arc::new(Router::new(vec![]));
207 let config = Arc::new(empty_config());
208
209 let request = hyperinfer_core::ChatRequest {
210 model: "gpt-4".to_string(),
211 messages: vec![hyperinfer_core::types::ChatMessage {
212 role: hyperinfer_core::types::MessageRole::User,
213 content: "hello".to_string(),
214 }],
215 max_tokens: Some(10),
216 temperature: None,
217 stream: None,
218 stop: None,
219 };
220
221 maybe_mirror(handle, http, router, config, "key".to_string(), request);
222 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
224 }
225
226 #[tokio::test]
227 async fn test_maybe_mirror_none_config_no_panic() {
228 let handle: MirrorHandle = Arc::new(RwLock::new(None));
230 let http = Arc::new(HttpCaller::new().unwrap());
231 let router = Arc::new(Router::new(vec![]));
232 let config = Arc::new(empty_config());
233
234 let request = hyperinfer_core::ChatRequest {
235 model: "gpt-4".to_string(),
236 messages: vec![hyperinfer_core::types::ChatMessage {
237 role: hyperinfer_core::types::MessageRole::User,
238 content: "hello".to_string(),
239 }],
240 max_tokens: Some(10),
241 temperature: None,
242 stream: None,
243 stop: None,
244 };
245
246 maybe_mirror(handle, http, router, config, "key".to_string(), request);
247 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
248 }
249
250 #[tokio::test]
251 async fn test_maybe_mirror_unresolvable_model_no_panic() {
252 let handle: MirrorHandle = Arc::new(RwLock::new(Some(MirrorConfig {
254 model: "unknown-llm-xyz".to_string(),
255 sample_rate: 1.0,
256 })));
257 let http = Arc::new(HttpCaller::new().unwrap());
258 let router = Arc::new(Router::new(vec![]));
259 let config = Arc::new(empty_config());
260
261 let request = hyperinfer_core::ChatRequest {
262 model: "gpt-4".to_string(),
263 messages: vec![hyperinfer_core::types::ChatMessage {
264 role: hyperinfer_core::types::MessageRole::User,
265 content: "hello".to_string(),
266 }],
267 max_tokens: Some(10),
268 temperature: None,
269 stream: None,
270 stop: None,
271 };
272
273 maybe_mirror(handle, http, router, config, "key".to_string(), request);
274 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
275 }
276
277 static SERIALIZE_MIRROR_TEST: OnceLock<tokio::sync::Mutex<()>> = OnceLock::new();
278
279 fn get_serialize_mutex() -> &'static tokio::sync::Mutex<()> {
280 SERIALIZE_MIRROR_TEST.get_or_init(|| tokio::sync::Mutex::new(()))
281 }
282
283 #[tokio::test]
284 async fn test_maybe_mirror_concurrency_limit_no_panic() {
285 let _guard = get_serialize_mutex().lock().await;
286
287 let sem = mirror_semaphore();
288 let mut permits: Vec<tokio::sync::SemaphorePermit<'_>> =
289 Vec::with_capacity(MIRROR_CONCURRENCY_LIMIT);
290 for _ in 0..MIRROR_CONCURRENCY_LIMIT {
291 let permit = sem.acquire().await.expect("should acquire permit");
292 permits.push(permit);
293 }
294
295 assert_eq!(sem.available_permits(), 0);
296
297 let handle: MirrorHandle = Arc::new(RwLock::new(Some(MirrorConfig {
298 model: "gpt-4o".to_string(),
299 sample_rate: 1.0,
300 })));
301 let http = Arc::new(HttpCaller::new().unwrap());
302 let router = Arc::new(Router::new(vec![]));
303 let config = Arc::new(empty_config());
304
305 let request = hyperinfer_core::ChatRequest {
306 model: "gpt-4".to_string(),
307 messages: vec![hyperinfer_core::types::ChatMessage {
308 role: hyperinfer_core::types::MessageRole::User,
309 content: "hello".to_string(),
310 }],
311 max_tokens: Some(10),
312 temperature: None,
313 stream: None,
314 stop: None,
315 };
316
317 maybe_mirror(handle, http, router, config, "key".to_string(), request);
318
319 drop(permits);
320
321 assert_eq!(
322 sem.available_permits(),
323 MIRROR_CONCURRENCY_LIMIT,
324 "maybe_mirror should not have acquired a permit when at capacity"
325 );
326 }
327}