1use crate::{
4 chat, ctx,
5 util::{ChoiceIndexer, StreamOnce},
6};
7use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
8use rand::Rng;
9use rust_decimal::Decimal;
10use std::{collections::HashMap, sync::Arc, time};
11
12pub fn response_id(created: u64) -> String {
14 let uuid = uuid::Uuid::new_v4();
15 format!("vctcpl-{}-{}", uuid.simple(), created)
16}
17
18pub struct Client<CTXEXT, FENSLLM, CUSG, FENS, FVVOTE, FCVOTE, VUSG> {
23 pub chat_client: Arc<chat::completions::Client<CTXEXT, FENSLLM, CUSG>>,
25 pub ensemble_fetcher:
27 Arc<crate::ensemble::fetcher::CachingFetcher<CTXEXT, FENS>>,
28 pub completion_votes_fetcher: Arc<FVVOTE>,
30 pub cache_vote_fetcher: Arc<FCVOTE>,
32 pub usage_handler: Arc<VUSG>,
34}
35
36impl<CTXEXT, FENSLLM, CUSG, FENS, FVVOTE, FCVOTE, VUSG>
37 Client<CTXEXT, FENSLLM, CUSG, FENS, FVVOTE, FCVOTE, VUSG>
38{
39 pub fn new(
41 chat_client: Arc<chat::completions::Client<CTXEXT, FENSLLM, CUSG>>,
42 ensemble_fetcher: Arc<
43 crate::ensemble::fetcher::CachingFetcher<CTXEXT, FENS>,
44 >,
45 completion_votes_fetcher: Arc<FVVOTE>,
46 cache_vote_fetcher: Arc<FCVOTE>,
47 usage_handler: Arc<VUSG>,
48 ) -> Self {
49 Self {
50 chat_client,
51 ensemble_fetcher,
52 completion_votes_fetcher,
53 cache_vote_fetcher,
54 usage_handler,
55 }
56 }
57}
58
59impl<CTXEXT, FENSLLM, CUSG, FENS, FVVOTE, FCVOTE, VUSG>
60 Client<CTXEXT, FENSLLM, CUSG, FENS, FVVOTE, FCVOTE, VUSG>
61where
62 CTXEXT: ctx::ContextExt + Send + Sync + 'static,
63 FENSLLM: crate::ensemble_llm::fetcher::Fetcher<CTXEXT>
64 + Send
65 + Sync
66 + 'static,
67 CUSG: chat::completions::usage_handler::UsageHandler<CTXEXT>
68 + Send
69 + Sync
70 + 'static,
71 FENS: crate::ensemble::fetcher::Fetcher<CTXEXT> + Send + Sync + 'static,
72 FVVOTE: super::completion_votes_fetcher::Fetcher<CTXEXT>
73 + Send
74 + Sync
75 + 'static,
76 FCVOTE: super::cache_vote_fetcher::Fetcher<CTXEXT> + Send + Sync + 'static,
77 VUSG: super::usage_handler::UsageHandler<CTXEXT> + Send + Sync + 'static,
78{
79 pub async fn create_unary_handle_usage(
83 self: Arc<Self>,
84 ctx: ctx::Context<CTXEXT>,
85 request: Arc<objectiveai::vector::completions::request::VectorCompletionCreateParams>,
86 ) -> Result<
87 objectiveai::vector::completions::response::unary::VectorCompletion,
88 super::Error,
89 > {
90 let mut aggregate: Option<
91 objectiveai::vector::completions::response::streaming::VectorCompletionChunk,
92 > = None;
93 let mut stream =
94 self.create_streaming_handle_usage(ctx, request).await?;
95 while let Some(chunk) = stream.next().await {
96 match &mut aggregate {
97 Some(aggregate) => aggregate.push(&chunk),
98 None => {
99 aggregate = Some(chunk);
100 }
101 }
102 }
103 Ok(aggregate.unwrap().into())
104 }
105
106 pub async fn create_streaming_handle_usage(
110 self: Arc<Self>,
111 ctx: ctx::Context<CTXEXT>,
112 request: Arc<objectiveai::vector::completions::request::VectorCompletionCreateParams>,
113 ) -> Result<
114 impl Stream<Item = objectiveai::vector::completions::response::streaming::VectorCompletionChunk>
115 + Send
116 + Unpin
117 + 'static,
118 super::Error,
119 >{
120 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
121 tokio::spawn(async move {
122 let mut aggregate: Option<
123 objectiveai::vector::completions::response::streaming::VectorCompletionChunk,
124 > = None;
125 let stream = match self
126 .clone()
127 .create_streaming(ctx.clone(), request.clone())
128 .await
129 {
130 Ok(stream) => stream,
131 Err(e) => {
132 let _ = tx.send(Err(e));
133 return;
134 }
135 };
136 futures::pin_mut!(stream);
137 while let Some(chunk) = stream.next().await {
138 match &mut aggregate {
139 Some(aggregate) => aggregate.push(&chunk),
140 None => aggregate = Some(chunk.clone()),
141 }
142 let _ = tx.send(Ok(chunk));
143 }
144 drop(stream);
145 drop(tx);
146 let response: objectiveai::vector::completions::response::unary::VectorCompletion =
147 aggregate.unwrap().into();
148 let all_retry_or_cached_or_rng = request
149 .retry
150 .as_deref()
151 .is_some_and(|id| id == response.id.as_str())
152 || response.id.is_empty();
153 let any_ok_completions =
154 response.completions.iter().any(|c| c.error.is_none());
155 if any_ok_completions && !all_retry_or_cached_or_rng {
156 self.usage_handler
157 .handle_usage(ctx, request, response)
158 .await;
159 }
160 });
161 let mut stream =
162 tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
163 match stream.next().await {
164 Some(Ok(chunk)) => {
165 Ok(StreamOnce::new(chunk).chain(stream.map(Result::unwrap)))
166 }
167 Some(Err(e)) => Err(e),
168 None => unreachable!(),
169 }
170 }
171}
172
173impl<CTXEXT, FENSLLM, CUSG, FENS, FVVOTE, FCVOTE, VUSG>
174 Client<CTXEXT, FENSLLM, CUSG, FENS, FVVOTE, FCVOTE, VUSG>
175where
176 CTXEXT: ctx::ContextExt + Send + Sync + 'static,
177 FENSLLM: crate::ensemble_llm::fetcher::Fetcher<CTXEXT>
178 + Send
179 + Sync
180 + 'static,
181 CUSG: chat::completions::usage_handler::UsageHandler<CTXEXT>
182 + Send
183 + Sync
184 + 'static,
185 FENS: crate::ensemble::fetcher::Fetcher<CTXEXT> + Send + Sync + 'static,
186 FVVOTE: super::completion_votes_fetcher::Fetcher<CTXEXT>
187 + Send
188 + Sync
189 + 'static,
190 FCVOTE: super::cache_vote_fetcher::Fetcher<CTXEXT> + Send + Sync + 'static,
191 VUSG: Send + Sync + 'static,
192{
193 pub async fn create_streaming(
198 self: Arc<Self>,
199 ctx: ctx::Context<CTXEXT>,
200 request: Arc<objectiveai::vector::completions::request::VectorCompletionCreateParams>,
201 ) -> Result<
202 impl Stream<Item = objectiveai::vector::completions::response::streaming::VectorCompletionChunk>
203 + Send
204 + 'static,
205 super::Error,
206 >{
207 let created = time::SystemTime::now()
209 .duration_since(time::UNIX_EPOCH)
210 .unwrap()
211 .as_secs();
212 let response_id = response_id(created);
213
214 let request_responses_len = request.responses.len();
216 if request_responses_len < 2 {
217 return Err(super::Error::ExpectedTwoOrMoreRequestVectorResponses(
218 request_responses_len,
219 ));
220 }
221
222 let (ensemble, mut static_votes) = match (
224 &request.ensemble,
225 &request.retry,
226 ) {
227 (
228 objectiveai::vector::completions::request::Ensemble::Id(
229 ensemble_id,
230 ),
231 Some(retry),
232 ) => {
233 let (ensemble, mut votes) = tokio::try_join!(
234 self.ensemble_fetcher.fetch(ctx.clone(), ensemble_id).map(
235 |result| {
236 match result {
237 Ok(Some((ensemble, _))) => Ok(ensemble),
238 Ok(None) => Err(super::Error::EnsembleNotFound),
239 Err(e) => Err(super::Error::FetchEnsemble(e)),
240 }
241 }
242 ),
243 self.completion_votes_fetcher
244 .fetch(ctx.clone(), retry)
245 .map(|result| {
246 match result {
247 Ok(Some(votes)) => Ok(votes),
248 Ok(None) => Err(super::Error::RetryNotFound),
249 Err(e) => Err(super::Error::FetchRetry(e)),
250 }
251 }),
252 )?;
253 votes.iter_mut().for_each(|vote| {
254 vote.retry = Some(true);
255 vote.from_cache = Some(true);
256 vote.from_rng = None;
257 vote.completion_index = None;
258 });
259 (ensemble, votes)
260 }
261 (
262 objectiveai::vector::completions::request::Ensemble::Provided(
263 ensemble_base,
264 ),
265 Some(retry),
266 ) => {
267 let ensemble = ensemble_base
268 .clone()
269 .try_into()
270 .map_err(super::Error::InvalidEnsemble)?;
271 let mut votes = self
272 .completion_votes_fetcher
273 .fetch(ctx.clone(), retry)
274 .map(|result| match result {
275 Ok(Some(votes)) => Ok(votes),
276 Ok(None) => Err(super::Error::RetryNotFound),
277 Err(e) => Err(super::Error::FetchRetry(e)),
278 })
279 .await?;
280 votes.iter_mut().for_each(|vote| {
281 vote.retry = Some(true);
282 vote.from_cache = Some(true);
283 vote.from_rng = None;
284 vote.completion_index = None;
285 });
286 (ensemble, votes)
287 }
288 (
289 objectiveai::vector::completions::request::Ensemble::Id(
290 ensemble_id,
291 ),
292 None,
293 ) => {
294 let ensemble = self
295 .ensemble_fetcher
296 .fetch(ctx.clone(), ensemble_id)
297 .map(|result| match result {
298 Ok(Some((ensemble, _))) => Ok(ensemble),
299 Ok(None) => Err(super::Error::EnsembleNotFound),
300 Err(e) => Err(super::Error::FetchEnsemble(e)),
301 })
302 .await?;
303 (ensemble, Vec::new())
304 }
305 (
306 objectiveai::vector::completions::request::Ensemble::Provided(
307 ensemble_base,
308 ),
309 None,
310 ) => {
311 let ensemble = ensemble_base
312 .clone()
313 .try_into()
314 .map_err(super::Error::InvalidEnsemble)?;
315 (ensemble, Vec::new())
316 }
317 };
318
319 static_votes.retain(|vote| vote.vote.len() == request_responses_len);
321
322 if request.profile.len() != ensemble.llms.len() {
324 return Err(super::Error::InvalidProfile(
325 "profile length must match ensemble length".to_string(),
326 ));
327 }
328 let mut positive_weight_count = 0;
329 for weight in &request.profile {
330 if *weight > Decimal::ZERO {
331 if *weight > Decimal::ONE || *weight < Decimal::ZERO {
332 return Err(super::Error::InvalidProfile(
333 "profile weights must be between 0 and 1".to_string(),
334 ));
335 } else if *weight > Decimal::ZERO {
336 positive_weight_count += 1;
337 }
338 }
339 }
340 if positive_weight_count < 2 {
341 return Err(super::Error::InvalidProfile(
342 "profile must have two or more positive weights".to_string(),
343 ));
344 }
345
346 let prompt_id = {
348 let mut prompt = request.messages.clone();
349 objectiveai::chat::completions::request::prompt::prepare(
350 &mut prompt,
351 );
352 objectiveai::chat::completions::request::prompt::id(&prompt)
353 };
354 let tools_id = match &request.tools {
355 Some(tools) if !tools.is_empty() => {
356 Some(objectiveai::chat::completions::request::tools::id(tools))
357 }
358 _ => None,
359 };
360 let responses_ids = {
361 let mut responses = request.responses.clone();
362 let mut responses_ids = Vec::with_capacity(responses.len());
363 for response in &mut responses {
364 response.prepare();
365 responses_ids.push(response.id());
366 }
367 responses_ids
368 };
369
370 let mut llms = ensemble
373 .llms
374 .into_iter()
375 .enumerate()
376 .flat_map(|(ensemble_index, llm)| {
377 let count = llm.count as usize;
378 std::iter::repeat_n(
379 (ensemble_index, llm, request.profile[ensemble_index]),
380 count,
381 )
382 })
383 .enumerate()
384 .filter_map(
385 |(flat_ensemble_index, (ensemble_index, llm, weight))| {
386 if weight <= Decimal::ZERO {
387 None
389 } else if static_votes.iter().any(|v| {
390 v.flat_ensemble_index == flat_ensemble_index as u64
391 }) {
392 None
394 } else {
395 Some((flat_ensemble_index, ensemble_index, llm, weight))
396 }
397 },
398 )
399 .collect::<Vec<_>>();
400
401 if request.from_cache.is_some_and(|bool| bool) {
403 let mut model_refs = Vec::with_capacity(llms.len());
405 for (_, _, llm, _) in &llms {
406 let model =
407 objectiveai::chat::completions::request::Model::Provided(
408 llm.inner.base.clone(),
409 );
410 let models = llm.fallbacks.as_ref().map(|fallbacks| {
411 fallbacks
412 .iter()
413 .map(|fallback| objectiveai::chat::completions::request::Model::Provided(
414 fallback.base.clone(),
415 ))
416 .collect::<Vec<_>>()
417 });
418 model_refs.push((model, models));
419 }
420 let mut futs = Vec::with_capacity(llms.len());
422 for (
423 (flat_ensemble_index, ensemble_index, _, weight),
424 (model, models),
425 ) in llms.iter().zip(model_refs.iter())
426 {
427 let cache_vote_fetcher = self.cache_vote_fetcher.clone();
428 let request = request.clone();
429 let ctx = ctx.clone();
430 let responses_ids = responses_ids.clone();
431 futs.push(async move {
432 match cache_vote_fetcher.fetch(
433 ctx,
434 model,
435 models.as_deref(),
436 &request.messages,
437 request.tools.as_deref(),
438 &request.responses,
439 ).await {
440 Ok(Some(mut vote)) => {
441 vote.ensemble_index = *ensemble_index as u64;
443 vote.flat_ensemble_index = *flat_ensemble_index as u64;
444 vote.weight = *weight;
445 vote.retry = None;
446 vote.from_cache = Some(true);
447 vote.completion_index = None;
448
449 let mut rearranged_vote = vec![
451 Decimal::ZERO;
452 request_responses_len
453 ];
454 for (i, response_id) in
455 responses_ids.iter().enumerate()
456 {
457 let pos = vote
458 .responses_ids
459 .iter()
460 .position(|id| id == response_id)
461 .expect(
462 "data integrity error: response ID not found in vote responses IDs",
463 );
464 rearranged_vote[i] = vote.vote[pos];
465 }
466 vote.vote = rearranged_vote;
467 vote.responses_ids = responses_ids;
468
469 Ok(Some(vote))
471 }
472 Ok(None) => Ok(None),
473 Err(e) => Err(super::Error::FetchCacheVote(e))
474 }
475 });
476 }
477 let cached_votes = futures::future::try_join_all(futs).await?;
478 static_votes.reserve(cached_votes.iter().flatten().count());
479 for vote in cached_votes.into_iter().flatten() {
480 static_votes.push(vote);
481 }
482 }
483
484 llms.retain(|(flat_ensemble_index, _, _, _)| {
486 !static_votes
487 .iter()
488 .any(|v| v.flat_ensemble_index == *flat_ensemble_index as u64)
489 });
490
491 if request.from_rng.is_some_and(|bool| bool) {
493 let mut rng = rand::rng();
494 for (flat_ensemble_index, ensemble_index, llm, weight) in &llms {
495 let mut vote = vec![Decimal::ZERO; request_responses_len];
497 let mut sum = Decimal::ZERO;
499 for i in 0..request_responses_len {
500 let v = Decimal::from(rng.random_range(0..=u64::MAX))
501 / Decimal::from(u64::MAX);
502 vote[i] = v;
503 sum += v;
504 }
505 for v in &mut vote {
507 *v /= sum;
508 }
509 static_votes.push(
511 objectiveai::vector::completions::response::Vote {
512 model: llm.inner.id.clone(),
513 ensemble_index: *ensemble_index as u64,
514 flat_ensemble_index: *flat_ensemble_index as u64,
515 prompt_id: prompt_id.clone(),
516 tools_id: tools_id.clone(),
517 responses_ids: responses_ids.clone(),
518 vote,
519 weight: *weight,
520 retry: None,
521 from_cache: None,
522 from_rng: Some(true),
523 completion_index: None,
524 },
525 );
526 }
527 }
528
529 llms.retain(|(flat_ensemble_index, _, _, _)| {
531 !static_votes
532 .iter()
533 .any(|v| v.flat_ensemble_index == *flat_ensemble_index as u64)
534 });
535
536 static_votes.sort_by_key(|vote| vote.flat_ensemble_index);
538
539 let mut usage =
541 objectiveai::vector::completions::response::Usage::default();
542
543 let mut weights = vec![Decimal::ZERO; request_responses_len];
545 let mut scores = vec![
546 Decimal::ONE
547 / Decimal::from(request_responses_len);
548 request_responses_len
549 ];
550
551 let indexer = Arc::new(ChoiceIndexer::new(0));
553
554 let mut vote_stream =
556 futures::stream::select_all(llms.into_iter().map(
557 |(flat_ensemble_index, ensemble_index, llm, weight)| {
558 futures::stream::once(self.clone().llm_create_streaming(
559 ctx.clone(),
560 response_id.clone(),
561 created,
562 ensemble.id.clone(),
563 indexer.clone(),
564 llm,
565 ensemble_index,
566 flat_ensemble_index,
567 weight,
568 request.clone(),
569 prompt_id.clone(),
570 tools_id.clone(),
571 responses_ids.clone(),
572 ))
573 .flatten()
574 .boxed()
575 },
576 ));
577
578 if vote_stream.len() == 0 {
580 if static_votes.len() > 0 {
581 for vote in &static_votes {
583 for (i, v) in vote.vote.iter().enumerate() {
584 weights[i] += *v * vote.weight;
585 }
586 }
587 let weight_sum: Decimal = weights.iter().sum();
589 if weight_sum > Decimal::ZERO {
590 for (i, score) in scores.iter_mut().enumerate() {
591 *score = weights[i] / weight_sum;
592 }
593 }
594 return Ok(futures::future::Either::Left(StreamOnce::new(
596 objectiveai::vector::completions::response::streaming::VectorCompletionChunk {
597 id: request.retry.clone().unwrap_or_default(),
598 completions: Vec::new(),
599 votes: static_votes,
600 scores,
601 weights,
602 created,
603 ensemble: ensemble.id,
604 object: objectiveai::vector::completions::response::streaming::Object::VectorCompletionChunk,
605 usage: None,
606 }
607 )));
608 } else {
609 unreachable!()
610 }
611 }
612
613 let mut next_chunk = match vote_stream.next().await {
615 Some(chunk) => Some(chunk),
616 None => {
617 unreachable!()
619 }
620 };
621
622 Ok(futures::future::Either::Right(async_stream::stream! {
623 while let Some(mut chunk) = next_chunk.take() {
625 next_chunk = vote_stream.next().await;
627
628 if static_votes.len() > 0 {
630 for vote in chunk.votes.drain(..) {
631 static_votes.push(vote);
632 }
633 chunk.votes = std::mem::take(&mut static_votes);
634 }
635
636 for completion in &chunk.completions
638 {
639 if let Some(completion_usage) = &completion.inner.usage {
640 usage.push_chat_completion_usage(&completion_usage);
641 }
642 }
643
644 let mut vote_found = false;
646 for vote in &chunk.votes {
647 vote_found = true;
648 for (i, v) in vote.vote.iter().enumerate() {
649 weights[i] += *v * vote.weight;
650 }
651 }
652
653 if vote_found {
655 let weight_sum: Decimal = weights.iter().sum();
656 if weight_sum > Decimal::ZERO {
657 for (i, score) in scores.iter_mut().enumerate() {
658 *score = weights[i] / weight_sum;
659 }
660 }
661 }
662
663 chunk.weights = weights.clone();
665 chunk.scores = scores.clone();
666
667 if next_chunk.is_none() {
669 chunk.usage = Some(usage.clone());
670 }
671
672 yield chunk;
673 }
674 }))
675 }
676
677 async fn llm_create_streaming(
682 self: Arc<Self>,
683 ctx: ctx::Context<CTXEXT>,
684 id: String,
685 created: u64,
686 ensemble: String,
687 indexer: Arc<ChoiceIndexer>,
688 llm: objectiveai::ensemble_llm::EnsembleLlmWithFallbacksAndCount,
689 ensemble_index: usize,
690 flat_ensemble_index: usize,
691 weight: Decimal,
692 request: Arc<objectiveai::vector::completions::request::VectorCompletionCreateParams>,
693 prompt_id: String,
694 tools_id: Option<String>,
695 responses_ids: Vec<String>,
696 ) -> impl Stream<Item = objectiveai::vector::completions::response::streaming::VectorCompletionChunk> + Send + 'static
697 {
698 let request_responses_len = request.responses.len();
699
700 let (vector_pfx_data, vector_pfx_indices) = {
702 let mut rng = rand::rng();
703 let mut vector_pfx_data = HashMap::with_capacity(
704 1 + llm.fallbacks.as_ref().map(Vec::len).unwrap_or(0),
705 );
706 let mut vector_pfx_indices = Vec::with_capacity(
707 1 + llm.fallbacks.as_ref().map(Vec::len).unwrap_or(0),
708 );
709 for llm in std::iter::once(&llm.inner).chain(
710 llm.fallbacks
711 .iter()
712 .map(|fallbacks| fallbacks.iter())
713 .flatten(),
714 ) {
715 let pfx_tree = super::PfxTree::new(
717 &mut rng,
718 request_responses_len,
719 match llm.base.top_logprobs {
720 Some(0) | Some(1) | None => 20,
721 Some(top_logprobs) => top_logprobs as usize,
722 },
723 );
724
725 let pfx_indices =
727 pfx_tree.pfx_indices(&mut rng, request_responses_len);
728
729 let (
730 responses_key_pattern,
732 responses_key_pattern_stripped,
734 ) = pfx_tree.regex_patterns(&pfx_indices);
735
736 vector_pfx_data.insert(
737 llm.id.clone(),
738 super::PfxData {
739 pfx_tree,
740 responses_key_pattern,
741 responses_key_pattern_stripped,
742 },
743 );
744 vector_pfx_indices.push(Arc::new(pfx_indices));
745 }
746 (vector_pfx_data, vector_pfx_indices)
747 };
748
749 let mut stream = match self
751 .chat_client
752 .clone()
753 .create_streaming_for_vector_handle_usage(
754 ctx,
755 request,
756 vector_pfx_indices,
757 llm,
758 )
759 .await
760 {
761 Ok(stream) => stream,
762 Err(e) => {
763 return futures::future::Either::Left(
764 Self::llm_create_streaming_vector_error(
765 id,
766 indexer.get(flat_ensemble_index),
767 e,
768 created,
769 ensemble,
770 ),
771 );
772 }
773 };
774
775 let mut next_chat_chunk = match stream.try_next().await {
777 Ok(Some(chunk)) => Some(chunk),
778 Err(e) => {
779 return futures::future::Either::Left(
780 Self::llm_create_streaming_vector_error(
781 id,
782 indexer.get(flat_ensemble_index),
783 e,
784 created,
785 ensemble,
786 ),
787 );
788 }
789 Ok(None) => {
790 unreachable!()
792 }
793 };
794
795 let mut aggregate: Option<
797 objectiveai::vector::completions::response::streaming::VectorCompletionChunk,
798 > = None;
799
800 futures::future::Either::Right(async_stream::stream! {
801 while let Some(chat_chunk) = next_chat_chunk.take() {
802 let error = match stream.next().await {
804 Some(Ok(ncc)) => {
805 next_chat_chunk = Some(ncc);
807 None
808 }
809 Some(Err(e)) => {
810 Some(objectiveai::error::ResponseError::from(&e))
813 }
814 None => {
815 None
817 }
818 };
819
820 let mut chunk = objectiveai::vector::completions::response::streaming::VectorCompletionChunk {
822 id: id.clone(),
823 completions: vec![
824 objectiveai::vector::completions::response::streaming::ChatCompletionChunk {
825 index: indexer.get(flat_ensemble_index),
826 inner: chat_chunk,
827 error,
828 },
829 ],
830 votes: Vec::new(),
831 scores: Vec::new(),
832 weights: Vec::new(),
833 created,
834 ensemble: ensemble.clone(),
835 object: objectiveai::vector::completions::response::streaming::Object::VectorCompletionChunk,
836 usage: None,
837 };
838
839 match aggregate {
841 Some(ref mut aggregate) => {
842 aggregate.push(&chunk);
843 }
844 None => {
845 aggregate = Some(chunk.clone());
846 }
847 }
848
849 if next_chat_chunk.is_none() {
851 let aggregate = aggregate.take().unwrap();
852 for completion in aggregate.completions {
853 let super::PfxData {
855 pfx_tree,
856 responses_key_pattern,
857 responses_key_pattern_stripped,
858 } = &vector_pfx_data[&completion.inner.model];
859
860 for choice in completion.inner.choices {
862 if let Some(vote) = super::get_vote(
863 pfx_tree.clone(),
864 &responses_key_pattern,
865 &responses_key_pattern_stripped,
866 request_responses_len,
867 &choice,
868 ) {
869 chunk.votes.push(objectiveai::vector::completions::response::Vote {
870 model: completion.inner.model.clone(),
871 ensemble_index: ensemble_index as u64,
872 flat_ensemble_index: flat_ensemble_index as u64,
873 prompt_id: prompt_id.clone(),
874 tools_id: tools_id.clone(),
875 responses_ids: responses_ids.clone(),
876 vote,
877 weight,
878 retry: None,
879 from_cache: None,
880 from_rng: None,
881 completion_index: Some(completion.index),
882 });
883 }
884 }
885 }
886 }
887
888 yield chunk;
890 }
891 })
892 }
893
894 fn llm_create_streaming_vector_error(
896 id: String,
897 completion_index: u64,
898 error: chat::completions::Error,
899 created: u64,
900 ensemble: String,
901 ) -> impl Stream<Item = objectiveai::vector::completions::response::streaming::VectorCompletionChunk>
902 + Send
903 + Unpin
904 + 'static
905 {
906 StreamOnce::new(
907 objectiveai::vector::completions::response::streaming::VectorCompletionChunk {
908 id,
909 completions: vec![
910 objectiveai::vector::completions::response::streaming::ChatCompletionChunk {
911 index: completion_index,
912 inner: objectiveai::chat::completions::response::streaming::ChatCompletionChunk::default(),
913 error: Some(objectiveai::error::ResponseError::from(&error)),
914 },
915 ],
916 votes: Vec::new(),
917 scores: Vec::new(),
918 weights: Vec::new(),
919 created,
920 ensemble,
921 object: objectiveai::vector::completions::response::streaming::Object::VectorCompletionChunk,
922 usage: None,
923 }
924 )
925 }
926}