dynamo_llm/protocols/openai/embeddings/
aggregator.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use super::NvCreateEmbeddingResponse;
17use crate::protocols::{
18    codec::{Message, SseCodecError},
19    convert_sse_stream, Annotated,
20};
21
22use futures::{Stream, StreamExt};
23use std::pin::Pin;
24
25/// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`.
26type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>;
27
28/// Aggregates a stream of [`NvCreateEmbeddingResponse`]s into a single
29/// [`NvCreateEmbeddingResponse`]. For embeddings, this is typically simpler
30/// than text generation as embeddings are usually returned as a complete response.
31pub struct DeltaAggregator {
32    /// The accumulated embeddings response.
33    response: Option<NvCreateEmbeddingResponse>,
34    /// Optional error message if an error occurs during aggregation.
35    error: Option<String>,
36}
37
38impl Default for DeltaAggregator {
39    /// Provides a default implementation for `DeltaAggregator` by calling [`DeltaAggregator::new`].
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl DeltaAggregator {
46    /// Creates a new, empty [`DeltaAggregator`] instance.
47    pub fn new() -> Self {
48        Self {
49            response: None,
50            error: None,
51        }
52    }
53
54    /// Aggregates a stream of [`NvCreateEmbeddingResponse`]s into a single
55    /// [`NvCreateEmbeddingResponse`].
56    ///
57    /// # Arguments
58    /// * `stream` - A stream of annotated embedding responses.
59    ///
60    /// # Returns
61    /// * `Ok(NvCreateEmbeddingResponse)` if aggregation is successful.
62    /// * `Err(String)` if an error occurs during processing.
63    pub async fn apply(
64        stream: DataStream<Annotated<NvCreateEmbeddingResponse>>,
65    ) -> Result<NvCreateEmbeddingResponse, String> {
66        let aggregator = stream
67            .fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
68                // Attempt to unwrap the delta, capturing any errors.
69                let delta = match delta.ok() {
70                    Ok(delta) => delta,
71                    Err(error) => {
72                        aggregator.error = Some(error);
73                        return aggregator;
74                    }
75                };
76
77                if aggregator.error.is_none() {
78                    if let Some(response) = delta.data {
79                        // For embeddings, we typically expect a single complete response
80                        // or we accumulate data from multiple responses
81                        match &mut aggregator.response {
82                            Some(existing) => {
83                                // Merge embedding data if we have multiple responses
84                                existing.inner.data.extend(response.inner.data);
85
86                                // Update usage statistics
87                                existing.inner.usage.prompt_tokens +=
88                                    response.inner.usage.prompt_tokens;
89                                existing.inner.usage.total_tokens +=
90                                    response.inner.usage.total_tokens;
91                            }
92                            None => {
93                                aggregator.response = Some(response);
94                            }
95                        }
96                    }
97                }
98                aggregator
99            })
100            .await;
101
102        // Return early if an error was encountered.
103        if let Some(error) = aggregator.error {
104            return Err(error);
105        }
106
107        // Return the aggregated response or an empty response if none was found.
108        Ok(aggregator
109            .response
110            .unwrap_or_else(NvCreateEmbeddingResponse::empty))
111    }
112}
113
114impl NvCreateEmbeddingResponse {
115    /// Converts an SSE stream into a [`NvCreateEmbeddingResponse`].
116    ///
117    /// # Arguments
118    /// * `stream` - A stream of SSE messages containing embedding responses.
119    ///
120    /// # Returns
121    /// * `Ok(NvCreateEmbeddingResponse)` if aggregation succeeds.
122    /// * `Err(String)` if an error occurs.
123    pub async fn from_sse_stream(
124        stream: DataStream<Result<Message, SseCodecError>>,
125    ) -> Result<NvCreateEmbeddingResponse, String> {
126        let stream = convert_sse_stream::<NvCreateEmbeddingResponse>(stream);
127        NvCreateEmbeddingResponse::from_annotated_stream(stream).await
128    }
129
130    /// Aggregates an annotated stream of embedding responses into a final response.
131    ///
132    /// # Arguments
133    /// * `stream` - A stream of annotated embedding responses.
134    ///
135    /// # Returns
136    /// * `Ok(NvCreateEmbeddingResponse)` if aggregation succeeds.
137    /// * `Err(String)` if an error occurs.
138    pub async fn from_annotated_stream(
139        stream: DataStream<Annotated<NvCreateEmbeddingResponse>>,
140    ) -> Result<NvCreateEmbeddingResponse, String> {
141        DeltaAggregator::apply(stream).await
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use futures::stream;
149
150    fn create_test_embedding_response(
151        embeddings: Vec<async_openai::types::Embedding>,
152        prompt_tokens: u32,
153        total_tokens: u32,
154    ) -> Annotated<NvCreateEmbeddingResponse> {
155        let response = NvCreateEmbeddingResponse {
156            inner: async_openai::types::CreateEmbeddingResponse {
157                object: "list".to_string(),
158                model: "test-model".to_string(),
159                data: embeddings,
160                usage: async_openai::types::EmbeddingUsage {
161                    prompt_tokens,
162                    total_tokens,
163                },
164            },
165        };
166
167        Annotated::from_data(response)
168    }
169
170    #[tokio::test]
171    async fn test_empty_stream() {
172        let stream = stream::empty();
173        let result = DeltaAggregator::apply(Box::pin(stream)).await;
174
175        assert!(result.is_ok());
176        let response = result.unwrap();
177        assert_eq!(response.inner.data.len(), 0);
178        assert_eq!(response.inner.object, "list");
179        assert_eq!(response.inner.model, "embedding");
180    }
181
182    #[tokio::test]
183    async fn test_single_embedding() {
184        let embedding = async_openai::types::Embedding {
185            index: 0,
186            object: "embedding".to_string(),
187            embedding: vec![0.1, 0.2, 0.3],
188        };
189
190        let annotated = create_test_embedding_response(vec![embedding.clone()], 10, 10);
191        let stream = stream::iter(vec![annotated]);
192
193        let result = DeltaAggregator::apply(Box::pin(stream)).await;
194
195        assert!(result.is_ok());
196        let response = result.unwrap();
197        assert_eq!(response.inner.data.len(), 1);
198        assert_eq!(response.inner.data[0].index, 0);
199        assert_eq!(response.inner.data[0].embedding, vec![0.1, 0.2, 0.3]);
200        assert_eq!(response.inner.usage.prompt_tokens, 10);
201        assert_eq!(response.inner.usage.total_tokens, 10);
202    }
203
204    #[tokio::test]
205    async fn test_multiple_embeddings() {
206        let embedding1 = async_openai::types::Embedding {
207            index: 0,
208            object: "embedding".to_string(),
209            embedding: vec![0.1, 0.2, 0.3],
210        };
211
212        let embedding2 = async_openai::types::Embedding {
213            index: 1,
214            object: "embedding".to_string(),
215            embedding: vec![0.4, 0.5, 0.6],
216        };
217
218        let annotated1 = create_test_embedding_response(vec![embedding1.clone()], 5, 5);
219        let annotated2 = create_test_embedding_response(vec![embedding2.clone()], 7, 7);
220        let stream = stream::iter(vec![annotated1, annotated2]);
221
222        let result = DeltaAggregator::apply(Box::pin(stream)).await;
223
224        assert!(result.is_ok());
225        let response = result.unwrap();
226        assert_eq!(response.inner.data.len(), 2);
227        assert_eq!(response.inner.data[0].index, 0);
228        assert_eq!(response.inner.data[1].index, 1);
229        assert_eq!(response.inner.usage.prompt_tokens, 12); // sum of 5 and 7
230        assert_eq!(response.inner.usage.total_tokens, 12); // sum of 5 and 7
231    }
232
233    #[tokio::test]
234    async fn test_error_in_stream() {
235        let error_annotated =
236            Annotated::<NvCreateEmbeddingResponse>::from_error("Test error".to_string());
237        let stream = stream::iter(vec![error_annotated]);
238
239        let result = DeltaAggregator::apply(Box::pin(stream)).await;
240
241        assert!(result.is_err());
242        assert!(result.unwrap_err().contains("Test error"));
243    }
244}