use std::net::SocketAddr;
use std::time::Duration;
use anyhow::Context;
use clap::Parser;
use dotenvy::dotenv;
use sqlx::postgres::PgPoolOptions;
use tokio::net::TcpListener;
use tokio::signal;
use tokio_util::sync::CancellationToken;
use tracing::{Level, info};
use tracing_subscriber::FmtSubscriber;
use ceres_client::{EmbeddingConfig, EmbeddingProviderEnum};
use ceres_core::traits::EmbeddingProvider;
use ceres_core::{
TracingReporter, TracingWorkerReporter, WorkerConfig, WorkerService, load_portals_config,
};
use ceres_db::DatasetRepository;
use ceres_server::{AppState, ServerConfig, create_router};
#[tokio::main]
async fn main() -> anyhow::Result<()> {
dotenv().ok();
let subscriber = FmtSubscriber::builder()
.with_max_level(Level::INFO)
.finish();
tracing::subscriber::set_global_default(subscriber)?;
let config = ServerConfig::parse();
info!("Connecting to database...");
let pool = PgPoolOptions::new()
.max_connections(config.max_connections)
.connect(&config.database_url)
.await
.context("Failed to connect to database")?;
info!("Database connection established");
let embedding_client = EmbeddingProviderEnum::from_config(&EmbeddingConfig {
provider: config.embedding_provider.clone(),
gemini_api_key: config.gemini_api_key.clone(),
openai_api_key: config.openai_api_key.clone(),
embedding_model: config.embedding_model.clone(),
ollama_endpoint: config.ollama_endpoint.clone(),
})?;
let repo = DatasetRepository::new(pool.clone());
repo.validate_embedding_dimension(embedding_client.dimension())
.await
.context("Embedding provider validation failed")?;
info!(
"Using {} embedding provider ({} dimensions)",
embedding_client.name(),
embedding_client.dimension()
);
let portals_config = if let Some(path) = &config.portals_config {
load_portals_config(Some(path.clone()))?
} else {
load_portals_config(None).unwrap_or(None)
};
if let Some(ref portals_cfg) = portals_config {
info!(
"Loaded {} portals from configuration",
portals_cfg.portals.len()
);
}
let shutdown_token = CancellationToken::new();
if config.admin_token.is_some() {
info!("Admin authentication: enabled");
} else {
info!("Admin authentication: disabled (set CERES_ADMIN_TOKEN to enable)");
}
let app_state = AppState::new(
pool,
embedding_client,
portals_config,
shutdown_token.clone(),
config.admin_token.clone(),
);
let app = create_router(app_state.clone(), &config);
let worker_shutdown = shutdown_token.clone();
let worker_handle = {
let worker = WorkerService::new(
app_state.job_repo.clone(),
app_state.harvest_pipeline.clone(),
WorkerConfig::default(),
);
tokio::spawn(async move {
info!("Starting background worker for harvest jobs");
if let Err(e) = worker
.run(worker_shutdown, &TracingWorkerReporter, &TracingReporter)
.await
{
tracing::error!("Worker error: {}", e);
}
info!("Background worker stopped");
})
};
let addr: SocketAddr = format!("{}:{}", config.host, config.port)
.parse()
.context("Invalid address")?;
let listener = TcpListener::bind(addr)
.await
.context("Failed to bind to address")?;
info!("Starting Ceres API server on http://{}", addr);
info!("Swagger UI available at http://{}/swagger-ui", addr);
info!(
"Rate limiting: {} req/s, burst size {}",
config.rate_limit_rps, config.rate_limit_burst
);
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(shutdown_signal(shutdown_token))
.await
.context("Server error")?;
if let Err(e) = worker_handle.await {
tracing::error!("Worker task failed: {:?}", e);
}
info!("Server shutdown complete");
Ok(())
}
async fn shutdown_signal(shutdown_token: CancellationToken) {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("Failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("Failed to install SIGTERM handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
info!("Shutdown signal received, starting graceful shutdown...");
shutdown_token.cancel();
tokio::time::sleep(Duration::from_secs(2)).await;
}