dynamo_llm/protocols/openai/embeddings/
aggregator.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::NvCreateEmbeddingResponse;
5use crate::protocols::{
6    Annotated,
7    codec::{Message, SseCodecError},
8    convert_sse_stream,
9};
10
11use dynamo_runtime::engine::DataStream;
12use futures::{Stream, StreamExt};
13
14/// Aggregates a stream of [`NvCreateEmbeddingResponse`]s into a single
15/// [`NvCreateEmbeddingResponse`]. For embeddings, this is typically simpler
16/// than text generation as embeddings are usually returned as a complete response.
17pub struct DeltaAggregator {
18    /// The accumulated embeddings response.
19    response: Option<NvCreateEmbeddingResponse>,
20    /// Optional error message if an error occurs during aggregation.
21    error: Option<String>,
22}
23
24impl Default for DeltaAggregator {
25    /// Provides a default implementation for `DeltaAggregator` by calling [`DeltaAggregator::new`].
26    fn default() -> Self {
27        Self::new()
28    }
29}
30
31impl DeltaAggregator {
32    /// Creates a new, empty [`DeltaAggregator`] instance.
33    pub fn new() -> Self {
34        Self {
35            response: None,
36            error: None,
37        }
38    }
39
40    /// Aggregates a stream of [`NvCreateEmbeddingResponse`]s into a single
41    /// [`NvCreateEmbeddingResponse`].
42    ///
43    /// # Arguments
44    /// * `stream` - A stream of annotated embedding responses.
45    ///
46    /// # Returns
47    /// * `Ok(NvCreateEmbeddingResponse)` if aggregation is successful.
48    /// * `Err(String)` if an error occurs during processing.
49    pub async fn apply(
50        stream: impl Stream<Item = Annotated<NvCreateEmbeddingResponse>>,
51    ) -> Result<NvCreateEmbeddingResponse, String> {
52        let aggregator = stream
53            .fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
54                // Attempt to unwrap the delta, capturing any errors.
55                let delta = match delta.ok() {
56                    Ok(delta) => delta,
57                    Err(error) => {
58                        aggregator.error = Some(error);
59                        return aggregator;
60                    }
61                };
62
63                if aggregator.error.is_none()
64                    && let Some(response) = delta.data
65                {
66                    // For embeddings, we typically expect a single complete response
67                    // or we accumulate data from multiple responses
68                    match &mut aggregator.response {
69                        Some(existing) => {
70                            // Merge embedding data if we have multiple responses
71                            existing.inner.data.extend(response.inner.data);
72
73                            // Update usage statistics
74                            existing.inner.usage.prompt_tokens +=
75                                response.inner.usage.prompt_tokens;
76                            existing.inner.usage.total_tokens += response.inner.usage.total_tokens;
77                        }
78                        None => {
79                            aggregator.response = Some(response);
80                        }
81                    }
82                }
83                aggregator
84            })
85            .await;
86
87        // Return early if an error was encountered.
88        if let Some(error) = aggregator.error {
89            return Err(error);
90        }
91
92        // Return the aggregated response or an empty response if none was found.
93        Ok(aggregator
94            .response
95            .unwrap_or_else(NvCreateEmbeddingResponse::empty))
96    }
97}
98
99impl NvCreateEmbeddingResponse {
100    /// Converts an SSE stream into a [`NvCreateEmbeddingResponse`].
101    ///
102    /// # Arguments
103    /// * `stream` - A stream of SSE messages containing embedding responses.
104    ///
105    /// # Returns
106    /// * `Ok(NvCreateEmbeddingResponse)` if aggregation succeeds.
107    /// * `Err(String)` if an error occurs.
108    pub async fn from_sse_stream(
109        stream: DataStream<Result<Message, SseCodecError>>,
110    ) -> Result<NvCreateEmbeddingResponse, String> {
111        let stream = convert_sse_stream::<NvCreateEmbeddingResponse>(stream);
112        NvCreateEmbeddingResponse::from_annotated_stream(stream).await
113    }
114
115    /// Aggregates an annotated stream of embedding responses into a final response.
116    ///
117    /// # Arguments
118    /// * `stream` - A stream of annotated embedding responses.
119    ///
120    /// # Returns
121    /// * `Ok(NvCreateEmbeddingResponse)` if aggregation succeeds.
122    /// * `Err(String)` if an error occurs.
123    pub async fn from_annotated_stream(
124        stream: impl Stream<Item = Annotated<NvCreateEmbeddingResponse>>,
125    ) -> Result<NvCreateEmbeddingResponse, String> {
126        DeltaAggregator::apply(stream).await
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use futures::stream;
134
135    fn create_test_embedding_response(
136        embeddings: Vec<dynamo_async_openai::types::Embedding>,
137        prompt_tokens: u32,
138        total_tokens: u32,
139    ) -> Annotated<NvCreateEmbeddingResponse> {
140        let response = NvCreateEmbeddingResponse {
141            inner: dynamo_async_openai::types::CreateEmbeddingResponse {
142                object: "list".to_string(),
143                model: "test-model".to_string(),
144                data: embeddings,
145                usage: dynamo_async_openai::types::EmbeddingUsage {
146                    prompt_tokens,
147                    total_tokens,
148                },
149            },
150        };
151
152        Annotated::from_data(response)
153    }
154
155    #[tokio::test]
156    async fn test_empty_stream() {
157        let stream = stream::empty();
158        let result = DeltaAggregator::apply(Box::pin(stream)).await;
159
160        assert!(result.is_ok());
161        let response = result.unwrap();
162        assert_eq!(response.inner.data.len(), 0);
163        assert_eq!(response.inner.object, "list");
164        assert_eq!(response.inner.model, "embedding");
165    }
166
167    #[tokio::test]
168    async fn test_single_embedding() {
169        let embedding = dynamo_async_openai::types::Embedding {
170            index: 0,
171            object: "embedding".to_string(),
172            embedding: vec![0.1, 0.2, 0.3],
173        };
174
175        let annotated = create_test_embedding_response(vec![embedding.clone()], 10, 10);
176        let stream = stream::iter(vec![annotated]);
177
178        let result = DeltaAggregator::apply(Box::pin(stream)).await;
179
180        assert!(result.is_ok());
181        let response = result.unwrap();
182        assert_eq!(response.inner.data.len(), 1);
183        assert_eq!(response.inner.data[0].index, 0);
184        assert_eq!(response.inner.data[0].embedding, vec![0.1, 0.2, 0.3]);
185        assert_eq!(response.inner.usage.prompt_tokens, 10);
186        assert_eq!(response.inner.usage.total_tokens, 10);
187    }
188
189    #[tokio::test]
190    async fn test_multiple_embeddings() {
191        let embedding1 = dynamo_async_openai::types::Embedding {
192            index: 0,
193            object: "embedding".to_string(),
194            embedding: vec![0.1, 0.2, 0.3],
195        };
196
197        let embedding2 = dynamo_async_openai::types::Embedding {
198            index: 1,
199            object: "embedding".to_string(),
200            embedding: vec![0.4, 0.5, 0.6],
201        };
202
203        let annotated1 = create_test_embedding_response(vec![embedding1.clone()], 5, 5);
204        let annotated2 = create_test_embedding_response(vec![embedding2.clone()], 7, 7);
205        let stream = stream::iter(vec![annotated1, annotated2]);
206
207        let result = DeltaAggregator::apply(Box::pin(stream)).await;
208
209        assert!(result.is_ok());
210        let response = result.unwrap();
211        assert_eq!(response.inner.data.len(), 2);
212        assert_eq!(response.inner.data[0].index, 0);
213        assert_eq!(response.inner.data[1].index, 1);
214        assert_eq!(response.inner.usage.prompt_tokens, 12); // sum of 5 and 7
215        assert_eq!(response.inner.usage.total_tokens, 12); // sum of 5 and 7
216    }
217
218    #[tokio::test]
219    async fn test_error_in_stream() {
220        let error_annotated =
221            Annotated::<NvCreateEmbeddingResponse>::from_error("Test error".to_string());
222        let stream = stream::iter(vec![error_annotated]);
223
224        let result = DeltaAggregator::apply(Box::pin(stream)).await;
225
226        assert!(result.is_err());
227        assert!(result.unwrap_err().contains("Test error"));
228    }
229}