use std::time::Duration;
use futures::StreamExt;
use futures::stream::BoxStream;
use tracing::warn;
use crate::error::ApiError;
use crate::msg::LlmEvent;
use crate::request::{Provider, Request};
use crate::types::CompleteResponse;
use super::translated::Translated;
#[derive(Debug, Clone)]
pub struct UpstreamSpec {
pub provider: Provider,
pub api_key: String,
pub base_url: Option<String>,
pub model: Option<String>,
pub pre_commit_timeout: Duration,
}
impl UpstreamSpec {
pub fn new(provider: Provider, api_key: impl Into<String>) -> Self {
Self {
provider,
api_key: api_key.into(),
base_url: None,
model: None,
pre_commit_timeout: Duration::from_secs(30),
}
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = Some(base_url.into());
self
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
}
fn build_request(spec: &UpstreamSpec, t: &Translated) -> Request {
let mut req = Request::new(spec.provider, spec.api_key.clone());
let model = spec.model.clone().unwrap_or_else(|| t.model_from_client.clone());
req = req.model(model).max_tokens(t.max_tokens);
if let Some(url) = &spec.base_url {
req = req.base_url(url.clone());
}
if let Some(s) = &t.system_prompt {
req = req.system_prompt(s.clone());
}
if let Some(temp) = t.temperature {
req = req.temperature(temp);
}
if let Some(eff) = t.reasoning_effort {
req = req.reasoning_effort(eff);
}
req = req.messages(t.messages.clone());
req = req.tools(t.tools.clone());
if let Some(tc) = &t.tool_choice {
req.tool_choice = Some(tc.clone());
}
if !t.extra_body.is_empty() {
req = req.extra_body(t.extra_body.clone());
}
req
}
pub async fn complete_with_fallback(
chain: &[UpstreamSpec],
translated: &Translated,
http: &reqwest::Client,
) -> Result<CompleteResponse, ApiError> {
let mut last_err: Option<ApiError> = None;
for (i, spec) in chain.iter().enumerate() {
let req = build_request(spec, translated);
match req.complete(http).await {
Ok(r) => return Ok(r),
Err(e) => {
warn!(
upstream_index = i,
provider = ?spec.provider,
error = %e,
"non-streaming upstream failed, trying next"
);
last_err = Some(e);
}
}
}
Err(last_err.unwrap_or_else(|| ApiError::Other("no upstreams configured".into())))
}
pub async fn stream_with_fallback(
chain: Vec<UpstreamSpec>,
translated: Translated,
http: reqwest::Client,
) -> Result<BoxStream<'static, LlmEvent>, ApiError> {
let mut last_err: Option<ApiError> = None;
for (i, spec) in chain.iter().enumerate() {
let req = build_request(spec, &translated);
let provider_label = format!("{:?}", spec.provider);
let timeout = spec.pre_commit_timeout;
let stream_res = tokio::time::timeout(timeout, req.stream(&http)).await;
let mut stream = match stream_res {
Ok(Ok(s)) => s,
Ok(Err(e)) => {
warn!(upstream_index = i, provider = %provider_label, error = %e, "stream open failed, trying next");
last_err = Some(e);
continue;
}
Err(_) => {
warn!(upstream_index = i, provider = %provider_label, "stream open timed out, trying next");
last_err = Some(ApiError::Other(format!(
"upstream {} timed out before first event",
provider_label
)));
continue;
}
};
let first = match tokio::time::timeout(timeout, stream.next()).await {
Ok(Some(LlmEvent::Error(e))) => {
warn!(upstream_index = i, provider = %provider_label, error = %e, "first event was Error, trying next");
last_err = Some(ApiError::Llm(e));
continue;
}
Ok(Some(ev)) => ev,
Ok(None) => {
warn!(upstream_index = i, provider = %provider_label, "stream ended before any event, trying next");
last_err = Some(ApiError::Other(format!(
"upstream {} produced no events",
provider_label
)));
continue;
}
Err(_) => {
warn!(upstream_index = i, provider = %provider_label, "first event timed out, trying next");
last_err = Some(ApiError::Other(format!(
"upstream {} timed out before first event",
provider_label
)));
continue;
}
};
let head = futures::stream::iter(std::iter::once(first));
let combined = head.chain(stream);
return Ok(combined.boxed());
}
Err(last_err.unwrap_or_else(|| ApiError::Other("no upstreams configured".into())))
}