use std::sync::OnceLock;
use koda_core::providers::openai_compat::OpenAiCompatProvider;
use koda_core::providers::{ChatMessage, LlmProvider};
use koda_core::{config::ModelSettings, config::ProviderType, runtime_env};
use koda_test_utils::ENV_MUTEX;
use koda_test_utils::network::FakeLlmServer;
use serde_json::{Value, json};
const PROXY_ENV_KEYS: &[&str] = &[
"HTTP_PROXY",
"HTTPS_PROXY",
"http_proxy",
"https_proxy",
"PROXY_USER",
"PROXY_PASS",
];
struct EnvGuard<'a> {
set_keys: Vec<&'a str>,
masked_keys: Vec<&'a str>,
}
impl<'a> EnvGuard<'a> {
fn new() -> Self {
Self {
set_keys: Vec::new(),
masked_keys: Vec::new(),
}
}
fn set(&mut self, key: &'a str, value: &str) {
runtime_env::set(key, value);
self.set_keys.push(key);
}
fn mask(&mut self, key: &'a str) {
runtime_env::mask(key);
self.masked_keys.push(key);
}
fn mask_all_proxy_env(&mut self) {
for k in PROXY_ENV_KEYS {
self.mask(k);
}
}
}
impl Drop for EnvGuard<'_> {
fn drop(&mut self) {
for k in &self.set_keys {
runtime_env::remove(k);
}
for k in &self.masked_keys {
runtime_env::unmask(k);
}
}
}
fn ok_chat_body() -> Value {
json!({
"id": "chatcmpl-test",
"object": "chat.completion",
"created": 1_700_000_000,
"model": "gpt-4o",
"choices": [{
"index": 0,
"message": { "role": "assistant", "content": "ok" },
"finish_reason": "stop"
}],
"usage": { "prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2 }
})
}
fn settings() -> ModelSettings {
static S: OnceLock<ModelSettings> = OnceLock::new();
S.get_or_init(|| ModelSettings::defaults_for("gpt-4o", &ProviderType::OpenAI))
.clone()
}
fn user_msg() -> ChatMessage {
ChatMessage::text("user", "hi")
}
#[tokio::test]
async fn http_proxy_routes_remote_requests_through_proxy() {
let _g = ENV_MUTEX.lock().await;
let proxy = FakeLlmServer::spawn().await;
proxy.mount_chat_ok(ok_chat_body()).await;
let mut env = EnvGuard::new();
env.mask_all_proxy_env();
env.set("HTTP_PROXY", &proxy.url());
let provider =
OpenAiCompatProvider::new("http://upstream.test.invalid", Some("sk-test".into()));
provider
.chat(&[user_msg()], &[], &settings())
.await
.expect("chat must succeed via proxy");
let proxied = proxy.received_requests().await;
assert_eq!(proxied.len(), 1, "request must be routed through the proxy");
}
#[tokio::test]
async fn localhost_traffic_bypasses_proxy_even_when_set() {
let _g = ENV_MUTEX.lock().await;
let proxy = FakeLlmServer::spawn().await;
proxy.mount_chat_ok(ok_chat_body()).await;
let upstream = FakeLlmServer::spawn().await;
upstream.mount_chat_ok(ok_chat_body()).await;
let mut env = EnvGuard::new();
env.mask_all_proxy_env();
env.set("HTTP_PROXY", &proxy.url());
let provider = OpenAiCompatProvider::new(&upstream.url(), Some("sk-test".into()));
provider
.chat(&[user_msg()], &[], &settings())
.await
.expect("chat must succeed directly to localhost upstream");
assert_eq!(
proxy.received_requests().await.len(),
0,
"proxy must be bypassed for localhost targets"
);
assert_eq!(
upstream.received_requests().await.len(),
1,
"upstream must receive the request directly"
);
}
#[tokio::test]
async fn proxy_basic_auth_from_env_vars_attaches_proxy_authorization_header() {
let _g = ENV_MUTEX.lock().await;
let proxy = FakeLlmServer::spawn().await;
proxy.mount_chat_ok(ok_chat_body()).await;
let mut env = EnvGuard::new();
env.mask_all_proxy_env();
env.set("HTTP_PROXY", &proxy.url());
env.set("PROXY_USER", "alice");
env.set("PROXY_PASS", "s3cret");
let provider =
OpenAiCompatProvider::new("http://upstream.test.invalid", Some("sk-test".into()));
provider
.chat(&[user_msg()], &[], &settings())
.await
.expect("chat must succeed via authenticated proxy");
let proxied = proxy.received_requests().await;
assert_eq!(proxied.len(), 1);
let auth = proxied[0]
.headers
.get("proxy-authorization")
.expect("Proxy-Authorization header must be set when PROXY_USER/PROXY_PASS are present");
assert_eq!(auth, "Basic YWxpY2U6czNjcmV0");
}
#[tokio::test]
async fn invalid_proxy_url_degrades_gracefully_to_no_proxy() {
let _g = ENV_MUTEX.lock().await;
let upstream = FakeLlmServer::spawn().await;
upstream.mount_chat_ok(ok_chat_body()).await;
let mut env = EnvGuard::new();
env.mask_all_proxy_env();
env.set("HTTP_PROXY", "://this is not a url@@@");
let provider = OpenAiCompatProvider::new(&upstream.url(), Some("sk-test".into()));
provider
.chat(&[user_msg()], &[], &settings())
.await
.expect("chat must still succeed when HTTP_PROXY is malformed");
assert_eq!(
upstream.received_requests().await.len(),
1,
"malformed proxy URL must not block legitimate localhost traffic"
);
}
#[tokio::test]
async fn read_timeout_aborts_silent_servers_quickly() {
let _g = ENV_MUTEX.lock().await;
let upstream = FakeLlmServer::spawn().await;
upstream
.mount_chat_delayed(std::time::Duration::from_secs(3), ok_chat_body())
.await;
let mut env = EnvGuard::new();
env.mask_all_proxy_env();
env.set("KODA_READ_TIMEOUT_SECS", "1");
let provider = OpenAiCompatProvider::new(&upstream.url(), Some("sk-test".into()));
let started = std::time::Instant::now();
let result = provider.chat(&[user_msg()], &[], &settings()).await;
let elapsed = started.elapsed();
assert!(result.is_err(), "stalled server must produce an error");
assert!(
elapsed < std::time::Duration::from_millis(2500),
"should error in ~1s, took {elapsed:?} — read_timeout likely not applied"
);
let msg = format!("{:?}", result.unwrap_err()).to_lowercase();
assert!(
msg.contains("timeout") || msg.contains("timed out") || msg.contains("operation"),
"error should mention timeout, got: {msg}"
);
}