use std::time::Duration;
use futures::StreamExt;
use futures::stream::BoxStream;
use tracing::warn;
use crate::error::ApiError;
use crate::msg::LlmEvent;
use crate::request::{Provider, Request};
use crate::types::CompleteResponse;
use super::translated::Translated;
#[derive(Debug, Clone)]
pub struct UpstreamSpec {
pub provider: Provider,
pub api_key: String,
pub base_url: Option<String>,
pub model: Option<String>,
pub match_patterns: Vec<String>,
pub pre_commit_timeout: Duration,
}
impl UpstreamSpec {
pub fn new(provider: Provider, api_key: impl Into<String>) -> Self {
Self {
provider,
api_key: api_key.into(),
base_url: None,
model: None,
match_patterns: Vec::new(),
pre_commit_timeout: Duration::from_secs(30),
}
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = Some(base_url.into());
self
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn with_match(mut self, pattern: impl Into<String>) -> Self {
self.match_patterns.push(pattern.into());
self
}
pub fn accepts_model(&self, model: &str) -> bool {
if self.match_patterns.is_empty() {
return true;
}
self.match_patterns.iter().any(|p| glob_match(p, model))
}
}
pub trait ChainResolver: Send + Sync + 'static {
fn resolve(&self, model: &str) -> Vec<UpstreamSpec>;
fn list_all(&self) -> Vec<UpstreamSpec> {
self.resolve("*")
}
}
impl ChainResolver for Vec<UpstreamSpec> {
fn resolve(&self, model: &str) -> Vec<UpstreamSpec> {
self.iter()
.filter(|s| s.accepts_model(model))
.cloned()
.collect()
}
fn list_all(&self) -> Vec<UpstreamSpec> {
self.clone()
}
}
pub fn glob_match(pattern: &str, input: &str) -> bool {
if pattern == "*" || pattern.is_empty() {
return true;
}
let starts_wild = pattern.starts_with('*');
let ends_wild = pattern.ends_with('*');
let fragments: Vec<&str> = pattern.split('*').filter(|f| !f.is_empty()).collect();
if fragments.is_empty() {
return true;
}
let mut cursor = 0;
for (i, frag) in fragments.iter().enumerate() {
if i == 0 && !starts_wild {
if !input[cursor..].starts_with(frag) {
return false;
}
cursor += frag.len();
continue;
}
match input[cursor..].find(frag) {
Some(pos) => cursor += pos + frag.len(),
None => return false,
}
}
if !ends_wild && cursor != input.len() {
return false;
}
true
}
fn build_request(spec: &UpstreamSpec, t: &Translated) -> Request {
let mut req = Request::new(spec.provider, spec.api_key.clone());
let model = spec
.model
.clone()
.unwrap_or_else(|| t.model_from_client.clone());
req = req.model(model).max_tokens(t.max_tokens);
if let Some(url) = &spec.base_url {
req = req.base_url(url.clone());
}
if let Some(s) = &t.system_prompt {
req = req.system_prompt(s.clone());
}
if let Some(temp) = t.temperature {
req = req.temperature(temp);
}
if let Some(eff) = t.reasoning_effort {
req = req.reasoning_effort(eff);
}
req = req.messages(t.messages.clone());
req = req.tools(t.tools.clone());
if let Some(tc) = &t.tool_choice {
req.tool_choice = Some(tc.clone());
}
if !t.extra_body.is_empty() {
req = req.extra_body(t.extra_body.clone());
}
req
}
#[derive(Debug, Clone)]
pub struct CommittedUpstream {
pub index: usize,
pub provider: Provider,
pub model: String,
}
pub async fn complete_with_fallback(
chain: &[UpstreamSpec],
translated: &Translated,
http: &reqwest::Client,
) -> Result<(CompleteResponse, CommittedUpstream), ApiError> {
let model = translated.model_from_client.as_str();
if chain.is_empty() {
return Err(ApiError::Other(format!(
"no upstream available for model {model}"
)));
}
let mut last_err: Option<ApiError> = None;
for (i, spec) in chain.iter().enumerate() {
let req = build_request(spec, translated);
let model = req.model.clone();
match req.complete(http).await {
Ok(r) => {
return Ok((
r,
CommittedUpstream {
index: i,
provider: spec.provider,
model,
},
));
}
Err(e) => {
warn!(
upstream_index = i,
provider = ?spec.provider,
error = %e,
"non-streaming upstream failed, trying next"
);
last_err = Some(e);
}
}
}
Err(last_err.unwrap_or_else(|| ApiError::Other("no upstreams configured".into())))
}
pub async fn stream_with_fallback(
chain: Vec<UpstreamSpec>,
translated: Translated,
http: reqwest::Client,
) -> Result<(BoxStream<'static, LlmEvent>, CommittedUpstream), ApiError> {
let model = translated.model_from_client.as_str();
if chain.is_empty() {
return Err(ApiError::Other(format!(
"no upstream available for model {model}"
)));
}
let candidates: Vec<(usize, UpstreamSpec)> = chain.into_iter().enumerate().collect();
let mut last_err: Option<ApiError> = None;
for (i, spec) in &candidates {
let i = *i;
let req = build_request(spec, &translated);
let upstream_model = req.model.clone();
let provider_label = format!("{:?}", spec.provider);
let timeout = spec.pre_commit_timeout;
let stream_res = tokio::time::timeout(timeout, req.stream(&http)).await;
let mut stream = match stream_res {
Ok(Ok(s)) => s,
Ok(Err(e)) => {
warn!(upstream_index = i, provider = %provider_label, error = %e, "stream open failed, trying next");
last_err = Some(e);
continue;
}
Err(_) => {
warn!(upstream_index = i, provider = %provider_label, "stream open timed out, trying next");
last_err = Some(ApiError::Other(format!(
"upstream {} timed out before first event",
provider_label
)));
continue;
}
};
let first = match tokio::time::timeout(timeout, stream.next()).await {
Ok(Some(LlmEvent::Error(e))) => {
warn!(upstream_index = i, provider = %provider_label, error = %e, "first event was Error, trying next");
last_err = Some(ApiError::Llm(e));
continue;
}
Ok(Some(ev)) => ev,
Ok(None) => {
warn!(upstream_index = i, provider = %provider_label, "stream ended before any event, trying next");
last_err = Some(ApiError::Other(format!(
"upstream {} produced no events",
provider_label
)));
continue;
}
Err(_) => {
warn!(upstream_index = i, provider = %provider_label, "first event timed out, trying next");
last_err = Some(ApiError::Other(format!(
"upstream {} timed out before first event",
provider_label
)));
continue;
}
};
let head = futures::stream::iter(std::iter::once(first));
let combined = head.chain(stream);
return Ok((
combined.boxed(),
CommittedUpstream {
index: i,
provider: spec.provider,
model: upstream_model,
},
));
}
Err(last_err.unwrap_or_else(|| ApiError::Other("no upstreams configured".into())))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn glob_star_all() {
assert!(glob_match("*", "anything"));
assert!(glob_match("", "anything"));
}
#[test]
fn glob_prefix() {
assert!(glob_match("claude-*", "claude-sonnet-4-5"));
assert!(glob_match("claude-*", "claude-"));
assert!(!glob_match("claude-*", "gpt-4o"));
}
#[test]
fn glob_suffix() {
assert!(glob_match("*-turbo", "gpt-3.5-turbo"));
assert!(!glob_match("*-turbo", "gpt-3.5-turbo-preview"));
}
#[test]
fn glob_substring() {
assert!(glob_match("*sonnet*", "claude-sonnet-4-5"));
assert!(!glob_match("*sonnet*", "claude-haiku-4-5"));
}
#[test]
fn glob_exact() {
assert!(glob_match("gpt-4o", "gpt-4o"));
assert!(!glob_match("gpt-4o", "gpt-4o-mini"));
}
#[test]
fn glob_multi_wildcard() {
assert!(glob_match("claude-*-sonnet-*", "claude-3-5-sonnet-20240620"));
assert!(!glob_match("claude-*-opus-*", "claude-3-5-sonnet-20240620"));
}
#[test]
fn upstream_accepts_catchall() {
let s = UpstreamSpec::new(Provider::DeepSeek, "");
assert!(s.accepts_model("anything"));
}
#[test]
fn upstream_accepts_specific() {
let s = UpstreamSpec::new(Provider::Anthropic, "")
.with_match("claude-*");
assert!(s.accepts_model("claude-sonnet-4-5"));
assert!(!s.accepts_model("gpt-4o"));
}
#[test]
fn upstream_multiple_patterns_or() {
let s = UpstreamSpec::new(Provider::OpenRouter, "")
.with_match("claude-*")
.with_match("gpt-*");
assert!(s.accepts_model("claude-sonnet-4-5"));
assert!(s.accepts_model("gpt-4o"));
assert!(!s.accepts_model("deepseek-chat"));
}
}