use lancedb::Connection;
use std::collections::HashMap;
use std::error::Error;
use std::path::{Component, Path, PathBuf};
use std::sync::Arc;
use tokio::sync::Mutex;
pub type BoxError = Box<dyn Error + Send + Sync + 'static>;
#[derive(Debug, Clone)]
pub struct DatabaseRuntimeConfig {
pub default_db_path: String,
pub db_root: Option<String>,
pub read_consistency_interval_ms: Option<u64>,
}
impl DatabaseRuntimeConfig {
pub fn is_local_default_db_path(&self) -> bool {
!looks_like_uri(&self.default_db_path)
}
pub fn read_consistency_interval(&self) -> Option<std::time::Duration> {
self.read_consistency_interval_ms
.map(std::time::Duration::from_millis)
}
pub fn concurrent_write_warning(&self) -> Option<&'static str> {
if is_plain_s3_uri(&self.default_db_path) {
Some(
"plain s3:// storage is not safe for concurrent LanceDB writers; use s3+ddb:// for multi-writer deployments or ensure this service is the only writer for each table",
)
} else {
None
}
}
pub fn database_path_for_name(&self, database_name: Option<&str>) -> Result<String, BoxError> {
let normalized = database_name
.map(str::trim)
.filter(|value| !value.is_empty());
let Some(database_name) = normalized else {
return Ok(self.default_db_path.clone());
};
validate_database_name(database_name)?;
let Some(db_root) = self.db_root.as_ref() else {
return Err(invalid_input(format!(
"database `{database_name}` requires a database root to be configured"
)));
};
Ok(join_database_root(db_root, database_name))
}
}
#[derive(Clone)]
pub struct DatabaseManager {
config: Arc<DatabaseRuntimeConfig>,
connections: Arc<Mutex<HashMap<String, Arc<Connection>>>>,
}
impl DatabaseManager {
pub fn new(config: DatabaseRuntimeConfig) -> Self {
Self {
config: Arc::new(config),
connections: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn config(&self) -> &DatabaseRuntimeConfig {
self.config.as_ref()
}
pub async fn open_default(&self) -> Result<Arc<Connection>, BoxError> {
self.open_named(None).await
}
pub async fn open_named(
&self,
database_name: Option<&str>,
) -> Result<Arc<Connection>, BoxError> {
let database_key = normalize_database_key(database_name);
let target_path = self.config.database_path_for_name(database_name)?;
{
let guard = self.connections.lock().await;
if let Some(existing) = guard.get(&database_key) {
return Ok(Arc::clone(existing));
}
}
if !target_path.contains("://") {
tokio::fs::create_dir_all(&target_path).await?;
}
let mut builder = lancedb::connect(&target_path);
if let Some(interval) = self.config.read_consistency_interval() {
builder = builder.read_consistency_interval(interval);
}
let connection = Arc::new(builder.execute().await?);
let mut guard = self.connections.lock().await;
let entry = guard
.entry(database_key)
.or_insert_with(|| Arc::clone(&connection));
Ok(Arc::clone(entry))
}
pub fn database_path_for_name(&self, database_name: Option<&str>) -> Result<String, BoxError> {
self.config.database_path_for_name(database_name)
}
pub async fn cached_database_keys(&self) -> Vec<String> {
let guard = self.connections.lock().await;
let mut keys = guard.keys().cloned().collect::<Vec<_>>();
keys.sort();
keys
}
}
fn normalize_database_key(database_name: Option<&str>) -> String {
database_name
.map(str::trim)
.filter(|value| !value.is_empty())
.unwrap_or("default")
.to_string()
}
fn looks_like_uri(value: &str) -> bool {
value.contains("://")
}
fn is_plain_s3_uri(value: &str) -> bool {
value.to_ascii_lowercase().starts_with("s3://")
}
fn join_database_root(db_root: &str, database_name: &str) -> String {
if looks_like_uri(db_root) {
let trimmed = db_root.trim_end_matches('/');
return format!("{trimmed}/{database_name}");
}
PathBuf::from(db_root)
.join(database_name)
.to_string_lossy()
.to_string()
}
fn validate_database_name(database_name: &str) -> Result<(), BoxError> {
if database_name.contains('/') || database_name.contains('\\') {
return Err(invalid_input(format!(
"database `{database_name}` must not contain path separators"
)));
}
let mut components = Path::new(database_name).components();
match (components.next(), components.next()) {
(Some(Component::Normal(_)), None) => Ok(()),
_ => Err(invalid_input(format!(
"database `{database_name}` must be a single path segment"
))),
}
}
fn invalid_input(message: impl Into<String>) -> BoxError {
Box::new(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
message.into(),
))
}
#[cfg(test)]
mod tests {
use super::{DatabaseManager, DatabaseRuntimeConfig};
fn sample_config() -> DatabaseRuntimeConfig {
DatabaseRuntimeConfig {
default_db_path: "/srv/vldb/default".to_string(),
db_root: Some("/srv/vldb/databases".to_string()),
read_consistency_interval_ms: Some(0),
}
}
#[test]
fn default_database_path_uses_default_db_path() {
let manager = DatabaseManager::new(sample_config());
assert_eq!(
manager
.database_path_for_name(None)
.expect("default database path should resolve"),
"/srv/vldb/default"
);
}
#[test]
fn named_database_path_uses_database_root() {
let manager = DatabaseManager::new(sample_config());
assert_eq!(
manager
.database_path_for_name(Some("memory"))
.expect("named database path should resolve"),
std::path::PathBuf::from("/srv/vldb/databases")
.join("memory")
.to_string_lossy()
.to_string()
);
}
#[test]
fn named_database_path_rejects_traversal_segments() {
let manager = DatabaseManager::new(sample_config());
let error = manager
.database_path_for_name(Some("../escape"))
.expect_err("traversal name should be rejected");
assert!(error.to_string().contains("path separators"));
}
#[test]
fn named_database_path_rejects_path_separators() {
let manager = DatabaseManager::new(sample_config());
let error = manager
.database_path_for_name(Some("nested/name"))
.expect_err("nested path should be rejected");
assert!(error.to_string().contains("path separators"));
}
}