use std::net::TcpListener;
use std::sync::{Arc, Once};
use tokio::sync::oneshot;
use iron_cost::budget::CostController;
use iron_telemetry::{init_logging, LogLevel};
static LOGGING_INIT: Once = Once::new();
fn ensure_logging() {
LOGGING_INIT.call_once(|| {
let _ = init_logging(LogLevel::Info);
});
}
use crate::llm_router::key_fetcher::KeyFetcher;
use crate::llm_router::proxy::{run_proxy, ProxyConfig};
#[cfg(feature = "analytics")]
use iron_runtime_analytics::EventStore;
#[cfg(feature = "analytics")]
use iron_runtime_analytics::{SyncClient, SyncConfig, SyncHandle};
pub struct LlmRouter {
port: u16,
api_key: String,
#[allow(dead_code)]
server_url: String,
provider: String,
#[allow(dead_code)]
runtime: tokio::runtime::Runtime,
shutdown_tx: Option<oneshot::Sender<()>>,
cost_controller: Option<Arc<CostController>>,
#[cfg(feature = "analytics")]
event_store: Arc<EventStore>,
#[cfg(feature = "analytics")]
#[allow(dead_code)]
agent_id: Option<Arc<str>>,
#[cfg(feature = "analytics")]
#[allow(dead_code)]
provider_id: Option<Arc<str>>,
#[cfg(feature = "analytics")]
#[allow(dead_code)] sync_handle: Option<SyncHandle>,
lease_id: Option<String>,
}
impl LlmRouter {
pub fn create(
api_key: String,
server_url: String,
cache_ttl_seconds: u64,
) -> Result<Self, String> {
Self::create_inner(api_key, server_url, cache_ttl_seconds, None, None)
}
pub fn create_with_budget(
api_key: String,
server_url: String,
cache_ttl_seconds: u64,
budget: f64,
) -> Result<Self, String> {
Self::create_inner(api_key, server_url, cache_ttl_seconds, Some(budget), None)
}
pub fn create_with_provider_key(
provider_key: String,
budget: Option<f64>,
) -> Result<Self, String> {
Self::create_inner(
"direct".to_string(),
String::new(),
0,
budget,
Some(provider_key),
)
}
pub fn create_full(
api_key: String,
server_url: String,
cache_ttl_seconds: u64,
budget: Option<f64>,
provider_key: Option<String>,
) -> Result<Self, String> {
Self::create_inner(api_key, server_url, cache_ttl_seconds, budget, provider_key)
}
pub fn get_base_url(&self) -> String {
format!("http://127.0.0.1:{}/v1", self.port)
}
pub fn get_api_key(&self) -> &str {
&self.api_key
}
pub fn get_port(&self) -> u16 {
self.port
}
pub fn get_provider(&self) -> &str {
&self.provider
}
pub fn is_running(&self) -> bool {
self.shutdown_tx.is_some()
}
pub fn total_spent(&self) -> f64 {
self.cost_controller
.as_ref()
.map(|c| c.total_spent() as f64 / 1_000_000.0)
.unwrap_or(0.0)
}
pub fn set_budget(&self, amount_usd: f64) {
if let Some(ref controller) = self.cost_controller {
let budget_micros = (amount_usd * 1_000_000.0) as i64;
controller.set_budget(budget_micros);
}
}
pub fn get_budget(&self) -> Option<f64> {
self.cost_controller
.as_ref()
.map(|c| c.budget_limit() as f64 / 1_000_000.0)
}
pub fn get_budget_status(&self) -> Option<(f64, f64)> {
self.cost_controller.as_ref().map(|c| {
let (spent_micros, limit_micros) = c.get_status();
(
spent_micros as f64 / 1_000_000.0,
limit_micros as f64 / 1_000_000.0,
)
})
}
pub fn shutdown(&mut self) {
self.stop_inner();
}
fn create_inner(
api_key: String,
server_url: String,
cache_ttl_seconds: u64,
budget: Option<f64>,
provider_key: Option<String>,
) -> Result<Self, String> {
ensure_logging();
let port = find_free_port().map_err(|e| format!("Failed to find free port: {}", e))?;
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(2)
.enable_all()
.build()
.map_err(|e| format!("Failed to create runtime: {}", e))?;
let key_fetcher = Arc::new(if let Some(ref pk) = provider_key {
KeyFetcher::new_static(pk.clone(), None)
} else {
KeyFetcher::new(server_url.clone(), api_key.clone(), cache_ttl_seconds)
});
let provider = runtime.block_on(async {
key_fetcher
.get_key()
.await
.map(|k| k.provider)
.unwrap_or_else(|_| "unknown".to_string())
});
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let (effective_budget_micros, lease_id): (i64, Option<String>) = if let Some(b) = budget {
((b * 1_000_000.0) as i64, None)
} else if !server_url.is_empty() {
match runtime.block_on(async { fetch_budget_from_handshake(&server_url, &api_key).await }) {
Some(result) => (result.budget, Some(result.lease_id)),
None => (0, None),
}
} else {
(0, None)
};
let cost_controller = Some(Arc::new(CostController::new(effective_budget_micros)));
#[cfg(feature = "analytics")]
let event_store = Arc::new(EventStore::new());
#[cfg(feature = "analytics")]
let agent_id_arc: Option<Arc<str>> = None;
#[cfg(feature = "analytics")]
let provider_id_arc: Option<Arc<str>> = None;
#[cfg(feature = "analytics")]
event_store.record_router_started(port);
let config = ProxyConfig {
port,
ic_token: api_key.clone(),
server_url: server_url.clone(),
cache_ttl_seconds,
cost_controller: cost_controller.clone(),
provider_key: provider_key.clone(),
#[cfg(feature = "analytics")]
event_store: event_store.clone(),
#[cfg(feature = "analytics")]
agent_id: agent_id_arc.clone(),
#[cfg(feature = "analytics")]
provider_id: provider_id_arc.clone(),
};
runtime.spawn(async move {
if let Err(e) = run_proxy(config, shutdown_rx).await {
tracing::error!("Proxy server error: {}", e);
}
});
#[cfg(feature = "analytics")]
let sync_handle = if !server_url.is_empty() {
let sync_config = SyncConfig::new(&server_url, &api_key);
let sync_client = SyncClient::new(event_store.clone(), sync_config);
Some(sync_client.start(runtime.handle()))
} else {
None
};
std::thread::sleep(std::time::Duration::from_millis(50));
Ok(Self {
port,
api_key,
server_url,
provider,
runtime,
shutdown_tx: Some(shutdown_tx),
cost_controller,
#[cfg(feature = "analytics")]
event_store,
#[cfg(feature = "analytics")]
agent_id: agent_id_arc,
#[cfg(feature = "analytics")]
provider_id: provider_id_arc,
#[cfg(feature = "analytics")]
sync_handle,
lease_id,
})
}
fn stop_inner(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
if let Some(lease_id) = self.lease_id.take() {
if !self.server_url.is_empty() {
let spent_microdollars = self
.cost_controller
.as_ref()
.map(|cc| cc.total_spent())
.unwrap_or(0);
let url = format!("{}/api/v1/budget/return", self.server_url);
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(5))
.build();
match client {
Ok(client) => {
match client
.post(&url)
.header("Content-Type", "application/json")
.json(&serde_json::json!({"lease_id": lease_id, "spent_microdollars": spent_microdollars}))
.send()
{
Ok(resp) if resp.status().is_success() => {
if let Ok(body) = resp.json::<serde_json::Value>() {
if let Some(returned) = body.get("returned").and_then(|v| v.as_i64()) {
tracing::info!(
"Budget returned to server: ${:.6} (spent: ${:.6})",
returned as f64 / 1_000_000.0,
spent_microdollars as f64 / 1_000_000.0
);
}
}
}
Ok(resp) => {
tracing::warn!("Budget return failed with status: {}", resp.status());
}
Err(e) => {
tracing::warn!("Budget return request failed: {}", e);
}
}
}
Err(e) => {
tracing::warn!("Failed to create HTTP client for budget return: {}", e);
}
}
}
}
#[cfg(feature = "analytics")]
self.event_store.record_router_stopped();
#[cfg(feature = "analytics")]
if let Some(handle) = self.sync_handle.take() {
handle.stop(); std::thread::sleep(std::time::Duration::from_millis(500));
}
let _ = tx.send(());
}
}
}
impl Drop for LlmRouter {
fn drop(&mut self) {
self.stop_inner();
}
}
fn find_free_port() -> std::io::Result<u16> {
let listener = TcpListener::bind("127.0.0.1:0")?;
Ok(listener.local_addr()?.port())
}
struct HandshakeResult {
budget: i64, lease_id: String,
}
async fn fetch_budget_from_handshake(server_url: &str, ic_token: &str) -> Option<HandshakeResult> {
let client = reqwest::Client::new();
let url = format!("{}/api/v1/budget/handshake", server_url);
let response = match client
.post(&url)
.header("Authorization", format!("Bearer {}", ic_token))
.json(&serde_json::json!({
"ic_token": ic_token,
"provider": "openai"
}))
.send()
.await
{
Ok(r) => r,
Err(e) => {
tracing::warn!("Failed to connect to server for handshake: {}", e);
return None;
}
};
if !response.status().is_success() {
tracing::warn!("Handshake failed with status: {}", response.status());
return None;
}
#[derive(serde::Deserialize)]
struct HandshakeResponse {
budget_granted: i64, lease_id: String,
}
match response.json::<HandshakeResponse>().await {
Ok(data) => {
tracing::info!(
"Budget from server handshake: ${:.6} ({}μ$), lease_id: {}",
data.budget_granted as f64 / 1_000_000.0,
data.budget_granted,
data.lease_id
);
Some(HandshakeResult {
budget: data.budget_granted,
lease_id: data.lease_id,
})
}
Err(e) => {
tracing::warn!("Failed to parse handshake response: {}", e);
None
}
}
}