use std::net::IpAddr;
use std::num::{NonZeroU32, NonZeroUsize};
use std::path::PathBuf;
use std::str::FromStr;
use clap::Parser;
use hdbconnect_mcp::config::{self, AllowedOperations, TransportMode};
#[cfg(feature = "cache")]
use hdbconnect_mcp::create_cache;
use hdbconnect_mcp::observability::{init_observability, shutdown_observability};
use hdbconnect_mcp::security::SchemaFilter;
use hdbconnect_mcp::transport::run_transport;
use hdbconnect_mcp::{ServerHandler, create_pool};
use url::Url;
#[derive(Parser, Debug)]
#[command(name = "hdbconnect-mcp")]
#[command(about = "MCP server for SAP HANA database", long_about = None)]
#[command(version)]
#[allow(clippy::struct_excessive_bools)]
struct Args {
#[arg(short, long, env = "HANA_URL")]
url: Option<String>,
#[arg(short, long, default_value_t = true)]
read_only: bool,
#[arg(short = 'l', long, default_value = "10000")]
row_limit: u32,
#[arg(short, long, default_value = "4")]
pool_size: usize,
#[arg(short, long)]
verbose: bool,
#[arg(short, long)]
config: Option<PathBuf>,
#[arg(long, default_value = "stdio")]
transport: String,
#[arg(long, default_value = "127.0.0.1")]
http_host: IpAddr,
#[arg(long, default_value = "8080")]
http_port: u16,
#[arg(long, default_value = "none")]
schema_filter_mode: String,
#[arg(long)]
schema_filter_schemas: Option<String>,
#[arg(long, default_value = "30")]
query_timeout: u64,
#[arg(long)]
json_logs: bool,
#[arg(long)]
allow_dml: bool,
#[arg(long)]
no_dml_confirm: bool,
#[arg(long, default_value = "1000")]
dml_max_rows: u32,
#[arg(long)]
no_where_clause: bool,
#[arg(long)]
dml_ops: Option<String>,
#[arg(long)]
allow_procedures: bool,
#[arg(long)]
no_procedure_confirm: bool,
#[arg(long, default_value = "10")]
procedure_max_result_sets: u32,
#[arg(long, default_value = "1000")]
procedure_max_rows: u32,
}
#[tokio::main]
#[allow(clippy::too_many_lines)]
async fn main() -> anyhow::Result<()> {
let args = Args::parse();
let mut builder = if let Some(ref path) = args.config {
config::load_config_from_path(path)?
} else {
config::load_config()?
};
if let Some(ref url_str) = args.url {
let url =
Url::parse(url_str).map_err(|e| anyhow::anyhow!("Invalid connection URL: {e}"))?;
builder = builder.connection_url(url);
}
let transport_mode: TransportMode = args.transport.parse().unwrap_or_default();
builder = builder
.pool_size(NonZeroUsize::new(args.pool_size).unwrap_or(NonZeroUsize::MIN.saturating_add(3)))
.read_only(args.read_only)
.row_limit(NonZeroU32::new(args.row_limit))
.query_timeout(std::time::Duration::from_secs(args.query_timeout))
.transport_mode(transport_mode)
.http_host(args.http_host)
.http_port(args.http_port)
.json_logs(args.json_logs);
if args.schema_filter_mode != "none" {
let schemas: Vec<String> = args
.schema_filter_schemas
.as_deref()
.unwrap_or("")
.split(',')
.map(|s| s.trim().to_uppercase())
.filter(|s| !s.is_empty())
.collect();
let filter = SchemaFilter::from_config(&args.schema_filter_mode, &schemas)?;
builder = builder.schema_filter(filter);
}
if args.verbose {
builder = builder.log_level("debug".to_string());
}
if args.allow_dml {
builder = builder.allow_dml(true);
}
if args.no_dml_confirm {
builder = builder.require_dml_confirmation(false);
}
builder = builder.max_affected_rows(NonZeroU32::new(args.dml_max_rows));
if args.no_where_clause {
builder = builder.require_where_clause(false);
}
if let Some(ref ops_str) = args.dml_ops {
let ops = AllowedOperations::from_str(ops_str).unwrap_or_default();
builder = builder.allowed_operations(ops);
}
if args.allow_procedures {
builder = builder.allow_procedures(true);
}
if args.no_procedure_confirm {
builder = builder.require_procedure_confirmation(false);
}
builder = builder.max_result_sets(NonZeroU32::new(args.procedure_max_result_sets));
builder = builder.max_rows_per_result_set(NonZeroU32::new(args.procedure_max_rows));
let config = builder.build()?;
init_observability(&config.telemetry)?;
let pool = create_pool(config.connection_url.to_string(), config.pool_size.get());
#[cfg(feature = "cache")]
let handler = {
let cache = create_cache(config.cache());
tracing::info!(
"Cache enabled: {}, backend: {:?}",
config.cache.enabled,
config.cache.backend
);
ServerHandler::new(pool, config.clone(), cache)
};
#[cfg(not(feature = "cache"))]
let handler = ServerHandler::new(pool, config.clone());
tracing::info!("Starting MCP server for SAP HANA");
tracing::info!("Transport: {:?}", config.transport.mode);
tracing::info!("Read-only mode: {}", config.read_only);
tracing::info!("Row limit: {:?}", config.row_limit);
tracing::info!("Query timeout: {:?}", config.query_timeout);
tracing::info!("DML enabled: {}", config.dml.allow_dml);
if config.dml.allow_dml {
tracing::info!(
"DML confirmation required: {}",
config.dml.require_confirmation
);
tracing::info!("DML max affected rows: {:?}", config.dml.max_affected_rows);
tracing::info!(
"DML WHERE clause required: {}",
config.dml.require_where_clause
);
}
tracing::info!("Procedures enabled: {}", config.procedure.allow_procedures);
if config.procedure.allow_procedures {
tracing::info!(
"Procedure confirmation required: {}",
config.procedure.require_confirmation
);
tracing::info!(
"Procedure max result sets: {:?}",
config.procedure.max_result_sets
);
tracing::info!(
"Procedure max rows per result set: {:?}",
config.procedure.max_rows_per_result_set
);
}
let shutdown = async {
tokio::signal::ctrl_c()
.await
.expect("Failed to install Ctrl+C handler");
tracing::info!("Shutdown signal received");
};
let result = run_transport(handler, &config, shutdown).await;
shutdown_observability();
result.map_err(Into::into)
}