dynamo_llm/protocols/openai/embeddings/
aggregator.rs1use super::NvCreateEmbeddingResponse;
5use crate::protocols::{
6 Annotated,
7 codec::{Message, SseCodecError},
8 convert_sse_stream,
9};
10
11use dynamo_runtime::engine::DataStream;
12use futures::{Stream, StreamExt};
13
14pub struct DeltaAggregator {
18 response: Option<NvCreateEmbeddingResponse>,
20 error: Option<String>,
22}
23
24impl Default for DeltaAggregator {
25 fn default() -> Self {
27 Self::new()
28 }
29}
30
31impl DeltaAggregator {
32 pub fn new() -> Self {
34 Self {
35 response: None,
36 error: None,
37 }
38 }
39
40 pub async fn apply(
50 stream: impl Stream<Item = Annotated<NvCreateEmbeddingResponse>>,
51 ) -> Result<NvCreateEmbeddingResponse, String> {
52 let aggregator = stream
53 .fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
54 let delta = match delta.ok() {
56 Ok(delta) => delta,
57 Err(error) => {
58 aggregator.error = Some(error);
59 return aggregator;
60 }
61 };
62
63 if aggregator.error.is_none()
64 && let Some(response) = delta.data
65 {
66 match &mut aggregator.response {
69 Some(existing) => {
70 existing.inner.data.extend(response.inner.data);
72
73 existing.inner.usage.prompt_tokens +=
75 response.inner.usage.prompt_tokens;
76 existing.inner.usage.total_tokens += response.inner.usage.total_tokens;
77 }
78 None => {
79 aggregator.response = Some(response);
80 }
81 }
82 }
83 aggregator
84 })
85 .await;
86
87 if let Some(error) = aggregator.error {
89 return Err(error);
90 }
91
92 Ok(aggregator
94 .response
95 .unwrap_or_else(NvCreateEmbeddingResponse::empty))
96 }
97}
98
99impl NvCreateEmbeddingResponse {
100 pub async fn from_sse_stream(
109 stream: DataStream<Result<Message, SseCodecError>>,
110 ) -> Result<NvCreateEmbeddingResponse, String> {
111 let stream = convert_sse_stream::<NvCreateEmbeddingResponse>(stream);
112 NvCreateEmbeddingResponse::from_annotated_stream(stream).await
113 }
114
115 pub async fn from_annotated_stream(
124 stream: impl Stream<Item = Annotated<NvCreateEmbeddingResponse>>,
125 ) -> Result<NvCreateEmbeddingResponse, String> {
126 DeltaAggregator::apply(stream).await
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use futures::stream;
134
135 fn create_test_embedding_response(
136 embeddings: Vec<dynamo_async_openai::types::Embedding>,
137 prompt_tokens: u32,
138 total_tokens: u32,
139 ) -> Annotated<NvCreateEmbeddingResponse> {
140 let response = NvCreateEmbeddingResponse {
141 inner: dynamo_async_openai::types::CreateEmbeddingResponse {
142 object: "list".to_string(),
143 model: "test-model".to_string(),
144 data: embeddings,
145 usage: dynamo_async_openai::types::EmbeddingUsage {
146 prompt_tokens,
147 total_tokens,
148 },
149 },
150 };
151
152 Annotated::from_data(response)
153 }
154
155 #[tokio::test]
156 async fn test_empty_stream() {
157 let stream = stream::empty();
158 let result = DeltaAggregator::apply(Box::pin(stream)).await;
159
160 assert!(result.is_ok());
161 let response = result.unwrap();
162 assert_eq!(response.inner.data.len(), 0);
163 assert_eq!(response.inner.object, "list");
164 assert_eq!(response.inner.model, "embedding");
165 }
166
167 #[tokio::test]
168 async fn test_single_embedding() {
169 let embedding = dynamo_async_openai::types::Embedding {
170 index: 0,
171 object: "embedding".to_string(),
172 embedding: vec![0.1, 0.2, 0.3],
173 };
174
175 let annotated = create_test_embedding_response(vec![embedding.clone()], 10, 10);
176 let stream = stream::iter(vec![annotated]);
177
178 let result = DeltaAggregator::apply(Box::pin(stream)).await;
179
180 assert!(result.is_ok());
181 let response = result.unwrap();
182 assert_eq!(response.inner.data.len(), 1);
183 assert_eq!(response.inner.data[0].index, 0);
184 assert_eq!(response.inner.data[0].embedding, vec![0.1, 0.2, 0.3]);
185 assert_eq!(response.inner.usage.prompt_tokens, 10);
186 assert_eq!(response.inner.usage.total_tokens, 10);
187 }
188
189 #[tokio::test]
190 async fn test_multiple_embeddings() {
191 let embedding1 = dynamo_async_openai::types::Embedding {
192 index: 0,
193 object: "embedding".to_string(),
194 embedding: vec![0.1, 0.2, 0.3],
195 };
196
197 let embedding2 = dynamo_async_openai::types::Embedding {
198 index: 1,
199 object: "embedding".to_string(),
200 embedding: vec![0.4, 0.5, 0.6],
201 };
202
203 let annotated1 = create_test_embedding_response(vec![embedding1.clone()], 5, 5);
204 let annotated2 = create_test_embedding_response(vec![embedding2.clone()], 7, 7);
205 let stream = stream::iter(vec![annotated1, annotated2]);
206
207 let result = DeltaAggregator::apply(Box::pin(stream)).await;
208
209 assert!(result.is_ok());
210 let response = result.unwrap();
211 assert_eq!(response.inner.data.len(), 2);
212 assert_eq!(response.inner.data[0].index, 0);
213 assert_eq!(response.inner.data[1].index, 1);
214 assert_eq!(response.inner.usage.prompt_tokens, 12); assert_eq!(response.inner.usage.total_tokens, 12); }
217
218 #[tokio::test]
219 async fn test_error_in_stream() {
220 let error_annotated =
221 Annotated::<NvCreateEmbeddingResponse>::from_error("Test error".to_string());
222 let stream = stream::iter(vec![error_annotated]);
223
224 let result = DeltaAggregator::apply(Box::pin(stream)).await;
225
226 assert!(result.is_err());
227 assert!(result.unwrap_err().contains("Test error"));
228 }
229}