use crate::error::Result;
use crate::rest;
use crate::state::AppState;
use axum::Router;
use std::net::SocketAddr;
use std::path::PathBuf;
use axum::extract::DefaultBodyLimit;
use tower_http::cors::CorsLayer;
use tower_http::trace::TraceLayer;
use tracing::info;
#[derive(Debug, Clone)]
pub struct CortexConfig {
pub host: String,
pub port: u16,
pub cors_allowed_origins: Vec<String>,
pub graphql_playground: bool,
pub tracing: bool,
pub rate_limit_enabled: bool,
pub rate_limit_rpm: u32,
pub audit_log_path: Option<PathBuf>,
pub max_body_size: usize,
pub flush_interval_secs: u64,
pub db_path: Option<String>,
}
impl Default for CortexConfig {
fn default() -> Self {
Self {
host: "127.0.0.1".to_string(),
port: 19090,
cors_allowed_origins: vec![], graphql_playground: false, tracing: true,
rate_limit_enabled: true,
rate_limit_rpm: 100,
audit_log_path: None,
max_body_size: 1024 * 1024, flush_interval_secs: 300,
db_path: None,
}
}
}
impl CortexConfig {
pub fn public() -> Self {
Self {
host: "0.0.0.0".to_string(),
..Default::default()
}
}
pub fn with_port(mut self, port: u16) -> Self {
self.port = port;
self
}
pub fn with_host(mut self, host: impl Into<String>) -> Self {
self.host = host.into();
self
}
}
pub struct CortexServer {
config: CortexConfig,
state: AppState,
}
impl CortexServer {
pub fn new(config: CortexConfig) -> Result<Self> {
let db_path = resolve_db_path(&config.db_path);
let state = AppState::with_db_path(&db_path, config.audit_log_path.clone())?;
info!("Graph database: {}", db_path);
Ok(Self { config, state })
}
pub fn with_state(config: CortexConfig, state: AppState) -> Self {
Self { config, state }
}
pub fn state(&self) -> &AppState {
&self.state
}
pub fn state_mut(&mut self) -> &mut AppState {
&mut self.state
}
pub fn build_router(&self) -> Router {
let mut app: Router<AppState> = Router::new();
app = app.merge(rest::router());
#[cfg(feature = "sparql")]
{
app = app.merge(crate::sparql::router());
}
#[cfg(feature = "auth")]
{
app = app.merge(crate::auth::router());
}
#[cfg(feature = "auth")]
let app = {
use crate::middleware::namespace_extractor;
app.layer(axum::middleware::from_fn(namespace_extractor))
};
let app = app.with_state(self.state.clone());
let app = if self.config.rate_limit_enabled {
use crate::middleware::RateLimiter;
let rate_limiter = RateLimiter::new(self.config.rate_limit_rpm)
.with_burst_capacity(self.config.rate_limit_rpm);
app.layer(rate_limiter.into_layer())
} else {
app
};
let app = app.layer(DefaultBodyLimit::max(self.config.max_body_size));
let app = if !self.config.cors_allowed_origins.is_empty() {
use tower_http::cors::{Any, AllowOrigin};
let cors = if self.config.cors_allowed_origins == ["*"] {
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any)
} else {
let origins: Vec<_> = self
.config
.cors_allowed_origins
.iter()
.filter_map(|o| o.parse().ok())
.collect();
CorsLayer::new()
.allow_origin(AllowOrigin::list(origins))
.allow_methods(Any)
.allow_headers(Any)
};
app.layer(cors)
} else {
app
};
if self.config.tracing {
app.layer(TraceLayer::new_for_http())
} else {
app
}
}
pub async fn run(self) -> Result<()> {
let addr: SocketAddr = format!("{}:{}", self.config.host, self.config.port)
.parse()
.map_err(|e| crate::error::Error::Internal(format!("Invalid address: {}", e)))?;
let router = self.build_router();
#[cfg(feature = "cluster")]
if let Some(ref tls_config) = self.state.tls_server_config {
info!("Starting Córtex API server on https://{}", addr);
let tls_acceptor =
tokio_rustls::TlsAcceptor::from(tls_config.clone());
let tcp_listener = tokio::net::TcpListener::bind(addr).await?;
let tls_listener = TlsListener {
inner: tcp_listener,
acceptor: tls_acceptor,
};
axum::serve(tls_listener, router.into_make_service()).await?;
return Ok(());
}
info!("Starting Córtex API server on http://{}", addr);
info!("REST API: http://{}/api/v1", addr);
#[cfg(feature = "graphql")]
info!("GraphQL: http://{}/graphql", addr);
#[cfg(feature = "sparql")]
info!("SPARQL: http://{}/sparql", addr);
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(
listener,
router.into_make_service_with_connect_info::<SocketAddr>(),
)
.await?;
Ok(())
}
pub async fn run_with_shutdown<F>(self, shutdown_signal: F) -> Result<()>
where
F: std::future::Future<Output = ()> + Send + 'static,
{
let addr: SocketAddr = format!("{}:{}", self.config.host, self.config.port)
.parse()
.map_err(|e| crate::error::Error::Internal(format!("Invalid address: {}", e)))?;
let router = self.build_router();
#[cfg(feature = "cluster")]
if let Some(ref tls_config) = self.state.tls_server_config {
info!("Starting Córtex API server on https://{}", addr);
let tls_acceptor =
tokio_rustls::TlsAcceptor::from(tls_config.clone());
let tcp_listener = tokio::net::TcpListener::bind(addr).await?;
let tls_listener = TlsListener {
inner: tcp_listener,
acceptor: tls_acceptor,
};
axum::serve(
tls_listener,
router.into_make_service(),
)
.with_graceful_shutdown(shutdown_signal)
.await?;
info!("Córtex API server stopped");
return Ok(());
}
info!("Starting Córtex API server on http://{}", addr);
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(
listener,
router.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(shutdown_signal)
.await?;
info!("Córtex API server stopped");
Ok(())
}
}
fn resolve_db_path(db_path: &Option<String>) -> String {
match db_path {
Some(p) if p == ":memory:" => ":memory:".to_string(),
Some(p) => p.clone(),
None => {
let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
let default_dir = home.join(".aingle").join("cortex");
std::fs::create_dir_all(&default_dir).ok();
default_dir.join("graph.sled").to_string_lossy().to_string()
}
}
}
#[cfg(feature = "cluster")]
struct TlsListener {
inner: tokio::net::TcpListener,
acceptor: tokio_rustls::TlsAcceptor,
}
#[cfg(feature = "cluster")]
impl axum::serve::Listener for TlsListener {
type Io = tokio_rustls::server::TlsStream<tokio::net::TcpStream>;
type Addr = SocketAddr;
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
loop {
match self.inner.accept().await {
Ok((stream, addr)) => match self.acceptor.accept(stream).await {
Ok(tls_stream) => return (tls_stream, addr),
Err(e) => {
tracing::debug!("TLS handshake failed from {addr}: {e}");
}
},
Err(e) => {
tracing::debug!("TCP accept failed: {e}");
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
}
}
}
fn local_addr(&self) -> std::io::Result<Self::Addr> {
self.inner.local_addr()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = CortexConfig::default();
assert_eq!(config.host, "127.0.0.1");
assert_eq!(config.port, 19090);
assert!(config.cors_allowed_origins.is_empty());
}
#[test]
fn test_config_public() {
let config = CortexConfig::public();
assert_eq!(config.host, "0.0.0.0");
}
#[test]
fn test_config_builder() {
let config = CortexConfig::default()
.with_host("localhost")
.with_port(9090);
assert_eq!(config.host, "localhost");
assert_eq!(config.port, 9090);
}
}