use reqwest::Client;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use crate::llm_router::error::LlmRouterError;
#[derive(Clone, Debug)]
pub struct ProviderKey
{
pub provider: String,
pub api_key: String,
pub base_url: Option<String>,
}
impl ProviderKey
{
pub fn detect_provider_from_key(api_key: &str) -> &'static str
{
if api_key.starts_with("sk-ant-")
{
"anthropic"
}
else
{
"openai" }
}
}
struct CachedKey
{
key: ProviderKey,
fetched_at: Instant,
}
pub struct KeyFetcher
{
server_url: String,
ic_token: String,
client: Client,
cache: Arc<RwLock<Option<CachedKey>>>,
cache_ttl: Duration,
static_key: Option<ProviderKey>,
}
impl KeyFetcher
{
pub fn new(server_url: String, ic_token: String, cache_ttl_seconds: u64) -> Self
{
let client = Client::builder()
.timeout(Duration::from_secs(30))
.build()
.expect("LOUD FAILURE: Failed to create HTTP client");
Self {
server_url,
ic_token,
client,
cache: Arc::new(RwLock::new(None)),
cache_ttl: Duration::from_secs(cache_ttl_seconds),
static_key: None,
}
}
pub fn new_static(api_key: String, base_url: Option<String>) -> Self
{
let provider = ProviderKey::detect_provider_from_key(&api_key).to_string();
let static_key = ProviderKey {
provider,
api_key,
base_url,
};
Self {
server_url: String::new(),
ic_token: String::new(),
client: Client::new(),
cache: Arc::new(RwLock::new(None)),
cache_ttl: Duration::from_secs(0),
static_key: Some(static_key),
}
}
pub async fn get_key(&self) -> Result<ProviderKey, LlmRouterError>
{
if let Some(ref key) = self.static_key {
return Ok(key.clone());
}
{
let cache = self.cache.read().await;
if let Some(cached) = cache.as_ref()
{
if cached.fetched_at.elapsed() < self.cache_ttl
{
return Ok(cached.key.clone());
}
}
}
let key = self.fetch_from_server().await?;
{
let mut cache = self.cache.write().await;
*cache = Some(CachedKey {
key: key.clone(),
fetched_at: Instant::now(),
});
}
Ok(key)
}
async fn fetch_from_server(&self) -> Result<ProviderKey, LlmRouterError>
{
let url = format!("{}/api/v1/agents/provider-key", self.server_url);
#[derive(serde::Serialize)]
struct ProviderKeyRequest<'a>
{
ic_token: &'a str,
}
let response = self
.client
.post(&url)
.header("Content-Type", "application/json")
.json(&ProviderKeyRequest { ic_token: &self.ic_token })
.send()
.await
.map_err(|e| LlmRouterError::KeyFetch(e.to_string()))?;
if !response.status().is_success()
{
let status = response.status();
let error_msg = match response.json::<serde_json::Value>().await {
Ok(json) => {
let error = json.get("error").and_then(|v| v.as_str()).unwrap_or("Unknown error");
let code = json.get("code").and_then(|v| v.as_str()).unwrap_or("UNKNOWN");
format!("{}: {} ({})", status, error, code)
}
Err(_) => format!("Server returned status {}", status),
};
return Err(LlmRouterError::KeyFetch(error_msg));
}
#[derive(serde::Deserialize)]
struct KeyResponse
{
provider_key: String,
provider: String,
#[serde(default)]
base_url: Option<String>,
}
let data: KeyResponse = response
.json()
.await
.map_err(|e| LlmRouterError::KeyFetch(e.to_string()))?;
Ok(ProviderKey {
provider: data.provider,
api_key: data.provider_key,
base_url: data.base_url,
})
}
}