dynamo_llm/http/
client.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! HTTP clients for streaming LLM responses with performance recording
5//!
6//! This module provides HTTP clients that leverage async-openai with BYOT (Bring Your Own Types)
7//! feature to work with OpenAI-compatible APIs. The clients support recording streaming responses
8//! for performance analysis.
9
10use std::pin::Pin;
11use std::sync::{Arc, Mutex};
12use std::task::{Context, Poll};
13use std::time::Instant;
14
15use async_trait::async_trait;
16use derive_getters::Dissolve;
17use dynamo_async_openai::{Client, config::OpenAIConfig, error::OpenAIError};
18use futures::Stream;
19use serde_json::Value;
20use tokio_util::sync::CancellationToken;
21use tracing;
22use uuid::Uuid;
23
24// Import our existing recording infrastructure
25use crate::protocols::Annotated;
26use crate::protocols::openai::chat_completions::{
27    NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
28};
29use dynamo_runtime::engine::{
30    AsyncEngineContext, AsyncEngineContextProvider, AsyncEngineStream, Data, DataStream,
31};
32
33/// Configuration for HTTP clients
34#[derive(Clone, Default)]
35pub struct HttpClientConfig {
36    /// OpenAI API configuration
37    pub openai_config: OpenAIConfig,
38    /// Whether to enable detailed logging
39    pub verbose: bool,
40}
41
42/// Error types for HTTP clients
43#[derive(Debug, thiserror::Error)]
44pub enum HttpClientError {
45    #[error("OpenAI API error: {0}")]
46    OpenAI(#[from] OpenAIError),
47    #[error("Request timeout")]
48    Timeout,
49    #[error("Request cancelled")]
50    Cancelled,
51    #[error("Invalid request: {0}")]
52    InvalidRequest(String),
53}
54
55/// Context for HTTP client requests that supports cancellation
56/// This bridges AsyncEngineContext and reqwest cancellation
57#[derive(Clone)]
58pub struct HttpRequestContext {
59    /// Unique request identifier
60    id: String,
61    /// Tokio cancellation token for reqwest integration
62    cancel_token: CancellationToken,
63    /// When this context was created
64    created_at: Instant,
65    /// Whether the request has been stopped
66    stopped: Arc<std::sync::atomic::AtomicBool>,
67    /// Child contexts to be stopped if this is stopped
68    child_context: Arc<Mutex<Vec<Arc<dyn AsyncEngineContext>>>>,
69}
70
71impl HttpRequestContext {
72    /// Create a new HTTP request context
73    pub fn new() -> Self {
74        Self {
75            id: Uuid::new_v4().to_string(),
76            cancel_token: CancellationToken::new(),
77            created_at: Instant::now(),
78            stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)),
79            child_context: Arc::new(Mutex::new(Vec::new())),
80        }
81    }
82
83    /// Create a new context with a specific ID
84    pub fn with_id(id: String) -> Self {
85        Self {
86            id,
87            cancel_token: CancellationToken::new(),
88            created_at: Instant::now(),
89            stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)),
90            child_context: Arc::new(Mutex::new(Vec::new())),
91        }
92    }
93
94    /// Create a child context from this parent context
95    /// The child will be cancelled when the parent is cancelled
96    pub fn child(&self) -> Self {
97        Self {
98            id: Uuid::new_v4().to_string(),
99            cancel_token: self.cancel_token.child_token(),
100            created_at: Instant::now(),
101            stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)),
102            child_context: Arc::new(Mutex::new(Vec::new())),
103        }
104    }
105
106    /// Create a child context with a specific ID
107    pub fn child_with_id(&self, id: String) -> Self {
108        Self {
109            id,
110            cancel_token: self.cancel_token.child_token(),
111            created_at: Instant::now(),
112            stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)),
113            child_context: Arc::new(Mutex::new(Vec::new())),
114        }
115    }
116
117    /// Get the cancellation token for use with reqwest
118    pub fn cancellation_token(&self) -> CancellationToken {
119        self.cancel_token.clone()
120    }
121
122    /// Get the elapsed time since context creation
123    pub fn elapsed(&self) -> std::time::Duration {
124        self.created_at.elapsed()
125    }
126}
127
128impl Default for HttpRequestContext {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134impl std::fmt::Debug for HttpRequestContext {
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        f.debug_struct("HttpRequestContext")
137            .field("id", &self.id)
138            .field("created_at", &self.created_at)
139            .field("is_stopped", &self.is_stopped())
140            .field("is_killed", &self.is_killed())
141            .field("is_cancelled", &self.cancel_token.is_cancelled())
142            .finish()
143    }
144}
145
146#[async_trait]
147impl AsyncEngineContext for HttpRequestContext {
148    fn id(&self) -> &str {
149        &self.id
150    }
151
152    fn stop(&self) {
153        // Clone child Arcs to avoid deadlock if parent is accidentally linked under child
154        let children = self
155            .child_context
156            .lock()
157            .expect("Failed to lock child context")
158            .iter()
159            .cloned()
160            .collect::<Vec<_>>();
161        for child in children {
162            child.stop();
163        }
164
165        self.stopped
166            .store(true, std::sync::atomic::Ordering::Release);
167        self.cancel_token.cancel();
168    }
169
170    fn stop_generating(&self) {
171        // Clone child Arcs to avoid deadlock if parent is accidentally linked under child
172        let children = self
173            .child_context
174            .lock()
175            .expect("Failed to lock child context")
176            .iter()
177            .cloned()
178            .collect::<Vec<_>>();
179        for child in children {
180            child.stop_generating();
181        }
182
183        // For HTTP clients, stop_generating is the same as stop
184        self.stopped
185            .store(true, std::sync::atomic::Ordering::Release);
186        self.cancel_token.cancel();
187    }
188
189    fn kill(&self) {
190        // Clone child Arcs to avoid deadlock if parent is accidentally linked under child
191        let children = self
192            .child_context
193            .lock()
194            .expect("Failed to lock child context")
195            .iter()
196            .cloned()
197            .collect::<Vec<_>>();
198        for child in children {
199            child.kill();
200        }
201
202        self.stopped
203            .store(true, std::sync::atomic::Ordering::Release);
204        self.cancel_token.cancel();
205    }
206
207    fn is_stopped(&self) -> bool {
208        self.stopped.load(std::sync::atomic::Ordering::Acquire)
209    }
210
211    fn is_killed(&self) -> bool {
212        self.stopped.load(std::sync::atomic::Ordering::Acquire)
213    }
214
215    async fn stopped(&self) {
216        self.cancel_token.cancelled().await;
217    }
218
219    async fn killed(&self) {
220        // For HTTP clients, killed is the same as stopped
221        self.cancel_token.cancelled().await;
222    }
223
224    fn link_child(&self, child: Arc<dyn AsyncEngineContext>) {
225        self.child_context
226            .lock()
227            .expect("Failed to lock child context")
228            .push(child);
229    }
230}
231
232/// Base HTTP client with common functionality
233pub struct BaseHttpClient {
234    /// async-openai client
235    client: Client<OpenAIConfig>,
236    /// Client configuration
237    config: HttpClientConfig,
238    /// Root context for this client
239    root_context: HttpRequestContext,
240}
241
242impl BaseHttpClient {
243    /// Create a new base HTTP client
244    pub fn new(config: HttpClientConfig) -> Self {
245        let client = Client::with_config(config.openai_config.clone());
246        Self {
247            client,
248            config,
249            root_context: HttpRequestContext::new(),
250        }
251    }
252
253    /// Get a reference to the underlying async-openai client
254    pub fn client(&self) -> &Client<OpenAIConfig> {
255        &self.client
256    }
257
258    /// Create a new request context as a child of the root context
259    pub fn create_context(&self) -> HttpRequestContext {
260        self.root_context.child()
261    }
262
263    /// Create a new request context with a specific ID as a child of the root context
264    pub fn create_context_with_id(&self, id: String) -> HttpRequestContext {
265        self.root_context.child_with_id(id)
266    }
267
268    /// Get the root context for this client
269    pub fn root_context(&self) -> &HttpRequestContext {
270        &self.root_context
271    }
272
273    /// Check if verbose logging is enabled
274    pub fn is_verbose(&self) -> bool {
275        self.config.verbose
276    }
277}
278
279/// Type alias for NV chat response stream
280pub type NvChatResponseStream =
281    DataStream<Result<Annotated<NvCreateChatCompletionStreamResponse>, OpenAIError>>;
282
283/// Type alias for generic BYOT response stream
284pub type ByotResponseStream = DataStream<Result<Value, OpenAIError>>;
285
286/// Type alias for pure OpenAI chat response stream
287pub type OpenAIChatResponseStream =
288    DataStream<Result<dynamo_async_openai::types::CreateChatCompletionStreamResponse, OpenAIError>>;
289
290/// A wrapped HTTP response stream that combines a stream with its context
291/// This provides a unified interface for HTTP client responses
292#[derive(Dissolve)]
293pub struct HttpResponseStream<T> {
294    /// The underlying stream of responses
295    pub stream: DataStream<T>,
296    /// The context for this request
297    pub context: Arc<dyn AsyncEngineContext>,
298}
299
300impl<T> HttpResponseStream<T> {
301    /// Create a new HttpResponseStream
302    pub fn new(stream: DataStream<T>, context: Arc<dyn AsyncEngineContext>) -> Self {
303        Self { stream, context }
304    }
305}
306
307impl<T: Data> Stream for HttpResponseStream<T> {
308    type Item = T;
309
310    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
311        Pin::new(&mut self.stream).poll_next(cx)
312    }
313}
314
315impl<T: Data> AsyncEngineContextProvider for HttpResponseStream<T> {
316    fn context(&self) -> Arc<dyn AsyncEngineContext> {
317        self.context.clone()
318    }
319}
320
321impl<T: Data> HttpResponseStream<T> {
322    /// Convert this HttpResponseStream to a Pin<Box<dyn AsyncEngineStream<T>>>
323    /// This requires the stream to be Send + Sync, which may not be true for all streams
324    pub fn into_async_engine_stream(self) -> Pin<Box<dyn AsyncEngineStream<T>>>
325    where
326        T: 'static,
327    {
328        // This will only work if the underlying stream is actually Send + Sync
329        // For now, we create a wrapper that assumes this
330        Box::pin(AsyncEngineStreamWrapper {
331            stream: self.stream,
332            context: self.context,
333        })
334    }
335}
336
337/// A wrapper that implements AsyncEngineStream for streams that are Send + Sync
338struct AsyncEngineStreamWrapper<T> {
339    stream: DataStream<T>,
340    context: Arc<dyn AsyncEngineContext>,
341}
342
343impl<T: Data> Stream for AsyncEngineStreamWrapper<T> {
344    type Item = T;
345
346    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
347        Pin::new(&mut self.stream).poll_next(cx)
348    }
349}
350
351impl<T: Data> AsyncEngineContextProvider for AsyncEngineStreamWrapper<T> {
352    fn context(&self) -> Arc<dyn AsyncEngineContext> {
353        self.context.clone()
354    }
355}
356
357impl<T: Data> AsyncEngineStream<T> for AsyncEngineStreamWrapper<T> {}
358
359impl<T> std::fmt::Debug for AsyncEngineStreamWrapper<T> {
360    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
361        f.debug_struct("AsyncEngineStreamWrapper")
362            .field("context", &self.context)
363            .finish()
364    }
365}
366
367impl<T: Data> std::fmt::Debug for HttpResponseStream<T> {
368    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
369        f.debug_struct("HttpResponseStream")
370            .field("context", &self.context)
371            .finish()
372    }
373}
374
375/// Type alias for HttpResponseStream with NV chat completion responses
376pub type NvHttpResponseStream =
377    HttpResponseStream<Result<Annotated<NvCreateChatCompletionStreamResponse>, OpenAIError>>;
378
379/// Type alias for HttpResponseStream with BYOT responses
380pub type ByotHttpResponseStream = HttpResponseStream<Result<Value, OpenAIError>>;
381
382/// Type alias for HttpResponseStream with pure OpenAI responses
383pub type OpenAIHttpResponseStream = HttpResponseStream<
384    Result<dynamo_async_openai::types::CreateChatCompletionStreamResponse, OpenAIError>,
385>;
386
387/// Pure OpenAI client using standard async-openai types
388pub struct PureOpenAIClient {
389    base: BaseHttpClient,
390}
391
392impl PureOpenAIClient {
393    /// Create a new pure OpenAI client
394    pub fn new(config: HttpClientConfig) -> Self {
395        Self {
396            base: BaseHttpClient::new(config),
397        }
398    }
399
400    /// Create streaming chat completions using standard OpenAI types
401    /// Uses a client-managed context
402    pub async fn chat_stream(
403        &self,
404        request: dynamo_async_openai::types::CreateChatCompletionRequest,
405    ) -> Result<OpenAIHttpResponseStream, HttpClientError> {
406        let ctx = self.base.create_context();
407        self.chat_stream_with_context(request, ctx).await
408    }
409
410    /// Create streaming chat completions with a custom context
411    pub async fn chat_stream_with_context(
412        &self,
413        request: dynamo_async_openai::types::CreateChatCompletionRequest,
414        context: HttpRequestContext,
415    ) -> Result<OpenAIHttpResponseStream, HttpClientError> {
416        let ctx_arc: Arc<dyn AsyncEngineContext> = Arc::new(context.clone());
417
418        if !request.stream.unwrap_or(false) {
419            return Err(HttpClientError::InvalidRequest(
420                "chat_stream requires the request to have 'stream': true".to_string(),
421            ));
422        }
423
424        if self.base.is_verbose() {
425            tracing::info!(
426                "Starting pure OpenAI chat stream for request {}",
427                context.id()
428            );
429        }
430
431        // Create the stream with cancellation support
432        let stream = self
433            .base
434            .client()
435            .chat()
436            .create_stream(request)
437            .await
438            .map_err(HttpClientError::OpenAI)?;
439
440        // TODO: In Phase 3, we'll add cancellation integration with reqwest
441        // For now, return the stream as-is
442        Ok(HttpResponseStream::new(stream, ctx_arc))
443    }
444}
445
446/// NV Custom client using NvCreateChatCompletionRequest with Annotated responses
447pub struct NvCustomClient {
448    base: BaseHttpClient,
449}
450
451impl NvCustomClient {
452    /// Create a new NV custom client
453    pub fn new(config: HttpClientConfig) -> Self {
454        Self {
455            base: BaseHttpClient::new(config),
456        }
457    }
458
459    /// Create streaming chat completions using NV custom types
460    /// Uses a client-managed context
461    pub async fn chat_stream(
462        &self,
463        request: NvCreateChatCompletionRequest,
464    ) -> Result<NvHttpResponseStream, HttpClientError> {
465        let ctx = self.base.create_context();
466        self.chat_stream_with_context(request, ctx).await
467    }
468
469    /// Create streaming chat completions with a custom context
470    pub async fn chat_stream_with_context(
471        &self,
472        request: NvCreateChatCompletionRequest,
473        context: HttpRequestContext,
474    ) -> Result<NvHttpResponseStream, HttpClientError> {
475        let ctx_arc: Arc<dyn AsyncEngineContext> = Arc::new(context.clone());
476
477        if !request.inner.stream.unwrap_or(false) {
478            return Err(HttpClientError::InvalidRequest(
479                "chat_stream requires the request to have 'stream': true".to_string(),
480            ));
481        }
482
483        if self.base.is_verbose() {
484            tracing::info!(
485                "Starting NV custom chat stream for request {}",
486                context.id()
487            );
488        }
489
490        // Use BYOT feature to send NvCreateChatCompletionRequest
491        // The stream type is explicitly specified to deserialize directly into Annotated<NvCreateChatCompletionStreamResponse>
492        let stream = self
493            .base
494            .client()
495            .chat()
496            .create_stream_byot(request)
497            .await
498            .map_err(HttpClientError::OpenAI)?;
499
500        Ok(HttpResponseStream::new(stream, ctx_arc))
501    }
502}
503
504/// Generic BYOT client using serde_json::Value for maximum flexibility
505pub struct GenericBYOTClient {
506    base: BaseHttpClient,
507}
508
509impl GenericBYOTClient {
510    /// Create a new generic BYOT client
511    pub fn new(config: HttpClientConfig) -> Self {
512        Self {
513            base: BaseHttpClient::new(config),
514        }
515    }
516
517    /// Create streaming chat completions using arbitrary JSON values
518    /// Uses a client-managed context
519    pub async fn chat_stream(
520        &self,
521        request: Value,
522    ) -> Result<ByotHttpResponseStream, HttpClientError> {
523        let ctx = self.base.create_context();
524        self.chat_stream_with_context(request, ctx).await
525    }
526
527    /// Create streaming chat completions with a custom context
528    pub async fn chat_stream_with_context(
529        &self,
530        request: Value,
531        context: HttpRequestContext,
532    ) -> Result<ByotHttpResponseStream, HttpClientError> {
533        let ctx_arc: Arc<dyn AsyncEngineContext> = Arc::new(context.clone());
534
535        if self.base.is_verbose() {
536            tracing::info!(
537                "Starting generic BYOT chat stream for request {}",
538                context.id()
539            );
540        }
541
542        // Validate that the request has stream: true
543        if let Some(stream_val) = request.get("stream") {
544            if !stream_val.as_bool().unwrap_or(false) {
545                return Err(HttpClientError::InvalidRequest(
546                    "Request must have 'stream': true for streaming".to_string(),
547                ));
548            }
549        } else {
550            return Err(HttpClientError::InvalidRequest(
551                "Request must include 'stream' field".to_string(),
552            ));
553        }
554
555        // Use BYOT feature with raw JSON
556        // The stream type is explicitly specified to deserialize directly into serde_json::Value
557        let stream = self
558            .base
559            .client()
560            .chat()
561            .create_stream_byot(request)
562            .await
563            .map_err(HttpClientError::OpenAI)?;
564
565        Ok(HttpResponseStream::new(stream, ctx_arc))
566    }
567}
568
569// TODO: Implement recording integration in Phase 3:
570// - Recording wrapper functions
571// - Capacity hints from request parameters
572// - Integration with existing recording infrastructure
573
574#[cfg(test)]
575mod tests {
576    use super::*;
577    use tokio::time::{Duration, sleep};
578
579    #[tokio::test]
580    async fn test_http_request_context_creation() {
581        let ctx = HttpRequestContext::new();
582        assert!(!ctx.id().is_empty());
583        assert!(!ctx.is_stopped());
584        assert!(!ctx.is_killed());
585    }
586
587    #[tokio::test]
588    async fn test_http_request_context_child() {
589        let parent = HttpRequestContext::new();
590        let child = parent.child();
591
592        // Child should have different ID
593        assert_ne!(parent.id(), child.id());
594
595        // Child should not be stopped initially
596        assert!(!child.is_stopped());
597
598        // When parent is stopped, child should be cancelled via token
599        parent.stop();
600        assert!(parent.is_stopped());
601        assert!(child.cancellation_token().is_cancelled());
602    }
603
604    #[tokio::test]
605    async fn test_http_request_context_child_with_id() {
606        let parent = HttpRequestContext::new();
607        let child_id = "test-child";
608        let child = parent.child_with_id(child_id.to_string());
609
610        assert_eq!(child.id(), child_id);
611        assert!(!child.is_stopped());
612
613        // Test hierarchical cancellation
614        parent.stop();
615        assert!(child.cancellation_token().is_cancelled());
616    }
617
618    #[tokio::test]
619    async fn test_http_request_context_cancellation() {
620        let ctx = HttpRequestContext::new();
621        let cancel_token = ctx.cancellation_token();
622
623        // Test stop functionality
624        assert!(!ctx.is_stopped());
625        ctx.stop();
626        assert!(ctx.is_stopped());
627        assert!(cancel_token.is_cancelled());
628    }
629
630    #[tokio::test]
631    async fn test_http_request_context_kill() {
632        let ctx = HttpRequestContext::new();
633
634        // Test kill functionality
635        assert!(!ctx.is_killed());
636        ctx.kill();
637        assert!(ctx.is_killed());
638        assert!(ctx.is_stopped());
639    }
640
641    #[tokio::test]
642    async fn test_http_request_context_async_cancellation() {
643        let ctx = HttpRequestContext::new();
644
645        // Test async cancellation
646        let ctx_clone = ctx.clone();
647        let task = tokio::spawn(async move {
648            ctx_clone.stopped().await;
649        });
650
651        // Give a moment for the task to start waiting
652        sleep(Duration::from_millis(10)).await;
653
654        // Cancel the context
655        ctx.stop();
656
657        // The task should complete
658        task.await.unwrap();
659    }
660
661    #[test]
662    fn test_base_http_client_creation() {
663        let config = HttpClientConfig::default();
664        let client = BaseHttpClient::new(config);
665        assert!(!client.is_verbose());
666
667        // Test that client has a root context
668        assert!(!client.root_context().id().is_empty());
669    }
670
671    #[test]
672    fn test_base_http_client_context_creation() {
673        let config = HttpClientConfig::default();
674        let client = BaseHttpClient::new(config);
675
676        // Test creating child contexts
677        let ctx1 = client.create_context();
678        let ctx2 = client.create_context();
679
680        // Should have different IDs
681        assert_ne!(ctx1.id(), ctx2.id());
682
683        // Should be children of root context
684        client.root_context().stop();
685        assert!(ctx1.cancellation_token().is_cancelled());
686        assert!(ctx2.cancellation_token().is_cancelled());
687    }
688
689    #[test]
690    fn test_base_http_client_context_with_id() {
691        let config = HttpClientConfig::default();
692        let client = BaseHttpClient::new(config);
693
694        let custom_id = "custom-request-id";
695        let ctx = client.create_context_with_id(custom_id.to_string());
696
697        assert_eq!(ctx.id(), custom_id);
698
699        // Should still be child of root
700        client.root_context().stop();
701        assert!(ctx.cancellation_token().is_cancelled());
702    }
703
704    #[test]
705    fn test_http_client_config_defaults() {
706        let config = HttpClientConfig::default();
707        assert!(!config.verbose);
708    }
709
710    #[test]
711    fn test_pure_openai_client_creation() {
712        let config = HttpClientConfig::default();
713        let _client = PureOpenAIClient::new(config);
714        // If we get here, creation succeeded
715    }
716
717    #[test]
718    fn test_nv_custom_client_creation() {
719        let config = HttpClientConfig::default();
720        let _client = NvCustomClient::new(config);
721        // If we get here, creation succeeded
722    }
723
724    #[test]
725    fn test_generic_byot_client_creation() {
726        let config = HttpClientConfig::default();
727        let _client = GenericBYOTClient::new(config);
728        // If we get here, creation succeeded
729    }
730}