use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse};
use crate::protocols::{
codec::{Message, SseCodecError},
convert_sse_stream, Annotated,
};
use futures::{Stream, StreamExt};
use std::{collections::HashMap, pin::Pin};
type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>;
pub struct DeltaAggregator {
id: String,
model: String,
created: u32,
usage: Option<async_openai::types::CompletionUsage>,
system_fingerprint: Option<String>,
choices: HashMap<u32, DeltaChoice>,
error: Option<String>,
service_tier: Option<async_openai::types::ServiceTierResponse>,
}
struct DeltaChoice {
index: u32,
text: String,
role: Option<async_openai::types::Role>,
finish_reason: Option<async_openai::types::FinishReason>,
logprobs: Option<async_openai::types::ChatChoiceLogprobs>,
}
impl Default for DeltaAggregator {
fn default() -> Self {
Self::new()
}
}
impl DeltaAggregator {
pub fn new() -> Self {
Self {
id: "".to_string(),
model: "".to_string(),
created: 0,
usage: None,
system_fingerprint: None,
choices: HashMap::new(),
error: None,
service_tier: None,
}
}
pub async fn apply(
stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>>,
) -> Result<NvCreateChatCompletionResponse, String> {
let aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
let delta = match delta.ok() {
Ok(delta) => delta,
Err(error) => {
aggregator.error = Some(error);
return aggregator;
}
};
if aggregator.error.is_none() && delta.data.is_some() {
let delta = delta.data.unwrap();
aggregator.id = delta.inner.id;
aggregator.model = delta.inner.model;
aggregator.created = delta.inner.created;
aggregator.service_tier = delta.inner.service_tier;
if let Some(usage) = delta.inner.usage {
aggregator.usage = Some(usage);
}
if let Some(system_fingerprint) = delta.inner.system_fingerprint {
aggregator.system_fingerprint = Some(system_fingerprint);
}
for choice in delta.inner.choices {
let state_choice =
aggregator
.choices
.entry(choice.index)
.or_insert(DeltaChoice {
index: choice.index,
text: "".to_string(),
role: choice.delta.role,
finish_reason: None,
logprobs: choice.logprobs,
});
if let Some(content) = &choice.delta.content {
state_choice.text.push_str(content);
}
if let Some(finish_reason) = choice.finish_reason {
state_choice.finish_reason = Some(finish_reason);
}
}
}
aggregator
})
.await;
let aggregator = if let Some(error) = aggregator.error {
return Err(error);
} else {
aggregator
};
let mut choices: Vec<_> = aggregator
.choices
.into_values()
.map(async_openai::types::ChatChoice::from)
.collect();
choices.sort_by(|a, b| a.index.cmp(&b.index));
let inner = async_openai::types::CreateChatCompletionResponse {
id: aggregator.id,
created: aggregator.created,
usage: aggregator.usage,
model: aggregator.model,
object: "chat.completion".to_string(),
system_fingerprint: aggregator.system_fingerprint,
choices,
service_tier: aggregator.service_tier,
};
let response = NvCreateChatCompletionResponse { inner };
Ok(response)
}
}
#[allow(deprecated)]
impl From<DeltaChoice> for async_openai::types::ChatChoice {
fn from(delta: DeltaChoice) -> Self {
async_openai::types::ChatChoice {
message: async_openai::types::ChatCompletionResponseMessage {
role: delta.role.expect("delta should have a Role"),
content: Some(delta.text),
tool_calls: None,
refusal: None,
function_call: None,
audio: None,
},
index: delta.index,
finish_reason: delta.finish_reason,
logprobs: delta.logprobs,
}
}
}
impl NvCreateChatCompletionResponse {
pub async fn from_sse_stream(
stream: DataStream<Result<Message, SseCodecError>>,
) -> Result<NvCreateChatCompletionResponse, String> {
let stream = convert_sse_stream::<NvCreateChatCompletionStreamResponse>(stream);
NvCreateChatCompletionResponse::from_annotated_stream(stream).await
}
pub async fn from_annotated_stream(
stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>>,
) -> Result<NvCreateChatCompletionResponse, String> {
DeltaAggregator::apply(stream).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::stream;
#[allow(deprecated)]
fn create_test_delta(
index: u32,
text: &str,
role: Option<async_openai::types::Role>,
finish_reason: Option<async_openai::types::FinishReason>,
) -> Annotated<NvCreateChatCompletionStreamResponse> {
let delta = async_openai::types::ChatCompletionStreamResponseDelta {
content: Some(text.to_string()),
function_call: None,
tool_calls: None,
role,
refusal: None,
};
let choice = async_openai::types::ChatChoiceStream {
index,
delta,
finish_reason,
logprobs: None,
};
let inner = async_openai::types::CreateChatCompletionStreamResponse {
id: "test_id".to_string(),
model: "meta/llama-3.1-8b-instruct".to_string(),
created: 1234567890,
service_tier: None,
usage: None,
system_fingerprint: None,
choices: vec![choice],
object: "chat.completion".to_string(),
};
let data = NvCreateChatCompletionStreamResponse { inner };
Annotated {
data: Some(data),
id: Some("test_id".to_string()),
event: None,
comment: None,
}
}
#[tokio::test]
async fn test_empty_stream() {
let stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>> =
Box::pin(stream::empty());
let result = DeltaAggregator::apply(stream).await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.inner.id, "");
assert_eq!(response.inner.model, "");
assert_eq!(response.inner.created, 0);
assert!(response.inner.usage.is_none());
assert!(response.inner.system_fingerprint.is_none());
assert_eq!(response.inner.choices.len(), 0);
assert!(response.inner.service_tier.is_none());
}
#[tokio::test]
async fn test_single_delta() {
let annotated_delta =
create_test_delta(0, "Hello,", Some(async_openai::types::Role::User), None);
let stream = Box::pin(stream::iter(vec![annotated_delta]));
let result = DeltaAggregator::apply(stream).await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.inner.id, "test_id");
assert_eq!(response.inner.model, "meta/llama-3.1-8b-instruct");
assert_eq!(response.inner.created, 1234567890);
assert!(response.inner.usage.is_none());
assert!(response.inner.system_fingerprint.is_none());
assert_eq!(response.inner.choices.len(), 1);
let choice = &response.inner.choices[0];
assert_eq!(choice.index, 0);
assert_eq!(choice.message.content.as_ref().unwrap(), "Hello,");
assert!(choice.finish_reason.is_none());
assert_eq!(choice.message.role, async_openai::types::Role::User);
assert!(response.inner.service_tier.is_none());
}
#[tokio::test]
async fn test_multiple_deltas_same_choice() {
let annotated_delta1 =
create_test_delta(0, "Hello,", Some(async_openai::types::Role::User), None);
let annotated_delta2 = create_test_delta(
0,
" world!",
None,
Some(async_openai::types::FinishReason::Stop),
);
let annotated_deltas = vec![annotated_delta1, annotated_delta2];
let stream = Box::pin(stream::iter(annotated_deltas));
let result = DeltaAggregator::apply(stream).await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.inner.choices.len(), 1);
let choice = &response.inner.choices[0];
assert_eq!(choice.index, 0);
assert_eq!(choice.message.content.as_ref().unwrap(), "Hello, world!");
assert_eq!(
choice.finish_reason,
Some(async_openai::types::FinishReason::Stop)
);
assert_eq!(choice.message.role, async_openai::types::Role::User);
}
#[allow(deprecated)]
#[tokio::test]
async fn test_multiple_choices() {
let delta = async_openai::types::CreateChatCompletionStreamResponse {
id: "test_id".to_string(),
model: "test_model".to_string(),
created: 1234567890,
service_tier: None,
usage: None,
system_fingerprint: None,
choices: vec![
async_openai::types::ChatChoiceStream {
index: 0,
delta: async_openai::types::ChatCompletionStreamResponseDelta {
role: Some(async_openai::types::Role::Assistant),
content: Some("Choice 0".to_string()),
function_call: None,
tool_calls: None,
refusal: None,
},
finish_reason: Some(async_openai::types::FinishReason::Stop),
logprobs: None,
},
async_openai::types::ChatChoiceStream {
index: 1,
delta: async_openai::types::ChatCompletionStreamResponseDelta {
role: Some(async_openai::types::Role::Assistant),
content: Some("Choice 1".to_string()),
function_call: None,
tool_calls: None,
refusal: None,
},
finish_reason: Some(async_openai::types::FinishReason::Stop),
logprobs: None,
},
],
object: "chat.completion".to_string(),
};
let data = NvCreateChatCompletionStreamResponse { inner: delta };
let annotated_delta = Annotated {
data: Some(data),
id: Some("test_id".to_string()),
event: None,
comment: None,
};
let stream = Box::pin(stream::iter(vec![annotated_delta]));
let result = DeltaAggregator::apply(stream).await;
assert!(result.is_ok());
let mut response = result.unwrap();
assert_eq!(response.inner.choices.len(), 2);
response.inner.choices.sort_by(|a, b| a.index.cmp(&b.index)); let choice0 = &response.inner.choices[0];
assert_eq!(choice0.index, 0);
assert_eq!(choice0.message.content.as_ref().unwrap(), "Choice 0");
assert_eq!(
choice0.finish_reason,
Some(async_openai::types::FinishReason::Stop)
);
assert_eq!(choice0.message.role, async_openai::types::Role::Assistant);
let choice1 = &response.inner.choices[1];
assert_eq!(choice1.index, 1);
assert_eq!(choice1.message.content.as_ref().unwrap(), "Choice 1");
assert_eq!(
choice1.finish_reason,
Some(async_openai::types::FinishReason::Stop)
);
assert_eq!(choice1.message.role, async_openai::types::Role::Assistant);
}
}