1use std::pin::Pin;
11use std::sync::{Arc, Mutex};
12use std::task::{Context, Poll};
13use std::time::Instant;
14
15use async_trait::async_trait;
16use derive_getters::Dissolve;
17use dynamo_async_openai::{Client, config::OpenAIConfig, error::OpenAIError};
18use futures::Stream;
19use serde_json::Value;
20use tokio_util::sync::CancellationToken;
21use tracing;
22use uuid::Uuid;
23
24use crate::protocols::Annotated;
26use crate::protocols::openai::chat_completions::{
27 NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
28};
29use dynamo_runtime::engine::{
30 AsyncEngineContext, AsyncEngineContextProvider, AsyncEngineStream, Data, DataStream,
31};
32
33#[derive(Clone, Default)]
35pub struct HttpClientConfig {
36 pub openai_config: OpenAIConfig,
38 pub verbose: bool,
40}
41
42#[derive(Debug, thiserror::Error)]
44pub enum HttpClientError {
45 #[error("OpenAI API error: {0}")]
46 OpenAI(#[from] OpenAIError),
47 #[error("Request timeout")]
48 Timeout,
49 #[error("Request cancelled")]
50 Cancelled,
51 #[error("Invalid request: {0}")]
52 InvalidRequest(String),
53}
54
55#[derive(Clone)]
58pub struct HttpRequestContext {
59 id: String,
61 cancel_token: CancellationToken,
63 created_at: Instant,
65 stopped: Arc<std::sync::atomic::AtomicBool>,
67 child_context: Arc<Mutex<Vec<Arc<dyn AsyncEngineContext>>>>,
69}
70
71impl HttpRequestContext {
72 pub fn new() -> Self {
74 Self {
75 id: Uuid::new_v4().to_string(),
76 cancel_token: CancellationToken::new(),
77 created_at: Instant::now(),
78 stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)),
79 child_context: Arc::new(Mutex::new(Vec::new())),
80 }
81 }
82
83 pub fn with_id(id: String) -> Self {
85 Self {
86 id,
87 cancel_token: CancellationToken::new(),
88 created_at: Instant::now(),
89 stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)),
90 child_context: Arc::new(Mutex::new(Vec::new())),
91 }
92 }
93
94 pub fn child(&self) -> Self {
97 Self {
98 id: Uuid::new_v4().to_string(),
99 cancel_token: self.cancel_token.child_token(),
100 created_at: Instant::now(),
101 stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)),
102 child_context: Arc::new(Mutex::new(Vec::new())),
103 }
104 }
105
106 pub fn child_with_id(&self, id: String) -> Self {
108 Self {
109 id,
110 cancel_token: self.cancel_token.child_token(),
111 created_at: Instant::now(),
112 stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)),
113 child_context: Arc::new(Mutex::new(Vec::new())),
114 }
115 }
116
117 pub fn cancellation_token(&self) -> CancellationToken {
119 self.cancel_token.clone()
120 }
121
122 pub fn elapsed(&self) -> std::time::Duration {
124 self.created_at.elapsed()
125 }
126}
127
128impl Default for HttpRequestContext {
129 fn default() -> Self {
130 Self::new()
131 }
132}
133
134impl std::fmt::Debug for HttpRequestContext {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 f.debug_struct("HttpRequestContext")
137 .field("id", &self.id)
138 .field("created_at", &self.created_at)
139 .field("is_stopped", &self.is_stopped())
140 .field("is_killed", &self.is_killed())
141 .field("is_cancelled", &self.cancel_token.is_cancelled())
142 .finish()
143 }
144}
145
146#[async_trait]
147impl AsyncEngineContext for HttpRequestContext {
148 fn id(&self) -> &str {
149 &self.id
150 }
151
152 fn stop(&self) {
153 let children = self
155 .child_context
156 .lock()
157 .expect("Failed to lock child context")
158 .iter()
159 .cloned()
160 .collect::<Vec<_>>();
161 for child in children {
162 child.stop();
163 }
164
165 self.stopped
166 .store(true, std::sync::atomic::Ordering::Release);
167 self.cancel_token.cancel();
168 }
169
170 fn stop_generating(&self) {
171 let children = self
173 .child_context
174 .lock()
175 .expect("Failed to lock child context")
176 .iter()
177 .cloned()
178 .collect::<Vec<_>>();
179 for child in children {
180 child.stop_generating();
181 }
182
183 self.stopped
185 .store(true, std::sync::atomic::Ordering::Release);
186 self.cancel_token.cancel();
187 }
188
189 fn kill(&self) {
190 let children = self
192 .child_context
193 .lock()
194 .expect("Failed to lock child context")
195 .iter()
196 .cloned()
197 .collect::<Vec<_>>();
198 for child in children {
199 child.kill();
200 }
201
202 self.stopped
203 .store(true, std::sync::atomic::Ordering::Release);
204 self.cancel_token.cancel();
205 }
206
207 fn is_stopped(&self) -> bool {
208 self.stopped.load(std::sync::atomic::Ordering::Acquire)
209 }
210
211 fn is_killed(&self) -> bool {
212 self.stopped.load(std::sync::atomic::Ordering::Acquire)
213 }
214
215 async fn stopped(&self) {
216 self.cancel_token.cancelled().await;
217 }
218
219 async fn killed(&self) {
220 self.cancel_token.cancelled().await;
222 }
223
224 fn link_child(&self, child: Arc<dyn AsyncEngineContext>) {
225 self.child_context
226 .lock()
227 .expect("Failed to lock child context")
228 .push(child);
229 }
230}
231
232pub struct BaseHttpClient {
234 client: Client<OpenAIConfig>,
236 config: HttpClientConfig,
238 root_context: HttpRequestContext,
240}
241
242impl BaseHttpClient {
243 pub fn new(config: HttpClientConfig) -> Self {
245 let client = Client::with_config(config.openai_config.clone());
246 Self {
247 client,
248 config,
249 root_context: HttpRequestContext::new(),
250 }
251 }
252
253 pub fn client(&self) -> &Client<OpenAIConfig> {
255 &self.client
256 }
257
258 pub fn create_context(&self) -> HttpRequestContext {
260 self.root_context.child()
261 }
262
263 pub fn create_context_with_id(&self, id: String) -> HttpRequestContext {
265 self.root_context.child_with_id(id)
266 }
267
268 pub fn root_context(&self) -> &HttpRequestContext {
270 &self.root_context
271 }
272
273 pub fn is_verbose(&self) -> bool {
275 self.config.verbose
276 }
277}
278
279pub type NvChatResponseStream =
281 DataStream<Result<Annotated<NvCreateChatCompletionStreamResponse>, OpenAIError>>;
282
283pub type ByotResponseStream = DataStream<Result<Value, OpenAIError>>;
285
286pub type OpenAIChatResponseStream =
288 DataStream<Result<dynamo_async_openai::types::CreateChatCompletionStreamResponse, OpenAIError>>;
289
290#[derive(Dissolve)]
293pub struct HttpResponseStream<T> {
294 pub stream: DataStream<T>,
296 pub context: Arc<dyn AsyncEngineContext>,
298}
299
300impl<T> HttpResponseStream<T> {
301 pub fn new(stream: DataStream<T>, context: Arc<dyn AsyncEngineContext>) -> Self {
303 Self { stream, context }
304 }
305}
306
307impl<T: Data> Stream for HttpResponseStream<T> {
308 type Item = T;
309
310 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
311 Pin::new(&mut self.stream).poll_next(cx)
312 }
313}
314
315impl<T: Data> AsyncEngineContextProvider for HttpResponseStream<T> {
316 fn context(&self) -> Arc<dyn AsyncEngineContext> {
317 self.context.clone()
318 }
319}
320
321impl<T: Data> HttpResponseStream<T> {
322 pub fn into_async_engine_stream(self) -> Pin<Box<dyn AsyncEngineStream<T>>>
325 where
326 T: 'static,
327 {
328 Box::pin(AsyncEngineStreamWrapper {
331 stream: self.stream,
332 context: self.context,
333 })
334 }
335}
336
337struct AsyncEngineStreamWrapper<T> {
339 stream: DataStream<T>,
340 context: Arc<dyn AsyncEngineContext>,
341}
342
343impl<T: Data> Stream for AsyncEngineStreamWrapper<T> {
344 type Item = T;
345
346 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
347 Pin::new(&mut self.stream).poll_next(cx)
348 }
349}
350
351impl<T: Data> AsyncEngineContextProvider for AsyncEngineStreamWrapper<T> {
352 fn context(&self) -> Arc<dyn AsyncEngineContext> {
353 self.context.clone()
354 }
355}
356
357impl<T: Data> AsyncEngineStream<T> for AsyncEngineStreamWrapper<T> {}
358
359impl<T> std::fmt::Debug for AsyncEngineStreamWrapper<T> {
360 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
361 f.debug_struct("AsyncEngineStreamWrapper")
362 .field("context", &self.context)
363 .finish()
364 }
365}
366
367impl<T: Data> std::fmt::Debug for HttpResponseStream<T> {
368 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
369 f.debug_struct("HttpResponseStream")
370 .field("context", &self.context)
371 .finish()
372 }
373}
374
375pub type NvHttpResponseStream =
377 HttpResponseStream<Result<Annotated<NvCreateChatCompletionStreamResponse>, OpenAIError>>;
378
379pub type ByotHttpResponseStream = HttpResponseStream<Result<Value, OpenAIError>>;
381
382pub type OpenAIHttpResponseStream = HttpResponseStream<
384 Result<dynamo_async_openai::types::CreateChatCompletionStreamResponse, OpenAIError>,
385>;
386
387pub struct PureOpenAIClient {
389 base: BaseHttpClient,
390}
391
392impl PureOpenAIClient {
393 pub fn new(config: HttpClientConfig) -> Self {
395 Self {
396 base: BaseHttpClient::new(config),
397 }
398 }
399
400 pub async fn chat_stream(
403 &self,
404 request: dynamo_async_openai::types::CreateChatCompletionRequest,
405 ) -> Result<OpenAIHttpResponseStream, HttpClientError> {
406 let ctx = self.base.create_context();
407 self.chat_stream_with_context(request, ctx).await
408 }
409
410 pub async fn chat_stream_with_context(
412 &self,
413 request: dynamo_async_openai::types::CreateChatCompletionRequest,
414 context: HttpRequestContext,
415 ) -> Result<OpenAIHttpResponseStream, HttpClientError> {
416 let ctx_arc: Arc<dyn AsyncEngineContext> = Arc::new(context.clone());
417
418 if !request.stream.unwrap_or(false) {
419 return Err(HttpClientError::InvalidRequest(
420 "chat_stream requires the request to have 'stream': true".to_string(),
421 ));
422 }
423
424 if self.base.is_verbose() {
425 tracing::info!(
426 "Starting pure OpenAI chat stream for request {}",
427 context.id()
428 );
429 }
430
431 let stream = self
433 .base
434 .client()
435 .chat()
436 .create_stream(request)
437 .await
438 .map_err(HttpClientError::OpenAI)?;
439
440 Ok(HttpResponseStream::new(stream, ctx_arc))
443 }
444}
445
446pub struct NvCustomClient {
448 base: BaseHttpClient,
449}
450
451impl NvCustomClient {
452 pub fn new(config: HttpClientConfig) -> Self {
454 Self {
455 base: BaseHttpClient::new(config),
456 }
457 }
458
459 pub async fn chat_stream(
462 &self,
463 request: NvCreateChatCompletionRequest,
464 ) -> Result<NvHttpResponseStream, HttpClientError> {
465 let ctx = self.base.create_context();
466 self.chat_stream_with_context(request, ctx).await
467 }
468
469 pub async fn chat_stream_with_context(
471 &self,
472 request: NvCreateChatCompletionRequest,
473 context: HttpRequestContext,
474 ) -> Result<NvHttpResponseStream, HttpClientError> {
475 let ctx_arc: Arc<dyn AsyncEngineContext> = Arc::new(context.clone());
476
477 if !request.inner.stream.unwrap_or(false) {
478 return Err(HttpClientError::InvalidRequest(
479 "chat_stream requires the request to have 'stream': true".to_string(),
480 ));
481 }
482
483 if self.base.is_verbose() {
484 tracing::info!(
485 "Starting NV custom chat stream for request {}",
486 context.id()
487 );
488 }
489
490 let stream = self
493 .base
494 .client()
495 .chat()
496 .create_stream_byot(request)
497 .await
498 .map_err(HttpClientError::OpenAI)?;
499
500 Ok(HttpResponseStream::new(stream, ctx_arc))
501 }
502}
503
504pub struct GenericBYOTClient {
506 base: BaseHttpClient,
507}
508
509impl GenericBYOTClient {
510 pub fn new(config: HttpClientConfig) -> Self {
512 Self {
513 base: BaseHttpClient::new(config),
514 }
515 }
516
517 pub async fn chat_stream(
520 &self,
521 request: Value,
522 ) -> Result<ByotHttpResponseStream, HttpClientError> {
523 let ctx = self.base.create_context();
524 self.chat_stream_with_context(request, ctx).await
525 }
526
527 pub async fn chat_stream_with_context(
529 &self,
530 request: Value,
531 context: HttpRequestContext,
532 ) -> Result<ByotHttpResponseStream, HttpClientError> {
533 let ctx_arc: Arc<dyn AsyncEngineContext> = Arc::new(context.clone());
534
535 if self.base.is_verbose() {
536 tracing::info!(
537 "Starting generic BYOT chat stream for request {}",
538 context.id()
539 );
540 }
541
542 if let Some(stream_val) = request.get("stream") {
544 if !stream_val.as_bool().unwrap_or(false) {
545 return Err(HttpClientError::InvalidRequest(
546 "Request must have 'stream': true for streaming".to_string(),
547 ));
548 }
549 } else {
550 return Err(HttpClientError::InvalidRequest(
551 "Request must include 'stream' field".to_string(),
552 ));
553 }
554
555 let stream = self
558 .base
559 .client()
560 .chat()
561 .create_stream_byot(request)
562 .await
563 .map_err(HttpClientError::OpenAI)?;
564
565 Ok(HttpResponseStream::new(stream, ctx_arc))
566 }
567}
568
569#[cfg(test)]
575mod tests {
576 use super::*;
577 use tokio::time::{Duration, sleep};
578
579 #[tokio::test]
580 async fn test_http_request_context_creation() {
581 let ctx = HttpRequestContext::new();
582 assert!(!ctx.id().is_empty());
583 assert!(!ctx.is_stopped());
584 assert!(!ctx.is_killed());
585 }
586
587 #[tokio::test]
588 async fn test_http_request_context_child() {
589 let parent = HttpRequestContext::new();
590 let child = parent.child();
591
592 assert_ne!(parent.id(), child.id());
594
595 assert!(!child.is_stopped());
597
598 parent.stop();
600 assert!(parent.is_stopped());
601 assert!(child.cancellation_token().is_cancelled());
602 }
603
604 #[tokio::test]
605 async fn test_http_request_context_child_with_id() {
606 let parent = HttpRequestContext::new();
607 let child_id = "test-child";
608 let child = parent.child_with_id(child_id.to_string());
609
610 assert_eq!(child.id(), child_id);
611 assert!(!child.is_stopped());
612
613 parent.stop();
615 assert!(child.cancellation_token().is_cancelled());
616 }
617
618 #[tokio::test]
619 async fn test_http_request_context_cancellation() {
620 let ctx = HttpRequestContext::new();
621 let cancel_token = ctx.cancellation_token();
622
623 assert!(!ctx.is_stopped());
625 ctx.stop();
626 assert!(ctx.is_stopped());
627 assert!(cancel_token.is_cancelled());
628 }
629
630 #[tokio::test]
631 async fn test_http_request_context_kill() {
632 let ctx = HttpRequestContext::new();
633
634 assert!(!ctx.is_killed());
636 ctx.kill();
637 assert!(ctx.is_killed());
638 assert!(ctx.is_stopped());
639 }
640
641 #[tokio::test]
642 async fn test_http_request_context_async_cancellation() {
643 let ctx = HttpRequestContext::new();
644
645 let ctx_clone = ctx.clone();
647 let task = tokio::spawn(async move {
648 ctx_clone.stopped().await;
649 });
650
651 sleep(Duration::from_millis(10)).await;
653
654 ctx.stop();
656
657 task.await.unwrap();
659 }
660
661 #[test]
662 fn test_base_http_client_creation() {
663 let config = HttpClientConfig::default();
664 let client = BaseHttpClient::new(config);
665 assert!(!client.is_verbose());
666
667 assert!(!client.root_context().id().is_empty());
669 }
670
671 #[test]
672 fn test_base_http_client_context_creation() {
673 let config = HttpClientConfig::default();
674 let client = BaseHttpClient::new(config);
675
676 let ctx1 = client.create_context();
678 let ctx2 = client.create_context();
679
680 assert_ne!(ctx1.id(), ctx2.id());
682
683 client.root_context().stop();
685 assert!(ctx1.cancellation_token().is_cancelled());
686 assert!(ctx2.cancellation_token().is_cancelled());
687 }
688
689 #[test]
690 fn test_base_http_client_context_with_id() {
691 let config = HttpClientConfig::default();
692 let client = BaseHttpClient::new(config);
693
694 let custom_id = "custom-request-id";
695 let ctx = client.create_context_with_id(custom_id.to_string());
696
697 assert_eq!(ctx.id(), custom_id);
698
699 client.root_context().stop();
701 assert!(ctx.cancellation_token().is_cancelled());
702 }
703
704 #[test]
705 fn test_http_client_config_defaults() {
706 let config = HttpClientConfig::default();
707 assert!(!config.verbose);
708 }
709
710 #[test]
711 fn test_pure_openai_client_creation() {
712 let config = HttpClientConfig::default();
713 let _client = PureOpenAIClient::new(config);
714 }
716
717 #[test]
718 fn test_nv_custom_client_creation() {
719 let config = HttpClientConfig::default();
720 let _client = NvCustomClient::new(config);
721 }
723
724 #[test]
725 fn test_generic_byot_client_creation() {
726 let config = HttpClientConfig::default();
727 let _client = GenericBYOTClient::new(config);
728 }
730}