mod error;
mod handlers;
mod models;
mod net;
mod redis_repo;
mod state;
use axum::{
routing::{get, post},
Router,
};
use clap::Parser;
use std::net::SocketAddr;
use tower::ServiceBuilder;
use tower_http::trace::TraceLayer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use crate::redis_repo::{LuaScripts, RedisRepo};
#[derive(Parser, Debug)]
#[command(author, version, about)]
struct Args {
#[arg(long, default_value = "0.0.0.0")]
host: String,
#[arg(long, default_value_t = 3000)]
port: u16,
#[arg(long, default_value = "redis://127.0.0.1:6379")]
redis_url: String,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::new(
std::env::var("NANOCTRL_RUST_LOG").unwrap_or_else(|_| "info".into()),
))
.with(tracing_subscriber::fmt::layer())
.init();
let args = Args::parse();
let redis_url = std::env::var("NANOCTRL_REDIS_URL").unwrap_or(args.redis_url);
tracing::info!("Using Redis URL: {}", redis_url);
let scripts = LuaScripts::load()?;
tracing::info!("Loaded Lua scripts from lua/ directory");
let repo = RedisRepo::new(&redis_url, scripts)?;
tracing::info!("Redis connection pool initialized");
{
tracing::info!("Warming up Redis connection...");
let mut conn = repo.conn().await.map_err(|e| {
anyhow::anyhow!(
"Failed to connect to Redis at {redis_url}. Please start Redis or set \
NANOCTRL_REDIS_URL / --redis-url to a reachable Redis instance. Details: {e}"
)
})?;
let _: String = redis::cmd("PING").query_async(&mut *conn).await.map_err(|e| {
anyhow::anyhow!(
"Failed to ping Redis at {redis_url}. Please verify the Redis service is healthy. \
Details: {e}"
)
})?;
tracing::info!("Redis connection established successfully");
}
let app = Router::new()
.route("/", get(handlers::util::root))
.route("/heartbeat", post(handlers::util::heartbeat))
.route("/heartbeat_engine", post(handlers::util::heartbeat))
.route("/heartbeat_agent", post(handlers::util::heartbeat))
.route("/start_peer_agent", post(handlers::peer::start_peer_agent))
.route("/query", post(handlers::peer::query))
.route("/cleanup", post(handlers::peer::cleanup))
.route(
"/v1/desired_topology/:agent_id",
post(handlers::rdma::set_desired_topology),
)
.route("/register_mr", post(handlers::rdma::register_mr))
.route("/get_mr_info", post(handlers::rdma::get_mr_info))
.route("/register_engine", post(handlers::engine::register_engine))
.route(
"/unregister_engine",
post(handlers::engine::unregister_engine),
)
.route("/get_engine_info", post(handlers::engine::get_engine_info))
.route("/list_engines", post(handlers::engine::list_engines))
.route(
"/get_redis_address",
post(handlers::util::get_redis_address),
)
.layer(
ServiceBuilder::new().layer(
TraceLayer::new_for_http()
.make_span_with(|request: &axum::http::Request<_>| {
tracing::info_span!(
"http_request",
method = %request.method(),
uri = %request.uri(),
)
})
.on_request(|request: &axum::http::Request<_>, _span: &tracing::Span| {
tracing::debug!("Incoming request: {} {}", request.method(), request.uri());
})
.on_response(
|response: &axum::http::Response<_>,
latency: std::time::Duration,
_span: &tracing::Span| {
tracing::info!(
status = %response.status(),
latency_us = latency.as_micros(),
"api done"
);
},
)
.on_failure(
|error: tower_http::classify::ServerErrorsFailureClass,
latency: std::time::Duration,
_span: &tracing::Span| {
tracing::error!("Request failed: {:?}, latency={:?}", error, latency);
},
),
),
)
.with_state(repo);
let addr: SocketAddr = format!("{}:{}", args.host, args.port)
.parse()
.map_err(|e| {
anyhow::anyhow!("Invalid server address {}:{}: {}", args.host, args.port, e)
})?;
tracing::info!("listening on {}", addr);
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app).await?;
Ok(())
}