use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::time::Instant;
use async_trait::async_trait;
use derive_getters::Dissolve;
use dynamo_async_openai::{Client, config::OpenAIConfig, error::OpenAIError};
use futures::Stream;
use serde_json::Value;
use tokio_util::sync::CancellationToken;
use tracing;
use uuid::Uuid;
use crate::protocols::Annotated;
use crate::protocols::openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
};
use dynamo_runtime::engine::{
AsyncEngineContext, AsyncEngineContextProvider, AsyncEngineStream, Data, DataStream,
};
#[derive(Clone, Default)]
pub struct HttpClientConfig {
pub openai_config: OpenAIConfig,
pub verbose: bool,
}
#[derive(Debug, thiserror::Error)]
pub enum HttpClientError {
#[error("OpenAI API error: {0}")]
OpenAI(#[from] OpenAIError),
#[error("Request timeout")]
Timeout,
#[error("Request cancelled")]
Cancelled,
#[error("Invalid request: {0}")]
InvalidRequest(String),
}
#[derive(Clone)]
pub struct HttpRequestContext {
id: String,
cancel_token: CancellationToken,
created_at: Instant,
stopped: Arc<std::sync::atomic::AtomicBool>,
child_context: Arc<Mutex<Vec<Arc<dyn AsyncEngineContext>>>>,
}
impl HttpRequestContext {
pub fn new() -> Self {
Self {
id: Uuid::new_v4().to_string(),
cancel_token: CancellationToken::new(),
created_at: Instant::now(),
stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)),
child_context: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn with_id(id: String) -> Self {
Self {
id,
cancel_token: CancellationToken::new(),
created_at: Instant::now(),
stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)),
child_context: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn child(&self) -> Self {
Self {
id: Uuid::new_v4().to_string(),
cancel_token: self.cancel_token.child_token(),
created_at: Instant::now(),
stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)),
child_context: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn child_with_id(&self, id: String) -> Self {
Self {
id,
cancel_token: self.cancel_token.child_token(),
created_at: Instant::now(),
stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)),
child_context: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn cancellation_token(&self) -> CancellationToken {
self.cancel_token.clone()
}
pub fn elapsed(&self) -> std::time::Duration {
self.created_at.elapsed()
}
}
impl Default for HttpRequestContext {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for HttpRequestContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HttpRequestContext")
.field("id", &self.id)
.field("created_at", &self.created_at)
.field("is_stopped", &self.is_stopped())
.field("is_killed", &self.is_killed())
.field("is_cancelled", &self.cancel_token.is_cancelled())
.finish()
}
}
#[async_trait]
impl AsyncEngineContext for HttpRequestContext {
fn id(&self) -> &str {
&self.id
}
fn stop(&self) {
let children = self
.child_context
.lock()
.expect("Failed to lock child context")
.iter()
.cloned()
.collect::<Vec<_>>();
for child in children {
child.stop();
}
self.stopped
.store(true, std::sync::atomic::Ordering::Release);
self.cancel_token.cancel();
}
fn stop_generating(&self) {
let children = self
.child_context
.lock()
.expect("Failed to lock child context")
.iter()
.cloned()
.collect::<Vec<_>>();
for child in children {
child.stop_generating();
}
self.stopped
.store(true, std::sync::atomic::Ordering::Release);
self.cancel_token.cancel();
}
fn kill(&self) {
let children = self
.child_context
.lock()
.expect("Failed to lock child context")
.iter()
.cloned()
.collect::<Vec<_>>();
for child in children {
child.kill();
}
self.stopped
.store(true, std::sync::atomic::Ordering::Release);
self.cancel_token.cancel();
}
fn is_stopped(&self) -> bool {
self.stopped.load(std::sync::atomic::Ordering::Acquire)
}
fn is_killed(&self) -> bool {
self.stopped.load(std::sync::atomic::Ordering::Acquire)
}
async fn stopped(&self) {
self.cancel_token.cancelled().await;
}
async fn killed(&self) {
self.cancel_token.cancelled().await;
}
fn link_child(&self, child: Arc<dyn AsyncEngineContext>) {
self.child_context
.lock()
.expect("Failed to lock child context")
.push(child);
}
}
pub struct BaseHttpClient {
client: Client<OpenAIConfig>,
config: HttpClientConfig,
root_context: HttpRequestContext,
}
impl BaseHttpClient {
pub fn new(config: HttpClientConfig) -> Self {
let client = Client::with_config(config.openai_config.clone());
Self {
client,
config,
root_context: HttpRequestContext::new(),
}
}
pub fn client(&self) -> &Client<OpenAIConfig> {
&self.client
}
pub fn create_context(&self) -> HttpRequestContext {
self.root_context.child()
}
pub fn create_context_with_id(&self, id: String) -> HttpRequestContext {
self.root_context.child_with_id(id)
}
pub fn root_context(&self) -> &HttpRequestContext {
&self.root_context
}
pub fn is_verbose(&self) -> bool {
self.config.verbose
}
}
pub type NvChatResponseStream =
DataStream<Result<Annotated<NvCreateChatCompletionStreamResponse>, OpenAIError>>;
pub type ByotResponseStream = DataStream<Result<Value, OpenAIError>>;
pub type OpenAIChatResponseStream =
DataStream<Result<dynamo_async_openai::types::CreateChatCompletionStreamResponse, OpenAIError>>;
#[derive(Dissolve)]
pub struct HttpResponseStream<T> {
pub stream: DataStream<T>,
pub context: Arc<dyn AsyncEngineContext>,
}
impl<T> HttpResponseStream<T> {
pub fn new(stream: DataStream<T>, context: Arc<dyn AsyncEngineContext>) -> Self {
Self { stream, context }
}
}
impl<T: Data> Stream for HttpResponseStream<T> {
type Item = T;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.stream).poll_next(cx)
}
}
impl<T: Data> AsyncEngineContextProvider for HttpResponseStream<T> {
fn context(&self) -> Arc<dyn AsyncEngineContext> {
self.context.clone()
}
}
impl<T: Data> HttpResponseStream<T> {
pub fn into_async_engine_stream(self) -> Pin<Box<dyn AsyncEngineStream<T>>>
where
T: 'static,
{
Box::pin(AsyncEngineStreamWrapper {
stream: self.stream,
context: self.context,
})
}
}
struct AsyncEngineStreamWrapper<T> {
stream: DataStream<T>,
context: Arc<dyn AsyncEngineContext>,
}
impl<T: Data> Stream for AsyncEngineStreamWrapper<T> {
type Item = T;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.stream).poll_next(cx)
}
}
impl<T: Data> AsyncEngineContextProvider for AsyncEngineStreamWrapper<T> {
fn context(&self) -> Arc<dyn AsyncEngineContext> {
self.context.clone()
}
}
impl<T: Data> AsyncEngineStream<T> for AsyncEngineStreamWrapper<T> {}
impl<T> std::fmt::Debug for AsyncEngineStreamWrapper<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncEngineStreamWrapper")
.field("context", &self.context)
.finish()
}
}
impl<T: Data> std::fmt::Debug for HttpResponseStream<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HttpResponseStream")
.field("context", &self.context)
.finish()
}
}
pub type NvHttpResponseStream =
HttpResponseStream<Result<Annotated<NvCreateChatCompletionStreamResponse>, OpenAIError>>;
pub type ByotHttpResponseStream = HttpResponseStream<Result<Value, OpenAIError>>;
pub type OpenAIHttpResponseStream = HttpResponseStream<
Result<dynamo_async_openai::types::CreateChatCompletionStreamResponse, OpenAIError>,
>;
pub struct PureOpenAIClient {
base: BaseHttpClient,
}
impl PureOpenAIClient {
pub fn new(config: HttpClientConfig) -> Self {
Self {
base: BaseHttpClient::new(config),
}
}
pub async fn chat_stream(
&self,
request: dynamo_async_openai::types::CreateChatCompletionRequest,
) -> Result<OpenAIHttpResponseStream, HttpClientError> {
let ctx = self.base.create_context();
self.chat_stream_with_context(request, ctx).await
}
pub async fn chat_stream_with_context(
&self,
request: dynamo_async_openai::types::CreateChatCompletionRequest,
context: HttpRequestContext,
) -> Result<OpenAIHttpResponseStream, HttpClientError> {
let ctx_arc: Arc<dyn AsyncEngineContext> = Arc::new(context.clone());
if !request.stream.unwrap_or(false) {
return Err(HttpClientError::InvalidRequest(
"chat_stream requires the request to have 'stream': true".to_string(),
));
}
if self.base.is_verbose() {
tracing::info!(
"Starting pure OpenAI chat stream for request {}",
context.id()
);
}
let stream = self
.base
.client()
.chat()
.create_stream(request)
.await
.map_err(HttpClientError::OpenAI)?;
Ok(HttpResponseStream::new(stream, ctx_arc))
}
}
pub struct NvCustomClient {
base: BaseHttpClient,
}
impl NvCustomClient {
pub fn new(config: HttpClientConfig) -> Self {
Self {
base: BaseHttpClient::new(config),
}
}
pub async fn chat_stream(
&self,
request: NvCreateChatCompletionRequest,
) -> Result<NvHttpResponseStream, HttpClientError> {
let ctx = self.base.create_context();
self.chat_stream_with_context(request, ctx).await
}
pub async fn chat_stream_with_context(
&self,
request: NvCreateChatCompletionRequest,
context: HttpRequestContext,
) -> Result<NvHttpResponseStream, HttpClientError> {
let ctx_arc: Arc<dyn AsyncEngineContext> = Arc::new(context.clone());
if !request.inner.stream.unwrap_or(false) {
return Err(HttpClientError::InvalidRequest(
"chat_stream requires the request to have 'stream': true".to_string(),
));
}
if self.base.is_verbose() {
tracing::info!(
"Starting NV custom chat stream for request {}",
context.id()
);
}
let stream = self
.base
.client()
.chat()
.create_stream_byot(request)
.await
.map_err(HttpClientError::OpenAI)?;
Ok(HttpResponseStream::new(stream, ctx_arc))
}
}
pub struct GenericBYOTClient {
base: BaseHttpClient,
}
impl GenericBYOTClient {
pub fn new(config: HttpClientConfig) -> Self {
Self {
base: BaseHttpClient::new(config),
}
}
pub async fn chat_stream(
&self,
request: Value,
) -> Result<ByotHttpResponseStream, HttpClientError> {
let ctx = self.base.create_context();
self.chat_stream_with_context(request, ctx).await
}
pub async fn chat_stream_with_context(
&self,
request: Value,
context: HttpRequestContext,
) -> Result<ByotHttpResponseStream, HttpClientError> {
let ctx_arc: Arc<dyn AsyncEngineContext> = Arc::new(context.clone());
if self.base.is_verbose() {
tracing::info!(
"Starting generic BYOT chat stream for request {}",
context.id()
);
}
if let Some(stream_val) = request.get("stream") {
if !stream_val.as_bool().unwrap_or(false) {
return Err(HttpClientError::InvalidRequest(
"Request must have 'stream': true for streaming".to_string(),
));
}
} else {
return Err(HttpClientError::InvalidRequest(
"Request must include 'stream' field".to_string(),
));
}
let stream = self
.base
.client()
.chat()
.create_stream_byot(request)
.await
.map_err(HttpClientError::OpenAI)?;
Ok(HttpResponseStream::new(stream, ctx_arc))
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::{Duration, sleep};
#[tokio::test]
async fn test_http_request_context_creation() {
let ctx = HttpRequestContext::new();
assert!(!ctx.id().is_empty());
assert!(!ctx.is_stopped());
assert!(!ctx.is_killed());
}
#[tokio::test]
async fn test_http_request_context_child() {
let parent = HttpRequestContext::new();
let child = parent.child();
assert_ne!(parent.id(), child.id());
assert!(!child.is_stopped());
parent.stop();
assert!(parent.is_stopped());
assert!(child.cancellation_token().is_cancelled());
}
#[tokio::test]
async fn test_http_request_context_child_with_id() {
let parent = HttpRequestContext::new();
let child_id = "test-child";
let child = parent.child_with_id(child_id.to_string());
assert_eq!(child.id(), child_id);
assert!(!child.is_stopped());
parent.stop();
assert!(child.cancellation_token().is_cancelled());
}
#[tokio::test]
async fn test_http_request_context_cancellation() {
let ctx = HttpRequestContext::new();
let cancel_token = ctx.cancellation_token();
assert!(!ctx.is_stopped());
ctx.stop();
assert!(ctx.is_stopped());
assert!(cancel_token.is_cancelled());
}
#[tokio::test]
async fn test_http_request_context_kill() {
let ctx = HttpRequestContext::new();
assert!(!ctx.is_killed());
ctx.kill();
assert!(ctx.is_killed());
assert!(ctx.is_stopped());
}
#[tokio::test]
async fn test_http_request_context_async_cancellation() {
let ctx = HttpRequestContext::new();
let ctx_clone = ctx.clone();
let task = tokio::spawn(async move {
ctx_clone.stopped().await;
});
sleep(Duration::from_millis(10)).await;
ctx.stop();
task.await.unwrap();
}
#[test]
fn test_base_http_client_creation() {
let config = HttpClientConfig::default();
let client = BaseHttpClient::new(config);
assert!(!client.is_verbose());
assert!(!client.root_context().id().is_empty());
}
#[test]
fn test_base_http_client_context_creation() {
let config = HttpClientConfig::default();
let client = BaseHttpClient::new(config);
let ctx1 = client.create_context();
let ctx2 = client.create_context();
assert_ne!(ctx1.id(), ctx2.id());
client.root_context().stop();
assert!(ctx1.cancellation_token().is_cancelled());
assert!(ctx2.cancellation_token().is_cancelled());
}
#[test]
fn test_base_http_client_context_with_id() {
let config = HttpClientConfig::default();
let client = BaseHttpClient::new(config);
let custom_id = "custom-request-id";
let ctx = client.create_context_with_id(custom_id.to_string());
assert_eq!(ctx.id(), custom_id);
client.root_context().stop();
assert!(ctx.cancellation_token().is_cancelled());
}
#[test]
fn test_http_client_config_defaults() {
let config = HttpClientConfig::default();
assert!(!config.verbose);
}
#[test]
fn test_pure_openai_client_creation() {
let config = HttpClientConfig::default();
let _client = PureOpenAIClient::new(config);
}
#[test]
fn test_nv_custom_client_creation() {
let config = HttpClientConfig::default();
let _client = NvCustomClient::new(config);
}
#[test]
fn test_generic_byot_client_creation() {
let config = HttpClientConfig::default();
let _client = GenericBYOTClient::new(config);
}
}