dynamo_llm/protocols/openai/completions/
aggregator.rs1use std::collections::HashMap;
5
6use anyhow::Result;
7use futures::{Stream, StreamExt};
8
9use super::NvCreateCompletionResponse;
10use crate::protocols::{
11 Annotated, DataStream,
12 codec::{Message, SseCodecError},
13 common::FinishReason,
14 convert_sse_stream,
15 openai::ParsingOptions,
16};
17
18pub struct DeltaAggregator {
20 id: String,
21 model: String,
22 created: u32,
23 usage: Option<dynamo_async_openai::types::CompletionUsage>,
24 system_fingerprint: Option<String>,
25 choices: HashMap<u32, DeltaChoice>,
26 error: Option<String>,
27}
28
29struct DeltaChoice {
30 index: u32,
31 text: String,
32 finish_reason: Option<FinishReason>,
33 logprobs: Option<dynamo_async_openai::types::Logprobs>,
34}
35
36impl Default for DeltaAggregator {
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42impl DeltaAggregator {
43 pub fn new() -> Self {
44 Self {
45 id: "".to_string(),
46 model: "".to_string(),
47 created: 0,
48 usage: None,
49 system_fingerprint: None,
50 choices: HashMap::new(),
51 error: None,
52 }
53 }
54
55 pub async fn apply(
57 stream: impl Stream<Item = Annotated<NvCreateCompletionResponse>>,
58 parsing_options: ParsingOptions,
59 ) -> Result<NvCreateCompletionResponse> {
60 tracing::debug!("Tool Call Parser: {:?}", parsing_options.tool_call_parser); let aggregator = stream
62 .fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
63 let delta = match delta.ok() {
64 Ok(delta) => delta,
65 Err(error) => {
66 aggregator.error = Some(error);
67 return aggregator;
68 }
69 };
70
71 if aggregator.error.is_none() && delta.data.is_some() {
72 let delta = delta.data.unwrap();
78 aggregator.id = delta.inner.id;
79 aggregator.model = delta.inner.model;
80 aggregator.created = delta.inner.created;
81 if let Some(usage) = delta.inner.usage {
82 aggregator.usage = Some(usage);
83 }
84 if let Some(system_fingerprint) = delta.inner.system_fingerprint {
85 aggregator.system_fingerprint = Some(system_fingerprint);
86 }
87
88 for choice in delta.inner.choices {
90 let state_choice =
91 aggregator
92 .choices
93 .entry(choice.index)
94 .or_insert(DeltaChoice {
95 index: choice.index,
96 text: "".to_string(),
97 finish_reason: None,
98 logprobs: None,
99 });
100
101 state_choice.text.push_str(&choice.text);
102
103 state_choice.finish_reason = match choice.finish_reason {
107 Some(dynamo_async_openai::types::CompletionFinishReason::Stop) => {
108 Some(FinishReason::Stop)
109 }
110 Some(dynamo_async_openai::types::CompletionFinishReason::Length) => {
111 Some(FinishReason::Length)
112 }
113 Some(
114 dynamo_async_openai::types::CompletionFinishReason::ContentFilter,
115 ) => Some(FinishReason::ContentFilter),
116 None => None,
117 };
118
119 if let Some(logprobs) = &choice.logprobs {
121 let state_lps = state_choice.logprobs.get_or_insert(
122 dynamo_async_openai::types::Logprobs {
123 tokens: Vec::new(),
124 token_logprobs: Vec::new(),
125 top_logprobs: Vec::new(),
126 text_offset: Vec::new(),
127 },
128 );
129 state_lps.tokens.extend(logprobs.tokens.clone());
130 state_lps
131 .token_logprobs
132 .extend(logprobs.token_logprobs.clone());
133 state_lps.top_logprobs.extend(logprobs.top_logprobs.clone());
134 state_lps.text_offset.extend(logprobs.text_offset.clone());
135 }
136 }
137 }
138 aggregator
139 })
140 .await;
141
142 let aggregator = if let Some(error) = aggregator.error {
144 return Err(anyhow::anyhow!(error));
145 } else {
146 aggregator
147 };
148
149 let mut choices: Vec<_> = aggregator
151 .choices
152 .into_values()
153 .map(dynamo_async_openai::types::Choice::from)
154 .collect();
155
156 choices.sort_by(|a, b| a.index.cmp(&b.index));
157
158 let inner = dynamo_async_openai::types::CreateCompletionResponse {
159 id: aggregator.id,
160 created: aggregator.created,
161 usage: aggregator.usage,
162 model: aggregator.model,
163 object: "text_completion".to_string(),
164 system_fingerprint: aggregator.system_fingerprint,
165 choices,
166 };
167
168 let response = NvCreateCompletionResponse { inner };
169
170 Ok(response)
171 }
172}
173
174impl From<DeltaChoice> for dynamo_async_openai::types::Choice {
175 fn from(delta: DeltaChoice) -> Self {
176 let finish_reason = delta.finish_reason.map(Into::into);
177
178 dynamo_async_openai::types::Choice {
179 index: delta.index,
180 text: delta.text,
181 finish_reason,
182 logprobs: delta.logprobs,
183 }
184 }
185}
186
187impl NvCreateCompletionResponse {
188 pub async fn from_sse_stream(
189 stream: DataStream<Result<Message, SseCodecError>>,
190 parsing_options: ParsingOptions,
191 ) -> Result<NvCreateCompletionResponse> {
192 let stream = convert_sse_stream::<NvCreateCompletionResponse>(stream);
193 NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options).await
194 }
195
196 pub async fn from_annotated_stream(
197 stream: impl Stream<Item = Annotated<NvCreateCompletionResponse>>,
198 parsing_options: ParsingOptions,
199 ) -> Result<NvCreateCompletionResponse> {
200 DeltaAggregator::apply(stream, parsing_options).await
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use std::str::FromStr;
207
208 use futures::stream;
209
210 use super::*;
211 use crate::protocols::openai::completions::NvCreateCompletionResponse;
212
213 fn create_test_delta(
214 index: u32,
215 text: &str,
216 finish_reason: Option<String>,
217 logprob: Option<f32>,
218 ) -> Annotated<NvCreateCompletionResponse> {
219 let finish_reason = finish_reason
222 .as_deref()
223 .and_then(|s| FinishReason::from_str(s).ok())
224 .map(Into::into);
225
226 let logprobs = logprob.map(|lp| dynamo_async_openai::types::Logprobs {
227 tokens: vec![text.to_string()],
228 token_logprobs: vec![Some(lp)],
229 top_logprobs: vec![
230 serde_json::to_value(dynamo_async_openai::types::TopLogprobs {
231 token: text.to_string(),
232 logprob: lp,
233 bytes: None,
234 })
235 .unwrap(),
236 ],
237 text_offset: vec![0],
238 });
239
240 let inner = dynamo_async_openai::types::CreateCompletionResponse {
241 id: "test_id".to_string(),
242 model: "meta/llama-3.1-8b".to_string(),
243 created: 1234567890,
244 usage: None,
245 system_fingerprint: None,
246 choices: vec![dynamo_async_openai::types::Choice {
247 index,
248 text: text.to_string(),
249 finish_reason,
250 logprobs,
251 }],
252 object: "text_completion".to_string(),
253 };
254
255 let response = NvCreateCompletionResponse { inner };
256
257 Annotated {
258 data: Some(response),
259 id: Some("test_id".to_string()),
260 event: None,
261 comment: None,
262 }
263 }
264
265 #[tokio::test]
266 async fn test_empty_stream() {
267 let stream: DataStream<Annotated<NvCreateCompletionResponse>> = Box::pin(stream::empty());
269
270 let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
272
273 assert!(result.is_ok());
275 let response = result.unwrap();
276
277 assert_eq!(response.inner.id, "");
279 assert_eq!(response.inner.model, "");
280 assert_eq!(response.inner.created, 0);
281 assert!(response.inner.usage.is_none());
282 assert!(response.inner.system_fingerprint.is_none());
283 assert_eq!(response.inner.choices.len(), 0);
284 }
285
286 #[tokio::test]
287 async fn test_single_delta() {
288 let annotated_delta = create_test_delta(0, "Hello,", Some("length".to_string()), None);
290
291 let stream = Box::pin(stream::iter(vec![annotated_delta]));
293
294 let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
296
297 assert!(result.is_ok());
299 let response = result.unwrap();
300
301 assert_eq!(response.inner.id, "test_id");
303 assert_eq!(response.inner.model, "meta/llama-3.1-8b");
304 assert_eq!(response.inner.created, 1234567890);
305 assert!(response.inner.usage.is_none());
306 assert!(response.inner.system_fingerprint.is_none());
307 assert_eq!(response.inner.choices.len(), 1);
308 let choice = &response.inner.choices[0];
309 assert_eq!(choice.index, 0);
310 assert_eq!(choice.text, "Hello,".to_string());
311 assert_eq!(
312 choice.finish_reason,
313 Some(dynamo_async_openai::types::CompletionFinishReason::Length)
314 );
315 assert_eq!(
316 choice.finish_reason,
317 Some(dynamo_async_openai::types::CompletionFinishReason::Length)
318 );
319 assert!(choice.logprobs.is_none());
320 }
321
322 #[tokio::test]
323 async fn test_multiple_deltas_same_choice() {
324 let annotated_delta1 = create_test_delta(0, "Hello,", None, Some(-0.1));
328 let annotated_delta2 =
329 create_test_delta(0, " world!", Some("stop".to_string()), Some(-0.2));
330
331 let annotated_deltas = vec![annotated_delta1, annotated_delta2];
333 let stream = Box::pin(stream::iter(annotated_deltas));
334
335 let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
337
338 assert!(result.is_ok());
340 let response = result.unwrap();
341
342 assert_eq!(response.inner.choices.len(), 1);
344 let choice = &response.inner.choices[0];
345 assert_eq!(choice.index, 0);
346 assert_eq!(choice.text, "Hello, world!".to_string());
347 assert_eq!(
348 choice.finish_reason,
349 Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
350 );
351 assert_eq!(choice.logprobs.as_ref().unwrap().tokens.len(), 2);
352 assert_eq!(
353 choice.logprobs.as_ref().unwrap().token_logprobs,
354 vec![Some(-0.1), Some(-0.2)]
355 );
356 }
357
358 #[tokio::test]
359 async fn test_multiple_choices() {
360 let inner = dynamo_async_openai::types::CreateCompletionResponse {
362 id: "test_id".to_string(),
363 model: "meta/llama-3.1-8b".to_string(),
364 created: 1234567890,
365 usage: None,
366 system_fingerprint: None,
367 choices: vec![
368 dynamo_async_openai::types::Choice {
369 index: 0,
370 text: "Choice 0".to_string(),
371 finish_reason: Some(dynamo_async_openai::types::CompletionFinishReason::Stop),
372 logprobs: None,
373 },
374 dynamo_async_openai::types::Choice {
375 index: 1,
376 text: "Choice 1".to_string(),
377 finish_reason: Some(dynamo_async_openai::types::CompletionFinishReason::Stop),
378 logprobs: None,
379 },
380 ],
381 object: "text_completion".to_string(),
382 };
383
384 let response = NvCreateCompletionResponse { inner };
385
386 let annotated_delta = Annotated {
387 data: Some(response),
388 id: Some("test_id".to_string()),
389 event: None,
390 comment: None,
391 };
392
393 let stream = Box::pin(stream::iter(vec![annotated_delta]));
395
396 let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
398
399 assert!(result.is_ok());
401 let mut response = result.unwrap();
402
403 assert_eq!(response.inner.choices.len(), 2);
405 response.inner.choices.sort_by(|a, b| a.index.cmp(&b.index)); let choice0 = &response.inner.choices[0];
407 assert_eq!(choice0.index, 0);
408 assert_eq!(choice0.text, "Choice 0".to_string());
409 assert_eq!(
410 choice0.finish_reason,
411 Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
412 );
413 assert_eq!(
414 choice0.finish_reason,
415 Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
416 );
417
418 let choice1 = &response.inner.choices[1];
419 assert_eq!(choice1.index, 1);
420 assert_eq!(choice1.text, "Choice 1".to_string());
421 assert_eq!(
422 choice1.finish_reason,
423 Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
424 );
425 assert_eq!(
426 choice1.finish_reason,
427 Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
428 );
429 }
430}