use anyhow::Result;
use arc_swap::ArcSwap;
use byokey_auth::AuthManager;
use byokey_config::{Config, ConfigWatcher, LogConfig, LogFormat};
use byokey_proxy::AppState;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::Notify;
use tracing_appender::non_blocking::WorkerGuard;
use tracing_appender::rolling;
use tracing_subscriber::EnvFilter;
use tracing_subscriber::fmt::writer::BoxMakeWriter;
use crate::ServerArgs;
use crate::control_server::{self, ControlState};
fn init_logging(cfg: &LogConfig, log_file: Option<PathBuf>) -> Option<WorkerGuard> {
let env_filter =
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&cfg.level));
let path = log_file
.map(|p| p.to_string_lossy().into_owned())
.or_else(|| cfg.file.clone());
let (writer, guard): (BoxMakeWriter, Option<WorkerGuard>) = if let Some(p) = &path {
let dir = Path::new(p).parent().unwrap_or_else(|| Path::new("."));
let name = Path::new(p)
.file_name()
.unwrap_or_else(|| std::ffi::OsStr::new("byokey.log"));
let (nb, g) = tracing_appender::non_blocking(rolling::daily(dir, name));
(BoxMakeWriter::new(nb), Some(g))
} else {
(BoxMakeWriter::new(std::io::stdout), None)
};
let builder = tracing_subscriber::fmt()
.with_env_filter(env_filter)
.with_target(true)
.with_ansi(path.is_none())
.with_writer(writer);
match cfg.format {
LogFormat::Json => builder.json().init(),
LogFormat::Text => builder.init(),
}
guard
}
pub async fn cmd_serve(args: ServerArgs) -> Result<()> {
let ServerArgs {
config: config_path,
port,
host,
db,
log_file,
} = args;
let effective_path = config_path.or_else(|| {
let default = byokey_daemon::paths::config_path().ok()?;
if default.exists() {
Some(default)
} else {
None
}
});
let (config_arc, config_watcher): (Arc<ArcSwap<Config>>, Option<Arc<ConfigWatcher>>) =
if let Some(ref path) = effective_path {
let watcher = Arc::new(
ConfigWatcher::new(path.clone())
.map_err(|e| anyhow::anyhow!("config error: {e}"))?,
);
let arc = watcher.arc();
Arc::clone(&watcher).watch();
(arc, Some(watcher))
} else {
(Arc::new(ArcSwap::from_pointee(Config::default())), None)
};
let snapshot = config_arc.load();
let _log_guard = init_logging(&snapshot.log, log_file);
let effective_host = host.as_deref().unwrap_or(&snapshot.host).to_owned();
let effective_port = port.unwrap_or(snapshot.port);
let addr = format!("{effective_host}:{effective_port}");
let store = Arc::new(crate::open_store(db).await?);
let auth = Arc::new(AuthManager::new(store.clone(), rquest::Client::new()));
let _refresh_handle = auth.spawn_refresh_loop(
std::time::Duration::from_secs(60),
std::time::Duration::from_secs(300),
);
let versions = byokey_proxy::VersionStore::fetch(&rquest::Client::new()).await;
let usage_store: Arc<dyn byokey_types::UsageStore> = store;
let state = AppState::new(
Arc::clone(&config_arc),
auth,
Some(usage_store.clone()),
versions,
);
if let Ok(totals) = usage_store.totals(None, None).await {
for bucket in &totals {
state.usage.preload(
&bucket.model,
bucket.request_count,
bucket.input_tokens,
bucket.output_tokens,
);
}
}
let app = byokey_proxy::make_router(Arc::clone(&state));
let listener = match listenfd::ListenFd::from_env().take_tcp_listener(0) {
Ok(Some(l)) => {
tracing::info!("using inherited TCP listener from environment");
l.set_nonblocking(true)
.map_err(|e| anyhow::anyhow!("set_nonblocking: {e}"))?;
tokio::net::TcpListener::from_std(l).map_err(|e| anyhow::anyhow!("from_std: {e}"))?
}
_ => {
let parsed: std::net::SocketAddr = addr
.parse()
.map_err(|e| anyhow::anyhow!("invalid address {addr}: {e}"))?;
let std_listener = std::net::TcpListener::bind(parsed)
.map_err(|e| anyhow::anyhow!("bind {addr}: {e}"))?;
std_listener
.set_nonblocking(true)
.map_err(|e| anyhow::anyhow!("set_nonblocking: {e}"))?;
tokio::net::TcpListener::from_std(std_listener)
.map_err(|e| anyhow::anyhow!("from_std: {e}"))?
}
};
let shutdown = Arc::new(Notify::new());
let sock_path = byokey_daemon::paths::control_sock_path()
.map_err(|e| anyhow::anyhow!("control socket path: {e}"))?;
if byokey_daemon::control::is_alive() {
return Err(anyhow::anyhow!(
"another byokey serve is already running (control socket {} is live)",
sock_path.display()
));
}
let ctl_state = Arc::new(ControlState {
watcher: config_watcher,
shutdown: Arc::clone(&shutdown),
start: Instant::now(),
host: effective_host.clone(),
port: effective_port,
});
let ctl_handle = control_server::bind_and_serve(sock_path.clone(), ctl_state)
.map_err(|e| anyhow::anyhow!("bind control socket {}: {e}", sock_path.display()))?;
tracing::info!(socket = %sock_path.display(), "control socket ready");
spawn_signal_handler(Arc::clone(&shutdown));
drop(snapshot);
tracing::info!(addr = %addr, "byokey listening");
let shutdown_for_serve = Arc::clone(&shutdown);
let serve_result = axum::serve(listener, app)
.with_graceful_shutdown(async move {
shutdown_for_serve.notified().await;
})
.await
.map_err(anyhow::Error::from);
ctl_handle.cleanup();
if serve_result.is_ok() {
std::process::exit(0);
}
serve_result
}
fn spawn_signal_handler(shutdown: Arc<Notify>) {
tokio::spawn(async move {
#[cfg(unix)]
{
use tokio::signal::unix::{SignalKind, signal};
let mut sigterm = match signal(SignalKind::terminate()) {
Ok(s) => s,
Err(e) => {
tracing::warn!(error = %e, "install SIGTERM handler failed");
return;
}
};
let mut sigint = match signal(SignalKind::interrupt()) {
Ok(s) => s,
Err(e) => {
tracing::warn!(error = %e, "install SIGINT handler failed");
return;
}
};
tokio::select! {
_ = sigterm.recv() => tracing::info!("received SIGTERM"),
_ = sigint.recv() => tracing::info!("received SIGINT"),
}
}
#[cfg(not(unix))]
{
let _ = tokio::signal::ctrl_c().await;
tracing::info!("received Ctrl-C");
}
shutdown.notify_waiters();
});
}