use crate::util::rand_f64;
use crate::HttpCaller;
use crate::Router;
use hyperinfer_core::{types::Provider, ChatRequest, Config};
use std::sync::{Arc, OnceLock};
use tokio::sync::{RwLock, Semaphore};
use tracing::warn;
const MIRROR_CONCURRENCY_LIMIT: usize = 100;
fn mirror_semaphore() -> &'static Arc<Semaphore> {
static SEM: OnceLock<Arc<Semaphore>> = OnceLock::new();
SEM.get_or_init(|| Arc::new(Semaphore::new(MIRROR_CONCURRENCY_LIMIT)))
}
#[derive(Debug, Clone)]
pub struct MirrorConfig {
pub model: String,
pub sample_rate: f64,
}
impl MirrorConfig {
pub fn new(model: String, sample_rate: f64) -> Self {
Self {
model,
sample_rate: sample_rate.clamp(0.0, 1.0),
}
}
}
pub type MirrorHandle = Arc<RwLock<Option<MirrorConfig>>>;
pub fn maybe_mirror(
mirror_handle: MirrorHandle,
http_caller: Arc<HttpCaller>,
router: Arc<Router>,
config_snapshot: Arc<Config>,
_key: String,
mut request: ChatRequest,
) {
let mirror_cfg = match mirror_handle.try_read() {
Ok(guard) => match guard.as_ref() {
Some(cfg) if cfg.sample_rate > 0.0 => cfg.clone(),
_ => return,
},
Err(_) => return,
};
if mirror_cfg.sample_rate < 1.0 {
let roll: f64 = rand_f64();
if roll > mirror_cfg.sample_rate {
tracing::debug!(
"Mirror skipped (sample_rate={:.2}, roll={:.2})",
mirror_cfg.sample_rate,
roll
);
return;
}
}
request.model = mirror_cfg.model.clone();
let resolved = router.resolve(&request.model, &config_snapshot);
let (model, provider) = match resolved {
Some(r) => r,
None => {
warn!(
"Mirror: could not resolve model '{}', skipping",
request.model
);
return;
}
};
let api_key = match config_snapshot.api_keys.get(&provider.to_string()) {
Some(k) => k.clone(),
None => {
warn!("Mirror: no API key for provider {:?}", provider);
return;
}
};
match provider {
Provider::OpenAI | Provider::Anthropic => {}
_ => {
warn!("Mirror: unsupported provider {:?}", provider);
return;
}
}
let permit = match mirror_semaphore().clone().try_acquire_owned() {
Ok(p) => p,
Err(_) => {
tracing::debug!("Mirror skipped: concurrency limit reached");
return;
}
};
tokio::spawn(async move {
let _permit = permit;
let result = match provider {
Provider::OpenAI => http_caller.call_openai(&model, &api_key, &request).await,
Provider::Anthropic => http_caller.call_anthropic(&model, &api_key, &request).await,
_ => unreachable!(),
};
match result {
Ok(resp) => {
let content_len = resp
.choices
.first()
.map(|c| c.message.content.len())
.unwrap_or(0);
tracing::debug!(
mirror_model = %model,
input_tokens = resp.usage.input_tokens,
output_tokens = resp.usage.output_tokens,
content_len,
"Mirror response received",
);
}
Err(e) => {
warn!("Mirror request failed: {:?}", e);
}
}
});
}
#[cfg(test)]
mod tests {
use super::*;
use hyperinfer_core::types::Config;
use std::collections::HashMap;
fn empty_config() -> Config {
Config {
api_keys: HashMap::new(),
routing_rules: vec![],
quotas: HashMap::new(),
model_aliases: HashMap::new(),
default_provider: None,
}
}
#[test]
fn test_mirror_config_clone() {
let cfg = MirrorConfig {
model: "gpt-4o".to_string(),
sample_rate: 0.5,
};
let cloned = cfg.clone();
assert_eq!(cloned.model, "gpt-4o");
assert!((cloned.sample_rate - 0.5).abs() < 1e-9);
}
#[tokio::test]
async fn test_maybe_mirror_disabled_no_panic() {
let handle: MirrorHandle = Arc::new(RwLock::new(Some(MirrorConfig {
model: "gpt-4o".to_string(),
sample_rate: 0.0,
})));
let http = Arc::new(HttpCaller::new().unwrap());
let router = Arc::new(Router::new(vec![]));
let config = Arc::new(empty_config());
let request = hyperinfer_core::ChatRequest {
model: "gpt-4".to_string(),
messages: vec![hyperinfer_core::types::ChatMessage {
role: hyperinfer_core::types::MessageRole::User,
content: "hello".to_string(),
}],
max_tokens: Some(10),
temperature: None,
stream: None,
stop: None,
};
maybe_mirror(handle, http, router, config, "key".to_string(), request);
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
#[tokio::test]
async fn test_maybe_mirror_none_config_no_panic() {
let handle: MirrorHandle = Arc::new(RwLock::new(None));
let http = Arc::new(HttpCaller::new().unwrap());
let router = Arc::new(Router::new(vec![]));
let config = Arc::new(empty_config());
let request = hyperinfer_core::ChatRequest {
model: "gpt-4".to_string(),
messages: vec![hyperinfer_core::types::ChatMessage {
role: hyperinfer_core::types::MessageRole::User,
content: "hello".to_string(),
}],
max_tokens: Some(10),
temperature: None,
stream: None,
stop: None,
};
maybe_mirror(handle, http, router, config, "key".to_string(), request);
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
#[tokio::test]
async fn test_maybe_mirror_unresolvable_model_no_panic() {
let handle: MirrorHandle = Arc::new(RwLock::new(Some(MirrorConfig {
model: "unknown-llm-xyz".to_string(),
sample_rate: 1.0,
})));
let http = Arc::new(HttpCaller::new().unwrap());
let router = Arc::new(Router::new(vec![]));
let config = Arc::new(empty_config());
let request = hyperinfer_core::ChatRequest {
model: "gpt-4".to_string(),
messages: vec![hyperinfer_core::types::ChatMessage {
role: hyperinfer_core::types::MessageRole::User,
content: "hello".to_string(),
}],
max_tokens: Some(10),
temperature: None,
stream: None,
stop: None,
};
maybe_mirror(handle, http, router, config, "key".to_string(), request);
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
}
static SERIALIZE_MIRROR_TEST: OnceLock<tokio::sync::Mutex<()>> = OnceLock::new();
fn get_serialize_mutex() -> &'static tokio::sync::Mutex<()> {
SERIALIZE_MIRROR_TEST.get_or_init(|| tokio::sync::Mutex::new(()))
}
#[tokio::test]
async fn test_maybe_mirror_concurrency_limit_no_panic() {
let _guard = get_serialize_mutex().lock().await;
let sem = mirror_semaphore();
let mut permits: Vec<tokio::sync::SemaphorePermit<'_>> =
Vec::with_capacity(MIRROR_CONCURRENCY_LIMIT);
for _ in 0..MIRROR_CONCURRENCY_LIMIT {
let permit = sem.acquire().await.expect("should acquire permit");
permits.push(permit);
}
assert_eq!(sem.available_permits(), 0);
let handle: MirrorHandle = Arc::new(RwLock::new(Some(MirrorConfig {
model: "gpt-4o".to_string(),
sample_rate: 1.0,
})));
let http = Arc::new(HttpCaller::new().unwrap());
let router = Arc::new(Router::new(vec![]));
let config = Arc::new(empty_config());
let request = hyperinfer_core::ChatRequest {
model: "gpt-4".to_string(),
messages: vec![hyperinfer_core::types::ChatMessage {
role: hyperinfer_core::types::MessageRole::User,
content: "hello".to_string(),
}],
max_tokens: Some(10),
temperature: None,
stream: None,
stop: None,
};
maybe_mirror(handle, http, router, config, "key".to_string(), request);
drop(permits);
assert_eq!(
sem.available_permits(),
MIRROR_CONCURRENCY_LIMIT,
"maybe_mirror should not have acquired a permit when at capacity"
);
}
}