use figment::{
providers::{Env, Format, Serialized, Toml},
Figment,
};
use serde::{Deserialize, Serialize};
use std::path::Path;
use thiserror::Error;
const fn default_k() -> usize {
256
}
#[allow(clippy::unnecessary_wraps)]
const fn default_oversampling() -> Option<u32> {
Some(4)
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum QuantizationType {
#[default]
None,
#[serde(alias = "sq8")]
SQ8,
Binary,
#[serde(alias = "pq")]
PQ {
m: usize,
#[serde(default = "default_k")]
k: usize,
#[serde(default)]
opq_enabled: bool,
#[serde(default = "default_oversampling")]
oversampling: Option<u32>,
},
#[serde(alias = "rabitq")]
RaBitQ,
}
impl QuantizationType {
#[must_use]
pub const fn is_pq(&self) -> bool {
matches!(self, Self::PQ { .. })
}
#[must_use]
pub const fn is_rabitq(&self) -> bool {
matches!(self, Self::RaBitQ)
}
}
#[derive(Error, Debug)]
pub enum ConfigError {
#[error("Failed to parse configuration: {0}")]
ParseError(String),
#[error("Invalid configuration value for '{key}': {message}")]
InvalidValue {
key: String,
message: String,
},
#[error("Configuration file not found: {0}")]
FileNotFound(String),
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SearchMode {
Fast,
#[default]
Balanced,
Accurate,
Perfect,
}
impl SearchMode {
#[must_use]
pub fn ef_search(&self) -> usize {
match self {
Self::Fast => 64,
Self::Balanced => 128,
Self::Accurate => 512,
Self::Perfect => usize::MAX, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct SearchConfig {
pub default_mode: SearchMode,
pub ef_search: Option<usize>,
pub max_results: usize,
pub query_timeout_ms: u64,
}
impl Default for SearchConfig {
fn default() -> Self {
Self {
default_mode: SearchMode::Balanced,
ef_search: None,
max_results: 1000,
query_timeout_ms: 30000,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(default)]
pub struct HnswConfig {
pub m: Option<usize>,
pub ef_construction: Option<usize>,
pub max_layers: usize,
}
pub mod server {
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct StorageConfig {
pub data_dir: String,
pub storage_mode: String,
pub mmap_cache_mb: usize,
pub vector_alignment: usize,
}
impl Default for StorageConfig {
fn default() -> Self {
Self {
data_dir: "./velesdb_data".to_string(),
storage_mode: "mmap".to_string(),
mmap_cache_mb: 1024,
vector_alignment: 64,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct ServerConfig {
pub host: String,
pub port: u16,
pub workers: usize,
pub max_body_size: usize,
pub cors_enabled: bool,
pub cors_origins: Vec<String>,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
host: "127.0.0.1".to_string(),
port: 8080,
workers: 0,
max_body_size: 104_857_600,
cors_enabled: false,
cors_origins: vec!["*".to_string()],
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct LoggingConfig {
pub level: String,
pub format: String,
pub file: String,
}
impl Default for LoggingConfig {
fn default() -> Self {
Self {
level: "info".to_string(),
format: "text".to_string(),
file: String::new(),
}
}
}
}
pub use server::{LoggingConfig, ServerConfig, StorageConfig};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct LimitsConfig {
pub max_dimensions: usize,
pub max_vectors_per_collection: usize,
pub max_collections: usize,
pub max_payload_size: usize,
pub max_perfect_mode_vectors: usize,
}
impl Default for LimitsConfig {
fn default() -> Self {
Self {
max_dimensions: 4096,
max_vectors_per_collection: 100_000_000,
max_collections: 1000,
max_payload_size: 1_048_576, max_perfect_mode_vectors: 500_000,
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct QuantizationConfig {
pub mode: QuantizationType,
pub rerank_enabled: bool,
pub rerank_multiplier: usize,
pub auto_quantization: bool,
pub auto_quantization_threshold: usize,
}
impl Default for QuantizationConfig {
fn default() -> Self {
Self {
mode: QuantizationType::None,
rerank_enabled: true,
rerank_multiplier: 2,
auto_quantization: true,
auto_quantization_threshold: 10_000,
}
}
}
impl QuantizationConfig {
#[must_use]
pub const fn mode(&self) -> &QuantizationType {
&self.mode
}
#[must_use]
pub fn should_quantize(&self, vector_count: usize) -> bool {
self.auto_quantization && vector_count >= self.auto_quantization_threshold
}
}
impl<'de> Deserialize<'de> for QuantizationConfig {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
struct RawQuantizationConfig {
#[serde(default)]
mode: Option<QuantizationType>,
#[serde(default)]
default_type: Option<String>,
#[serde(default = "default_rerank_enabled")]
rerank_enabled: bool,
#[serde(default = "default_rerank_multiplier")]
rerank_multiplier: usize,
#[serde(default = "default_auto_quantization")]
auto_quantization: bool,
#[serde(default = "default_auto_quantization_threshold")]
auto_quantization_threshold: usize,
}
fn default_rerank_enabled() -> bool {
true
}
fn default_rerank_multiplier() -> usize {
2
}
fn default_auto_quantization() -> bool {
true
}
fn default_auto_quantization_threshold() -> usize {
10_000
}
let raw = RawQuantizationConfig::deserialize(deserializer)?;
let mode = if let Some(m) = raw.mode {
m
} else if let Some(ref s) = raw.default_type {
match s.as_str() {
"none" | "" => QuantizationType::None,
"sq8" => QuantizationType::SQ8,
"binary" => QuantizationType::Binary,
other => {
return Err(serde::de::Error::custom(format!(
"unknown quantization type: '{other}'"
)));
}
}
} else {
QuantizationType::None
};
Ok(Self {
mode,
rerank_enabled: raw.rerank_enabled,
rerank_multiplier: raw.rerank_multiplier,
auto_quantization: raw.auto_quantization,
auto_quantization_threshold: raw.auto_quantization_threshold,
})
}
}
const fn default_commit_delay_us() -> u64 {
100
}
const fn default_max_batch_size() -> usize {
128
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WalBatchConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_commit_delay_us")]
pub commit_delay_us: u64,
#[serde(default = "default_max_batch_size")]
pub max_batch_size: usize,
}
impl Default for WalBatchConfig {
fn default() -> Self {
Self {
enabled: false,
commit_delay_us: 100,
max_batch_size: 128,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(default)]
pub struct VelesConfig {
pub search: SearchConfig,
pub hnsw: HnswConfig,
pub storage: StorageConfig,
pub limits: LimitsConfig,
pub server: ServerConfig,
pub logging: LoggingConfig,
pub quantization: QuantizationConfig,
pub wal_batch: WalBatchConfig,
}
impl VelesConfig {
pub fn load() -> Result<Self, ConfigError> {
Self::load_from_path("velesdb.toml")
}
pub fn load_from_path<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
let figment = Figment::new()
.merge(Serialized::defaults(Self::default()))
.merge(Toml::file(path.as_ref()))
.merge(Env::prefixed("VELESDB_").split("_").lowercase(false));
figment
.extract()
.map_err(|e| ConfigError::ParseError(e.to_string()))
}
pub fn from_toml(toml_str: &str) -> Result<Self, ConfigError> {
let figment = Figment::new()
.merge(Serialized::defaults(Self::default()))
.merge(Toml::string(toml_str));
figment
.extract()
.map_err(|e| ConfigError::ParseError(e.to_string()))
}
pub fn validate(&self) -> Result<(), ConfigError> {
if let Some(ef) = self.search.ef_search {
if !(16..=4096).contains(&ef) {
return Err(ConfigError::InvalidValue {
key: "search.ef_search".to_string(),
message: format!("value {ef} is out of range [16, 4096]"),
});
}
}
if self.search.max_results == 0 || self.search.max_results > 10000 {
return Err(ConfigError::InvalidValue {
key: "search.max_results".to_string(),
message: format!(
"value {} is out of range [1, 10000]",
self.search.max_results
),
});
}
if let Some(m) = self.hnsw.m {
if !(4..=128).contains(&m) {
return Err(ConfigError::InvalidValue {
key: "hnsw.m".to_string(),
message: format!("value {m} is out of range [4, 128]"),
});
}
}
if let Some(ef) = self.hnsw.ef_construction {
if !(100..=2000).contains(&ef) {
return Err(ConfigError::InvalidValue {
key: "hnsw.ef_construction".to_string(),
message: format!("value {ef} is out of range [100, 2000]"),
});
}
}
if self.limits.max_dimensions == 0 || self.limits.max_dimensions > 65536 {
return Err(ConfigError::InvalidValue {
key: "limits.max_dimensions".to_string(),
message: format!(
"value {} is out of range [1, 65536]",
self.limits.max_dimensions
),
});
}
if self.server.port < 1024 {
return Err(ConfigError::InvalidValue {
key: "server.port".to_string(),
message: format!("value {} must be >= 1024", self.server.port),
});
}
let valid_modes = ["mmap", "memory"];
if !valid_modes.contains(&self.storage.storage_mode.as_str()) {
return Err(ConfigError::InvalidValue {
key: "storage.storage_mode".to_string(),
message: format!(
"value '{}' is invalid, expected one of: {:?}",
self.storage.storage_mode, valid_modes
),
});
}
let valid_levels = ["error", "warn", "info", "debug", "trace"];
if !valid_levels.contains(&self.logging.level.as_str()) {
return Err(ConfigError::InvalidValue {
key: "logging.level".to_string(),
message: format!(
"value '{}' is invalid, expected one of: {:?}",
self.logging.level, valid_levels
),
});
}
Ok(())
}
#[must_use]
pub fn effective_ef_search(&self) -> usize {
self.search
.ef_search
.unwrap_or_else(|| self.search.default_mode.ef_search())
}
pub fn to_toml(&self) -> Result<String, ConfigError> {
toml::to_string_pretty(self).map_err(|e| ConfigError::ParseError(e.to_string()))
}
}