use super::NvCreateEmbeddingResponse;
use crate::protocols::{
Annotated,
codec::{Message, SseCodecError},
convert_sse_stream,
};
use dynamo_runtime::engine::DataStream;
use futures::{Stream, StreamExt};
pub struct DeltaAggregator {
response: Option<NvCreateEmbeddingResponse>,
error: Option<String>,
}
impl Default for DeltaAggregator {
fn default() -> Self {
Self::new()
}
}
impl DeltaAggregator {
pub fn new() -> Self {
Self {
response: None,
error: None,
}
}
pub async fn apply(
stream: impl Stream<Item = Annotated<NvCreateEmbeddingResponse>>,
) -> Result<NvCreateEmbeddingResponse, String> {
let aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
let delta = match delta.ok() {
Ok(delta) => delta,
Err(error) => {
aggregator.error = Some(error);
return aggregator;
}
};
if aggregator.error.is_none()
&& let Some(response) = delta.data
{
match &mut aggregator.response {
Some(existing) => {
existing.inner.data.extend(response.inner.data);
existing.inner.usage.prompt_tokens +=
response.inner.usage.prompt_tokens;
existing.inner.usage.total_tokens += response.inner.usage.total_tokens;
}
None => {
aggregator.response = Some(response);
}
}
}
aggregator
})
.await;
if let Some(error) = aggregator.error {
return Err(error);
}
Ok(aggregator
.response
.unwrap_or_else(NvCreateEmbeddingResponse::empty))
}
}
impl NvCreateEmbeddingResponse {
pub async fn from_sse_stream(
stream: DataStream<Result<Message, SseCodecError>>,
) -> Result<NvCreateEmbeddingResponse, String> {
let stream = convert_sse_stream::<NvCreateEmbeddingResponse>(stream);
NvCreateEmbeddingResponse::from_annotated_stream(stream).await
}
pub async fn from_annotated_stream(
stream: impl Stream<Item = Annotated<NvCreateEmbeddingResponse>>,
) -> Result<NvCreateEmbeddingResponse, String> {
DeltaAggregator::apply(stream).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::stream;
fn create_test_embedding_response(
embeddings: Vec<dynamo_async_openai::types::Embedding>,
prompt_tokens: u32,
total_tokens: u32,
) -> Annotated<NvCreateEmbeddingResponse> {
let response = NvCreateEmbeddingResponse {
inner: dynamo_async_openai::types::CreateEmbeddingResponse {
object: "list".to_string(),
model: "test-model".to_string(),
data: embeddings,
usage: dynamo_async_openai::types::EmbeddingUsage {
prompt_tokens,
total_tokens,
},
},
};
Annotated::from_data(response)
}
#[tokio::test]
async fn test_empty_stream() {
let stream = stream::empty();
let result = DeltaAggregator::apply(Box::pin(stream)).await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.inner.data.len(), 0);
assert_eq!(response.inner.object, "list");
assert_eq!(response.inner.model, "embedding");
}
#[tokio::test]
async fn test_single_embedding() {
let embedding = dynamo_async_openai::types::Embedding {
index: 0,
object: "embedding".to_string(),
embedding: vec![0.1, 0.2, 0.3],
};
let annotated = create_test_embedding_response(vec![embedding.clone()], 10, 10);
let stream = stream::iter(vec![annotated]);
let result = DeltaAggregator::apply(Box::pin(stream)).await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.inner.data.len(), 1);
assert_eq!(response.inner.data[0].index, 0);
assert_eq!(response.inner.data[0].embedding, vec![0.1, 0.2, 0.3]);
assert_eq!(response.inner.usage.prompt_tokens, 10);
assert_eq!(response.inner.usage.total_tokens, 10);
}
#[tokio::test]
async fn test_multiple_embeddings() {
let embedding1 = dynamo_async_openai::types::Embedding {
index: 0,
object: "embedding".to_string(),
embedding: vec![0.1, 0.2, 0.3],
};
let embedding2 = dynamo_async_openai::types::Embedding {
index: 1,
object: "embedding".to_string(),
embedding: vec![0.4, 0.5, 0.6],
};
let annotated1 = create_test_embedding_response(vec![embedding1.clone()], 5, 5);
let annotated2 = create_test_embedding_response(vec![embedding2.clone()], 7, 7);
let stream = stream::iter(vec![annotated1, annotated2]);
let result = DeltaAggregator::apply(Box::pin(stream)).await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.inner.data.len(), 2);
assert_eq!(response.inner.data[0].index, 0);
assert_eq!(response.inner.data[1].index, 1);
assert_eq!(response.inner.usage.prompt_tokens, 12); assert_eq!(response.inner.usage.total_tokens, 12); }
#[tokio::test]
async fn test_error_in_stream() {
let error_annotated =
Annotated::<NvCreateEmbeddingResponse>::from_error("Test error".to_string());
let stream = stream::iter(vec![error_annotated]);
let result = DeltaAggregator::apply(Box::pin(stream)).await;
assert!(result.is_err());
assert!(result.unwrap_err().contains("Test error"));
}
}