dynamo_llm/protocols/openai/embeddings/
aggregator.rs1use super::NvCreateEmbeddingResponse;
17use crate::protocols::{
18 codec::{Message, SseCodecError},
19 convert_sse_stream, Annotated,
20};
21
22use futures::{Stream, StreamExt};
23use std::pin::Pin;
24
25type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>;
27
28pub struct DeltaAggregator {
32 response: Option<NvCreateEmbeddingResponse>,
34 error: Option<String>,
36}
37
38impl Default for DeltaAggregator {
39 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl DeltaAggregator {
46 pub fn new() -> Self {
48 Self {
49 response: None,
50 error: None,
51 }
52 }
53
54 pub async fn apply(
64 stream: DataStream<Annotated<NvCreateEmbeddingResponse>>,
65 ) -> Result<NvCreateEmbeddingResponse, String> {
66 let aggregator = stream
67 .fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
68 let delta = match delta.ok() {
70 Ok(delta) => delta,
71 Err(error) => {
72 aggregator.error = Some(error);
73 return aggregator;
74 }
75 };
76
77 if aggregator.error.is_none() {
78 if let Some(response) = delta.data {
79 match &mut aggregator.response {
82 Some(existing) => {
83 existing.inner.data.extend(response.inner.data);
85
86 existing.inner.usage.prompt_tokens +=
88 response.inner.usage.prompt_tokens;
89 existing.inner.usage.total_tokens +=
90 response.inner.usage.total_tokens;
91 }
92 None => {
93 aggregator.response = Some(response);
94 }
95 }
96 }
97 }
98 aggregator
99 })
100 .await;
101
102 if let Some(error) = aggregator.error {
104 return Err(error);
105 }
106
107 Ok(aggregator
109 .response
110 .unwrap_or_else(NvCreateEmbeddingResponse::empty))
111 }
112}
113
114impl NvCreateEmbeddingResponse {
115 pub async fn from_sse_stream(
124 stream: DataStream<Result<Message, SseCodecError>>,
125 ) -> Result<NvCreateEmbeddingResponse, String> {
126 let stream = convert_sse_stream::<NvCreateEmbeddingResponse>(stream);
127 NvCreateEmbeddingResponse::from_annotated_stream(stream).await
128 }
129
130 pub async fn from_annotated_stream(
139 stream: DataStream<Annotated<NvCreateEmbeddingResponse>>,
140 ) -> Result<NvCreateEmbeddingResponse, String> {
141 DeltaAggregator::apply(stream).await
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 use futures::stream;
149
150 fn create_test_embedding_response(
151 embeddings: Vec<async_openai::types::Embedding>,
152 prompt_tokens: u32,
153 total_tokens: u32,
154 ) -> Annotated<NvCreateEmbeddingResponse> {
155 let response = NvCreateEmbeddingResponse {
156 inner: async_openai::types::CreateEmbeddingResponse {
157 object: "list".to_string(),
158 model: "test-model".to_string(),
159 data: embeddings,
160 usage: async_openai::types::EmbeddingUsage {
161 prompt_tokens,
162 total_tokens,
163 },
164 },
165 };
166
167 Annotated::from_data(response)
168 }
169
170 #[tokio::test]
171 async fn test_empty_stream() {
172 let stream = stream::empty();
173 let result = DeltaAggregator::apply(Box::pin(stream)).await;
174
175 assert!(result.is_ok());
176 let response = result.unwrap();
177 assert_eq!(response.inner.data.len(), 0);
178 assert_eq!(response.inner.object, "list");
179 assert_eq!(response.inner.model, "embedding");
180 }
181
182 #[tokio::test]
183 async fn test_single_embedding() {
184 let embedding = async_openai::types::Embedding {
185 index: 0,
186 object: "embedding".to_string(),
187 embedding: vec![0.1, 0.2, 0.3],
188 };
189
190 let annotated = create_test_embedding_response(vec![embedding.clone()], 10, 10);
191 let stream = stream::iter(vec![annotated]);
192
193 let result = DeltaAggregator::apply(Box::pin(stream)).await;
194
195 assert!(result.is_ok());
196 let response = result.unwrap();
197 assert_eq!(response.inner.data.len(), 1);
198 assert_eq!(response.inner.data[0].index, 0);
199 assert_eq!(response.inner.data[0].embedding, vec![0.1, 0.2, 0.3]);
200 assert_eq!(response.inner.usage.prompt_tokens, 10);
201 assert_eq!(response.inner.usage.total_tokens, 10);
202 }
203
204 #[tokio::test]
205 async fn test_multiple_embeddings() {
206 let embedding1 = async_openai::types::Embedding {
207 index: 0,
208 object: "embedding".to_string(),
209 embedding: vec![0.1, 0.2, 0.3],
210 };
211
212 let embedding2 = async_openai::types::Embedding {
213 index: 1,
214 object: "embedding".to_string(),
215 embedding: vec![0.4, 0.5, 0.6],
216 };
217
218 let annotated1 = create_test_embedding_response(vec![embedding1.clone()], 5, 5);
219 let annotated2 = create_test_embedding_response(vec![embedding2.clone()], 7, 7);
220 let stream = stream::iter(vec![annotated1, annotated2]);
221
222 let result = DeltaAggregator::apply(Box::pin(stream)).await;
223
224 assert!(result.is_ok());
225 let response = result.unwrap();
226 assert_eq!(response.inner.data.len(), 2);
227 assert_eq!(response.inner.data[0].index, 0);
228 assert_eq!(response.inner.data[1].index, 1);
229 assert_eq!(response.inner.usage.prompt_tokens, 12); assert_eq!(response.inner.usage.total_tokens, 12); }
232
233 #[tokio::test]
234 async fn test_error_in_stream() {
235 let error_annotated =
236 Annotated::<NvCreateEmbeddingResponse>::from_error("Test error".to_string());
237 let stream = stream::iter(vec![error_annotated]);
238
239 let result = DeltaAggregator::apply(Box::pin(stream)).await;
240
241 assert!(result.is_err());
242 assert!(result.unwrap_err().contains("Test error"));
243 }
244}