use std::net::{IpAddr, SocketAddr};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, Result};
use boltr::server::BoltServer;
use clap::{Parser, ValueEnum};
use tracing_subscriber::EnvFilter;
use kglite::api::load_file;
use crate::backend::KgliteBackend;
mod auth;
mod backend;
mod error_map;
mod value_adapter;
#[derive(Copy, Clone, Debug, ValueEnum)]
enum AuthScheme {
None,
Basic,
}
#[derive(Parser, Debug)]
#[command(
name = "kglite-bolt-server",
about = "Bolt v5.x protocol server for kglite knowledge graphs.",
long_about = "Loads a .kgl file and serves it over the Neo4j Bolt wire protocol \
so any Neo4j-aware client (Cypher Shell, Neo4j Browser, the official \
drivers, BloodHound, LangChain's Neo4jGraph, ...) can query it as if \
it were a Neo4j instance. See bolt_implementation.md for the phase plan."
)]
struct Cli {
#[arg(long, value_name = "PATH")]
graph: PathBuf,
#[arg(long, default_value = "127.0.0.1")]
bind: IpAddr,
#[arg(long, default_value_t = 7687)]
port: u16,
#[arg(long, default_value_t = false)]
readonly: bool,
#[arg(long, value_enum, default_value_t = AuthScheme::None)]
auth: AuthScheme,
#[arg(long, requires = "auth_pass")]
auth_user: Option<String>,
#[arg(long, requires = "auth_user")]
auth_pass: Option<String>,
#[arg(long, value_name = "SECS")]
idle_timeout: Option<u64>,
#[arg(long, default_value_t = 256)]
max_sessions: usize,
#[arg(long, value_name = "BYTES", default_value_t = 16 * 1024 * 1024)]
max_message_size: usize,
#[arg(long, value_name = "HOST:PORT")]
advertise_addr: Option<String>,
#[arg(long, value_name = "PATH", requires = "tls_key")]
tls_cert: Option<PathBuf>,
#[arg(long, value_name = "PATH", requires = "tls_cert")]
tls_key: Option<PathBuf>,
}
fn init_tracing() {
let filter = EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new("kglite_bolt_server=info,boltr=warn,warn"));
tracing_subscriber::fmt()
.with_env_filter(filter)
.with_target(false)
.init();
}
#[tokio::main]
async fn main() -> Result<()> {
init_tracing();
let cli = Cli::parse();
if !cli.graph.exists() {
anyhow::bail!("--graph {} does not exist", cli.graph.display());
}
tracing::info!(path = %cli.graph.display(), "loading graph");
let dir_arc = load_file(&cli.graph.to_string_lossy())
.map_err(|e| anyhow::anyhow!("kglite::load_file failed: {}", e))
.with_context(|| format!("loading {}", cli.graph.display()))?;
tracing::info!("graph loaded; constructing Bolt server");
let dir = Arc::try_unwrap(dir_arc).unwrap_or_else(|arc| (*arc).clone());
let advertised_addr = cli
.advertise_addr
.clone()
.unwrap_or_else(|| format!("{}:{}", cli.bind, cli.port));
let backend = KgliteBackend::new(dir, cli.readonly, advertised_addr);
let addr = SocketAddr::new(cli.bind, cli.port);
let mut builder = BoltServer::builder(backend)
.max_sessions(cli.max_sessions)
.max_message_size(cli.max_message_size)
.shutdown(async {
let _ = tokio::signal::ctrl_c().await;
tracing::info!("SIGINT received; shutting down");
});
if let Some(secs) = cli.idle_timeout {
builder = builder.idle_timeout(Duration::from_secs(secs));
}
if let (Some(cert_path), Some(key_path)) = (cli.tls_cert.as_ref(), cli.tls_key.as_ref()) {
let _ = rustls::crypto::ring::default_provider().install_default();
let cert_pem = std::fs::read(cert_path)
.with_context(|| format!("reading TLS cert {}", cert_path.display()))?;
let key_pem = std::fs::read(key_path)
.with_context(|| format!("reading TLS key {}", key_path.display()))?;
let tls_config = boltr::server::TlsConfig::from_pem(&cert_pem, &key_pem)
.map_err(|e| anyhow::anyhow!("invalid TLS cert/key: {}", e))?;
builder = builder.tls(tls_config);
tracing::info!(
cert = %cert_path.display(),
key = %key_path.display(),
"TLS enabled — clients must connect via bolt+s:// or neo4j+s://"
);
}
if matches!(cli.auth, AuthScheme::Basic) {
let user = cli.auth_user.clone().ok_or_else(|| {
anyhow::anyhow!("--auth basic requires both --auth-user and --auth-pass")
})?;
let pass = cli.auth_pass.clone().ok_or_else(|| {
anyhow::anyhow!("--auth basic requires both --auth-user and --auth-pass")
})?;
builder = builder.auth(crate::auth::BasicAuthValidator::new(user, pass));
tracing::info!(user = %cli.auth_user.as_deref().unwrap_or(""), "wired --auth basic validator");
}
tracing::info!(%addr, readonly = cli.readonly, "Bolt server starting");
builder
.serve(addr)
.await
.map_err(|e| anyhow::anyhow!("BoltServer::serve failed: {}", e))?;
tracing::info!("Bolt server stopped");
Ok(())
}