objectiveai_api/chat/completions/upstream/
client.rs1use crate::{ctx, util::StreamOnce};
4use futures::{Stream, StreamExt, TryStreamExt};
5use std::{sync::Arc, time::Duration};
6
7#[derive(Debug, Clone)]
11pub struct Client {
12 pub openrouter_client: super::openrouter::Client,
14}
15
16impl Client {
17 pub fn new(openrouter_client: super::openrouter::Client) -> Self {
19 Self { openrouter_client }
20 }
21
22 pub async fn create_streaming(
27 &self,
28 ctx: ctx::Context<impl ctx::ContextExt + Send + Sync + 'static>,
29 id: String,
30 first_chunk_timeout: Duration,
31 other_chunk_timeout: Duration,
32 ensemble_llm: Arc<objectiveai::ensemble_llm::EnsembleLlm>,
33 request: super::Params,
34 ) -> Result<
35 Option<
36 impl Stream<
37 Item = Result<
38 objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
39 super::Error,
40 >,
41 > + Send
42 + Unpin
43 + 'static,
44 >,
45 super::Error,
46 >{
47 let mut errors = Vec::new();
48
49 for upstream in super::upstreams(&ensemble_llm, request.clone()) {
51 let byok = ctx
53 .ext
54 .get_byok(upstream)
55 .await
56 .map_err(super::Error::FetchByok)?;
57
58 if let Some(byok) = byok {
60 match self
61 .upstream_create_streaming(
62 upstream,
63 id.clone(),
64 Some(byok),
65 ctx.cost_multiplier,
66 first_chunk_timeout,
67 other_chunk_timeout,
68 ensemble_llm.clone(),
69 request.clone(),
70 )
71 .await
72 {
73 Ok(stream) => {
74 return Ok(Some(stream));
75 }
76 Err(e) => {
77 errors.push(e);
78 }
79 }
80 }
81
82 match self
84 .upstream_create_streaming(
85 upstream,
86 id.clone(),
87 None,
88 ctx.cost_multiplier,
89 first_chunk_timeout,
90 other_chunk_timeout,
91 ensemble_llm.clone(),
92 request.clone(),
93 )
94 .await
95 {
96 Ok(stream) => {
97 return Ok(Some(stream));
98 }
99 Err(e) => {
100 errors.push(e);
101 }
102 }
103 }
104
105 if errors.is_empty() {
106 Ok(None)
107 } else {
108 Err(super::Error::MultipleErrors(errors))
109 }
110 }
111
112 async fn upstream_create_streaming(
114 &self,
115 upstream: super::Upstream,
116 id: String,
117 byok: Option<String>,
118 cost_multiplier: rust_decimal::Decimal,
119 first_chunk_timeout: Duration,
120 other_chunk_timeout: Duration,
121 ensemble_llm: Arc<objectiveai::ensemble_llm::EnsembleLlm>,
122 request: super::Params,
123 ) -> Result<
124 impl Stream<
125 Item = Result<
126 objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
127 super::Error,
128 >,
129 > + Send
130 + Unpin
131 + 'static,
132 super::Error,
133 >{
134 let mut stream = match request {
135 super::Params::Chat { request } => self
136 .create_streaming_for_chat(
137 upstream,
138 id,
139 byok.as_deref(),
140 cost_multiplier,
141 first_chunk_timeout,
142 other_chunk_timeout,
143 &ensemble_llm,
144 &request,
145 )
146 .boxed(),
147 super::Params::Vector {
148 request,
149 vector_pfx_indices,
150 } => self
151 .create_streaming_for_vector(
152 upstream,
153 id,
154 byok.as_deref(),
155 cost_multiplier,
156 first_chunk_timeout,
157 other_chunk_timeout,
158 &ensemble_llm,
159 &request,
160 &vector_pfx_indices,
161 )
162 .boxed(),
163 };
164 match stream.try_next().await {
165 Ok(Some(chunk)) => Ok(StreamOnce::new(Ok(chunk)).chain(stream)),
166 Ok(None) => Err(super::Error::EmptyStream),
167 Err(e) => Err(e),
168 }
169 }
170
171 fn create_streaming_for_chat(
173 &self,
174 upstream: super::Upstream,
175 id: String,
176 byok: Option<&str>,
177 cost_multiplier: rust_decimal::Decimal,
178 first_chunk_timeout: Duration,
179 other_chunk_timeout: Duration,
180 ensemble_llm: &objectiveai::ensemble_llm::EnsembleLlm,
181 request: &objectiveai::chat::completions::request::ChatCompletionCreateParams,
182 ) -> impl Stream<
183 Item = Result<
184 objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
185 super::Error,
186 >,
187 > + Send
188 + 'static{
189 match upstream {
190 super::Upstream::OpenRouter => self
191 .openrouter_client
192 .create_streaming_for_chat(
193 id,
194 byok,
195 cost_multiplier,
196 first_chunk_timeout,
197 other_chunk_timeout,
198 ensemble_llm,
199 request,
200 )
201 .map_err(super::Error::from),
202 }
203 }
204
205 fn create_streaming_for_vector(
209 &self,
210 upstream: super::Upstream,
211 id: String,
212 byok: Option<&str>,
213 cost_multiplier: rust_decimal::Decimal,
214 first_chunk_timeout: Duration,
215 other_chunk_timeout: Duration,
216 ensemble_llm: &objectiveai::ensemble_llm::EnsembleLlm,
217 request: &objectiveai::vector::completions::request::VectorCompletionCreateParams,
218 vector_pfx_indices: &[(String, usize)],
219 ) -> impl Stream<
220 Item = Result<
221 objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
222 super::Error,
223 >,
224 > + Send
225 + 'static{
226 match upstream {
227 super::Upstream::OpenRouter => self
228 .openrouter_client
229 .create_streaming_for_vector(
230 id,
231 byok,
232 cost_multiplier,
233 first_chunk_timeout,
234 other_chunk_timeout,
235 ensemble_llm,
236 request,
237 vector_pfx_indices,
238 )
239 .map_err(super::Error::from),
240 }
241 }
242}