use super::*;
use futures::stream::StreamExt;
use sqlx::any::install_default_drivers;
use sqlx::Acquire;
use std::sync::{Arc, Mutex};
use tokio::runtime::Runtime;
use url::Url;
use uuid::Uuid;
type DatabaseURL = String;
type IndexName = String;
type IndexFile = PathBuf;
type Index = Arc<Mutex<Box<dyn VectorIndex>>>;
type IndicesPool = Mutex<HashMap<IndexName, Index>>;
pub struct Database {
root: PathBuf,
state: Mutex<DatabaseState>,
pool: IndicesPool,
}
impl Database {
pub fn open(
root: impl Into<PathBuf>,
source_url: Option<impl Into<DatabaseURL>>,
) -> Result<Database, Error> {
let root_dir: PathBuf = root.into();
let indices_dir = root_dir.join("indices");
if !indices_dir.try_exists()? {
fs::create_dir_all(&indices_dir)?;
}
let tmp_dir = root_dir.join("tmp");
if !tmp_dir.try_exists()? {
fs::create_dir_all(&tmp_dir)?;
}
let state_file = root_dir.join("odbstate");
let state = if state_file.try_exists()? {
let mut state = DatabaseState::restore(&state_file)?;
if let Some(source) = source_url {
state.with_source(source)?;
}
state
} else {
let source = source_url.ok_or_else(|| {
let code = ErrorCode::MissingSource;
let message = "Data source is required for a new database.";
Error::new(code, message)
})?;
let indices = HashMap::new();
let source = source.into();
DatabaseState::validate_source(&source)?;
let state = DatabaseState { source, indices };
file::write_binary_file(tmp_dir, state_file, &state)?;
state
};
install_default_drivers();
let state = Mutex::new(state);
let pool: IndicesPool = Mutex::new(HashMap::new());
Ok(Self { root: root_dir, state, pool })
}
pub async fn async_create_index(
&self,
name: impl Into<IndexName>,
algorithm: IndexAlgorithm,
config: SourceConfig,
) -> Result<(), Error> {
let index_name: IndexName = name.into();
if self.get_index_ref(&index_name).is_some() {
let code = ErrorCode::RequestError;
let message = format!("Index already exists: {index_name}.");
return Err(Error::new(code, message));
}
let query = config.to_query();
let mut conn = self.state()?.async_connect().await?;
let mut stream = sqlx::query(&query).fetch(conn.acquire().await?);
let mut records = HashMap::new();
while let Some(row) = stream.next().await {
let row = row?;
let (id, record) = config.to_record(&row)?;
records.insert(id, record);
}
let index_file = {
let uuid = Uuid::new_v4().to_string();
self.indices_dir().join(uuid)
};
let mut index = algorithm.initialize()?;
index.build(records)?;
let tmp_dir = self.tmp_dir();
algorithm.persist_index(tmp_dir, &index_file, index.as_ref())?;
{
let mut pool = self.pool.lock()?;
pool.insert(index_name.clone(), Arc::new(Mutex::new(index)));
}
{
let mut state = self.state.lock()?;
let index_ref = IndexRef { algorithm, config, file: index_file };
state.indices.insert(index_name, index_ref);
}
self.persist_state()?;
Ok(())
}
pub fn create_index(
&self,
name: impl Into<IndexName>,
algorithm: IndexAlgorithm,
config: SourceConfig,
) -> Result<(), Error> {
let rt = Runtime::new()?;
rt.block_on(self.async_create_index(name, algorithm, config))
}
pub fn get_index_ref(&self, name: impl AsRef<str>) -> Option<IndexRef> {
let state = self.state.lock().ok()?;
let index_ref = state.indices.get(name.as_ref())?;
Some(index_ref.to_owned())
}
pub fn get_index(&self, name: impl AsRef<str>) -> Option<Index> {
let name = name.as_ref();
let IndexRef { algorithm, file, .. } = self.get_index_ref(name)?;
let mut pool = self.pool.lock().ok()?;
if let Some(index) = pool.get(name).cloned() {
return Some(index);
}
let index = algorithm.load_index(file).ok()?;
let index: Index = Arc::new(Mutex::new(index));
pool.insert(name.into(), index.clone());
Some(index)
}
pub fn try_get_index(&self, name: impl AsRef<str>) -> Result<Index, Error> {
let name = name.as_ref();
self.get_index(name).ok_or_else(|| {
let code = ErrorCode::NotFound;
let message = format!("Index not found in database: {name}.");
Error::new(code, message)
})
}
pub async fn async_refresh_index(
&self,
name: impl AsRef<str>,
) -> Result<(), Error> {
let name = name.as_ref();
let index_ref = self.get_index_ref(name).ok_or_else(|| {
let code = ErrorCode::NotFound;
let message = format!("Index not found: {name}.");
Error::new(code, message)
})?;
let IndexRef { config, .. } = index_ref.to_owned();
let (query, config) = {
let index: Index = self.get_index(name).unwrap();
let index = index.lock()?;
let meta = index.metadata();
let checkpoint = meta.last_inserted.unwrap_or_default();
(config.to_query_after(&checkpoint), config)
};
let mut conn = self.state()?.async_connect().await?;
let mut stream = sqlx::query(&query).fetch(conn.acquire().await?);
let mut records = HashMap::new();
while let Some(row) = stream.next().await {
let row = row?;
let (id, record) = config.to_record(&row)?;
records.insert(id, record);
}
self.insert_into_index(name, records)
}
pub fn refresh_index(&self, name: impl AsRef<str>) -> Result<(), Error> {
let rt = Runtime::new()?;
rt.block_on(self.async_refresh_index(name))
}
pub fn search_index(
&self,
name: impl AsRef<str>,
query: impl Into<Vector>,
k: usize,
filters: impl Into<Filters>,
) -> Result<Vec<SearchResult>, Error> {
let index: Index = self.try_get_index(name)?;
let index = index.lock()?;
index.search(query.into(), k, filters.into())
}
pub fn insert_into_index(
&self,
name: impl AsRef<str>,
records: HashMap<RecordID, Record>,
) -> Result<(), Error> {
let index: Index = self.try_get_index(name.as_ref())?;
let mut index = index.lock()?;
index.insert(records)?;
self.persist_existing_index(name, index.as_ref())
}
pub fn update_index(
&self,
name: impl AsRef<str>,
records: HashMap<RecordID, Record>,
) -> Result<(), Error> {
let index: Index = self.try_get_index(name.as_ref())?;
let mut index = index.lock()?;
index.update(records)?;
self.persist_existing_index(name, index.as_ref())
}
pub fn delete_from_index(
&self,
name: impl AsRef<str>,
ids: Vec<RecordID>,
) -> Result<(), Error> {
let index: Index = self.try_get_index(name.as_ref())?;
let mut index = index.lock()?;
index.delete(ids)?;
self.persist_existing_index(name, index.as_ref())
}
pub fn delete_index(&self, name: impl AsRef<str>) -> Result<(), Error> {
let name = name.as_ref();
let index_ref = {
let mut state = self.state.lock()?;
state.indices.remove(name).ok_or_else(|| {
let code = ErrorCode::NotFound;
let message = format!("Index doesn't exist: {name}.");
Error::new(code, message)
})?
};
self.release_indices(vec![name])?;
fs::remove_file(index_ref.file())?;
self.persist_state()
}
pub fn load_indices(
&self,
names: Vec<impl AsRef<str>>,
) -> Result<(), Error> {
let state = self.state()?;
if names.iter().any(|name| !state.indices.contains_key(name.as_ref())) {
let code = ErrorCode::NotFound;
let message = "Some indices are not found in the database.";
return Err(Error::new(code, message));
}
for name in names {
self.get_index(name);
}
Ok(())
}
pub fn release_indices(
&self,
names: Vec<impl AsRef<str>>,
) -> Result<(), Error> {
let mut pool = self.pool.lock()?;
for name in names {
let name = name.as_ref();
pool.remove(name);
}
Ok(())
}
pub fn state(&self) -> Result<DatabaseState, Error> {
let state = self.state.lock()?;
Ok(state.to_owned())
}
pub fn persist_state(&self) -> Result<(), Error> {
file::write_binary_file(
self.tmp_dir(),
self.state_file(),
&self.state()?,
)
}
}
impl Database {
fn state_file(&self) -> PathBuf {
self.root.join("odbstate")
}
fn indices_dir(&self) -> PathBuf {
self.root.join("indices")
}
fn tmp_dir(&self) -> PathBuf {
self.root.join("tmp")
}
fn persist_existing_index(
&self,
name: impl AsRef<str>,
index: &dyn VectorIndex,
) -> Result<(), Error> {
let name = name.as_ref();
let IndexRef { algorithm, file, .. } =
self.get_index_ref(name).ok_or_else(|| {
let code = ErrorCode::NotFound;
let message = format!("Index might not exists: {name}.");
Error::new(code, message)
})?;
let tmp_dir = self.tmp_dir();
algorithm.persist_index(tmp_dir, file, index)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseState {
source: DatabaseURL,
indices: HashMap<IndexName, IndexRef>,
}
impl DatabaseState {
pub fn restore(path: impl AsRef<Path>) -> Result<DatabaseState, Error> {
file::read_binary_file(path)
}
pub fn with_source(
&mut self,
source: impl Into<DatabaseURL>,
) -> Result<(), Error> {
let source = source.into();
Self::validate_source(&source)?;
self.source = source;
Ok(())
}
pub async fn async_connect(&self) -> Result<SourceConnection, Error> {
Ok(SourceConnection::connect(&self.source).await?)
}
pub fn connect(&self) -> Result<SourceConnection, Error> {
let rt = Runtime::new()?;
rt.block_on(self.async_connect())
}
pub async fn async_disconnect(conn: SourceConnection) -> Result<(), Error> {
Ok(conn.close().await?)
}
pub fn disconnect(conn: SourceConnection) -> Result<(), Error> {
let rt = Runtime::new()?;
rt.block_on(Self::async_disconnect(conn))
}
pub fn source_type(&self) -> SourceType {
let url = self.source.parse::<Url>().unwrap();
url.scheme().into()
}
pub fn validate_source(url: impl Into<DatabaseURL>) -> Result<(), Error> {
let url = url.into();
let url = url.parse::<Url>().map_err(|_| {
let code = ErrorCode::InvalidSource;
let message = "Invalid database source URL.";
Error::new(code, message)
})?;
let valid_schemes = ["sqlite", "mysql", "postgresql"];
if !valid_schemes.contains(&url.scheme()) {
let code = ErrorCode::InvalidSource;
let message = format!(
"Unsupported database scheme. Choose between: {}.",
valid_schemes.join(", ")
);
return Err(Error::new(code, message));
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexRef {
config: SourceConfig,
algorithm: IndexAlgorithm,
file: IndexFile,
}
impl IndexRef {
pub fn config(&self) -> &SourceConfig {
&self.config
}
pub fn algorithm(&self) -> &IndexAlgorithm {
&self.algorithm
}
pub fn file(&self) -> &IndexFile {
&self.file
}
}
#[cfg(test)]
mod tests {
use super::*;
use sqlx::{Executor, Row};
use std::env;
use std::sync::MutexGuard;
const TABLE: &str = "embeddings";
const TEST_INDEX: &str = "test_index";
#[test]
fn test_database_open() {
assert!(create_test_database().is_ok());
}
#[test]
fn test_database_create_index() -> Result<(), Error> {
let db = create_test_database()?;
let index: Index = db.try_get_index(TEST_INDEX)?;
let index = index.lock()?;
let metadata = index.metadata();
assert_eq!(index.len(), 100);
assert_eq!(metadata.last_inserted, Some(RecordID(100)));
Ok(())
}
#[test]
fn test_database_refresh_index() -> Result<(), Error> {
let db = create_test_database()?;
let query = generate_insert_query(100, 10);
let rt = Runtime::new()?;
rt.block_on(db.async_execute_sql(query))?;
db.refresh_index(TEST_INDEX).unwrap();
let index: Index = db.try_get_index(TEST_INDEX)?;
let index = index.lock()?;
let metadata = index.metadata();
assert_eq!(index.len(), 110);
assert_eq!(metadata.last_inserted, Some(RecordID(110)));
Ok(())
}
#[test]
fn test_database_search_index_basic() {
let db = create_test_database().unwrap();
let results = db
.search_index(TEST_INDEX, vec![0.0; 128], 5, Filters::NONE)
.unwrap();
assert_eq!(results.len(), 5);
assert_eq!(results[0].id, RecordID(1));
assert_eq!(results[0].distance, 0.0);
}
#[test]
fn test_database_search_index_advanced() {
let db = create_test_database().unwrap();
let results = db
.search_index(TEST_INDEX, vec![0.0; 128], 5, "data >= 1050")
.unwrap();
assert_eq!(results.len(), 5);
assert_eq!(results[0].id, RecordID(51));
}
#[test]
fn test_database_insert_into_index() -> Result<(), Error> {
let id = RecordID(101);
let vector = Vector::from(vec![100.0; 128]);
let data = HashMap::from([(
"number".to_string(),
Some(DataValue::Integer(1100)),
)]);
let record = Record { vector, data };
let records = HashMap::from([(id, record)]);
let db = create_test_database()?;
db.insert_into_index(TEST_INDEX, records)?;
let index: Index = db.try_get_index(TEST_INDEX)?;
let index = index.lock()?;
assert_eq!(index.len(), 101);
Ok(())
}
#[test]
fn test_database_delete_from_index() -> Result<(), Error> {
let db = create_test_database()?;
let ids = vec![RecordID(1), RecordID(2)];
db.delete_from_index(TEST_INDEX, ids)?;
let index: Index = db.try_get_index(TEST_INDEX)?;
let index = index.lock()?;
assert_eq!(index.len(), 98);
Ok(())
}
#[test]
fn test_database_delete_index() {
let db = create_test_database().unwrap();
db.delete_index(TEST_INDEX).unwrap();
let state = db.state().unwrap();
assert!(!state.indices.contains_key(TEST_INDEX));
}
#[test]
fn test_database_indices_pool() -> Result<(), Error> {
let db = create_test_database()?;
{
db.release_indices(vec![TEST_INDEX])?;
let pool = db.pool()?;
assert!(!pool.contains_key(TEST_INDEX));
}
{
db.load_indices(vec![TEST_INDEX])?;
let pool = db.pool()?;
assert!(pool.contains_key(TEST_INDEX));
}
Ok(())
}
fn create_test_database() -> Result<Database, Error> {
let path = PathBuf::from("odb_test");
if path.try_exists()? {
fs::remove_dir_all(&path)?;
}
let db_path = get_tmp_dir()?.join("sqlite.db");
let db_url = format!("sqlite://{}?mode=rwc", db_path.display());
let mut db = Database::open(path, Some(db_url.to_owned()))?;
let state = db.state()?;
assert_eq!(state.source_type(), SourceType::SQLITE);
let rt = Runtime::new()?;
rt.block_on(setup_test_source(&db_url))?;
create_test_index(&mut db)?;
Ok(db)
}
fn create_test_index(db: &mut Database) -> Result<(), Error> {
let algorithm = IndexAlgorithm::Flat(ParamsFlat::default());
let config = SourceConfig::new(TABLE, "id", "vector")
.with_metadata(vec!["data"]);
db.create_index(TEST_INDEX, algorithm, config)?;
let index_ref = db.get_index_ref(TEST_INDEX).unwrap();
assert_eq!(index_ref.algorithm().name(), "FLAT");
Ok(())
}
fn generate_insert_query(start: u8, count: u8) -> String {
let start = start as u16;
let end = start + count as u16;
let mut values = vec![];
for i in start..end {
let vector = vec![i as f32; 128];
let vector = serde_json::to_string(&vector).unwrap();
let data = 1000 + i;
values.push(format!("({vector:?}, {data})"));
}
let values = values.join(",\n");
format!(
"INSERT INTO {TABLE} (vector, data)
VALUES {values}"
)
}
pub fn get_tmp_dir() -> Result<PathBuf, Error> {
let tmp_dir = env::temp_dir().join("oasysdb");
if !tmp_dir.try_exists()? {
fs::create_dir_all(&tmp_dir)?;
}
Ok(tmp_dir)
}
async fn setup_test_source(
url: impl Into<DatabaseURL>,
) -> Result<(), Error> {
let url = url.into();
let mut conn = SourceConnection::connect(&url).await?;
let create_table = format!(
"CREATE TABLE IF NOT EXISTS {TABLE} (
id INTEGER PRIMARY KEY,
vector JSON NOT NULL,
data INTEGER NOT NULL
)"
);
let insert_records = generate_insert_query(0, 100);
let drop_table = format!("DROP TABLE IF EXISTS {TABLE}");
conn.execute(drop_table.as_str()).await?;
conn.execute(create_table.as_str()).await?;
conn.execute(insert_records.as_str()).await?;
let count = {
let query = format!("SELECT COUNT(*) FROM {TABLE}");
conn.fetch_one(query.as_str()).await?.get::<i64, usize>(0)
};
assert_eq!(count, 100);
Ok(())
}
impl Database {
fn pool(&self) -> Result<MutexGuard<HashMap<IndexName, Index>>, Error> {
Ok(self.pool.lock()?)
}
async fn async_execute_sql(
&self,
query: impl AsRef<str>,
) -> Result<(), Error> {
let mut conn = self.state()?.async_connect().await?;
conn.execute(query.as_ref()).await?;
Ok(())
}
}
}