Skip to main content

objectiveai_api/chat/completions/upstream/
client.rs

1//! Unified upstream client that dispatches to provider-specific clients.
2
3use crate::{ctx, util::StreamOnce};
4use futures::{Stream, StreamExt, TryStreamExt};
5use std::{sync::Arc, time::Duration};
6
7/// Client that manages connections to all upstream providers.
8///
9/// Handles provider selection, BYOK key injection, and fallback between providers.
10#[derive(Debug, Clone)]
11pub struct Client {
12    /// OpenRouter provider client.
13    pub openrouter_client: super::openrouter::Client,
14}
15
16impl Client {
17    /// Creates a new upstream client.
18    pub fn new(openrouter_client: super::openrouter::Client) -> Self {
19        Self { openrouter_client }
20    }
21
22    /// Creates a streaming completion, trying each upstream provider in order.
23    ///
24    /// First attempts with BYOK if available, then falls back to the default key.
25    /// Returns `Ok(None)` if no upstreams were available.
26    pub async fn create_streaming(
27        &self,
28        ctx: ctx::Context<impl ctx::ContextExt + Send + Sync + 'static>,
29        id: String,
30        first_chunk_timeout: Duration,
31        other_chunk_timeout: Duration,
32        ensemble_llm: Arc<objectiveai::ensemble_llm::EnsembleLlm>,
33        request: super::Params,
34    ) -> Result<
35        Option<
36            impl Stream<
37                Item = Result<
38                    objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
39                    super::Error,
40                >,
41            > + Send
42            + Unpin
43            + 'static,
44        >,
45        super::Error,
46    >{
47        let mut errors = Vec::new();
48
49        // try each upstream in order
50        for upstream in super::upstreams(&ensemble_llm, request.clone()) {
51            // fetch BYOK from context
52            let byok = ctx
53                .ext
54                .get_byok(upstream)
55                .await
56                .map_err(super::Error::FetchByok)?;
57
58            // first, try BYOK if available
59            if let Some(byok) = byok {
60                match self
61                    .upstream_create_streaming(
62                        upstream,
63                        id.clone(),
64                        Some(byok),
65                        ctx.cost_multiplier,
66                        first_chunk_timeout,
67                        other_chunk_timeout,
68                        ensemble_llm.clone(),
69                        request.clone(),
70                    )
71                    .await
72                {
73                    Ok(stream) => {
74                        return Ok(Some(stream));
75                    }
76                    Err(e) => {
77                        errors.push(e);
78                    }
79                }
80            }
81
82            // then, try without BYOK
83            match self
84                .upstream_create_streaming(
85                    upstream,
86                    id.clone(),
87                    None,
88                    ctx.cost_multiplier,
89                    first_chunk_timeout,
90                    other_chunk_timeout,
91                    ensemble_llm.clone(),
92                    request.clone(),
93                )
94                .await
95            {
96                Ok(stream) => {
97                    return Ok(Some(stream));
98                }
99                Err(e) => {
100                    errors.push(e);
101                }
102            }
103        }
104
105        if errors.is_empty() {
106            Ok(None)
107        } else {
108            Err(super::Error::MultipleErrors(errors))
109        }
110    }
111
112    /// Creates a streaming completion with a specific upstream provider.
113    async fn upstream_create_streaming(
114        &self,
115        upstream: super::Upstream,
116        id: String,
117        byok: Option<String>,
118        cost_multiplier: rust_decimal::Decimal,
119        first_chunk_timeout: Duration,
120        other_chunk_timeout: Duration,
121        ensemble_llm: Arc<objectiveai::ensemble_llm::EnsembleLlm>,
122        request: super::Params,
123    ) -> Result<
124        impl Stream<
125            Item = Result<
126                objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
127                super::Error,
128            >,
129        > + Send
130        + Unpin
131        + 'static,
132        super::Error,
133    >{
134        let mut stream = match request {
135            super::Params::Chat { request } => self
136                .create_streaming_for_chat(
137                    upstream,
138                    id,
139                    byok.as_deref(),
140                    cost_multiplier,
141                    first_chunk_timeout,
142                    other_chunk_timeout,
143                    &ensemble_llm,
144                    &request,
145                )
146                .boxed(),
147            super::Params::Vector {
148                request,
149                vector_pfx_indices,
150            } => self
151                .create_streaming_for_vector(
152                    upstream,
153                    id,
154                    byok.as_deref(),
155                    cost_multiplier,
156                    first_chunk_timeout,
157                    other_chunk_timeout,
158                    &ensemble_llm,
159                    &request,
160                    &vector_pfx_indices,
161                )
162                .boxed(),
163        };
164        match stream.try_next().await {
165            Ok(Some(chunk)) => Ok(StreamOnce::new(Ok(chunk)).chain(stream)),
166            Ok(None) => Err(super::Error::EmptyStream),
167            Err(e) => Err(e),
168        }
169    }
170
171    /// Creates a streaming chat completion with a specific upstream provider.
172    fn create_streaming_for_chat(
173        &self,
174        upstream: super::Upstream,
175        id: String,
176        byok: Option<&str>,
177        cost_multiplier: rust_decimal::Decimal,
178        first_chunk_timeout: Duration,
179        other_chunk_timeout: Duration,
180        ensemble_llm: &objectiveai::ensemble_llm::EnsembleLlm,
181        request: &objectiveai::chat::completions::request::ChatCompletionCreateParams,
182    ) -> impl Stream<
183        Item = Result<
184            objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
185            super::Error,
186        >,
187    > + Send
188    + 'static{
189        match upstream {
190            super::Upstream::OpenRouter => self
191                .openrouter_client
192                .create_streaming_for_chat(
193                    id,
194                    byok,
195                    cost_multiplier,
196                    first_chunk_timeout,
197                    other_chunk_timeout,
198                    ensemble_llm,
199                    request,
200                )
201                .map_err(super::Error::from),
202        }
203    }
204
205    /// Creates a streaming chat completion for vector voting with a specific upstream provider.
206    ///
207    /// The LLM sees responses labeled with prefix keys and responds with its choice.
208    fn create_streaming_for_vector(
209        &self,
210        upstream: super::Upstream,
211        id: String,
212        byok: Option<&str>,
213        cost_multiplier: rust_decimal::Decimal,
214        first_chunk_timeout: Duration,
215        other_chunk_timeout: Duration,
216        ensemble_llm: &objectiveai::ensemble_llm::EnsembleLlm,
217        request: &objectiveai::vector::completions::request::VectorCompletionCreateParams,
218        vector_pfx_indices: &[(String, usize)],
219    ) -> impl Stream<
220        Item = Result<
221            objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
222            super::Error,
223        >,
224    > + Send
225    + 'static{
226        match upstream {
227            super::Upstream::OpenRouter => self
228                .openrouter_client
229                .create_streaming_for_vector(
230                    id,
231                    byok,
232                    cost_multiplier,
233                    first_chunk_timeout,
234                    other_chunk_timeout,
235                    ensemble_llm,
236                    request,
237                    vector_pfx_indices,
238                )
239                .map_err(super::Error::from),
240        }
241    }
242}