use std::net::SocketAddr;
use std::time::Duration;
use anyhow::Context;
use clap::Parser;
use dotenvy::dotenv;
use sqlx::postgres::PgPoolOptions;
use tokio::net::TcpListener;
use tokio::signal;
use tokio_util::sync::CancellationToken;
use tracing::{Level, info};
use tracing_subscriber::FmtSubscriber;
use ceres_client::EmbeddingProviderEnum;
use ceres_core::config::EmbeddingProviderType;
use ceres_core::traits::EmbeddingProvider;
use ceres_core::{
TracingReporter, TracingWorkerReporter, WorkerConfig, WorkerService, load_portals_config,
};
use ceres_db::DatasetRepository;
use ceres_server::{AppState, ServerConfig, create_router};
#[tokio::main]
async fn main() -> anyhow::Result<()> {
dotenv().ok();
let subscriber = FmtSubscriber::builder()
.with_max_level(Level::INFO)
.finish();
tracing::subscriber::set_global_default(subscriber)?;
let config = ServerConfig::parse();
info!("Connecting to database...");
let pool = PgPoolOptions::new()
.max_connections(10)
.connect(&config.database_url)
.await
.context("Failed to connect to database")?;
info!("Database connection established");
let embedding_client = create_embedding_provider(&config)?;
let repo = DatasetRepository::new(pool.clone());
repo.validate_embedding_dimension(embedding_client.dimension())
.await
.context("Embedding provider validation failed")?;
info!(
"Using {} embedding provider ({} dimensions)",
embedding_client.name(),
embedding_client.dimension()
);
let portals_config = if let Some(path) = &config.portals_config {
load_portals_config(Some(path.clone()))?
} else {
load_portals_config(None).unwrap_or(None)
};
if let Some(ref portals_cfg) = portals_config {
info!(
"Loaded {} portals from configuration",
portals_cfg.portals.len()
);
}
let shutdown_token = CancellationToken::new();
let app_state = AppState::new(
pool,
embedding_client,
portals_config,
shutdown_token.clone(),
);
let app = create_router(app_state.clone(), &config);
let worker_shutdown = shutdown_token.clone();
let worker_handle = {
let worker = WorkerService::new(
app_state.job_repo.clone(),
app_state.harvest_service.clone(),
WorkerConfig::default(),
);
tokio::spawn(async move {
info!("Starting background worker for harvest jobs");
if let Err(e) = worker
.run(worker_shutdown, &TracingWorkerReporter, &TracingReporter)
.await
{
tracing::error!("Worker error: {}", e);
}
info!("Background worker stopped");
})
};
let addr: SocketAddr = format!("{}:{}", config.host, config.port)
.parse()
.context("Invalid address")?;
let listener = TcpListener::bind(addr)
.await
.context("Failed to bind to address")?;
info!("Starting Ceres API server on http://{}", addr);
info!("Swagger UI available at http://{}/swagger-ui", addr);
info!(
"Rate limiting: {} req/s, burst size {}",
config.rate_limit_rps, config.rate_limit_burst
);
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(shutdown_signal(shutdown_token))
.await
.context("Server error")?;
let _ = worker_handle.await;
info!("Server shutdown complete");
Ok(())
}
async fn shutdown_signal(shutdown_token: CancellationToken) {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("Failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("Failed to install SIGTERM handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
info!("Shutdown signal received, starting graceful shutdown...");
shutdown_token.cancel();
tokio::time::sleep(Duration::from_secs(2)).await;
}
fn create_embedding_provider(config: &ServerConfig) -> anyhow::Result<EmbeddingProviderEnum> {
let provider_type: EmbeddingProviderType = config
.embedding_provider
.parse()
.context("Invalid embedding provider")?;
match provider_type {
EmbeddingProviderType::Gemini => {
let api_key = config.gemini_api_key.as_ref().ok_or_else(|| {
anyhow::anyhow!("GEMINI_API_KEY required when using gemini provider")
})?;
EmbeddingProviderEnum::gemini(api_key).context("Failed to initialize Gemini client")
}
EmbeddingProviderType::OpenAI => {
let api_key = config.openai_api_key.as_ref().ok_or_else(|| {
anyhow::anyhow!("OPENAI_API_KEY required when using openai provider")
})?;
if let Some(model) = &config.embedding_model {
EmbeddingProviderEnum::openai_with_model(api_key, model)
.context("Failed to initialize OpenAI client")
} else {
EmbeddingProviderEnum::openai(api_key).context("Failed to initialize OpenAI client")
}
}
}
}