Skip to main content

objectiveai_api/chat/completions/
client.rs

1//! Chat completions client implementation.
2
3use futures::{StreamExt, TryStreamExt};
4
5use crate::{ctx, util::StreamOnce};
6use std::{sync::Arc, time::Duration};
7
8/// Generates a unique response ID for a chat completion.
9pub fn response_id(created: u64) -> String {
10    let uuid = uuid::Uuid::new_v4();
11    format!("chtcpl-{}-{}", uuid.simple(), created)
12}
13
14/// Client for creating chat completions.
15///
16/// Handles Ensemble LLM fetching, upstream provider selection with fallbacks,
17/// retry logic with exponential backoff, and usage tracking.
18#[derive(Debug, Clone)]
19pub struct Client<CTXEXT, FENSLLM, CUSG> {
20    /// Caching fetcher for Ensemble LLM definitions.
21    pub ensemble_llm_fetcher:
22        Arc<crate::ensemble_llm::fetcher::CachingFetcher<CTXEXT, FENSLLM>>,
23    /// Handler for tracking usage after completion.
24    pub usage_handler: Arc<CUSG>,
25    /// Client for communicating with upstream providers.
26    pub upstream_client: super::upstream::Client,
27
28    /// Current backoff interval for retry logic.
29    pub backoff_current_interval: Duration,
30    /// Initial backoff interval for retry logic.
31    pub backoff_initial_interval: Duration,
32    /// Randomization factor for backoff jitter.
33    pub backoff_randomization_factor: f64,
34    /// Multiplier for exponential backoff growth.
35    pub backoff_multiplier: f64,
36    /// Maximum backoff interval.
37    pub backoff_max_interval: Duration,
38    /// Maximum total time to spend on retries.
39    pub backoff_max_elapsed_time: Duration,
40}
41
42impl<CTXEXT, FENSLLM, CUSG> Client<CTXEXT, FENSLLM, CUSG> {
43    /// Creates a new chat completions client.
44    pub fn new(
45        ensemble_llm_fetcher: Arc<
46            crate::ensemble_llm::fetcher::CachingFetcher<CTXEXT, FENSLLM>,
47        >,
48        usage_handler: Arc<CUSG>,
49        upstream_client: super::upstream::Client,
50        backoff_current_interval: Duration,
51        backoff_initial_interval: Duration,
52        backoff_randomization_factor: f64,
53        backoff_multiplier: f64,
54        backoff_max_interval: Duration,
55        backoff_max_elapsed_time: Duration,
56    ) -> Self {
57        Self {
58            ensemble_llm_fetcher,
59            usage_handler,
60            upstream_client,
61            backoff_current_interval,
62            backoff_initial_interval,
63            backoff_randomization_factor,
64            backoff_multiplier,
65            backoff_max_interval,
66            backoff_max_elapsed_time,
67        }
68    }
69}
70
71impl<CTXEXT, FENSLLM, CUSG> Client<CTXEXT, FENSLLM, CUSG>
72where
73    CTXEXT: ctx::ContextExt + Send + Sync + 'static,
74    FENSLLM:
75        crate::ensemble_llm::fetcher::Fetcher<CTXEXT> + Send + Sync + 'static,
76    CUSG: super::usage_handler::UsageHandler<CTXEXT> + Send + Sync + 'static,
77{
78    /// Creates a unary chat completion, tracking usage after completion.
79    ///
80    /// Internally streams the response and aggregates chunks into a single response.
81    pub async fn create_unary_for_chat_handle_usage(
82        self: Arc<Self>,
83        ctx: ctx::Context<CTXEXT>,
84        request: Arc<
85            objectiveai::chat::completions::request::ChatCompletionCreateParams,
86        >,
87    ) -> Result<
88        objectiveai::chat::completions::response::unary::ChatCompletion,
89        super::Error,
90    > {
91        let mut aggregate: Option<
92            objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
93        > = None;
94        let mut stream = self
95            .create_streaming_for_chat_handle_usage(ctx, request)
96            .await?;
97        while let Some(chunk) = stream.try_next().await? {
98            match &mut aggregate {
99                Some(aggregate) => aggregate.push(&chunk),
100                None => {
101                    aggregate = Some(chunk);
102                }
103            }
104        }
105        Ok(aggregate.unwrap().into())
106    }
107
108    /// Creates a streaming chat completion, tracking usage after the stream ends.
109    pub async fn create_streaming_for_chat_handle_usage(
110        self: Arc<Self>,
111        ctx: ctx::Context<CTXEXT>,
112        request: Arc<objectiveai::chat::completions::request::ChatCompletionCreateParams>,
113    ) -> Result<
114        impl futures::Stream<
115            Item = Result<
116                objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
117                super::Error,
118            >,
119        > + Send
120        + Unpin
121        + 'static,
122        super::Error,
123    >{
124        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
125        let _ = tokio::spawn(async move {
126            let mut aggregate: Option<
127                objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
128            > = None;
129            let mut error = false;
130            let stream = match self
131                .clone()
132                .create_streaming_for_chat(ctx.clone(), request.clone())
133                .await
134            {
135                Ok(stream) => stream,
136                Err(e) => {
137                    let _ = tx.send(Err(e));
138                    return;
139                }
140            };
141            futures::pin_mut!(stream);
142            while let Some(result) = stream.next().await {
143                match &result {
144                    Ok(chunk) => match &mut aggregate {
145                        Some(aggregate) => aggregate.push(chunk),
146                        None => {
147                            aggregate = Some(chunk.clone());
148                        }
149                    },
150                    Err(_) => {
151                        error = true;
152                    }
153                }
154                let _ = tx.send(result);
155            }
156            drop(stream);
157            drop(tx);
158            if !error {
159                self.usage_handler
160                    .handle_usage(ctx, Some(request), aggregate.unwrap().into())
161                    .await;
162            }
163        });
164        let mut stream =
165            tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
166        match stream.next().await {
167            Some(Ok(chunk)) => Ok(StreamOnce::new(Ok(chunk)).chain(stream)),
168            Some(Err(e)) => Err(e),
169            None => unreachable!(),
170        }
171    }
172
173    /// Creates a streaming completion for vector voting, tracking usage after the stream ends.
174    ///
175    /// Used internally by vector completions to generate LLM votes.
176    pub async fn create_streaming_for_vector_handle_usage(
177        self: Arc<Self>,
178        ctx: ctx::Context<CTXEXT>,
179        request: Arc<
180            objectiveai::vector::completions::request::VectorCompletionCreateParams,
181        >,
182        vector_pfx_indices: Vec<Arc<Vec<(String, usize)>>>,
183        ensemble_llm: objectiveai::ensemble_llm::EnsembleLlmWithFallbacksAndCount,
184    ) -> Result<
185        impl futures::Stream<
186            Item = Result<
187                objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
188                super::Error,
189            >,
190        > + Send
191        + Unpin
192        + 'static,
193        super::Error,
194    >{
195        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
196        let _ = tokio::spawn(async move {
197            let mut aggregate: Option<
198                objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
199            > = None;
200            let mut error = false;
201            let stream = match self
202                .clone()
203                .create_streaming_for_vector(
204                    ctx.clone(),
205                    request,
206                    vector_pfx_indices,
207                    ensemble_llm,
208                )
209                .await
210            {
211                Ok(stream) => stream,
212                Err(e) => {
213                    let _ = tx.send(Err(e));
214                    return;
215                }
216            };
217            futures::pin_mut!(stream);
218            while let Some(result) = stream.next().await {
219                match &result {
220                    Ok(chunk) => match &mut aggregate {
221                        Some(aggregate) => aggregate.push(chunk),
222                        None => {
223                            aggregate = Some(chunk.clone());
224                        }
225                    },
226                    Err(_) => {
227                        error = true;
228                    }
229                }
230                let _ = tx.send(result);
231            }
232            drop(stream);
233            drop(tx);
234            if !error {
235                self.usage_handler
236                    .handle_usage(ctx, None, aggregate.unwrap().into())
237                    .await;
238            }
239        });
240        let mut stream =
241            tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
242        match stream.next().await {
243            Some(Ok(chunk)) => Ok(StreamOnce::new(Ok(chunk)).chain(stream)),
244            Some(Err(e)) => Err(e),
245            None => unreachable!(),
246        }
247    }
248}
249
250impl<CTXEXT, FENSLLM, CUSG> Client<CTXEXT, FENSLLM, CUSG>
251where
252    CTXEXT: ctx::ContextExt + Send + Sync + 'static,
253    FENSLLM:
254        crate::ensemble_llm::fetcher::Fetcher<CTXEXT> + Send + Sync + 'static,
255{
256    /// Creates a streaming chat completion without usage tracking.
257    ///
258    /// Handles model validation, Ensemble LLM fetching, fallback logic,
259    /// and retry with exponential backoff.
260    pub async fn create_streaming_for_chat(
261        &self,
262        ctx: ctx::Context<CTXEXT>,
263        request: Arc<objectiveai::chat::completions::request::ChatCompletionCreateParams>,
264    ) -> Result<
265        impl futures::Stream<
266            Item = Result<
267                objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
268                super::Error,
269            >,
270        > + Send
271        + Unpin
272        + 'static,
273        super::Error,
274    >{
275        // timestamp and identify the completion
276        let created = std::time::SystemTime::now()
277            .duration_since(std::time::UNIX_EPOCH)
278            .unwrap()
279            .as_secs();
280        let response_id = response_id(created);
281
282        // validate models IDs
283        if let objectiveai::chat::completions::request::Model::Id(id) =
284            &request.model
285        {
286            if id.len() != 22 {
287                return Err(super::Error::InvalidEnsembleLlm(format!(
288                    "invalid ID: {}",
289                    id
290                )));
291            }
292        }
293        if let Some(models) = &request.models {
294            for model in models {
295                if let objectiveai::chat::completions::request::Model::Id(id) =
296                    model
297                {
298                    if id.len() != 22 {
299                        return Err(super::Error::InvalidEnsembleLlm(format!(
300                            "invalid ID: {}",
301                            id
302                        )));
303                    }
304                }
305            }
306        }
307
308        // collect all Ensemble LLMs
309        let mut models = Vec::with_capacity(
310            1 + request.models.as_ref().map(Vec::len).unwrap_or_default(),
311        );
312        models.push(&request.model);
313        if let Some(request_models) = &request.models {
314            models.extend(request_models.iter());
315        }
316
317        // spawn fetches for all Ensemble LLMs
318        self.ensemble_llm_fetcher.spawn_fetches(
319            ctx.clone(),
320            models.iter().filter_map(|model| {
321                if let objectiveai::chat::completions::request::Model::Id(id) =
322                    model
323                {
324                    Some(id.as_str())
325                } else {
326                    None
327                }
328            }),
329        );
330
331        // backoff and timeouts
332        let backoff = backoff::ExponentialBackoff {
333            current_interval: self.backoff_current_interval,
334            initial_interval: self.backoff_initial_interval,
335            randomization_factor: self.backoff_randomization_factor,
336            multiplier: self.backoff_multiplier,
337            max_interval: self.backoff_max_interval,
338            start_time: std::time::Instant::now(),
339            max_elapsed_time: Some(
340                request
341                    .backoff_max_elapsed_time
342                    .map(|ms| ms.min(600_000)) // at most 10 minutes
343                    .map(Duration::from_millis)
344                    .unwrap_or(self.backoff_max_elapsed_time),
345            ),
346            clock: backoff::SystemClock::default(),
347        };
348        let first_chunk_timeout = Duration::from_millis(
349            request
350                .first_chunk_timeout
351                .unwrap_or(10_000) // default 10 seconds
352                .min(10_000) // at least 10 seconds
353                .max(120_000), // at most 2 minutes
354        );
355        let other_chunk_timeout = Duration::from_millis(
356            request
357                .other_chunk_timeout
358                .unwrap_or(40_000) // default 40 seconds
359                .min(40_000) // at least 40 seconds
360                .max(120_000), // at most 2 minutes
361        );
362
363        // try each model in order
364        backoff::future::retry(backoff, || async {
365            let mut errors = Vec::new();
366            for model in &models {
367                // fetch or validate Ensemble LLM
368                let ensemble_llm = Arc::new(match model {
369                    objectiveai::chat::completions::request::Model::Id(id) => {
370                        match self
371                            .ensemble_llm_fetcher
372                            .fetch(ctx.clone(), id)
373                            .await
374                        {
375                            Ok(Some((ensemble_llm, _))) => ensemble_llm,
376                            Ok(None) => {
377                                errors.push(super::Error::EnsembleLlmNotFound);
378                                continue;
379                            }
380                            Err(e) => {
381                                errors.push(super::Error::FetchEnsembleLlm(e));
382                                continue;
383                            }
384                        }
385                    }
386                    objectiveai::chat::completions::request::Model::Provided(ensemble_llm_base) => {
387                        match ensemble_llm_base.clone().try_into() {
388                            Ok(ensemble_llm) => ensemble_llm,
389                            Err(msg) => {
390                                errors.push(super::Error::InvalidEnsembleLlm(msg));
391                                continue;
392                            }
393                        }
394                    }
395                });
396                // try to create streaming completion
397                match self.upstream_client.create_streaming(
398                    ctx.clone(),
399                    response_id.clone(),
400                    first_chunk_timeout,
401                    other_chunk_timeout,
402                    ensemble_llm,
403                    super::upstream::Params::Chat {
404                        request: request.clone(),
405                    },
406                ).await {
407                    Ok(Some(stream)) => {
408                        return Ok(stream.map_err(super::Error::UpstreamError));
409                    }
410                    Ok(None) => {}
411                    Err(e) => {
412                        errors.push(super::Error::UpstreamError(e));
413                    }
414                }
415            }
416            if errors.is_empty() {
417                Err(backoff::Error::permanent(super::Error::NoUpstreamsFound))
418            } else {
419                Err(backoff::Error::transient(super::Error::MultipleErrors(
420                    errors,
421                )))
422            }
423        })
424        .await
425    }
426
427    /// Creates a streaming completion for vector voting without usage tracking.
428    ///
429    /// Used internally by vector completions. Handles fallback logic
430    /// and retry with exponential backoff.
431    pub async fn create_streaming_for_vector(
432        &self,
433        ctx: ctx::Context<CTXEXT>,
434        request: Arc<
435            objectiveai::vector::completions::request::VectorCompletionCreateParams,
436        >,
437        vector_pfx_indices: Vec<Arc<Vec<(String, usize)>>>,
438        ensemble_llm: objectiveai::ensemble_llm::EnsembleLlmWithFallbacksAndCount,
439    ) -> Result<
440        impl futures::Stream<
441            Item = Result<
442                objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
443                super::Error,
444            >,
445        > + Send
446        + Unpin
447        + 'static,
448        super::Error,
449    >{
450        // timestamp and identify the completion
451        let created = std::time::SystemTime::now()
452            .duration_since(std::time::UNIX_EPOCH)
453            .unwrap()
454            .as_secs();
455        let response_id = response_id(created);
456
457        // collect all Ensemble LLMs
458        let mut models = Vec::with_capacity(
459            1 + ensemble_llm
460                .fallbacks
461                .as_ref()
462                .map(Vec::len)
463                .unwrap_or_default(),
464        );
465        models.push(Arc::new(ensemble_llm.inner));
466        if let Some(fallbacks) = ensemble_llm.fallbacks {
467            models.extend(fallbacks.into_iter().map(Arc::new));
468        }
469
470        // backoff and timeouts
471        let backoff = backoff::ExponentialBackoff {
472            current_interval: self.backoff_current_interval,
473            initial_interval: self.backoff_initial_interval,
474            randomization_factor: self.backoff_randomization_factor,
475            multiplier: self.backoff_multiplier,
476            max_interval: self.backoff_max_interval,
477            start_time: std::time::Instant::now(),
478            max_elapsed_time: Some(
479                request
480                    .backoff_max_elapsed_time
481                    .map(|ms| ms.min(600_000)) // at most 10 minutes
482                    .map(Duration::from_millis)
483                    .unwrap_or(self.backoff_max_elapsed_time),
484            ),
485            clock: backoff::SystemClock::default(),
486        };
487        let first_chunk_timeout = Duration::from_millis(
488            request
489                .first_chunk_timeout
490                .unwrap_or(10_000) // default 10 seconds
491                .min(10_000) // at least 10 seconds
492                .max(120_000), // at most 2 minutes
493        );
494        let other_chunk_timeout = Duration::from_millis(
495            request
496                .other_chunk_timeout
497                .unwrap_or(40_000) // default 40 seconds
498                .min(40_000) // at least 40 seconds
499                .max(120_000), // at most 2 minutes
500        );
501
502        // try each model in order
503        backoff::future::retry(backoff, || async {
504            let mut errors = Vec::new();
505            for (i, ensemble_llm) in models.iter().cloned().enumerate() {
506                // try to create streaming completion
507                match self
508                    .upstream_client
509                    .create_streaming(
510                        ctx.clone(),
511                        response_id.clone(),
512                        first_chunk_timeout,
513                        other_chunk_timeout,
514                        ensemble_llm.clone(),
515                        super::upstream::Params::Vector {
516                            request: request.clone(),
517                            vector_pfx_indices: vector_pfx_indices[i].clone(),
518                        },
519                    )
520                    .await
521                {
522                    Ok(Some(stream)) => {
523                        return Ok(stream.map_err(super::Error::UpstreamError));
524                    }
525                    Ok(None) => {}
526                    Err(e) => {
527                        errors.push(super::Error::UpstreamError(e));
528                    }
529                }
530            }
531            if errors.is_empty() {
532                Err(backoff::Error::permanent(super::Error::NoUpstreamsFound))
533            } else {
534                Err(backoff::Error::transient(super::Error::MultipleErrors(
535                    errors,
536                )))
537            }
538        })
539        .await
540    }
541}