Skip to main content

hyperinfer_client/
lib.rs

1//! HyperInfer Client Library - Data Plane
2
3pub mod cache;
4pub mod http_client;
5pub mod mirroring;
6pub mod router;
7pub mod router_engine;
8pub mod telemetry;
9pub mod telemetry_otlp;
10mod util;
11
12pub use cache::ExactMatchCache;
13pub use http_client::HttpCaller;
14pub use mirroring::{MirrorConfig, MirrorHandle};
15pub use router::Router;
16pub use router_engine::RouterEngine;
17pub use telemetry::Telemetry;
18pub use telemetry_otlp::{
19    init_langfuse_telemetry, init_telemetry, init_telemetry_with_headers, set_gen_ai_attributes,
20    set_gen_ai_response, set_gen_ai_usage, shutdown_telemetry,
21};
22
23use futures::{Stream, StreamExt};
24use hyperinfer_core::{
25    rate_limiting::RateLimiter, ChatChunk, ChatRequest, ChatResponse, Config, HyperInferError,
26    Provider,
27};
28use hyperinfer_providers::ProviderRegistry;
29use std::pin::Pin;
30use std::sync::{Arc, LazyLock};
31use std::task::{Context, Poll};
32
33static HTTP_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(reqwest::Client::new);
34use tokio::sync::RwLock;
35use tracing::Instrument as _;
36
37/// Wraps a provider `ChatChunk` stream and performs the same accounting as
38/// `chat()` once the stream terminates (naturally or via an error):
39///
40/// - Fires Redis telemetry off the critical path via `tokio::spawn`.
41/// - Records output-token usage in the rate-limiter bucket.
42/// - Sets OTel span usage / response attributes.
43///
44/// The accounting is triggered exactly once, either when a `[DONE]`-equivalent
45/// chunk with a `finish_reason` is seen **or** when the stream signals
46/// `Poll::Ready(None)`.
47struct AccountedStream {
48    inner: Pin<Box<dyn Stream<Item = Result<ChatChunk, HyperInferError>> + Send>>,
49    telemetry: Telemetry,
50    rate_limiter: RateLimiter,
51    key: String,
52    model: String,
53    start: std::time::Instant,
54    /// Accumulated token counts from the stream's usage chunk (if any).
55    input_tokens: u32,
56    output_tokens: u32,
57    /// Guards against running the accounting block more than once.
58    accounted: bool,
59    /// OTel span that lives for the full stream lifetime.
60    span: tracing::Span,
61}
62
63impl AccountedStream {
64    /// Run the accounting side-effects exactly once.
65    fn account(&mut self) {
66        if self.accounted {
67            return;
68        }
69        self.accounted = true;
70
71        let elapsed = self.start.elapsed().as_millis() as u64;
72        let input_tokens = self.input_tokens;
73        let output_tokens = self.output_tokens;
74
75        let _enter = self.span.clone().entered();
76        crate::telemetry_otlp::set_gen_ai_usage(&self.span, input_tokens, output_tokens);
77
78        // Telemetry write is off the critical path.
79        let telemetry = self.telemetry.clone();
80        let key = self.key.clone();
81        let model = self.model.clone();
82        tokio::spawn(async move {
83            if let Err(e) = telemetry
84                .record_with_tokens(&key, &model, input_tokens, output_tokens, elapsed)
85                .await
86            {
87                tracing::warn!(error = %e, "stream telemetry record failed");
88            }
89        });
90
91        // Rate-limiter token-bucket update is lightweight and synchronous-ish;
92        // run it in a spawn to avoid blocking the poll path.
93        let rate_limiter = self.rate_limiter.clone();
94        let key2 = self.key.clone();
95        let total = (input_tokens + output_tokens) as u64;
96        tokio::spawn(async move {
97            let _ = rate_limiter.record_usage(&key2, total).await;
98        });
99    }
100}
101
102impl Drop for AccountedStream {
103    fn drop(&mut self) {
104        self.account();
105    }
106}
107
108impl Stream for AccountedStream {
109    type Item = Result<ChatChunk, HyperInferError>;
110
111    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
112        // Clone the span so the Entered guard holds no borrow into `self`,
113        // allowing subsequent mutable borrows of other fields.
114        let _enter = self.span.clone().entered();
115        match self.inner.as_mut().poll_next(cx) {
116            Poll::Ready(Some(Ok(chunk))) => {
117                // Capture usage from whatever chunk carries it (typically the last).
118                if let Some(ref u) = chunk.usage {
119                    self.input_tokens = u.input_tokens;
120                    self.output_tokens = u.output_tokens;
121                }
122                // If this chunk has a finish_reason the stream is done; account now
123                // so the span attributes are set while the span is still open.
124                if chunk.finish_reason.is_some() {
125                    self.account();
126                }
127                Poll::Ready(Some(Ok(chunk)))
128            }
129            Poll::Ready(Some(Err(e))) => {
130                self.account();
131                Poll::Ready(Some(Err(e)))
132            }
133            Poll::Ready(None) => {
134                self.account();
135                Poll::Ready(None)
136            }
137            Poll::Pending => Poll::Pending,
138        }
139    }
140}
141
142pub struct HyperInferClient {
143    config: Arc<RwLock<Config>>,
144    http_caller: Arc<HttpCaller>,
145    router: Arc<Router>,
146    router_engine: Arc<RouterEngine>,
147    rate_limiter: RateLimiter,
148    telemetry: Telemetry,
149    cache: ExactMatchCache,
150    mirror: MirrorHandle,
151    provider_registry: Arc<RwLock<Arc<ProviderRegistry>>>,
152}
153
154impl HyperInferClient {
155    pub async fn new(redis_url: &str, config: Config) -> Result<Self, HyperInferError> {
156        let http_caller = Arc::new(HttpCaller::new().map_err(HyperInferError::Http)?);
157        let router = Arc::new(
158            Router::new(config.routing_rules.clone())
159                .with_aliases(config.model_aliases.clone())
160                .with_default_provider(config.default_provider.clone()),
161        );
162        let rate_limiter = RateLimiter::new(Some(redis_url))
163            .await
164            .map_err(|e| HyperInferError::Config(std::io::Error::other(e.to_string())))?;
165        let telemetry = Telemetry::new(redis_url)
166            .await
167            .map_err(|e| HyperInferError::Config(std::io::Error::other(e.to_string())))?;
168        let cache = ExactMatchCache::new(redis_url, "default").await;
169        let mirror: MirrorHandle = Arc::new(RwLock::new(None));
170        let config = Arc::new(RwLock::new(config));
171
172        let provider_registry_inner = Arc::new(ProviderRegistry::new());
173        hyperinfer_providers::init_default_registry(&provider_registry_inner);
174        let provider_registry = Arc::new(RwLock::new(provider_registry_inner));
175
176        Ok(Self {
177            config,
178            http_caller,
179            router,
180            router_engine: Arc::new(RouterEngine::new().await),
181            rate_limiter,
182            telemetry,
183            cache,
184            mirror,
185            provider_registry,
186        })
187    }
188
189    /// Configure traffic mirroring.  Pass `None` to disable.
190    pub async fn set_mirror(&self, cfg: Option<MirrorConfig>) {
191        let mut guard = self.mirror.write().await;
192        *guard = cfg;
193    }
194
195    pub async fn inject_provider_registry(&self, external_registry: Arc<ProviderRegistry>) {
196        let mut guard = self.provider_registry.write().await;
197        *guard = external_registry;
198    }
199
200    /// Load deployments into the router engine for deployment-based routing
201    pub async fn load_deployments(&self, deployments: Vec<hyperinfer_core::Deployment>) {
202        self.router_engine.load_deployments(deployments).await;
203    }
204
205    /// Subscribe to Redis Pub/Sub for live deployment config changes.
206    /// When a message arrives on "hyperinfer:config_updates", calls the provided
207    /// fetcher function to re-fetch deployments and rebuilds the routing pool.
208    pub async fn subscribe_config_updates<F, Fut>(
209        &self,
210        redis_url: &str,
211        fetcher: F,
212    ) -> Result<(), HyperInferError>
213    where
214        F: Fn() -> Fut + Send + Sync + 'static,
215        Fut: std::future::Future<Output = Result<Vec<hyperinfer_core::Deployment>, String>> + Send,
216    {
217        let client = redis::Client::open(redis_url)
218            .map_err(|e| HyperInferError::Config(std::io::Error::other(e.to_string())))?;
219        let mut pubsub = client
220            .get_async_pubsub()
221            .await
222            .map_err(|e| HyperInferError::Config(std::io::Error::other(e.to_string())))?;
223
224        pubsub
225            .subscribe("hyperinfer:config_updates")
226            .await
227            .map_err(|e| HyperInferError::Config(std::io::Error::other(e.to_string())))?;
228
229        let engine = self.router_engine.clone();
230        let _handle = tokio::spawn(async move {
231            let mut stream = pubsub.on_message();
232            loop {
233                match stream.next().await {
234                    Some(_msg) => {
235                        tracing::info!(
236                            "Received config update notification, re-fetching deployments"
237                        );
238                        match fetcher().await {
239                            Ok(deployments) => {
240                                engine.rebuild_pool(deployments).await;
241                                tracing::info!("Rebuilt deployment pool after config update");
242                            }
243                            Err(e) => {
244                                tracing::warn!(error = %e, "Failed to re-fetch deployments after config update");
245                            }
246                        }
247                    }
248                    None => {
249                        tracing::info!("Pub/Sub stream ended");
250                        break;
251                    }
252                }
253            }
254        });
255
256        Ok(())
257    }
258
259    /// Get a reference to the router engine
260    pub fn router_engine(&self) -> &Arc<RouterEngine> {
261        &self.router_engine
262    }
263
264    pub async fn chat(
265        &self,
266        key: &str,
267        request: ChatRequest,
268    ) -> Result<ChatResponse, HyperInferError> {
269        request.validate()?;
270
271        // 0. Exact-match cache lookup (before rate-limiting to avoid wasting quota).
272        if let Some(cached) = self.cache.get(&request).await {
273            return Ok(cached);
274        }
275
276        // Create a root OTel span following the GenAI Semantic Conventions.
277        // We use `.instrument(span)` on the inner async block so the span is
278        // properly propagated across every `.await` point (using `span.enter()`
279        // in an async function is unsafe — the guard can survive suspension).
280        let span = tracing::info_span!(
281            "gen_ai.chat",
282            gen_ai.operation.name = "chat",
283            gen_ai.request.model = %request.model,
284        );
285
286        async move {
287            let start = std::time::Instant::now();
288
289            // 1. Check rate limit
290            let allowed = self.rate_limiter.is_allowed(key, 1).await;
291            if let Err(e) = allowed {
292                return Err(HyperInferError::RateLimit(e.to_string()));
293            }
294            if !allowed.unwrap() {
295                return Err(HyperInferError::RateLimit(
296                    "Rate limit exceeded".to_string(),
297                ));
298            }
299
300            // 2. Try deployment-based routing first (if deployments are loaded)
301            let deployment_result = self.router_engine.select_deployment(&request).await;
302            if let Ok(routing_result) = deployment_result {
303                let deployment = &routing_result.deployment;
304                let default_url = match &deployment.provider {
305                    Provider::Anthropic => "https://api.anthropic.com/v1",
306                    _ => "https://api.openai.com/v1",
307                };
308                let base_url = deployment.base_url.as_deref().unwrap_or(default_url);
309                let api_key = &deployment.api_key_ref;
310
311                let url = format!("{}/chat/completions", base_url.trim_end_matches('/'));
312
313                let mut headers = reqwest::header::HeaderMap::new();
314                headers.insert("content-type", "application/json".parse().unwrap());
315                if !api_key.is_empty() {
316                    match &deployment.provider {
317                        Provider::Anthropic => {
318                            headers.insert("x-api-key", api_key.parse().unwrap());
319                            headers.insert("anthropic-version", "2023-06-01".parse().unwrap());
320                        }
321                        _ => {
322                            headers.insert(
323                                "authorization",
324                                format!("Bearer {}", api_key).parse().unwrap(),
325                            );
326                        }
327                    }
328                }
329
330                match HTTP_CLIENT
331                    .post(&url)
332                    .headers(headers)
333                    .json(&request)
334                    .send()
335                    .await
336                {
337                    Ok(response) => {
338                        let status = response.status();
339                        if status.is_success() {
340                            if let Ok(body) = response.json::<ChatResponse>().await {
341                                // Record success metrics
342                                let latency = start.elapsed().as_secs_f64() * 1000.0;
343                                let tokens =
344                                    (body.usage.input_tokens + body.usage.output_tokens) as u64;
345                                self.router_engine
346                                    .record_success(&deployment.id, latency, tokens)
347                                    .await;
348
349                                // Record telemetry
350                                let elapsed = start.elapsed().as_millis() as u64;
351                                let telemetry = self.telemetry.clone();
352                                let key_owned = key.to_string();
353                                let model_owned = request.model.clone();
354                                tokio::spawn(async move {
355                                    let _ = telemetry
356                                        .record_with_tokens(
357                                            &key_owned,
358                                            &model_owned,
359                                            body.usage.input_tokens,
360                                            body.usage.output_tokens,
361                                            elapsed,
362                                        )
363                                        .await;
364                                });
365
366                                return Ok(body);
367                            }
368                        }
369                        // If request failed, record failure and fall through to fallback
370                        self.router_engine.record_failure(&deployment.id).await;
371                    }
372                    Err(_) => {
373                        // Request failed, record failure and fall through to fallback
374                        self.router_engine.record_failure(&deployment.id).await;
375                    }
376                }
377            }
378
379            // 3. Fallback to existing Router-based flow
380            let (model, provider, api_key, config_snapshot) = {
381                let config = self.config.read().await;
382                let resolved = self.router.resolve(&request.model, &config);
383
384                let (model, provider) = resolved.ok_or_else(|| {
385                    HyperInferError::Config(std::io::Error::new(
386                        std::io::ErrorKind::NotFound,
387                        format!(
388                            "Unknown model: '{}'. No routing rule or alias found.",
389                            request.model
390                        ),
391                    ))
392                })?;
393
394                let api_key = config
395                    .api_keys
396                    .get(&provider.to_string())
397                    .cloned()
398                    .ok_or_else(|| {
399                        HyperInferError::Config(std::io::Error::new(
400                            std::io::ErrorKind::NotFound,
401                            format!("API key not found for provider: {:?}", provider),
402                        ))
403                    })?;
404
405                (model, provider, api_key, Arc::new(config.clone()))
406            };
407
408            // Enrich span with the resolved provider and final model name.
409            let provider_name = provider.to_string();
410            crate::telemetry_otlp::set_gen_ai_attributes(
411                &tracing::Span::current(),
412                &provider_name,
413                &model,
414                "chat",
415            );
416
417            // 3. Execute HTTP call via provider registry
418            let llm_provider = {
419                let registry = self.provider_registry.read().await;
420                registry.get(&provider_name).ok_or_else(|| {
421                    HyperInferError::Config(std::io::Error::new(
422                        std::io::ErrorKind::NotFound,
423                        format!("Provider '{}' not found in registry", provider_name),
424                    ))
425                })?
426            };
427
428            let mut resolved_request = request.clone();
429            resolved_request.model = model.clone();
430            let response = llm_provider.chat(&resolved_request, &api_key).await?;
431
432            // 4. Record OTel usage and response attributes on the span.
433            let elapsed = start.elapsed().as_millis() as u64;
434            let input_tokens = response.usage.input_tokens;
435            let output_tokens = response.usage.output_tokens;
436
437            crate::telemetry_otlp::set_gen_ai_usage(
438                &tracing::Span::current(),
439                input_tokens,
440                output_tokens,
441            );
442
443            let finish_reason = response
444                .choices
445                .first()
446                .and_then(|c| c.finish_reason.as_deref())
447                .unwrap_or("unknown");
448            crate::telemetry_otlp::set_gen_ai_response(
449                &tracing::Span::current(),
450                &response.id,
451                finish_reason,
452            );
453
454            // Store successful response in exact-match cache.
455            self.cache.set(&request, &response).await;
456
457            // Record async Redis telemetry off the critical path.
458            let telemetry = self.telemetry.clone();
459            let key_owned = key.to_string();
460            let model_owned = model.clone();
461            tokio::spawn(async move {
462                if let Err(e) = telemetry
463                    .record_with_tokens(
464                        &key_owned,
465                        &model_owned,
466                        input_tokens,
467                        output_tokens,
468                        elapsed,
469                    )
470                    .await
471                {
472                    tracing::warn!(error = %e, "telemetry record failed");
473                }
474            });
475
476            // Record usage for rate-limiter token bucket.
477            let total_tokens = response.usage.input_tokens + response.usage.output_tokens;
478            let _ = self
479                .rate_limiter
480                .record_usage(key, total_tokens as u64)
481                .await;
482
483            // 5. Fire-and-forget traffic mirror (if configured).
484            mirroring::maybe_mirror(
485                self.mirror.clone(),
486                self.http_caller.clone(),
487                self.router.clone(),
488                config_snapshot,
489                key.to_string(),
490                request,
491            );
492
493            // 6. Return response
494            Ok(response)
495        }
496        .instrument(span)
497        .await
498    }
499
500    /// Stream token chunks for a chat request.
501    ///
502    /// Returns a `Stream` of `ChatChunk` items.  The caller is responsible for
503    /// collecting `delta` fields and assembling the final text.  The last chunk
504    /// in the stream has a non-`None` `finish_reason` and may carry `usage`.
505    ///
506    /// Rate-limiting and routing follow the same logic as `chat()`.
507    pub async fn chat_stream(
508        &self,
509        key: &str,
510        request: ChatRequest,
511    ) -> Result<
512        Pin<Box<dyn Stream<Item = Result<ChatChunk, HyperInferError>> + Send>>,
513        HyperInferError,
514    > {
515        request.validate()?;
516
517        // 1. Rate limit check (same as non-streaming path).
518        let allowed = self.rate_limiter.is_allowed(key, 1).await;
519        if let Err(e) = allowed {
520            return Err(HyperInferError::RateLimit(e.to_string()));
521        }
522        if !allowed.unwrap() {
523            return Err(HyperInferError::RateLimit(
524                "Rate limit exceeded".to_string(),
525            ));
526        }
527
528        // 2. Resolve model / provider / api key.
529        let (model, provider_name, api_key) = {
530            let config = self.config.read().await;
531            let resolved = self.router.resolve(&request.model, &config);
532
533            let (model, provider) = resolved.ok_or_else(|| {
534                HyperInferError::Config(std::io::Error::new(
535                    std::io::ErrorKind::NotFound,
536                    format!(
537                        "Unknown model: '{}'. No routing rule or alias found.",
538                        request.model
539                    ),
540                ))
541            })?;
542
543            let provider_name = provider.to_string();
544            let api_key = config
545                .api_keys
546                .get(&provider_name)
547                .cloned()
548                .ok_or_else(|| {
549                    HyperInferError::Config(std::io::Error::new(
550                        std::io::ErrorKind::NotFound,
551                        format!("API key not found for provider: {:?}", provider),
552                    ))
553                })?;
554
555            (model, provider_name, api_key)
556        };
557
558        // 3. Get streaming provider from registry (already checks supports_streaming)
559        let streaming_provider = {
560            let registry = self.provider_registry.read().await;
561            registry.get_streaming(&provider_name).ok_or_else(|| {
562                HyperInferError::Config(std::io::Error::new(
563                    std::io::ErrorKind::NotFound,
564                    format!(
565                        "Provider '{}' not found in registry or does not support streaming",
566                        provider_name
567                    ),
568                ))
569            })?
570        };
571
572        let mut resolved_request = request.clone();
573        resolved_request.model = model.clone();
574        let provider_stream: Pin<
575            Box<dyn Stream<Item = Result<ChatChunk, HyperInferError>> + Send>,
576        > = streaming_provider.into_stream(&resolved_request, &api_key);
577        // Note: streaming responses are not cached — the stream is consumed
578        // incrementally by the caller so we cannot inspect it here.
579
580        // 4. Create an OTel span for the stream lifetime and enrich it with
581        //    resolved provider / model information (mirrors chat()).
582        let span = tracing::info_span!(
583            "gen_ai.chat_stream",
584            gen_ai.operation.name = "chat_stream",
585            gen_ai.request.model = %request.model,
586        );
587        crate::telemetry_otlp::set_gen_ai_attributes(&span, &provider_name, &model, "chat_stream");
588
589        // 5. Wrap the provider stream so usage/telemetry are recorded on
590        //    termination — the same accounting chat() performs, but deferred
591        //    to when the last chunk (or an error) is polled.
592        //    The span is stored inside the wrapper; poll_next enters it on
593        //    every poll so it covers the full stream lifetime.
594        let stream = AccountedStream {
595            inner: provider_stream,
596            telemetry: self.telemetry.clone(),
597            rate_limiter: self.rate_limiter.clone(),
598            key: key.to_string(),
599            model,
600            start: std::time::Instant::now(),
601            input_tokens: 0,
602            output_tokens: 0,
603            accounted: false,
604            span,
605        };
606
607        Ok(Box::pin(stream))
608    }
609}