1use futures::{StreamExt, TryStreamExt};
4
5use crate::{ctx, util::StreamOnce};
6use std::{sync::Arc, time::Duration};
7
8pub fn response_id(created: u64) -> String {
10 let uuid = uuid::Uuid::new_v4();
11 format!("chtcpl-{}-{}", uuid.simple(), created)
12}
13
14#[derive(Debug, Clone)]
19pub struct Client<CTXEXT, FENSLLM, CUSG> {
20 pub ensemble_llm_fetcher:
22 Arc<crate::ensemble_llm::fetcher::CachingFetcher<CTXEXT, FENSLLM>>,
23 pub usage_handler: Arc<CUSG>,
25 pub upstream_client: super::upstream::Client,
27
28 pub backoff_current_interval: Duration,
30 pub backoff_initial_interval: Duration,
32 pub backoff_randomization_factor: f64,
34 pub backoff_multiplier: f64,
36 pub backoff_max_interval: Duration,
38 pub backoff_max_elapsed_time: Duration,
40}
41
42impl<CTXEXT, FENSLLM, CUSG> Client<CTXEXT, FENSLLM, CUSG> {
43 pub fn new(
45 ensemble_llm_fetcher: Arc<
46 crate::ensemble_llm::fetcher::CachingFetcher<CTXEXT, FENSLLM>,
47 >,
48 usage_handler: Arc<CUSG>,
49 upstream_client: super::upstream::Client,
50 backoff_current_interval: Duration,
51 backoff_initial_interval: Duration,
52 backoff_randomization_factor: f64,
53 backoff_multiplier: f64,
54 backoff_max_interval: Duration,
55 backoff_max_elapsed_time: Duration,
56 ) -> Self {
57 Self {
58 ensemble_llm_fetcher,
59 usage_handler,
60 upstream_client,
61 backoff_current_interval,
62 backoff_initial_interval,
63 backoff_randomization_factor,
64 backoff_multiplier,
65 backoff_max_interval,
66 backoff_max_elapsed_time,
67 }
68 }
69}
70
71impl<CTXEXT, FENSLLM, CUSG> Client<CTXEXT, FENSLLM, CUSG>
72where
73 CTXEXT: ctx::ContextExt + Send + Sync + 'static,
74 FENSLLM:
75 crate::ensemble_llm::fetcher::Fetcher<CTXEXT> + Send + Sync + 'static,
76 CUSG: super::usage_handler::UsageHandler<CTXEXT> + Send + Sync + 'static,
77{
78 pub async fn create_unary_for_chat_handle_usage(
82 self: Arc<Self>,
83 ctx: ctx::Context<CTXEXT>,
84 request: Arc<
85 objectiveai::chat::completions::request::ChatCompletionCreateParams,
86 >,
87 ) -> Result<
88 objectiveai::chat::completions::response::unary::ChatCompletion,
89 super::Error,
90 > {
91 let mut aggregate: Option<
92 objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
93 > = None;
94 let mut stream = self
95 .create_streaming_for_chat_handle_usage(ctx, request)
96 .await?;
97 while let Some(chunk) = stream.try_next().await? {
98 match &mut aggregate {
99 Some(aggregate) => aggregate.push(&chunk),
100 None => {
101 aggregate = Some(chunk);
102 }
103 }
104 }
105 Ok(aggregate.unwrap().into())
106 }
107
108 pub async fn create_streaming_for_chat_handle_usage(
110 self: Arc<Self>,
111 ctx: ctx::Context<CTXEXT>,
112 request: Arc<objectiveai::chat::completions::request::ChatCompletionCreateParams>,
113 ) -> Result<
114 impl futures::Stream<
115 Item = Result<
116 objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
117 super::Error,
118 >,
119 > + Send
120 + Unpin
121 + 'static,
122 super::Error,
123 >{
124 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
125 let _ = tokio::spawn(async move {
126 let mut aggregate: Option<
127 objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
128 > = None;
129 let mut error = false;
130 let stream = match self
131 .clone()
132 .create_streaming_for_chat(ctx.clone(), request.clone())
133 .await
134 {
135 Ok(stream) => stream,
136 Err(e) => {
137 let _ = tx.send(Err(e));
138 return;
139 }
140 };
141 futures::pin_mut!(stream);
142 while let Some(result) = stream.next().await {
143 match &result {
144 Ok(chunk) => match &mut aggregate {
145 Some(aggregate) => aggregate.push(chunk),
146 None => {
147 aggregate = Some(chunk.clone());
148 }
149 },
150 Err(_) => {
151 error = true;
152 }
153 }
154 let _ = tx.send(result);
155 }
156 drop(stream);
157 drop(tx);
158 if !error {
159 self.usage_handler
160 .handle_usage(ctx, Some(request), aggregate.unwrap().into())
161 .await;
162 }
163 });
164 let mut stream =
165 tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
166 match stream.next().await {
167 Some(Ok(chunk)) => Ok(StreamOnce::new(Ok(chunk)).chain(stream)),
168 Some(Err(e)) => Err(e),
169 None => unreachable!(),
170 }
171 }
172
173 pub async fn create_streaming_for_vector_handle_usage(
177 self: Arc<Self>,
178 ctx: ctx::Context<CTXEXT>,
179 request: Arc<
180 objectiveai::vector::completions::request::VectorCompletionCreateParams,
181 >,
182 vector_pfx_indices: Vec<Arc<Vec<(String, usize)>>>,
183 ensemble_llm: objectiveai::ensemble_llm::EnsembleLlmWithFallbacksAndCount,
184 ) -> Result<
185 impl futures::Stream<
186 Item = Result<
187 objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
188 super::Error,
189 >,
190 > + Send
191 + Unpin
192 + 'static,
193 super::Error,
194 >{
195 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
196 let _ = tokio::spawn(async move {
197 let mut aggregate: Option<
198 objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
199 > = None;
200 let mut error = false;
201 let stream = match self
202 .clone()
203 .create_streaming_for_vector(
204 ctx.clone(),
205 request,
206 vector_pfx_indices,
207 ensemble_llm,
208 )
209 .await
210 {
211 Ok(stream) => stream,
212 Err(e) => {
213 let _ = tx.send(Err(e));
214 return;
215 }
216 };
217 futures::pin_mut!(stream);
218 while let Some(result) = stream.next().await {
219 match &result {
220 Ok(chunk) => match &mut aggregate {
221 Some(aggregate) => aggregate.push(chunk),
222 None => {
223 aggregate = Some(chunk.clone());
224 }
225 },
226 Err(_) => {
227 error = true;
228 }
229 }
230 let _ = tx.send(result);
231 }
232 drop(stream);
233 drop(tx);
234 if !error {
235 self.usage_handler
236 .handle_usage(ctx, None, aggregate.unwrap().into())
237 .await;
238 }
239 });
240 let mut stream =
241 tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
242 match stream.next().await {
243 Some(Ok(chunk)) => Ok(StreamOnce::new(Ok(chunk)).chain(stream)),
244 Some(Err(e)) => Err(e),
245 None => unreachable!(),
246 }
247 }
248}
249
250impl<CTXEXT, FENSLLM, CUSG> Client<CTXEXT, FENSLLM, CUSG>
251where
252 CTXEXT: ctx::ContextExt + Send + Sync + 'static,
253 FENSLLM:
254 crate::ensemble_llm::fetcher::Fetcher<CTXEXT> + Send + Sync + 'static,
255{
256 pub async fn create_streaming_for_chat(
261 &self,
262 ctx: ctx::Context<CTXEXT>,
263 request: Arc<objectiveai::chat::completions::request::ChatCompletionCreateParams>,
264 ) -> Result<
265 impl futures::Stream<
266 Item = Result<
267 objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
268 super::Error,
269 >,
270 > + Send
271 + Unpin
272 + 'static,
273 super::Error,
274 >{
275 let created = std::time::SystemTime::now()
277 .duration_since(std::time::UNIX_EPOCH)
278 .unwrap()
279 .as_secs();
280 let response_id = response_id(created);
281
282 if let objectiveai::chat::completions::request::Model::Id(id) =
284 &request.model
285 {
286 if id.len() != 22 {
287 return Err(super::Error::InvalidEnsembleLlm(format!(
288 "invalid ID: {}",
289 id
290 )));
291 }
292 }
293 if let Some(models) = &request.models {
294 for model in models {
295 if let objectiveai::chat::completions::request::Model::Id(id) =
296 model
297 {
298 if id.len() != 22 {
299 return Err(super::Error::InvalidEnsembleLlm(format!(
300 "invalid ID: {}",
301 id
302 )));
303 }
304 }
305 }
306 }
307
308 let mut models = Vec::with_capacity(
310 1 + request.models.as_ref().map(Vec::len).unwrap_or_default(),
311 );
312 models.push(&request.model);
313 if let Some(request_models) = &request.models {
314 models.extend(request_models.iter());
315 }
316
317 self.ensemble_llm_fetcher.spawn_fetches(
319 ctx.clone(),
320 models.iter().filter_map(|model| {
321 if let objectiveai::chat::completions::request::Model::Id(id) =
322 model
323 {
324 Some(id.as_str())
325 } else {
326 None
327 }
328 }),
329 );
330
331 let backoff = backoff::ExponentialBackoff {
333 current_interval: self.backoff_current_interval,
334 initial_interval: self.backoff_initial_interval,
335 randomization_factor: self.backoff_randomization_factor,
336 multiplier: self.backoff_multiplier,
337 max_interval: self.backoff_max_interval,
338 start_time: std::time::Instant::now(),
339 max_elapsed_time: Some(
340 request
341 .backoff_max_elapsed_time
342 .map(|ms| ms.min(600_000)) .map(Duration::from_millis)
344 .unwrap_or(self.backoff_max_elapsed_time),
345 ),
346 clock: backoff::SystemClock::default(),
347 };
348 let first_chunk_timeout = Duration::from_millis(
349 request
350 .first_chunk_timeout
351 .unwrap_or(10_000) .min(10_000) .max(120_000), );
355 let other_chunk_timeout = Duration::from_millis(
356 request
357 .other_chunk_timeout
358 .unwrap_or(40_000) .min(40_000) .max(120_000), );
362
363 backoff::future::retry(backoff, || async {
365 let mut errors = Vec::new();
366 for model in &models {
367 let ensemble_llm = Arc::new(match model {
369 objectiveai::chat::completions::request::Model::Id(id) => {
370 match self
371 .ensemble_llm_fetcher
372 .fetch(ctx.clone(), id)
373 .await
374 {
375 Ok(Some((ensemble_llm, _))) => ensemble_llm,
376 Ok(None) => {
377 errors.push(super::Error::EnsembleLlmNotFound);
378 continue;
379 }
380 Err(e) => {
381 errors.push(super::Error::FetchEnsembleLlm(e));
382 continue;
383 }
384 }
385 }
386 objectiveai::chat::completions::request::Model::Provided(ensemble_llm_base) => {
387 match ensemble_llm_base.clone().try_into() {
388 Ok(ensemble_llm) => ensemble_llm,
389 Err(msg) => {
390 errors.push(super::Error::InvalidEnsembleLlm(msg));
391 continue;
392 }
393 }
394 }
395 });
396 match self.upstream_client.create_streaming(
398 ctx.clone(),
399 response_id.clone(),
400 first_chunk_timeout,
401 other_chunk_timeout,
402 ensemble_llm,
403 super::upstream::Params::Chat {
404 request: request.clone(),
405 },
406 ).await {
407 Ok(Some(stream)) => {
408 return Ok(stream.map_err(super::Error::UpstreamError));
409 }
410 Ok(None) => {}
411 Err(e) => {
412 errors.push(super::Error::UpstreamError(e));
413 }
414 }
415 }
416 if errors.is_empty() {
417 Err(backoff::Error::permanent(super::Error::NoUpstreamsFound))
418 } else {
419 Err(backoff::Error::transient(super::Error::MultipleErrors(
420 errors,
421 )))
422 }
423 })
424 .await
425 }
426
427 pub async fn create_streaming_for_vector(
432 &self,
433 ctx: ctx::Context<CTXEXT>,
434 request: Arc<
435 objectiveai::vector::completions::request::VectorCompletionCreateParams,
436 >,
437 vector_pfx_indices: Vec<Arc<Vec<(String, usize)>>>,
438 ensemble_llm: objectiveai::ensemble_llm::EnsembleLlmWithFallbacksAndCount,
439 ) -> Result<
440 impl futures::Stream<
441 Item = Result<
442 objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
443 super::Error,
444 >,
445 > + Send
446 + Unpin
447 + 'static,
448 super::Error,
449 >{
450 let created = std::time::SystemTime::now()
452 .duration_since(std::time::UNIX_EPOCH)
453 .unwrap()
454 .as_secs();
455 let response_id = response_id(created);
456
457 let mut models = Vec::with_capacity(
459 1 + ensemble_llm
460 .fallbacks
461 .as_ref()
462 .map(Vec::len)
463 .unwrap_or_default(),
464 );
465 models.push(Arc::new(ensemble_llm.inner));
466 if let Some(fallbacks) = ensemble_llm.fallbacks {
467 models.extend(fallbacks.into_iter().map(Arc::new));
468 }
469
470 let backoff = backoff::ExponentialBackoff {
472 current_interval: self.backoff_current_interval,
473 initial_interval: self.backoff_initial_interval,
474 randomization_factor: self.backoff_randomization_factor,
475 multiplier: self.backoff_multiplier,
476 max_interval: self.backoff_max_interval,
477 start_time: std::time::Instant::now(),
478 max_elapsed_time: Some(
479 request
480 .backoff_max_elapsed_time
481 .map(|ms| ms.min(600_000)) .map(Duration::from_millis)
483 .unwrap_or(self.backoff_max_elapsed_time),
484 ),
485 clock: backoff::SystemClock::default(),
486 };
487 let first_chunk_timeout = Duration::from_millis(
488 request
489 .first_chunk_timeout
490 .unwrap_or(10_000) .min(10_000) .max(120_000), );
494 let other_chunk_timeout = Duration::from_millis(
495 request
496 .other_chunk_timeout
497 .unwrap_or(40_000) .min(40_000) .max(120_000), );
501
502 backoff::future::retry(backoff, || async {
504 let mut errors = Vec::new();
505 for (i, ensemble_llm) in models.iter().cloned().enumerate() {
506 match self
508 .upstream_client
509 .create_streaming(
510 ctx.clone(),
511 response_id.clone(),
512 first_chunk_timeout,
513 other_chunk_timeout,
514 ensemble_llm.clone(),
515 super::upstream::Params::Vector {
516 request: request.clone(),
517 vector_pfx_indices: vector_pfx_indices[i].clone(),
518 },
519 )
520 .await
521 {
522 Ok(Some(stream)) => {
523 return Ok(stream.map_err(super::Error::UpstreamError));
524 }
525 Ok(None) => {}
526 Err(e) => {
527 errors.push(super::Error::UpstreamError(e));
528 }
529 }
530 }
531 if errors.is_empty() {
532 Err(backoff::Error::permanent(super::Error::NoUpstreamsFound))
533 } else {
534 Err(backoff::Error::transient(super::Error::MultipleErrors(
535 errors,
536 )))
537 }
538 })
539 .await
540 }
541}