Skip to main content

objectiveai_api/vector/completions/
client.rs

1//! Vector completion client implementation.
2
3use crate::{
4    chat, ctx,
5    util::{ChoiceIndexer, StreamOnce},
6};
7use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
8use rand::Rng;
9use rust_decimal::Decimal;
10use std::{collections::HashMap, sync::Arc, time};
11
12/// Generates a unique response ID for a vector completion.
13pub fn response_id(created: u64) -> String {
14    let uuid = uuid::Uuid::new_v4();
15    format!("vctcpl-{}-{}", uuid.simple(), created)
16}
17
18/// Client for creating vector completions.
19///
20/// Orchestrates multiple LLM chat completions to vote on response options,
21/// combining their votes using weights to produce final scores.
22pub struct Client<CTXEXT, FENSLLM, CUSG, FENS, FVVOTE, FCVOTE, VUSG> {
23    /// The underlying chat completion client.
24    pub chat_client: Arc<chat::completions::Client<CTXEXT, FENSLLM, CUSG>>,
25    /// Fetcher for Ensemble definitions.
26    pub ensemble_fetcher:
27        Arc<crate::ensemble::fetcher::CachingFetcher<CTXEXT, FENS>>,
28    /// Fetcher for votes from historical completions.
29    pub completion_votes_fetcher: Arc<FVVOTE>,
30    /// Fetcher for votes from the global cache.
31    pub cache_vote_fetcher: Arc<FCVOTE>,
32    /// Handler for usage tracking.
33    pub usage_handler: Arc<VUSG>,
34}
35
36impl<CTXEXT, FENSLLM, CUSG, FENS, FVVOTE, FCVOTE, VUSG>
37    Client<CTXEXT, FENSLLM, CUSG, FENS, FVVOTE, FCVOTE, VUSG>
38{
39    /// Creates a new vector completion client.
40    pub fn new(
41        chat_client: Arc<chat::completions::Client<CTXEXT, FENSLLM, CUSG>>,
42        ensemble_fetcher: Arc<
43            crate::ensemble::fetcher::CachingFetcher<CTXEXT, FENS>,
44        >,
45        completion_votes_fetcher: Arc<FVVOTE>,
46        cache_vote_fetcher: Arc<FCVOTE>,
47        usage_handler: Arc<VUSG>,
48    ) -> Self {
49        Self {
50            chat_client,
51            ensemble_fetcher,
52            completion_votes_fetcher,
53            cache_vote_fetcher,
54            usage_handler,
55        }
56    }
57}
58
59impl<CTXEXT, FENSLLM, CUSG, FENS, FVVOTE, FCVOTE, VUSG>
60    Client<CTXEXT, FENSLLM, CUSG, FENS, FVVOTE, FCVOTE, VUSG>
61where
62    CTXEXT: ctx::ContextExt + Send + Sync + 'static,
63    FENSLLM: crate::ensemble_llm::fetcher::Fetcher<CTXEXT>
64        + Send
65        + Sync
66        + 'static,
67    CUSG: chat::completions::usage_handler::UsageHandler<CTXEXT>
68        + Send
69        + Sync
70        + 'static,
71    FENS: crate::ensemble::fetcher::Fetcher<CTXEXT> + Send + Sync + 'static,
72    FVVOTE: super::completion_votes_fetcher::Fetcher<CTXEXT>
73        + Send
74        + Sync
75        + 'static,
76    FCVOTE: super::cache_vote_fetcher::Fetcher<CTXEXT> + Send + Sync + 'static,
77    VUSG: super::usage_handler::UsageHandler<CTXEXT> + Send + Sync + 'static,
78{
79    /// Creates a unary (non-streaming) vector completion with usage tracking.
80    ///
81    /// Collects all streaming chunks into a single response.
82    pub async fn create_unary_handle_usage(
83        self: Arc<Self>,
84        ctx: ctx::Context<CTXEXT>,
85        request: Arc<objectiveai::vector::completions::request::VectorCompletionCreateParams>,
86    ) -> Result<
87        objectiveai::vector::completions::response::unary::VectorCompletion,
88        super::Error,
89    > {
90        let mut aggregate: Option<
91            objectiveai::vector::completions::response::streaming::VectorCompletionChunk,
92        > = None;
93        let mut stream =
94            self.create_streaming_handle_usage(ctx, request).await?;
95        while let Some(chunk) = stream.next().await {
96            match &mut aggregate {
97                Some(aggregate) => aggregate.push(&chunk),
98                None => {
99                    aggregate = Some(chunk);
100                }
101            }
102        }
103        Ok(aggregate.unwrap().into())
104    }
105
106    /// Creates a streaming vector completion with usage tracking.
107    ///
108    /// Spawns a background task to track usage after the stream completes.
109    pub async fn create_streaming_handle_usage(
110        self: Arc<Self>,
111        ctx: ctx::Context<CTXEXT>,
112        request: Arc<objectiveai::vector::completions::request::VectorCompletionCreateParams>,
113    ) -> Result<
114        impl Stream<Item = objectiveai::vector::completions::response::streaming::VectorCompletionChunk>
115        + Send
116        + Unpin
117        + 'static,
118        super::Error,
119    >{
120        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
121        tokio::spawn(async move {
122            let mut aggregate: Option<
123                objectiveai::vector::completions::response::streaming::VectorCompletionChunk,
124            > = None;
125            let stream = match self
126                .clone()
127                .create_streaming(ctx.clone(), request.clone())
128                .await
129            {
130                Ok(stream) => stream,
131                Err(e) => {
132                    let _ = tx.send(Err(e));
133                    return;
134                }
135            };
136            futures::pin_mut!(stream);
137            while let Some(chunk) = stream.next().await {
138                match &mut aggregate {
139                    Some(aggregate) => aggregate.push(&chunk),
140                    None => aggregate = Some(chunk.clone()),
141                }
142                let _ = tx.send(Ok(chunk));
143            }
144            drop(stream);
145            drop(tx);
146            let response: objectiveai::vector::completions::response::unary::VectorCompletion =
147                aggregate.unwrap().into();
148            let all_retry_or_cached_or_rng = request
149                .retry
150                .as_deref()
151                .is_some_and(|id| id == response.id.as_str())
152                || response.id.is_empty();
153            let any_ok_completions =
154                response.completions.iter().any(|c| c.error.is_none());
155            if any_ok_completions && !all_retry_or_cached_or_rng {
156                self.usage_handler
157                    .handle_usage(ctx, request, response)
158                    .await;
159            }
160        });
161        let mut stream =
162            tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
163        match stream.next().await {
164            Some(Ok(chunk)) => {
165                Ok(StreamOnce::new(chunk).chain(stream.map(Result::unwrap)))
166            }
167            Some(Err(e)) => Err(e),
168            None => unreachable!(),
169        }
170    }
171}
172
173impl<CTXEXT, FENSLLM, CUSG, FENS, FVVOTE, FCVOTE, VUSG>
174    Client<CTXEXT, FENSLLM, CUSG, FENS, FVVOTE, FCVOTE, VUSG>
175where
176    CTXEXT: ctx::ContextExt + Send + Sync + 'static,
177    FENSLLM: crate::ensemble_llm::fetcher::Fetcher<CTXEXT>
178        + Send
179        + Sync
180        + 'static,
181    CUSG: chat::completions::usage_handler::UsageHandler<CTXEXT>
182        + Send
183        + Sync
184        + 'static,
185    FENS: crate::ensemble::fetcher::Fetcher<CTXEXT> + Send + Sync + 'static,
186    FVVOTE: super::completion_votes_fetcher::Fetcher<CTXEXT>
187        + Send
188        + Sync
189        + 'static,
190    FCVOTE: super::cache_vote_fetcher::Fetcher<CTXEXT> + Send + Sync + 'static,
191    VUSG: Send + Sync + 'static,
192{
193    /// Creates a streaming vector completion.
194    ///
195    /// Orchestrates chat completions across all LLMs in the ensemble, extracting
196    /// votes from each and combining them with weights to produce scores.
197    pub async fn create_streaming(
198        self: Arc<Self>,
199        ctx: ctx::Context<CTXEXT>,
200        request: Arc<objectiveai::vector::completions::request::VectorCompletionCreateParams>,
201    ) -> Result<
202        impl Stream<Item = objectiveai::vector::completions::response::streaming::VectorCompletionChunk>
203        + Send
204        + 'static,
205        super::Error,
206    >{
207        // timestamp and identify the completion
208        let created = time::SystemTime::now()
209            .duration_since(time::UNIX_EPOCH)
210            .unwrap()
211            .as_secs();
212        let response_id = response_id(created);
213
214        // validate response count
215        let request_responses_len = request.responses.len();
216        if request_responses_len < 2 {
217            return Err(super::Error::ExpectedTwoOrMoreRequestVectorResponses(
218                request_responses_len,
219            ));
220        }
221
222        // validate credits + fetch ensemble if needed + fetch retry votes if needed
223        let (ensemble, mut static_votes) = match (
224            &request.ensemble,
225            &request.retry,
226        ) {
227            (
228                objectiveai::vector::completions::request::Ensemble::Id(
229                    ensemble_id,
230                ),
231                Some(retry),
232            ) => {
233                let (ensemble, mut votes) = tokio::try_join!(
234                    self.ensemble_fetcher.fetch(ctx.clone(), ensemble_id).map(
235                        |result| {
236                            match result {
237                                Ok(Some((ensemble, _))) => Ok(ensemble),
238                                Ok(None) => Err(super::Error::EnsembleNotFound),
239                                Err(e) => Err(super::Error::FetchEnsemble(e)),
240                            }
241                        }
242                    ),
243                    self.completion_votes_fetcher
244                        .fetch(ctx.clone(), retry)
245                        .map(|result| {
246                            match result {
247                                Ok(Some(votes)) => Ok(votes),
248                                Ok(None) => Err(super::Error::RetryNotFound),
249                                Err(e) => Err(super::Error::FetchRetry(e)),
250                            }
251                        }),
252                )?;
253                votes.iter_mut().for_each(|vote| {
254                    vote.retry = Some(true);
255                    vote.from_cache = Some(true);
256                    vote.from_rng = None;
257                    vote.completion_index = None;
258                });
259                (ensemble, votes)
260            }
261            (
262                objectiveai::vector::completions::request::Ensemble::Provided(
263                    ensemble_base,
264                ),
265                Some(retry),
266            ) => {
267                let ensemble = ensemble_base
268                    .clone()
269                    .try_into()
270                    .map_err(super::Error::InvalidEnsemble)?;
271                let mut votes = self
272                    .completion_votes_fetcher
273                    .fetch(ctx.clone(), retry)
274                    .map(|result| match result {
275                        Ok(Some(votes)) => Ok(votes),
276                        Ok(None) => Err(super::Error::RetryNotFound),
277                        Err(e) => Err(super::Error::FetchRetry(e)),
278                    })
279                    .await?;
280                votes.iter_mut().for_each(|vote| {
281                    vote.retry = Some(true);
282                    vote.from_cache = Some(true);
283                    vote.from_rng = None;
284                    vote.completion_index = None;
285                });
286                (ensemble, votes)
287            }
288            (
289                objectiveai::vector::completions::request::Ensemble::Id(
290                    ensemble_id,
291                ),
292                None,
293            ) => {
294                let ensemble = self
295                    .ensemble_fetcher
296                    .fetch(ctx.clone(), ensemble_id)
297                    .map(|result| match result {
298                        Ok(Some((ensemble, _))) => Ok(ensemble),
299                        Ok(None) => Err(super::Error::EnsembleNotFound),
300                        Err(e) => Err(super::Error::FetchEnsemble(e)),
301                    })
302                    .await?;
303                (ensemble, Vec::new())
304            }
305            (
306                objectiveai::vector::completions::request::Ensemble::Provided(
307                    ensemble_base,
308                ),
309                None,
310            ) => {
311                let ensemble = ensemble_base
312                    .clone()
313                    .try_into()
314                    .map_err(super::Error::InvalidEnsemble)?;
315                (ensemble, Vec::new())
316            }
317        };
318
319        // prune votes that don't match responses length
320        static_votes.retain(|vote| vote.vote.len() == request_responses_len);
321
322        // validate profile
323        if request.profile.len() != ensemble.llms.len() {
324            return Err(super::Error::InvalidProfile(
325                "profile length must match ensemble length".to_string(),
326            ));
327        }
328        let mut positive_weight_count = 0;
329        for weight in &request.profile {
330            if *weight > Decimal::ZERO {
331                if *weight > Decimal::ONE || *weight < Decimal::ZERO {
332                    return Err(super::Error::InvalidProfile(
333                        "profile weights must be between 0 and 1".to_string(),
334                    ));
335                } else if *weight > Decimal::ZERO {
336                    positive_weight_count += 1;
337                }
338            }
339        }
340        if positive_weight_count < 2 {
341            return Err(super::Error::InvalidProfile(
342                "profile must have two or more positive weights".to_string(),
343            ));
344        }
345
346        // compute hash IDs
347        let prompt_id = {
348            let mut prompt = request.messages.clone();
349            objectiveai::chat::completions::request::prompt::prepare(
350                &mut prompt,
351            );
352            objectiveai::chat::completions::request::prompt::id(&prompt)
353        };
354        let tools_id = match &request.tools {
355            Some(tools) if !tools.is_empty() => {
356                Some(objectiveai::chat::completions::request::tools::id(tools))
357            }
358            _ => None,
359        };
360        let responses_ids = {
361            let mut responses = request.responses.clone();
362            let mut responses_ids = Vec::with_capacity(responses.len());
363            for response in &mut responses {
364                response.prepare();
365                responses_ids.push(response.id());
366            }
367            responses_ids
368        };
369
370        // create a vector of LLMs with useful info
371        // only ones that may stream
372        let mut llms = ensemble
373            .llms
374            .into_iter()
375            .enumerate()
376            .flat_map(|(ensemble_index, llm)| {
377                let count = llm.count as usize;
378                std::iter::repeat_n(
379                    (ensemble_index, llm, request.profile[ensemble_index]),
380                    count,
381                )
382            })
383            .enumerate()
384            .filter_map(
385                |(flat_ensemble_index, (ensemble_index, llm, weight))| {
386                    if weight <= Decimal::ZERO {
387                        // skip LLMs with zero weight
388                        None
389                    } else if static_votes.iter().any(|v| {
390                        v.flat_ensemble_index == flat_ensemble_index as u64
391                    }) {
392                        // skip LLMs that have votes already
393                        None
394                    } else {
395                        Some((flat_ensemble_index, ensemble_index, llm, weight))
396                    }
397                },
398            )
399            .collect::<Vec<_>>();
400
401        // fetch from cache if requested
402        if request.from_cache.is_some_and(|bool| bool) {
403            // collect model refs so they're owned here
404            let mut model_refs = Vec::with_capacity(llms.len());
405            for (_, _, llm, _) in &llms {
406                let model =
407                    objectiveai::chat::completions::request::Model::Provided(
408                        llm.inner.base.clone(),
409                    );
410                let models = llm.fallbacks.as_ref().map(|fallbacks| {
411                    fallbacks
412                        .iter()
413                        .map(|fallback| objectiveai::chat::completions::request::Model::Provided(
414                            fallback.base.clone(),
415                        ))
416                        .collect::<Vec<_>>()
417                });
418                model_refs.push((model, models));
419            }
420            // execute the futures
421            let mut futs = Vec::with_capacity(llms.len());
422            for (
423                (flat_ensemble_index, ensemble_index, _, weight),
424                (model, models),
425            ) in llms.iter().zip(model_refs.iter())
426            {
427                let cache_vote_fetcher = self.cache_vote_fetcher.clone();
428                let request = request.clone();
429                let ctx = ctx.clone();
430                let responses_ids = responses_ids.clone();
431                futs.push(async move {
432                    match cache_vote_fetcher.fetch(
433                        ctx,
434                        model,
435                        models.as_deref(),
436                        &request.messages,
437                        request.tools.as_deref(),
438                        &request.responses,
439                    ).await {
440                        Ok(Some(mut vote)) => {
441                            // update fields
442                            vote.ensemble_index = *ensemble_index as u64;
443                            vote.flat_ensemble_index = *flat_ensemble_index as u64;
444                            vote.weight = *weight;
445                            vote.retry = None;
446                            vote.from_cache = Some(true);
447                            vote.completion_index = None;
448
449                            // rearrange vote vector to match response order
450                            let mut rearranged_vote = vec![
451                                Decimal::ZERO;
452                                request_responses_len
453                            ];
454                            for (i, response_id) in
455                                responses_ids.iter().enumerate()
456                            {
457                                let pos = vote
458                                    .responses_ids
459                                    .iter()
460                                    .position(|id| id == response_id)
461                                    .expect(
462                                        "data integrity error: response ID not found in vote responses IDs",
463                                    );
464                                rearranged_vote[i] = vote.vote[pos];
465                            }
466                            vote.vote = rearranged_vote;
467                            vote.responses_ids = responses_ids;
468
469                            // return vote
470                            Ok(Some(vote))
471                        }
472                        Ok(None) => Ok(None),
473                        Err(e) => Err(super::Error::FetchCacheVote(e))
474                    }
475                });
476            }
477            let cached_votes = futures::future::try_join_all(futs).await?;
478            static_votes.reserve(cached_votes.iter().flatten().count());
479            for vote in cached_votes.into_iter().flatten() {
480                static_votes.push(vote);
481            }
482        }
483
484        // filter LLMs that now have votes from cache
485        llms.retain(|(flat_ensemble_index, _, _, _)| {
486            !static_votes
487                .iter()
488                .any(|v| v.flat_ensemble_index == *flat_ensemble_index as u64)
489        });
490
491        // generate votes with RNG if requested
492        if request.from_rng.is_some_and(|bool| bool) {
493            let mut rng = rand::rng();
494            for (flat_ensemble_index, ensemble_index, llm, weight) in &llms {
495                // initialize the vote vector
496                let mut vote = vec![Decimal::ZERO; request_responses_len];
497                // generate a random value for each entry
498                let mut sum = Decimal::ZERO;
499                for i in 0..request_responses_len {
500                    let v = Decimal::from(rng.random_range(0..=u64::MAX))
501                        / Decimal::from(u64::MAX);
502                    vote[i] = v;
503                    sum += v;
504                }
505                // normalize the vote vector
506                for v in &mut vote {
507                    *v /= sum;
508                }
509                // push the vote
510                static_votes.push(
511                    objectiveai::vector::completions::response::Vote {
512                        model: llm.inner.id.clone(),
513                        ensemble_index: *ensemble_index as u64,
514                        flat_ensemble_index: *flat_ensemble_index as u64,
515                        prompt_id: prompt_id.clone(),
516                        tools_id: tools_id.clone(),
517                        responses_ids: responses_ids.clone(),
518                        vote,
519                        weight: *weight,
520                        retry: None,
521                        from_cache: None,
522                        from_rng: Some(true),
523                        completion_index: None,
524                    },
525                );
526            }
527        }
528
529        // filter LLMs that now have votes from RNG
530        llms.retain(|(flat_ensemble_index, _, _, _)| {
531            !static_votes
532                .iter()
533                .any(|v| v.flat_ensemble_index == *flat_ensemble_index as u64)
534        });
535
536        // sort retry/cached/rng votes
537        static_votes.sort_by_key(|vote| vote.flat_ensemble_index);
538
539        // track usage
540        let mut usage =
541            objectiveai::vector::completions::response::Usage::default();
542
543        // track scores and weights
544        let mut weights = vec![Decimal::ZERO; request_responses_len];
545        let mut scores = vec![
546            Decimal::ONE
547                / Decimal::from(request_responses_len);
548            request_responses_len
549        ];
550
551        // completion chunk indices are first come first served
552        let indexer = Arc::new(ChoiceIndexer::new(0));
553
554        // stream votes from each LLM in the ensemble
555        let mut vote_stream =
556            futures::stream::select_all(llms.into_iter().map(
557                |(flat_ensemble_index, ensemble_index, llm, weight)| {
558                    futures::stream::once(self.clone().llm_create_streaming(
559                        ctx.clone(),
560                        response_id.clone(),
561                        created,
562                        ensemble.id.clone(),
563                        indexer.clone(),
564                        llm,
565                        ensemble_index,
566                        flat_ensemble_index,
567                        weight,
568                        request.clone(),
569                        prompt_id.clone(),
570                        tools_id.clone(),
571                        responses_ids.clone(),
572                    ))
573                    .flatten()
574                    .boxed()
575                },
576            ));
577
578        // validate there is at least one retried vote
579        if vote_stream.len() == 0 {
580            if static_votes.len() > 0 {
581                // update weights
582                for vote in &static_votes {
583                    for (i, v) in vote.vote.iter().enumerate() {
584                        weights[i] += *v * vote.weight;
585                    }
586                }
587                // update scores
588                let weight_sum: Decimal = weights.iter().sum();
589                if weight_sum > Decimal::ZERO {
590                    for (i, score) in scores.iter_mut().enumerate() {
591                        *score = weights[i] / weight_sum;
592                    }
593                }
594                // return stream of existing votes
595                return Ok(futures::future::Either::Left(StreamOnce::new(
596                    objectiveai::vector::completions::response::streaming::VectorCompletionChunk {
597                        id: request.retry.clone().unwrap_or_default(),
598                        completions: Vec::new(),
599                        votes: static_votes,
600                        scores,
601                        weights,
602                        created,
603                        ensemble: ensemble.id,
604                        object: objectiveai::vector::completions::response::streaming::Object::VectorCompletionChunk,
605                        usage: None,
606                    }
607                )));
608            } else {
609                unreachable!()
610            }
611        }
612
613        // initial chunk
614        let mut next_chunk = match vote_stream.next().await {
615            Some(chunk) => Some(chunk),
616            None => {
617                // should not happen as there should be at least one LLM
618                unreachable!()
619            }
620        };
621
622        Ok(futures::future::Either::Right(async_stream::stream! {
623            // stream all chunks
624            while let Some(mut chunk) = next_chunk.take() {
625                // prepare next chunk
626                next_chunk = vote_stream.next().await;
627
628                // if retry votes were provided, add them to the first chunk
629                if static_votes.len() > 0 {
630                    for vote in chunk.votes.drain(..) {
631                        static_votes.push(vote);
632                    }
633                    chunk.votes = std::mem::take(&mut static_votes);
634                }
635
636                // import usage from each completion
637                for completion in &chunk.completions
638                {
639                    if let Some(completion_usage) = &completion.inner.usage {
640                        usage.push_chat_completion_usage(&completion_usage);
641                    }
642                }
643
644                // update weights from votes
645                let mut vote_found = false;
646                for vote in &chunk.votes {
647                    vote_found = true;
648                    for (i, v) in vote.vote.iter().enumerate() {
649                        weights[i] += *v * vote.weight;
650                    }
651                }
652
653                // update scores if votes were found
654                if vote_found {
655                    let weight_sum: Decimal = weights.iter().sum();
656                    if weight_sum > Decimal::ZERO {
657                        for (i, score) in scores.iter_mut().enumerate() {
658                            *score = weights[i] / weight_sum;
659                        }
660                    }
661                }
662
663                // add weights and scores to chunk
664                chunk.weights = weights.clone();
665                chunk.scores = scores.clone();
666
667                // if on last chunk, add usage
668                if next_chunk.is_none() {
669                    chunk.usage = Some(usage.clone());
670                }
671
672                yield chunk;
673            }
674        }))
675    }
676
677    /// Creates a streaming completion for a single LLM in the ensemble.
678    ///
679    /// Generates prefix data for vote extraction, streams the chat completion,
680    /// and extracts votes from the LLM's response.
681    async fn llm_create_streaming(
682        self: Arc<Self>,
683        ctx: ctx::Context<CTXEXT>,
684        id: String,
685        created: u64,
686        ensemble: String,
687        indexer: Arc<ChoiceIndexer>,
688        llm: objectiveai::ensemble_llm::EnsembleLlmWithFallbacksAndCount,
689        ensemble_index: usize,
690        flat_ensemble_index: usize,
691        weight: Decimal,
692        request: Arc<objectiveai::vector::completions::request::VectorCompletionCreateParams>,
693        prompt_id: String,
694        tools_id: Option<String>,
695        responses_ids: Vec<String>,
696    ) -> impl Stream<Item = objectiveai::vector::completions::response::streaming::VectorCompletionChunk> + Send + 'static
697    {
698        let request_responses_len = request.responses.len();
699
700        // create pfx data for each LLM
701        let (vector_pfx_data, vector_pfx_indices) = {
702            let mut rng = rand::rng();
703            let mut vector_pfx_data = HashMap::with_capacity(
704                1 + llm.fallbacks.as_ref().map(Vec::len).unwrap_or(0),
705            );
706            let mut vector_pfx_indices = Vec::with_capacity(
707                1 + llm.fallbacks.as_ref().map(Vec::len).unwrap_or(0),
708            );
709            for llm in std::iter::once(&llm.inner).chain(
710                llm.fallbacks
711                    .iter()
712                    .map(|fallbacks| fallbacks.iter())
713                    .flatten(),
714            ) {
715                // create the prefixes
716                let pfx_tree = super::PfxTree::new(
717                    &mut rng,
718                    request_responses_len,
719                    match llm.base.top_logprobs {
720                        Some(0) | Some(1) | None => 20,
721                        Some(top_logprobs) => top_logprobs as usize,
722                    },
723                );
724
725                // map prefix to response index
726                let pfx_indices =
727                    pfx_tree.pfx_indices(&mut rng, request_responses_len);
728
729                let (
730                    // regex capture pattern matching response keys as-is
731                    responses_key_pattern,
732                    // regex capture pattern matching response keys stripped of first and last tick
733                    responses_key_pattern_stripped,
734                ) = pfx_tree.regex_patterns(&pfx_indices);
735
736                vector_pfx_data.insert(
737                    llm.id.clone(),
738                    super::PfxData {
739                        pfx_tree,
740                        responses_key_pattern,
741                        responses_key_pattern_stripped,
742                    },
743                );
744                vector_pfx_indices.push(Arc::new(pfx_indices));
745            }
746            (vector_pfx_data, vector_pfx_indices)
747        };
748
749        // stream
750        let mut stream = match self
751            .chat_client
752            .clone()
753            .create_streaming_for_vector_handle_usage(
754                ctx,
755                request,
756                vector_pfx_indices,
757                llm,
758            )
759            .await
760        {
761            Ok(stream) => stream,
762            Err(e) => {
763                return futures::future::Either::Left(
764                    Self::llm_create_streaming_vector_error(
765                        id,
766                        indexer.get(flat_ensemble_index),
767                        e,
768                        created,
769                        ensemble,
770                    ),
771                );
772            }
773        };
774
775        // only return error if the very first stream item is an error
776        let mut next_chat_chunk = match stream.try_next().await {
777            Ok(Some(chunk)) => Some(chunk),
778            Err(e) => {
779                return futures::future::Either::Left(
780                    Self::llm_create_streaming_vector_error(
781                        id,
782                        indexer.get(flat_ensemble_index),
783                        e,
784                        created,
785                        ensemble,
786                    ),
787                );
788            }
789            Ok(None) => {
790                // chat client will always yield at least 1 item
791                unreachable!()
792            }
793        };
794
795        // the aggregate of all chunks
796        let mut aggregate: Option<
797            objectiveai::vector::completions::response::streaming::VectorCompletionChunk,
798        > = None;
799
800        futures::future::Either::Right(async_stream::stream! {
801            while let Some(chat_chunk) = next_chat_chunk.take() {
802                // fetch the next chat chunk or error
803                let error = match stream.next().await {
804                    Some(Ok(ncc)) => {
805                        // set next chat chunk
806                        next_chat_chunk = Some(ncc);
807                        None
808                    }
809                    Some(Err(e)) => {
810                        // end the loop after this iteration
811                        // add error to choices
812                        Some(objectiveai::error::ResponseError::from(&e))
813                    }
814                    None => {
815                        // end the loop after this iteration
816                        None
817                    }
818                };
819
820                // construct the vector completions chunk from the chat completions chunk
821                let mut chunk = objectiveai::vector::completions::response::streaming::VectorCompletionChunk {
822                    id: id.clone(),
823                    completions: vec![
824                        objectiveai::vector::completions::response::streaming::ChatCompletionChunk {
825                            index: indexer.get(flat_ensemble_index),
826                            inner: chat_chunk,
827                            error,
828                        },
829                    ],
830                    votes: Vec::new(),
831                    scores: Vec::new(),
832                    weights: Vec::new(),
833                    created,
834                    ensemble: ensemble.clone(),
835                    object: objectiveai::vector::completions::response::streaming::Object::VectorCompletionChunk,
836                    usage: None,
837                };
838
839                // push the chunk into the aggregate
840                match aggregate {
841                    Some(ref mut aggregate) => {
842                        aggregate.push(&chunk);
843                    }
844                    None => {
845                        aggregate = Some(chunk.clone());
846                    }
847                }
848
849                // if last chunk, add votes
850                if next_chat_chunk.is_none() {
851                    let aggregate = aggregate.take().unwrap();
852                    for completion in aggregate.completions {
853                        // get pfx data for this LLM
854                        let super::PfxData {
855                            pfx_tree,
856                            responses_key_pattern,
857                            responses_key_pattern_stripped,
858                        } = &vector_pfx_data[&completion.inner.model];
859
860                        // try to get votes for each choice
861                        for choice in completion.inner.choices {
862                            if let Some(vote) = super::get_vote(
863                                pfx_tree.clone(),
864                                &responses_key_pattern,
865                                &responses_key_pattern_stripped,
866                                request_responses_len,
867                                &choice,
868                            ) {
869                                chunk.votes.push(objectiveai::vector::completions::response::Vote {
870                                    model: completion.inner.model.clone(),
871                                    ensemble_index: ensemble_index as u64,
872                                    flat_ensemble_index: flat_ensemble_index as u64,
873                                    prompt_id: prompt_id.clone(),
874                                    tools_id: tools_id.clone(),
875                                    responses_ids: responses_ids.clone(),
876                                    vote,
877                                    weight,
878                                    retry: None,
879                                    from_cache: None,
880                                    from_rng: None,
881                                    completion_index: Some(completion.index),
882                                });
883                            }
884                        }
885                    }
886                }
887
888                // yield chunk
889                yield chunk;
890            }
891        })
892    }
893
894    /// Creates an error response chunk for a failed LLM completion.
895    fn llm_create_streaming_vector_error(
896        id: String,
897        completion_index: u64,
898        error: chat::completions::Error,
899        created: u64,
900        ensemble: String,
901    ) -> impl Stream<Item = objectiveai::vector::completions::response::streaming::VectorCompletionChunk>
902    + Send
903    + Unpin
904    + 'static
905    {
906        StreamOnce::new(
907            objectiveai::vector::completions::response::streaming::VectorCompletionChunk {
908                id,
909                completions: vec![
910                    objectiveai::vector::completions::response::streaming::ChatCompletionChunk {
911                        index: completion_index,
912                        inner: objectiveai::chat::completions::response::streaming::ChatCompletionChunk::default(),
913                        error: Some(objectiveai::error::ResponseError::from(&error)),
914                    },
915                ],
916                votes: Vec::new(),
917                scores: Vec::new(),
918                weights: Vec::new(),
919                created,
920                ensemble,
921                object: objectiveai::vector::completions::response::streaming::Object::VectorCompletionChunk,
922                usage: None,
923            }
924        )
925    }
926}