use std::net::SocketAddr;
use std::path::PathBuf;
use std::time::Duration;
use crate::audit::AuditConfig;
use crate::error::{Result, ServerError};
use crate::memory::MemoryBudgetConfig;
use crate::metrics::MetricsConfig;
use crate::rate_limit::RateLimitConfig;
use crate::shutdown::ShutdownConfig;
pub const ENV_BIND_ADDR: &str = "NEUMANN_BIND_ADDR";
pub const ENV_MAX_MESSAGE_SIZE: &str = "NEUMANN_MAX_MESSAGE_SIZE";
pub const ENV_MAX_UPLOAD_SIZE: &str = "NEUMANN_MAX_UPLOAD_SIZE";
pub const ENV_ENABLE_GRPC_WEB: &str = "NEUMANN_ENABLE_GRPC_WEB";
pub const ENV_ENABLE_REFLECTION: &str = "NEUMANN_ENABLE_REFLECTION";
pub const ENV_TLS_CERT_PATH: &str = "NEUMANN_TLS_CERT_PATH";
pub const ENV_TLS_KEY_PATH: &str = "NEUMANN_TLS_KEY_PATH";
pub const ENV_TLS_CA_CERT_PATH: &str = "NEUMANN_TLS_CA_CERT_PATH";
pub const ENV_TLS_REQUIRE_CLIENT_CERT: &str = "NEUMANN_TLS_REQUIRE_CLIENT_CERT";
pub const ENV_RATE_LIMIT_MAX_REQUESTS: &str = "NEUMANN_RATE_LIMIT_MAX_REQUESTS";
pub const ENV_RATE_LIMIT_MAX_QUERIES: &str = "NEUMANN_RATE_LIMIT_MAX_QUERIES";
pub const ENV_RATE_LIMIT_WINDOW_SECS: &str = "NEUMANN_RATE_LIMIT_WINDOW_SECS";
pub const ENV_SHUTDOWN_DRAIN_TIMEOUT_SECS: &str = "NEUMANN_SHUTDOWN_DRAIN_TIMEOUT_SECS";
pub const ENV_SHUTDOWN_GRACE_PERIOD_SECS: &str = "NEUMANN_SHUTDOWN_GRACE_PERIOD_SECS";
pub const ENV_BLOB_CHUNK_SIZE: &str = "NEUMANN_BLOB_CHUNK_SIZE";
pub const ENV_STREAM_CHANNEL_CAPACITY: &str = "NEUMANN_STREAM_CHANNEL_CAPACITY";
pub const ENV_MAX_CONCURRENT_CONNECTIONS: &str = "NEUMANN_MAX_CONCURRENT_CONNECTIONS";
pub const ENV_MAX_CONCURRENT_STREAMS: &str = "NEUMANN_MAX_CONCURRENT_STREAMS";
pub const ENV_INITIAL_WINDOW_SIZE: &str = "NEUMANN_INITIAL_WINDOW_SIZE";
pub const ENV_INITIAL_CONNECTION_WINDOW_SIZE: &str = "NEUMANN_INITIAL_CONNECTION_WINDOW_SIZE";
pub const ENV_REQUEST_TIMEOUT_SECS: &str = "NEUMANN_REQUEST_TIMEOUT_SECS";
pub const ENV_MEMORY_BUDGET_MAX_BYTES: &str = "NEUMANN_MEMORY_BUDGET_MAX_BYTES";
pub const ENV_MEMORY_BUDGET_LOAD_SHEDDING: &str = "NEUMANN_MEMORY_BUDGET_LOAD_SHEDDING";
pub const ENV_CLUSTER_NODE_ID: &str = "NEUMANN_CLUSTER_NODE_ID";
pub const ENV_CLUSTER_BIND_ADDR: &str = "NEUMANN_CLUSTER_BIND_ADDR";
pub const ENV_CLUSTER_PEERS: &str = "NEUMANN_CLUSTER_PEERS";
pub const ENV_DATA_DIR: &str = "NEUMANN_DATA_DIR";
mod env_parse {
use std::net::SocketAddr;
use std::path::PathBuf;
use std::time::Duration;
use super::{Result, ServerError};
pub fn parse_socket_addr(key: &str) -> Option<Result<SocketAddr>> {
std::env::var(key).ok().map(|val| {
val.parse()
.map_err(|e| ServerError::Config(format!("invalid {key}: {e}")))
})
}
pub fn parse_usize(key: &str) -> Option<Result<usize>> {
std::env::var(key).ok().map(|val| {
val.parse()
.map_err(|e| ServerError::Config(format!("invalid {key}: {e}")))
})
}
pub fn parse_u32(key: &str) -> Option<Result<u32>> {
std::env::var(key).ok().map(|val| {
val.parse()
.map_err(|e| ServerError::Config(format!("invalid {key}: {e}")))
})
}
pub fn parse_bool(key: &str) -> Option<Result<bool>> {
std::env::var(key)
.ok()
.map(|val| match val.to_lowercase().as_str() {
"true" | "1" | "yes" | "on" => Ok(true),
"false" | "0" | "no" | "off" => Ok(false),
_ => Err(ServerError::Config(format!(
"invalid {key}: expected boolean (true/false/1/0/yes/no/on/off)"
))),
})
}
pub fn parse_duration_secs(key: &str) -> Option<Result<Duration>> {
std::env::var(key).ok().map(|val| {
val.parse::<u64>()
.map(Duration::from_secs)
.map_err(|e| ServerError::Config(format!("invalid {key}: {e}")))
})
}
pub fn parse_path(key: &str) -> Option<PathBuf> {
std::env::var(key).ok().map(PathBuf::from)
}
}
#[derive(Debug, Clone)]
pub struct ClusterConfig {
pub node_id: String,
pub raft_bind_addr: SocketAddr,
pub peers: Vec<(String, SocketAddr)>,
pub data_dir: PathBuf,
}
impl ClusterConfig {
fn from_env() -> Result<Option<Self>> {
let Ok(node_id) = std::env::var(ENV_CLUSTER_NODE_ID) else {
return Ok(None);
};
let raft_bind_addr =
env_parse::parse_socket_addr(ENV_CLUSTER_BIND_ADDR).ok_or_else(|| {
ServerError::Config(format!(
"{ENV_CLUSTER_BIND_ADDR} required when {ENV_CLUSTER_NODE_ID} is set"
))
})??;
let peers_str = std::env::var(ENV_CLUSTER_PEERS).unwrap_or_default();
let mut peers = Vec::new();
for entry in peers_str.split(',') {
let entry = entry.trim();
if entry.is_empty() {
continue;
}
let (id, addr_str) = entry.split_once('=').ok_or_else(|| {
ServerError::Config(format!(
"invalid peer format '{entry}', expected 'node_id=addr'"
))
})?;
let addr: SocketAddr = addr_str.parse().map_err(|e| {
ServerError::Config(format!("invalid peer address '{addr_str}': {e}"))
})?;
peers.push((id.to_string(), addr));
}
let data_dir = env_parse::parse_path(ENV_DATA_DIR).ok_or_else(|| {
ServerError::Config(format!(
"{ENV_DATA_DIR} required when {ENV_CLUSTER_NODE_ID} is set"
))
})?;
Ok(Some(Self {
node_id,
raft_bind_addr,
peers,
data_dir,
}))
}
}
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub bind_addr: SocketAddr,
pub tls: Option<TlsConfig>,
pub auth: Option<AuthConfig>,
pub max_message_size: usize,
pub max_upload_size: usize,
pub enable_grpc_web: bool,
pub enable_reflection: bool,
pub blob_chunk_size: usize,
pub stream_channel_capacity: usize,
pub rate_limit: Option<RateLimitConfig>,
pub audit: Option<AuditConfig>,
pub shutdown: Option<ShutdownConfig>,
pub metrics: Option<MetricsConfig>,
pub max_concurrent_connections: Option<usize>,
pub max_concurrent_streams_per_connection: Option<u32>,
pub initial_window_size: Option<u32>,
pub initial_connection_window_size: Option<u32>,
pub request_timeout: Option<Duration>,
pub memory_budget: Option<MemoryBudgetConfig>,
pub rest_addr: Option<SocketAddr>,
pub web_addr: Option<SocketAddr>,
pub streaming: Option<StreamingConfig>,
pub cluster: Option<ClusterConfig>,
pub rest_config: crate::rest::RestConfig,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
bind_addr: "127.0.0.1:9200".parse().unwrap_or_else(|_| {
SocketAddr::from(([127, 0, 0, 1], 9200))
}),
tls: None,
auth: None,
max_message_size: 64 * 1024 * 1024, max_upload_size: 512 * 1024 * 1024, enable_grpc_web: true,
enable_reflection: true,
blob_chunk_size: 64 * 1024, stream_channel_capacity: 32, rate_limit: None,
audit: None,
shutdown: None,
metrics: None,
max_concurrent_connections: None,
max_concurrent_streams_per_connection: None,
initial_window_size: None,
initial_connection_window_size: None,
request_timeout: None,
memory_budget: None,
rest_addr: None,
web_addr: None,
streaming: None,
cluster: None,
rest_config: crate::rest::RestConfig::default(),
}
}
}
impl ServerConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn from_env() -> Result<Self> {
let mut config = Self::default();
if let Some(result) = env_parse::parse_socket_addr(ENV_BIND_ADDR) {
config.bind_addr = result?;
}
if let Some(result) = env_parse::parse_usize(ENV_MAX_MESSAGE_SIZE) {
config.max_message_size = result?;
}
if let Some(result) = env_parse::parse_usize(ENV_MAX_UPLOAD_SIZE) {
config.max_upload_size = result?;
}
if let Some(result) = env_parse::parse_bool(ENV_ENABLE_GRPC_WEB) {
config.enable_grpc_web = result?;
}
if let Some(result) = env_parse::parse_bool(ENV_ENABLE_REFLECTION) {
config.enable_reflection = result?;
}
if let Some(result) = env_parse::parse_usize(ENV_BLOB_CHUNK_SIZE) {
config.blob_chunk_size = result?;
}
if let Some(result) = env_parse::parse_usize(ENV_STREAM_CHANNEL_CAPACITY) {
config.stream_channel_capacity = result?;
}
let cert_path = env_parse::parse_path(ENV_TLS_CERT_PATH);
let key_path = env_parse::parse_path(ENV_TLS_KEY_PATH);
if let (Some(cert), Some(key)) = (cert_path, key_path) {
let mut tls = TlsConfig::new(cert, key);
if let Some(ca) = env_parse::parse_path(ENV_TLS_CA_CERT_PATH) {
tls = tls.with_ca_cert(ca);
}
if let Some(result) = env_parse::parse_bool(ENV_TLS_REQUIRE_CLIENT_CERT) {
tls = tls.with_required_client_cert(result?);
}
config.tls = Some(tls);
}
let has_rate_limit = std::env::var(ENV_RATE_LIMIT_MAX_REQUESTS).is_ok()
|| std::env::var(ENV_RATE_LIMIT_MAX_QUERIES).is_ok()
|| std::env::var(ENV_RATE_LIMIT_WINDOW_SECS).is_ok();
if has_rate_limit {
let mut rate_limit = RateLimitConfig::default();
if let Some(result) = env_parse::parse_u32(ENV_RATE_LIMIT_MAX_REQUESTS) {
rate_limit.max_requests = result?;
}
if let Some(result) = env_parse::parse_u32(ENV_RATE_LIMIT_MAX_QUERIES) {
rate_limit.max_queries = result?;
}
if let Some(result) = env_parse::parse_duration_secs(ENV_RATE_LIMIT_WINDOW_SECS) {
rate_limit.window = result?;
}
config.rate_limit = Some(rate_limit);
}
let has_shutdown = std::env::var(ENV_SHUTDOWN_DRAIN_TIMEOUT_SECS).is_ok()
|| std::env::var(ENV_SHUTDOWN_GRACE_PERIOD_SECS).is_ok();
if has_shutdown {
let mut shutdown = ShutdownConfig::default();
if let Some(result) = env_parse::parse_duration_secs(ENV_SHUTDOWN_DRAIN_TIMEOUT_SECS) {
shutdown.drain_timeout = result?;
}
if let Some(result) = env_parse::parse_duration_secs(ENV_SHUTDOWN_GRACE_PERIOD_SECS) {
shutdown.grace_period = result?;
}
config.shutdown = Some(shutdown);
}
if let Some(result) = env_parse::parse_usize(ENV_MAX_CONCURRENT_CONNECTIONS) {
config.max_concurrent_connections = Some(result?);
}
if let Some(result) = env_parse::parse_u32(ENV_MAX_CONCURRENT_STREAMS) {
config.max_concurrent_streams_per_connection = Some(result?);
}
if let Some(result) = env_parse::parse_u32(ENV_INITIAL_WINDOW_SIZE) {
config.initial_window_size = Some(result?);
}
if let Some(result) = env_parse::parse_u32(ENV_INITIAL_CONNECTION_WINDOW_SIZE) {
config.initial_connection_window_size = Some(result?);
}
if let Some(result) = env_parse::parse_duration_secs(ENV_REQUEST_TIMEOUT_SECS) {
config.request_timeout = Some(result?);
}
let has_memory = std::env::var(ENV_MEMORY_BUDGET_MAX_BYTES).is_ok()
|| std::env::var(ENV_MEMORY_BUDGET_LOAD_SHEDDING).is_ok();
if has_memory {
let mut memory = MemoryBudgetConfig::default();
if let Some(result) = env_parse::parse_usize(ENV_MEMORY_BUDGET_MAX_BYTES) {
memory.max_bytes = result?;
}
if let Some(result) = env_parse::parse_bool(ENV_MEMORY_BUDGET_LOAD_SHEDDING) {
memory.enable_load_shedding = result?;
}
config.memory_budget = Some(memory);
}
config.cluster = ClusterConfig::from_env()?;
Ok(config)
}
#[must_use]
pub const fn with_bind_addr(mut self, addr: SocketAddr) -> Self {
self.bind_addr = addr;
self
}
#[must_use]
pub fn with_tls(mut self, tls: TlsConfig) -> Self {
self.tls = Some(tls);
self
}
#[must_use]
pub fn with_auth(mut self, auth: AuthConfig) -> Self {
self.auth = Some(auth);
self
}
#[must_use]
pub const fn with_max_message_size(mut self, size: usize) -> Self {
self.max_message_size = size;
self
}
#[must_use]
pub const fn with_grpc_web(mut self, enabled: bool) -> Self {
self.enable_grpc_web = enabled;
self
}
#[must_use]
pub const fn with_reflection(mut self, enabled: bool) -> Self {
self.enable_reflection = enabled;
self
}
#[must_use]
pub const fn with_blob_chunk_size(mut self, size: usize) -> Self {
self.blob_chunk_size = size;
self
}
#[must_use]
pub const fn with_max_upload_size(mut self, size: usize) -> Self {
self.max_upload_size = size;
self
}
#[must_use]
pub const fn with_stream_channel_capacity(mut self, capacity: usize) -> Self {
self.stream_channel_capacity = capacity;
self
}
#[must_use]
pub const fn with_rate_limit(mut self, config: RateLimitConfig) -> Self {
self.rate_limit = Some(config);
self
}
#[must_use]
pub const fn with_audit(mut self, config: AuditConfig) -> Self {
self.audit = Some(config);
self
}
#[must_use]
pub const fn with_shutdown(mut self, config: ShutdownConfig) -> Self {
self.shutdown = Some(config);
self
}
#[must_use]
pub fn with_metrics(mut self, config: MetricsConfig) -> Self {
self.metrics = Some(config);
self
}
#[must_use]
pub const fn with_max_concurrent_connections(mut self, max: usize) -> Self {
self.max_concurrent_connections = Some(max);
self
}
#[must_use]
pub const fn with_max_concurrent_streams_per_connection(mut self, max: u32) -> Self {
self.max_concurrent_streams_per_connection = Some(max);
self
}
#[must_use]
pub const fn with_initial_window_size(mut self, size: u32) -> Self {
self.initial_window_size = Some(size);
self
}
#[must_use]
pub const fn with_initial_connection_window_size(mut self, size: u32) -> Self {
self.initial_connection_window_size = Some(size);
self
}
#[must_use]
pub const fn with_request_timeout(mut self, timeout: Duration) -> Self {
self.request_timeout = Some(timeout);
self
}
#[must_use]
pub const fn with_memory_budget(mut self, config: MemoryBudgetConfig) -> Self {
self.memory_budget = Some(config);
self
}
#[must_use]
pub const fn with_rest_addr(mut self, addr: SocketAddr) -> Self {
self.rest_addr = Some(addr);
self
}
#[must_use]
pub const fn with_web_addr(mut self, addr: SocketAddr) -> Self {
self.web_addr = Some(addr);
self
}
#[must_use]
pub const fn with_streaming(mut self, config: StreamingConfig) -> Self {
self.streaming = Some(config);
self
}
#[must_use]
pub fn with_rest_config(mut self, config: crate::rest::RestConfig) -> Self {
self.rest_config = config;
self
}
pub fn validate(&self) -> Result<()> {
if self.max_message_size == 0 {
return Err(ServerError::Config(
"max_message_size must be greater than 0".to_string(),
));
}
if self.blob_chunk_size == 0 {
return Err(ServerError::Config(
"blob_chunk_size must be greater than 0".to_string(),
));
}
if self.stream_channel_capacity == 0 {
return Err(ServerError::Config(
"stream_channel_capacity must be greater than 0".to_string(),
));
}
if let Some(ref tls) = self.tls {
tls.validate()?;
}
if let Some(ref auth) = self.auth {
auth.validate()?;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct StreamingConfig {
pub channel_capacity: usize,
pub max_stream_items: usize,
pub slow_consumer_timeout: Duration,
}
impl Default for StreamingConfig {
fn default() -> Self {
Self {
channel_capacity: 32,
max_stream_items: 10_000,
slow_consumer_timeout: Duration::from_secs(30),
}
}
}
impl StreamingConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub const fn with_channel_capacity(mut self, capacity: usize) -> Self {
self.channel_capacity = capacity;
self
}
#[must_use]
pub const fn with_max_stream_items(mut self, max: usize) -> Self {
self.max_stream_items = max;
self
}
#[must_use]
pub const fn with_slow_consumer_timeout(mut self, timeout: Duration) -> Self {
self.slow_consumer_timeout = timeout;
self
}
}
#[derive(Debug, Clone)]
pub struct TlsConfig {
pub cert_path: PathBuf,
pub key_path: PathBuf,
pub ca_cert_path: Option<PathBuf>,
pub require_client_cert: bool,
}
impl TlsConfig {
#[must_use]
pub const fn new(cert_path: PathBuf, key_path: PathBuf) -> Self {
Self {
cert_path,
key_path,
ca_cert_path: None,
require_client_cert: false,
}
}
#[must_use]
pub fn with_ca_cert(mut self, path: PathBuf) -> Self {
self.ca_cert_path = Some(path);
self
}
#[must_use]
pub const fn with_required_client_cert(mut self, required: bool) -> Self {
self.require_client_cert = required;
self
}
pub fn validate(&self) -> Result<()> {
if !self.cert_path.exists() {
return Err(ServerError::Config(format!(
"certificate file not found: {}",
self.cert_path.display()
)));
}
if !self.key_path.exists() {
return Err(ServerError::Config(format!(
"key file not found: {}",
self.key_path.display()
)));
}
if let Some(ref ca_path) = self.ca_cert_path {
if !ca_path.exists() {
return Err(ServerError::Config(format!(
"CA certificate file not found: {}",
ca_path.display()
)));
}
}
if self.require_client_cert && self.ca_cert_path.is_none() {
return Err(ServerError::Config(
"require_client_cert is true but ca_cert_path is not set; \
cannot verify client certificates without a CA"
.to_string(),
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct AuthConfig {
pub api_keys: Vec<ApiKey>,
pub api_key_header: String,
pub allow_anonymous: bool,
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
api_keys: Vec::new(),
api_key_header: "x-api-key".to_string(),
allow_anonymous: false,
}
}
}
impl AuthConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_api_key(mut self, key: ApiKey) -> Self {
self.api_keys.push(key);
self
}
#[must_use]
pub fn with_header(mut self, header: String) -> Self {
self.api_key_header = header;
self
}
#[must_use]
pub const fn with_anonymous(mut self, allowed: bool) -> Self {
self.allow_anonymous = allowed;
self
}
pub fn validate(&self) -> Result<()> {
if self.api_key_header.is_empty() {
return Err(ServerError::Config(
"api_key_header cannot be empty".to_string(),
));
}
if !self.allow_anonymous && self.api_keys.is_empty() {
return Err(ServerError::Config(
"at least one API key required when anonymous access is disabled".to_string(),
));
}
for key in &self.api_keys {
key.validate()?;
}
Ok(())
}
#[must_use]
pub fn validate_key(&self, key: &str) -> Option<&str> {
let key_bytes = key.as_bytes();
let mut found_identity: Option<&str> = None;
for api_key in &self.api_keys {
let stored_bytes = api_key.key.as_bytes();
let max_len = stored_bytes.len().max(key_bytes.len());
let mut matches: u8 = 1;
for i in 0..max_len {
let stored_byte = if i < stored_bytes.len() {
stored_bytes[i]
} else {
0
};
let key_byte = if i < key_bytes.len() { key_bytes[i] } else { 0 };
matches &= u8::from(stored_byte == key_byte);
}
let lengths_match = u8::from(stored_bytes.len() == key_bytes.len());
matches &= lengths_match;
if matches == 1 {
found_identity = Some(api_key.identity.as_str());
}
}
found_identity
}
}
#[derive(Debug, Clone)]
pub struct ApiKey {
pub key: String,
pub identity: String,
pub description: Option<String>,
}
impl ApiKey {
#[must_use]
pub const fn new(key: String, identity: String) -> Self {
Self {
key,
identity,
description: None,
}
}
#[must_use]
pub fn with_description(mut self, desc: String) -> Self {
self.description = Some(desc);
self
}
pub fn validate(&self) -> Result<()> {
if self.key.is_empty() {
return Err(ServerError::Config("API key cannot be empty".to_string()));
}
if self.key.len() < 16 {
return Err(ServerError::Config(
"API key must be at least 16 characters".to_string(),
));
}
if self.identity.is_empty() {
return Err(ServerError::Config(
"API key identity cannot be empty".to_string(),
));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = ServerConfig::default();
assert_eq!(config.bind_addr.port(), 9200);
assert!(config.tls.is_none());
assert!(config.auth.is_none());
assert!(config.enable_grpc_web);
assert!(config.enable_reflection);
}
#[test]
fn test_config_builder() {
let config = ServerConfig::new()
.with_bind_addr("0.0.0.0:8080".parse().unwrap())
.with_max_message_size(128 * 1024 * 1024)
.with_grpc_web(false)
.with_reflection(false)
.with_blob_chunk_size(32 * 1024);
assert_eq!(config.bind_addr.port(), 8080);
assert_eq!(config.max_message_size, 128 * 1024 * 1024);
assert!(!config.enable_grpc_web);
assert!(!config.enable_reflection);
assert_eq!(config.blob_chunk_size, 32 * 1024);
}
#[test]
fn test_config_validation() {
let config = ServerConfig::default();
assert!(config.validate().is_ok());
let invalid_config = ServerConfig::new().with_max_message_size(0);
assert!(invalid_config.validate().is_err());
let invalid_config = ServerConfig::new().with_blob_chunk_size(0);
assert!(invalid_config.validate().is_err());
}
#[test]
fn test_auth_config_validation() {
let auth = AuthConfig::new().with_anonymous(true);
assert!(auth.validate().is_ok());
let auth = AuthConfig::new().with_anonymous(false);
assert!(auth.validate().is_err());
let auth = AuthConfig::new()
.with_anonymous(false)
.with_api_key(ApiKey::new(
"test-api-key-12345678".to_string(),
"user:test".to_string(),
));
assert!(auth.validate().is_ok());
}
#[test]
fn test_api_key_validation() {
let key = ApiKey::new("test-api-key-12345678".to_string(), "user:test".to_string());
assert!(key.validate().is_ok());
let key = ApiKey::new("short".to_string(), "user:test".to_string());
assert!(key.validate().is_err());
let key = ApiKey::new("test-api-key-12345678".to_string(), String::new());
assert!(key.validate().is_err());
}
#[test]
fn test_validate_key() {
let auth = AuthConfig::new().with_api_key(ApiKey::new(
"test-api-key-12345678".to_string(),
"user:alice".to_string(),
));
assert_eq!(
auth.validate_key("test-api-key-12345678"),
Some("user:alice")
);
assert_eq!(auth.validate_key("wrong-key"), None);
}
#[test]
fn test_validate_key_different_lengths() {
let auth = AuthConfig::new().with_api_key(ApiKey::new(
"test-api-key-12345678".to_string(),
"user:alice".to_string(),
));
assert_eq!(auth.validate_key("test-api-key"), None);
assert_eq!(auth.validate_key("test-api-key-12345678-extra"), None);
assert_eq!(auth.validate_key("test-api-key-XXXXXXXX"), None);
assert_eq!(
auth.validate_key("test-api-key-12345678"),
Some("user:alice")
);
}
#[test]
fn test_validate_key_multiple_keys() {
let auth = AuthConfig::new()
.with_api_key(ApiKey::new(
"first-key-12345678".to_string(),
"user:first".to_string(),
))
.with_api_key(ApiKey::new(
"second-key-1234567".to_string(),
"user:second".to_string(),
))
.with_api_key(ApiKey::new(
"third-key-12345678".to_string(),
"user:third".to_string(),
));
assert_eq!(auth.validate_key("first-key-12345678"), Some("user:first"));
assert_eq!(auth.validate_key("second-key-1234567"), Some("user:second"));
assert_eq!(auth.validate_key("third-key-12345678"), Some("user:third"));
assert_eq!(auth.validate_key("unknown-key-12345"), None);
}
#[test]
fn test_tls_config() {
use tempfile::NamedTempFile;
let cert_file = NamedTempFile::new().unwrap();
let key_file = NamedTempFile::new().unwrap();
let tls = TlsConfig::new(
cert_file.path().to_path_buf(),
key_file.path().to_path_buf(),
);
assert!(tls.validate().is_ok());
let tls = TlsConfig::new(
PathBuf::from("/nonexistent/cert.pem"),
key_file.path().to_path_buf(),
);
assert!(tls.validate().is_err());
}
#[test]
fn test_api_key_with_description() {
let key = ApiKey::new("test-api-key-12345678".to_string(), "user:test".to_string())
.with_description("Test API key".to_string());
assert_eq!(key.description, Some("Test API key".to_string()));
}
#[test]
fn test_tls_config_with_ca() {
use tempfile::NamedTempFile;
let cert_file = NamedTempFile::new().unwrap();
let key_file = NamedTempFile::new().unwrap();
let ca_file = NamedTempFile::new().unwrap();
let tls = TlsConfig::new(
cert_file.path().to_path_buf(),
key_file.path().to_path_buf(),
)
.with_ca_cert(ca_file.path().to_path_buf())
.with_required_client_cert(true);
assert!(tls.validate().is_ok());
assert!(tls.require_client_cert);
}
#[test]
fn test_require_client_cert_without_ca_fails_validation() {
use tempfile::NamedTempFile;
let cert_file = NamedTempFile::new().unwrap();
let key_file = NamedTempFile::new().unwrap();
let tls = TlsConfig::new(
cert_file.path().to_path_buf(),
key_file.path().to_path_buf(),
)
.with_required_client_cert(true);
let result = tls.validate();
assert!(result.is_err());
let config = ServerConfig::new()
.with_rest_config(crate::rest::RestConfig::new().with_max_body_size(32 * 1024 * 1024));
assert_eq!(config.rest_config.max_body_size, 32 * 1024 * 1024);
let default_config = ServerConfig::default();
assert_eq!(default_config.rest_config.max_body_size, 16 * 1024 * 1024);
let err = result.unwrap_err().to_string();
assert!(err.contains("require_client_cert"));
assert!(err.contains("ca_cert_path"));
}
#[test]
fn test_require_client_cert_with_ca_passes_validation() {
use tempfile::NamedTempFile;
let cert_file = NamedTempFile::new().unwrap();
let key_file = NamedTempFile::new().unwrap();
let ca_file = NamedTempFile::new().unwrap();
let tls = TlsConfig::new(
cert_file.path().to_path_buf(),
key_file.path().to_path_buf(),
)
.with_ca_cert(ca_file.path().to_path_buf())
.with_required_client_cert(true);
assert!(tls.validate().is_ok());
}
#[test]
fn test_optional_client_cert_with_ca_passes_validation() {
use tempfile::NamedTempFile;
let cert_file = NamedTempFile::new().unwrap();
let key_file = NamedTempFile::new().unwrap();
let ca_file = NamedTempFile::new().unwrap();
let tls = TlsConfig::new(
cert_file.path().to_path_buf(),
key_file.path().to_path_buf(),
)
.with_ca_cert(ca_file.path().to_path_buf())
.with_required_client_cert(false);
assert!(tls.validate().is_ok());
assert!(!tls.require_client_cert);
}
#[test]
fn test_auth_config_with_header() {
let auth = AuthConfig::new()
.with_header("Authorization".to_string())
.with_anonymous(true);
assert_eq!(auth.api_key_header, "Authorization");
assert!(auth.validate().is_ok());
}
#[test]
fn test_empty_header_validation() {
let auth = AuthConfig {
api_keys: Vec::new(),
api_key_header: String::new(),
allow_anonymous: true,
};
assert!(auth.validate().is_err());
}
#[test]
fn test_server_config_with_rate_limit() {
let config = ServerConfig::new().with_rate_limit(RateLimitConfig::default());
assert!(config.rate_limit.is_some());
let rate_limit = config.rate_limit.as_ref().unwrap();
assert_eq!(rate_limit.max_requests, 1000);
}
#[test]
fn test_server_config_with_audit() {
let config = ServerConfig::new().with_audit(AuditConfig::default());
assert!(config.audit.is_some());
let audit = config.audit.as_ref().unwrap();
assert!(audit.enabled);
}
#[test]
fn test_server_config_with_both() {
let config = ServerConfig::new()
.with_rate_limit(RateLimitConfig::strict())
.with_audit(AuditConfig::default().with_query_logging());
assert!(config.rate_limit.is_some());
assert!(config.audit.is_some());
assert_eq!(config.rate_limit.as_ref().unwrap().max_requests, 10);
assert!(config.audit.as_ref().unwrap().log_queries);
}
#[test]
fn test_default_config_no_rate_limit_or_audit() {
let config = ServerConfig::default();
assert!(config.rate_limit.is_none());
assert!(config.audit.is_none());
}
#[test]
fn test_server_config_with_shutdown() {
use std::time::Duration;
let config = ServerConfig::new()
.with_shutdown(ShutdownConfig::default().with_drain_timeout(Duration::from_secs(60)));
assert!(config.shutdown.is_some());
let shutdown = config.shutdown.as_ref().unwrap();
assert_eq!(shutdown.drain_timeout, Duration::from_secs(60));
}
#[test]
fn test_server_config_with_metrics() {
let config = ServerConfig::new()
.with_metrics(MetricsConfig::default().with_service_name("test_service".to_string()));
assert!(config.metrics.is_some());
let metrics = config.metrics.as_ref().unwrap();
assert_eq!(metrics.service_name, "test_service");
}
#[test]
fn test_default_config_no_shutdown_or_metrics() {
let config = ServerConfig::default();
assert!(config.shutdown.is_none());
assert!(config.metrics.is_none());
}
#[test]
fn test_server_config_with_max_concurrent_connections() {
let config = ServerConfig::new().with_max_concurrent_connections(100);
assert_eq!(config.max_concurrent_connections, Some(100));
}
#[test]
fn test_server_config_with_max_concurrent_streams() {
let config = ServerConfig::new().with_max_concurrent_streams_per_connection(50);
assert_eq!(config.max_concurrent_streams_per_connection, Some(50));
}
#[test]
fn test_server_config_with_window_sizes() {
let config = ServerConfig::new()
.with_initial_window_size(65535)
.with_initial_connection_window_size(1024 * 1024);
assert_eq!(config.initial_window_size, Some(65535));
assert_eq!(config.initial_connection_window_size, Some(1024 * 1024));
}
#[test]
fn test_server_config_with_request_timeout() {
let config = ServerConfig::new().with_request_timeout(Duration::from_secs(30));
assert_eq!(config.request_timeout, Some(Duration::from_secs(30)));
}
#[test]
fn test_server_config_with_memory_budget() {
let config = ServerConfig::new().with_memory_budget(
MemoryBudgetConfig::new()
.with_max_bytes(512 * 1024 * 1024)
.with_load_shedding(true),
);
assert!(config.memory_budget.is_some());
let budget = config.memory_budget.as_ref().unwrap();
assert_eq!(budget.max_bytes, 512 * 1024 * 1024);
assert!(budget.enable_load_shedding);
}
#[test]
fn test_default_config_no_connection_limits() {
let config = ServerConfig::default();
assert!(config.max_concurrent_connections.is_none());
assert!(config.max_concurrent_streams_per_connection.is_none());
assert!(config.initial_window_size.is_none());
assert!(config.initial_connection_window_size.is_none());
assert!(config.request_timeout.is_none());
assert!(config.memory_budget.is_none());
}
#[test]
fn test_server_config_all_connection_limits() {
let config = ServerConfig::new()
.with_max_concurrent_connections(200)
.with_max_concurrent_streams_per_connection(100)
.with_initial_window_size(65535)
.with_initial_connection_window_size(2 * 1024 * 1024)
.with_request_timeout(Duration::from_secs(60))
.with_memory_budget(MemoryBudgetConfig::default());
assert_eq!(config.max_concurrent_connections, Some(200));
assert_eq!(config.max_concurrent_streams_per_connection, Some(100));
assert_eq!(config.initial_window_size, Some(65535));
assert_eq!(config.initial_connection_window_size, Some(2 * 1024 * 1024));
assert_eq!(config.request_timeout, Some(Duration::from_secs(60)));
assert!(config.memory_budget.is_some());
}
mod env_tests {
use super::*;
use std::sync::Mutex;
static ENV_MUTEX: Mutex<()> = Mutex::new(());
fn with_env_vars<F, R>(vars: &[(&str, &str)], f: F) -> R
where
F: FnOnce() -> R,
{
let _guard = ENV_MUTEX
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let saved: Vec<_> = vars
.iter()
.map(|(k, v)| {
let old = std::env::var(k).ok();
std::env::set_var(k, v);
(*k, old)
})
.collect();
let result = f();
for (k, old) in saved {
match old {
Some(v) => std::env::set_var(k, v),
None => std::env::remove_var(k),
}
}
result
}
fn without_env_vars<F, R>(keys: &[&str], f: F) -> R
where
F: FnOnce() -> R,
{
let _guard = ENV_MUTEX
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let saved: Vec<_> = keys
.iter()
.map(|k| {
let old = std::env::var(k).ok();
std::env::remove_var(k);
(*k, old)
})
.collect();
let result = f();
for (k, old) in saved {
if let Some(v) = old {
std::env::set_var(k, v);
}
}
result
}
#[test]
fn test_from_env_defaults() {
without_env_vars(
&[
ENV_BIND_ADDR,
ENV_MAX_MESSAGE_SIZE,
ENV_ENABLE_GRPC_WEB,
ENV_RATE_LIMIT_MAX_REQUESTS,
],
|| {
let config = ServerConfig::from_env().unwrap();
assert_eq!(config.bind_addr.port(), 9200);
assert!(config.enable_grpc_web);
assert!(config.enable_reflection);
assert!(config.rate_limit.is_none());
},
);
}
#[test]
fn test_from_env_bind_addr() {
with_env_vars(&[(ENV_BIND_ADDR, "0.0.0.0:8080")], || {
let config = ServerConfig::from_env().unwrap();
assert_eq!(config.bind_addr.to_string(), "0.0.0.0:8080");
});
}
#[test]
fn test_from_env_invalid_bind_addr() {
with_env_vars(&[(ENV_BIND_ADDR, "not-an-address")], || {
let result = ServerConfig::from_env();
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains(ENV_BIND_ADDR));
});
}
#[test]
fn test_from_env_max_message_size() {
with_env_vars(&[(ENV_MAX_MESSAGE_SIZE, "1048576")], || {
let config = ServerConfig::from_env().unwrap();
assert_eq!(config.max_message_size, 1048576);
});
}
#[test]
fn test_from_env_invalid_max_message_size() {
with_env_vars(&[(ENV_MAX_MESSAGE_SIZE, "not-a-number")], || {
let result = ServerConfig::from_env();
assert!(result.is_err());
});
}
#[test]
fn test_from_env_bool_true_variants() {
for val in &["true", "1", "yes", "on", "TRUE", "YES", "ON"] {
with_env_vars(&[(ENV_ENABLE_GRPC_WEB, val)], || {
let config = ServerConfig::from_env().unwrap();
assert!(config.enable_grpc_web, "failed for value: {val}");
});
}
}
#[test]
fn test_from_env_bool_false_variants() {
for val in &["false", "0", "no", "off", "FALSE", "NO", "OFF"] {
with_env_vars(&[(ENV_ENABLE_GRPC_WEB, val)], || {
let config = ServerConfig::from_env().unwrap();
assert!(!config.enable_grpc_web, "failed for value: {val}");
});
}
}
#[test]
fn test_from_env_invalid_bool() {
with_env_vars(&[(ENV_ENABLE_GRPC_WEB, "maybe")], || {
let result = ServerConfig::from_env();
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("boolean"));
});
}
#[test]
fn test_from_env_rate_limit() {
with_env_vars(
&[
(ENV_RATE_LIMIT_MAX_REQUESTS, "500"),
(ENV_RATE_LIMIT_MAX_QUERIES, "100"),
(ENV_RATE_LIMIT_WINDOW_SECS, "120"),
],
|| {
let config = ServerConfig::from_env().unwrap();
assert!(config.rate_limit.is_some());
let rate = config.rate_limit.unwrap();
assert_eq!(rate.max_requests, 500);
assert_eq!(rate.max_queries, 100);
assert_eq!(rate.window, Duration::from_secs(120));
},
);
}
#[test]
fn test_from_env_shutdown() {
with_env_vars(
&[
(ENV_SHUTDOWN_DRAIN_TIMEOUT_SECS, "60"),
(ENV_SHUTDOWN_GRACE_PERIOD_SECS, "10"),
],
|| {
let config = ServerConfig::from_env().unwrap();
assert!(config.shutdown.is_some());
let shutdown = config.shutdown.unwrap();
assert_eq!(shutdown.drain_timeout, Duration::from_secs(60));
assert_eq!(shutdown.grace_period, Duration::from_secs(10));
},
);
}
#[test]
fn test_from_env_connection_limits() {
with_env_vars(
&[
(ENV_MAX_CONCURRENT_CONNECTIONS, "100"),
(ENV_MAX_CONCURRENT_STREAMS, "50"),
(ENV_INITIAL_WINDOW_SIZE, "65535"),
(ENV_REQUEST_TIMEOUT_SECS, "30"),
],
|| {
let config = ServerConfig::from_env().unwrap();
assert_eq!(config.max_concurrent_connections, Some(100));
assert_eq!(config.max_concurrent_streams_per_connection, Some(50));
assert_eq!(config.initial_window_size, Some(65535));
assert_eq!(config.request_timeout, Some(Duration::from_secs(30)));
},
);
}
#[test]
fn test_from_env_memory_budget() {
with_env_vars(
&[
(ENV_MEMORY_BUDGET_MAX_BYTES, "536870912"),
(ENV_MEMORY_BUDGET_LOAD_SHEDDING, "true"),
],
|| {
let config = ServerConfig::from_env().unwrap();
assert!(config.memory_budget.is_some());
let memory = config.memory_budget.unwrap();
assert_eq!(memory.max_bytes, 536_870_912);
assert!(memory.enable_load_shedding);
},
);
}
#[test]
fn test_from_env_tls_requires_both_paths() {
with_env_vars(&[(ENV_TLS_CERT_PATH, "/path/to/cert.pem")], || {
let config = ServerConfig::from_env().unwrap();
assert!(config.tls.is_none());
});
with_env_vars(&[(ENV_TLS_KEY_PATH, "/path/to/key.pem")], || {
let config = ServerConfig::from_env().unwrap();
assert!(config.tls.is_none());
});
with_env_vars(
&[
(ENV_TLS_CERT_PATH, "/path/to/cert.pem"),
(ENV_TLS_KEY_PATH, "/path/to/key.pem"),
],
|| {
let config = ServerConfig::from_env().unwrap();
assert!(config.tls.is_some());
let tls = config.tls.unwrap();
assert_eq!(tls.cert_path.to_string_lossy(), "/path/to/cert.pem");
assert_eq!(tls.key_path.to_string_lossy(), "/path/to/key.pem");
},
);
}
#[test]
fn test_from_env_tls_with_ca() {
with_env_vars(
&[
(ENV_TLS_CERT_PATH, "/path/to/cert.pem"),
(ENV_TLS_KEY_PATH, "/path/to/key.pem"),
(ENV_TLS_CA_CERT_PATH, "/path/to/ca.pem"),
],
|| {
let config = ServerConfig::from_env().unwrap();
assert!(config.tls.is_some());
let tls = config.tls.unwrap();
assert!(tls.ca_cert_path.is_some());
assert_eq!(
tls.ca_cert_path.unwrap().to_string_lossy(),
"/path/to/ca.pem"
);
},
);
}
#[test]
fn test_from_env_require_client_cert() {
with_env_vars(
&[
(ENV_TLS_CERT_PATH, "/path/to/cert.pem"),
(ENV_TLS_KEY_PATH, "/path/to/key.pem"),
(ENV_TLS_CA_CERT_PATH, "/path/to/ca.pem"),
(ENV_TLS_REQUIRE_CLIENT_CERT, "true"),
],
|| {
let config = ServerConfig::from_env().unwrap();
assert!(config.tls.is_some());
let tls = config.tls.unwrap();
assert!(tls.require_client_cert);
},
);
with_env_vars(
&[
(ENV_TLS_CERT_PATH, "/path/to/cert.pem"),
(ENV_TLS_KEY_PATH, "/path/to/key.pem"),
(ENV_TLS_CA_CERT_PATH, "/path/to/ca.pem"),
(ENV_TLS_REQUIRE_CLIENT_CERT, "false"),
],
|| {
let config = ServerConfig::from_env().unwrap();
assert!(config.tls.is_some());
let tls = config.tls.unwrap();
assert!(!tls.require_client_cert);
},
);
with_env_vars(
&[
(ENV_TLS_CERT_PATH, "/path/to/cert.pem"),
(ENV_TLS_KEY_PATH, "/path/to/key.pem"),
],
|| {
let config = ServerConfig::from_env().unwrap();
assert!(config.tls.is_some());
let tls = config.tls.unwrap();
assert!(!tls.require_client_cert);
},
);
}
#[test]
fn test_env_var_constants() {
assert!(ENV_BIND_ADDR.starts_with("NEUMANN_"));
assert!(ENV_MAX_MESSAGE_SIZE.starts_with("NEUMANN_"));
assert!(ENV_ENABLE_GRPC_WEB.starts_with("NEUMANN_"));
assert!(ENV_TLS_CERT_PATH.starts_with("NEUMANN_"));
assert!(ENV_TLS_REQUIRE_CLIENT_CERT.starts_with("NEUMANN_"));
assert!(ENV_RATE_LIMIT_MAX_REQUESTS.starts_with("NEUMANN_"));
assert!(ENV_SHUTDOWN_DRAIN_TIMEOUT_SECS.starts_with("NEUMANN_"));
assert!(ENV_MAX_CONCURRENT_CONNECTIONS.starts_with("NEUMANN_"));
assert!(ENV_MEMORY_BUDGET_MAX_BYTES.starts_with("NEUMANN_"));
}
}
#[test]
fn test_streaming_config_default() {
let config = StreamingConfig::default();
assert_eq!(config.channel_capacity, 32);
assert_eq!(config.max_stream_items, 10_000);
assert_eq!(config.slow_consumer_timeout, Duration::from_secs(30));
}
#[test]
fn test_streaming_config_new() {
let config = StreamingConfig::new();
assert_eq!(config.channel_capacity, 32);
assert_eq!(config.max_stream_items, 10_000);
}
#[test]
fn test_streaming_config_with_channel_capacity() {
let config = StreamingConfig::new().with_channel_capacity(64);
assert_eq!(config.channel_capacity, 64);
}
#[test]
fn test_streaming_config_with_max_stream_items() {
let config = StreamingConfig::new().with_max_stream_items(5000);
assert_eq!(config.max_stream_items, 5000);
}
#[test]
fn test_streaming_config_with_slow_consumer_timeout() {
let config = StreamingConfig::new().with_slow_consumer_timeout(Duration::from_secs(60));
assert_eq!(config.slow_consumer_timeout, Duration::from_secs(60));
}
#[test]
fn test_streaming_config_builder_chain() {
let config = StreamingConfig::new()
.with_channel_capacity(128)
.with_max_stream_items(20_000)
.with_slow_consumer_timeout(Duration::from_secs(120));
assert_eq!(config.channel_capacity, 128);
assert_eq!(config.max_stream_items, 20_000);
assert_eq!(config.slow_consumer_timeout, Duration::from_secs(120));
}
#[test]
fn test_server_config_with_streaming() {
let streaming = StreamingConfig::new()
.with_channel_capacity(64)
.with_max_stream_items(5000);
let config = ServerConfig::new().with_streaming(streaming);
assert!(config.streaming.is_some());
let s = config.streaming.as_ref().unwrap();
assert_eq!(s.channel_capacity, 64);
assert_eq!(s.max_stream_items, 5000);
}
#[test]
fn test_server_config_with_rest_addr() {
let config = ServerConfig::new().with_rest_addr("0.0.0.0:8080".parse().unwrap());
assert!(config.rest_addr.is_some());
assert_eq!(config.rest_addr.unwrap().port(), 8080);
}
#[test]
fn test_server_config_with_web_addr() {
let config = ServerConfig::new().with_web_addr("0.0.0.0:9000".parse().unwrap());
assert!(config.web_addr.is_some());
assert_eq!(config.web_addr.unwrap().port(), 9000);
}
#[test]
fn test_server_config_with_max_upload_size() {
let config = ServerConfig::new().with_max_upload_size(1024 * 1024 * 1024);
assert_eq!(config.max_upload_size, 1024 * 1024 * 1024);
}
#[test]
fn test_server_config_with_stream_channel_capacity() {
let config = ServerConfig::new().with_stream_channel_capacity(64);
assert_eq!(config.stream_channel_capacity, 64);
}
#[test]
fn test_server_config_validate_stream_channel_capacity_zero() {
let config = ServerConfig::new().with_stream_channel_capacity(0);
assert!(config.validate().is_err());
}
#[test]
fn test_tls_config_missing_key_file() {
use tempfile::NamedTempFile;
let cert_file = NamedTempFile::new().unwrap();
let tls = TlsConfig::new(
cert_file.path().to_path_buf(),
PathBuf::from("/nonexistent/key.pem"),
);
assert!(tls.validate().is_err());
}
#[test]
fn test_tls_config_missing_ca_file() {
use tempfile::NamedTempFile;
let cert_file = NamedTempFile::new().unwrap();
let key_file = NamedTempFile::new().unwrap();
let tls = TlsConfig::new(
cert_file.path().to_path_buf(),
key_file.path().to_path_buf(),
)
.with_ca_cert(PathBuf::from("/nonexistent/ca.pem"));
assert!(tls.validate().is_err());
}
#[test]
fn test_api_key_empty() {
let key = ApiKey::new(String::new(), "user:test".to_string());
assert!(key.validate().is_err());
}
#[test]
fn test_streaming_config_debug() {
let config = StreamingConfig::default();
let debug_str = format!("{:?}", config);
assert!(debug_str.contains("StreamingConfig"));
assert!(debug_str.contains("channel_capacity"));
}
#[test]
fn test_streaming_config_clone() {
let config = StreamingConfig::new()
.with_channel_capacity(100)
.with_max_stream_items(500);
let cloned = config.clone();
assert_eq!(cloned.channel_capacity, 100);
assert_eq!(cloned.max_stream_items, 500);
}
}