1pub mod cache;
4pub mod http_client;
5pub mod mirroring;
6pub mod router;
7pub mod router_engine;
8pub mod telemetry;
9pub mod telemetry_otlp;
10mod util;
11
12pub use cache::ExactMatchCache;
13pub use http_client::HttpCaller;
14pub use mirroring::{MirrorConfig, MirrorHandle};
15pub use router::Router;
16pub use router_engine::RouterEngine;
17pub use telemetry::Telemetry;
18pub use telemetry_otlp::{
19 init_langfuse_telemetry, init_telemetry, init_telemetry_with_headers, set_gen_ai_attributes,
20 set_gen_ai_response, set_gen_ai_usage, shutdown_telemetry,
21};
22
23use futures::{Stream, StreamExt};
24use hyperinfer_core::{
25 rate_limiting::RateLimiter, ChatChunk, ChatRequest, ChatResponse, Config, HyperInferError,
26 Provider,
27};
28use hyperinfer_providers::ProviderRegistry;
29use std::pin::Pin;
30use std::sync::{Arc, LazyLock};
31use std::task::{Context, Poll};
32
33static HTTP_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(reqwest::Client::new);
34use tokio::sync::RwLock;
35use tracing::Instrument as _;
36
37struct AccountedStream {
48 inner: Pin<Box<dyn Stream<Item = Result<ChatChunk, HyperInferError>> + Send>>,
49 telemetry: Telemetry,
50 rate_limiter: RateLimiter,
51 key: String,
52 model: String,
53 start: std::time::Instant,
54 input_tokens: u32,
56 output_tokens: u32,
57 accounted: bool,
59 span: tracing::Span,
61}
62
63impl AccountedStream {
64 fn account(&mut self) {
66 if self.accounted {
67 return;
68 }
69 self.accounted = true;
70
71 let elapsed = self.start.elapsed().as_millis() as u64;
72 let input_tokens = self.input_tokens;
73 let output_tokens = self.output_tokens;
74
75 let _enter = self.span.clone().entered();
76 crate::telemetry_otlp::set_gen_ai_usage(&self.span, input_tokens, output_tokens);
77
78 let telemetry = self.telemetry.clone();
80 let key = self.key.clone();
81 let model = self.model.clone();
82 tokio::spawn(async move {
83 if let Err(e) = telemetry
84 .record_with_tokens(&key, &model, input_tokens, output_tokens, elapsed)
85 .await
86 {
87 tracing::warn!(error = %e, "stream telemetry record failed");
88 }
89 });
90
91 let rate_limiter = self.rate_limiter.clone();
94 let key2 = self.key.clone();
95 let total = (input_tokens + output_tokens) as u64;
96 tokio::spawn(async move {
97 let _ = rate_limiter.record_usage(&key2, total).await;
98 });
99 }
100}
101
102impl Drop for AccountedStream {
103 fn drop(&mut self) {
104 self.account();
105 }
106}
107
108impl Stream for AccountedStream {
109 type Item = Result<ChatChunk, HyperInferError>;
110
111 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
112 let _enter = self.span.clone().entered();
115 match self.inner.as_mut().poll_next(cx) {
116 Poll::Ready(Some(Ok(chunk))) => {
117 if let Some(ref u) = chunk.usage {
119 self.input_tokens = u.input_tokens;
120 self.output_tokens = u.output_tokens;
121 }
122 if chunk.finish_reason.is_some() {
125 self.account();
126 }
127 Poll::Ready(Some(Ok(chunk)))
128 }
129 Poll::Ready(Some(Err(e))) => {
130 self.account();
131 Poll::Ready(Some(Err(e)))
132 }
133 Poll::Ready(None) => {
134 self.account();
135 Poll::Ready(None)
136 }
137 Poll::Pending => Poll::Pending,
138 }
139 }
140}
141
142pub struct HyperInferClient {
143 config: Arc<RwLock<Config>>,
144 http_caller: Arc<HttpCaller>,
145 router: Arc<Router>,
146 router_engine: Arc<RouterEngine>,
147 rate_limiter: RateLimiter,
148 telemetry: Telemetry,
149 cache: ExactMatchCache,
150 mirror: MirrorHandle,
151 provider_registry: Arc<RwLock<Arc<ProviderRegistry>>>,
152}
153
154impl HyperInferClient {
155 pub async fn new(redis_url: &str, config: Config) -> Result<Self, HyperInferError> {
156 let http_caller = Arc::new(HttpCaller::new().map_err(HyperInferError::Http)?);
157 let router = Arc::new(
158 Router::new(config.routing_rules.clone())
159 .with_aliases(config.model_aliases.clone())
160 .with_default_provider(config.default_provider.clone()),
161 );
162 let rate_limiter = RateLimiter::new(Some(redis_url))
163 .await
164 .map_err(|e| HyperInferError::Config(std::io::Error::other(e.to_string())))?;
165 let telemetry = Telemetry::new(redis_url)
166 .await
167 .map_err(|e| HyperInferError::Config(std::io::Error::other(e.to_string())))?;
168 let cache = ExactMatchCache::new(redis_url, "default").await;
169 let mirror: MirrorHandle = Arc::new(RwLock::new(None));
170 let config = Arc::new(RwLock::new(config));
171
172 let provider_registry_inner = Arc::new(ProviderRegistry::new());
173 hyperinfer_providers::init_default_registry(&provider_registry_inner);
174 let provider_registry = Arc::new(RwLock::new(provider_registry_inner));
175
176 Ok(Self {
177 config,
178 http_caller,
179 router,
180 router_engine: Arc::new(RouterEngine::new().await),
181 rate_limiter,
182 telemetry,
183 cache,
184 mirror,
185 provider_registry,
186 })
187 }
188
189 pub async fn set_mirror(&self, cfg: Option<MirrorConfig>) {
191 let mut guard = self.mirror.write().await;
192 *guard = cfg;
193 }
194
195 pub async fn inject_provider_registry(&self, external_registry: Arc<ProviderRegistry>) {
196 let mut guard = self.provider_registry.write().await;
197 *guard = external_registry;
198 }
199
200 pub async fn load_deployments(&self, deployments: Vec<hyperinfer_core::Deployment>) {
202 self.router_engine.load_deployments(deployments).await;
203 }
204
205 pub async fn subscribe_config_updates<F, Fut>(
209 &self,
210 redis_url: &str,
211 fetcher: F,
212 ) -> Result<(), HyperInferError>
213 where
214 F: Fn() -> Fut + Send + Sync + 'static,
215 Fut: std::future::Future<Output = Result<Vec<hyperinfer_core::Deployment>, String>> + Send,
216 {
217 let client = redis::Client::open(redis_url)
218 .map_err(|e| HyperInferError::Config(std::io::Error::other(e.to_string())))?;
219 let mut pubsub = client
220 .get_async_pubsub()
221 .await
222 .map_err(|e| HyperInferError::Config(std::io::Error::other(e.to_string())))?;
223
224 pubsub
225 .subscribe("hyperinfer:config_updates")
226 .await
227 .map_err(|e| HyperInferError::Config(std::io::Error::other(e.to_string())))?;
228
229 let engine = self.router_engine.clone();
230 let _handle = tokio::spawn(async move {
231 let mut stream = pubsub.on_message();
232 loop {
233 match stream.next().await {
234 Some(_msg) => {
235 tracing::info!(
236 "Received config update notification, re-fetching deployments"
237 );
238 match fetcher().await {
239 Ok(deployments) => {
240 engine.rebuild_pool(deployments).await;
241 tracing::info!("Rebuilt deployment pool after config update");
242 }
243 Err(e) => {
244 tracing::warn!(error = %e, "Failed to re-fetch deployments after config update");
245 }
246 }
247 }
248 None => {
249 tracing::info!("Pub/Sub stream ended");
250 break;
251 }
252 }
253 }
254 });
255
256 Ok(())
257 }
258
259 pub fn router_engine(&self) -> &Arc<RouterEngine> {
261 &self.router_engine
262 }
263
264 pub async fn chat(
265 &self,
266 key: &str,
267 request: ChatRequest,
268 ) -> Result<ChatResponse, HyperInferError> {
269 request.validate()?;
270
271 if let Some(cached) = self.cache.get(&request).await {
273 return Ok(cached);
274 }
275
276 let span = tracing::info_span!(
281 "gen_ai.chat",
282 gen_ai.operation.name = "chat",
283 gen_ai.request.model = %request.model,
284 );
285
286 async move {
287 let start = std::time::Instant::now();
288
289 let allowed = self.rate_limiter.is_allowed(key, 1).await;
291 if let Err(e) = allowed {
292 return Err(HyperInferError::RateLimit(e.to_string()));
293 }
294 if !allowed.unwrap() {
295 return Err(HyperInferError::RateLimit(
296 "Rate limit exceeded".to_string(),
297 ));
298 }
299
300 let deployment_result = self.router_engine.select_deployment(&request).await;
302 if let Ok(routing_result) = deployment_result {
303 let deployment = &routing_result.deployment;
304 let default_url = match &deployment.provider {
305 Provider::Anthropic => "https://api.anthropic.com/v1",
306 _ => "https://api.openai.com/v1",
307 };
308 let base_url = deployment.base_url.as_deref().unwrap_or(default_url);
309 let api_key = &deployment.api_key_ref;
310
311 let url = format!("{}/chat/completions", base_url.trim_end_matches('/'));
312
313 let mut headers = reqwest::header::HeaderMap::new();
314 headers.insert("content-type", "application/json".parse().unwrap());
315 if !api_key.is_empty() {
316 match &deployment.provider {
317 Provider::Anthropic => {
318 headers.insert("x-api-key", api_key.parse().unwrap());
319 headers.insert("anthropic-version", "2023-06-01".parse().unwrap());
320 }
321 _ => {
322 headers.insert(
323 "authorization",
324 format!("Bearer {}", api_key).parse().unwrap(),
325 );
326 }
327 }
328 }
329
330 match HTTP_CLIENT
331 .post(&url)
332 .headers(headers)
333 .json(&request)
334 .send()
335 .await
336 {
337 Ok(response) => {
338 let status = response.status();
339 if status.is_success() {
340 if let Ok(body) = response.json::<ChatResponse>().await {
341 let latency = start.elapsed().as_secs_f64() * 1000.0;
343 let tokens =
344 (body.usage.input_tokens + body.usage.output_tokens) as u64;
345 self.router_engine
346 .record_success(&deployment.id, latency, tokens)
347 .await;
348
349 let elapsed = start.elapsed().as_millis() as u64;
351 let telemetry = self.telemetry.clone();
352 let key_owned = key.to_string();
353 let model_owned = request.model.clone();
354 tokio::spawn(async move {
355 let _ = telemetry
356 .record_with_tokens(
357 &key_owned,
358 &model_owned,
359 body.usage.input_tokens,
360 body.usage.output_tokens,
361 elapsed,
362 )
363 .await;
364 });
365
366 return Ok(body);
367 }
368 }
369 self.router_engine.record_failure(&deployment.id).await;
371 }
372 Err(_) => {
373 self.router_engine.record_failure(&deployment.id).await;
375 }
376 }
377 }
378
379 let (model, provider, api_key, config_snapshot) = {
381 let config = self.config.read().await;
382 let resolved = self.router.resolve(&request.model, &config);
383
384 let (model, provider) = resolved.ok_or_else(|| {
385 HyperInferError::Config(std::io::Error::new(
386 std::io::ErrorKind::NotFound,
387 format!(
388 "Unknown model: '{}'. No routing rule or alias found.",
389 request.model
390 ),
391 ))
392 })?;
393
394 let api_key = config
395 .api_keys
396 .get(&provider.to_string())
397 .cloned()
398 .ok_or_else(|| {
399 HyperInferError::Config(std::io::Error::new(
400 std::io::ErrorKind::NotFound,
401 format!("API key not found for provider: {:?}", provider),
402 ))
403 })?;
404
405 (model, provider, api_key, Arc::new(config.clone()))
406 };
407
408 let provider_name = provider.to_string();
410 crate::telemetry_otlp::set_gen_ai_attributes(
411 &tracing::Span::current(),
412 &provider_name,
413 &model,
414 "chat",
415 );
416
417 let llm_provider = {
419 let registry = self.provider_registry.read().await;
420 registry.get(&provider_name).ok_or_else(|| {
421 HyperInferError::Config(std::io::Error::new(
422 std::io::ErrorKind::NotFound,
423 format!("Provider '{}' not found in registry", provider_name),
424 ))
425 })?
426 };
427
428 let mut resolved_request = request.clone();
429 resolved_request.model = model.clone();
430 let response = llm_provider.chat(&resolved_request, &api_key).await?;
431
432 let elapsed = start.elapsed().as_millis() as u64;
434 let input_tokens = response.usage.input_tokens;
435 let output_tokens = response.usage.output_tokens;
436
437 crate::telemetry_otlp::set_gen_ai_usage(
438 &tracing::Span::current(),
439 input_tokens,
440 output_tokens,
441 );
442
443 let finish_reason = response
444 .choices
445 .first()
446 .and_then(|c| c.finish_reason.as_deref())
447 .unwrap_or("unknown");
448 crate::telemetry_otlp::set_gen_ai_response(
449 &tracing::Span::current(),
450 &response.id,
451 finish_reason,
452 );
453
454 self.cache.set(&request, &response).await;
456
457 let telemetry = self.telemetry.clone();
459 let key_owned = key.to_string();
460 let model_owned = model.clone();
461 tokio::spawn(async move {
462 if let Err(e) = telemetry
463 .record_with_tokens(
464 &key_owned,
465 &model_owned,
466 input_tokens,
467 output_tokens,
468 elapsed,
469 )
470 .await
471 {
472 tracing::warn!(error = %e, "telemetry record failed");
473 }
474 });
475
476 let total_tokens = response.usage.input_tokens + response.usage.output_tokens;
478 let _ = self
479 .rate_limiter
480 .record_usage(key, total_tokens as u64)
481 .await;
482
483 mirroring::maybe_mirror(
485 self.mirror.clone(),
486 self.http_caller.clone(),
487 self.router.clone(),
488 config_snapshot,
489 key.to_string(),
490 request,
491 );
492
493 Ok(response)
495 }
496 .instrument(span)
497 .await
498 }
499
500 pub async fn chat_stream(
508 &self,
509 key: &str,
510 request: ChatRequest,
511 ) -> Result<
512 Pin<Box<dyn Stream<Item = Result<ChatChunk, HyperInferError>> + Send>>,
513 HyperInferError,
514 > {
515 request.validate()?;
516
517 let allowed = self.rate_limiter.is_allowed(key, 1).await;
519 if let Err(e) = allowed {
520 return Err(HyperInferError::RateLimit(e.to_string()));
521 }
522 if !allowed.unwrap() {
523 return Err(HyperInferError::RateLimit(
524 "Rate limit exceeded".to_string(),
525 ));
526 }
527
528 let (model, provider_name, api_key) = {
530 let config = self.config.read().await;
531 let resolved = self.router.resolve(&request.model, &config);
532
533 let (model, provider) = resolved.ok_or_else(|| {
534 HyperInferError::Config(std::io::Error::new(
535 std::io::ErrorKind::NotFound,
536 format!(
537 "Unknown model: '{}'. No routing rule or alias found.",
538 request.model
539 ),
540 ))
541 })?;
542
543 let provider_name = provider.to_string();
544 let api_key = config
545 .api_keys
546 .get(&provider_name)
547 .cloned()
548 .ok_or_else(|| {
549 HyperInferError::Config(std::io::Error::new(
550 std::io::ErrorKind::NotFound,
551 format!("API key not found for provider: {:?}", provider),
552 ))
553 })?;
554
555 (model, provider_name, api_key)
556 };
557
558 let streaming_provider = {
560 let registry = self.provider_registry.read().await;
561 registry.get_streaming(&provider_name).ok_or_else(|| {
562 HyperInferError::Config(std::io::Error::new(
563 std::io::ErrorKind::NotFound,
564 format!(
565 "Provider '{}' not found in registry or does not support streaming",
566 provider_name
567 ),
568 ))
569 })?
570 };
571
572 let mut resolved_request = request.clone();
573 resolved_request.model = model.clone();
574 let provider_stream: Pin<
575 Box<dyn Stream<Item = Result<ChatChunk, HyperInferError>> + Send>,
576 > = streaming_provider.into_stream(&resolved_request, &api_key);
577 let span = tracing::info_span!(
583 "gen_ai.chat_stream",
584 gen_ai.operation.name = "chat_stream",
585 gen_ai.request.model = %request.model,
586 );
587 crate::telemetry_otlp::set_gen_ai_attributes(&span, &provider_name, &model, "chat_stream");
588
589 let stream = AccountedStream {
595 inner: provider_stream,
596 telemetry: self.telemetry.clone(),
597 rate_limiter: self.rate_limiter.clone(),
598 key: key.to_string(),
599 model,
600 start: std::time::Instant::now(),
601 input_tokens: 0,
602 output_tokens: 0,
603 accounted: false,
604 span,
605 };
606
607 Ok(Box::pin(stream))
608 }
609}