use serde::{Deserialize, Serialize};
use crate::sql::DatabaseType;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Extension {
pub name: String,
pub schema: Option<String>,
pub version: Option<String>,
pub cascade: bool,
}
impl Extension {
pub fn new(name: impl Into<String>) -> ExtensionBuilder {
ExtensionBuilder::new(name)
}
pub fn postgis() -> Self {
Self::new("postgis").build()
}
pub fn pgvector() -> Self {
Self::new("vector").build()
}
pub fn uuid_ossp() -> Self {
Self::new("uuid-ossp").build()
}
pub fn pgcrypto() -> Self {
Self::new("pgcrypto").build()
}
pub fn pg_trgm() -> Self {
Self::new("pg_trgm").build()
}
pub fn hstore() -> Self {
Self::new("hstore").build()
}
pub fn ltree() -> Self {
Self::new("ltree").build()
}
pub fn to_postgres_create(&self) -> String {
let mut sql = format!("CREATE EXTENSION IF NOT EXISTS \"{}\"", self.name);
if let Some(ref schema) = self.schema {
sql.push_str(&format!(" SCHEMA {}", schema));
}
if let Some(ref version) = self.version {
sql.push_str(&format!(" VERSION '{}'", version));
}
if self.cascade {
sql.push_str(" CASCADE");
}
sql
}
pub fn to_postgres_drop(&self) -> String {
let mut sql = format!("DROP EXTENSION IF EXISTS \"{}\"", self.name);
if self.cascade {
sql.push_str(" CASCADE");
}
sql
}
pub fn to_sqlite_load(&self) -> String {
format!("SELECT load_extension('{}')", self.name)
}
}
#[derive(Debug, Clone)]
pub struct ExtensionBuilder {
name: String,
schema: Option<String>,
version: Option<String>,
cascade: bool,
}
impl ExtensionBuilder {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
schema: None,
version: None,
cascade: false,
}
}
pub fn schema(mut self, schema: impl Into<String>) -> Self {
self.schema = Some(schema.into());
self
}
pub fn version(mut self, version: impl Into<String>) -> Self {
self.version = Some(version.into());
self
}
pub fn cascade(mut self) -> Self {
self.cascade = true;
self
}
pub fn build(self) -> Extension {
Extension {
name: self.name,
schema: self.schema,
version: self.version,
cascade: self.cascade,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct Point {
pub longitude: f64,
pub latitude: f64,
pub srid: Option<i32>,
}
impl Point {
pub fn new(longitude: f64, latitude: f64) -> Self {
Self {
longitude,
latitude,
srid: None,
}
}
pub fn with_srid(longitude: f64, latitude: f64, srid: i32) -> Self {
Self {
longitude,
latitude,
srid: Some(srid),
}
}
pub fn wgs84(longitude: f64, latitude: f64) -> Self {
Self::with_srid(longitude, latitude, 4326)
}
pub fn to_postgis(&self) -> String {
if let Some(srid) = self.srid {
format!(
"ST_SetSRID(ST_MakePoint({}, {}), {})",
self.longitude, self.latitude, srid
)
} else {
format!("ST_MakePoint({}, {})", self.longitude, self.latitude)
}
}
pub fn to_mysql(&self) -> String {
if let Some(srid) = self.srid {
format!(
"ST_GeomFromText('POINT({} {})', {})",
self.longitude, self.latitude, srid
)
} else {
format!(
"ST_GeomFromText('POINT({} {})')",
self.longitude, self.latitude
)
}
}
pub fn to_mssql(&self) -> String {
format!(
"geography::Point({}, {}, {})",
self.latitude,
self.longitude,
self.srid.unwrap_or(4326)
)
}
pub fn to_geojson(&self) -> serde_json::Value {
serde_json::json!({
"type": "Point",
"coordinates": [self.longitude, self.latitude]
})
}
pub fn to_sql(&self, db_type: DatabaseType) -> String {
match db_type {
DatabaseType::PostgreSQL => self.to_postgis(),
DatabaseType::MySQL => self.to_mysql(),
DatabaseType::MSSQL => self.to_mssql(),
DatabaseType::SQLite => format!("MakePoint({}, {})", self.longitude, self.latitude),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Polygon {
pub exterior: Vec<(f64, f64)>,
pub interiors: Vec<Vec<(f64, f64)>>,
pub srid: Option<i32>,
}
impl Polygon {
pub fn new(exterior: Vec<(f64, f64)>) -> Self {
Self {
exterior,
interiors: Vec::new(),
srid: None,
}
}
pub fn with_hole(mut self, hole: Vec<(f64, f64)>) -> Self {
self.interiors.push(hole);
self
}
pub fn with_srid(mut self, srid: i32) -> Self {
self.srid = Some(srid);
self
}
pub fn to_wkt(&self) -> String {
let ext_coords: Vec<String> = self
.exterior
.iter()
.map(|(x, y)| format!("{} {}", x, y))
.collect();
let mut wkt = format!("POLYGON(({})", ext_coords.join(", "));
for interior in &self.interiors {
let int_coords: Vec<String> = interior
.iter()
.map(|(x, y)| format!("{} {}", x, y))
.collect();
wkt.push_str(&format!(", ({})", int_coords.join(", ")));
}
wkt.push(')');
wkt
}
pub fn to_postgis(&self) -> String {
if let Some(srid) = self.srid {
format!("ST_GeomFromText('{}', {})", self.to_wkt(), srid)
} else {
format!("ST_GeomFromText('{}')", self.to_wkt())
}
}
pub fn to_geojson(&self) -> serde_json::Value {
let mut coordinates = vec![
self.exterior
.iter()
.map(|(x, y)| vec![*x, *y])
.collect::<Vec<_>>(),
];
for interior in &self.interiors {
coordinates.push(interior.iter().map(|(x, y)| vec![*x, *y]).collect());
}
serde_json::json!({
"type": "Polygon",
"coordinates": coordinates
})
}
}
pub mod geo {
use super::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DistanceUnit {
Meters,
Kilometers,
Miles,
Feet,
}
impl DistanceUnit {
pub fn from_meters(&self) -> f64 {
match self {
Self::Meters => 1.0,
Self::Kilometers => 0.001,
Self::Miles => 0.000621371,
Self::Feet => 3.28084,
}
}
}
pub fn distance_sql(col1: &str, col2: &str, db_type: DatabaseType) -> String {
match db_type {
DatabaseType::PostgreSQL => {
format!("ST_Distance({}::geography, {}::geography)", col1, col2)
}
DatabaseType::MySQL => format!("ST_Distance_Sphere({}, {})", col1, col2),
DatabaseType::MSSQL => format!("{}.STDistance({})", col1, col2),
DatabaseType::SQLite => format!("Distance({}, {})", col1, col2),
}
}
pub fn distance_from_point_sql(col: &str, point: &Point, db_type: DatabaseType) -> String {
let point_sql = point.to_sql(db_type);
match db_type {
DatabaseType::PostgreSQL => {
format!("ST_Distance({}::geography, {}::geography)", col, point_sql)
}
DatabaseType::MySQL => format!("ST_Distance_Sphere({}, {})", col, point_sql),
DatabaseType::MSSQL => format!("{}.STDistance({})", col, point_sql),
DatabaseType::SQLite => format!("Distance({}, {})", col, point_sql),
}
}
pub fn within_distance_sql(
col: &str,
point: &Point,
distance_meters: f64,
db_type: DatabaseType,
) -> String {
let point_sql = point.to_sql(db_type);
match db_type {
DatabaseType::PostgreSQL => {
format!(
"ST_DWithin({}::geography, {}::geography, {})",
col, point_sql, distance_meters
)
}
DatabaseType::MySQL => {
format!(
"ST_Distance_Sphere({}, {}) <= {}",
col, point_sql, distance_meters
)
}
DatabaseType::MSSQL => {
format!("{}.STDistance({}) <= {}", col, point_sql, distance_meters)
}
DatabaseType::SQLite => {
format!("Distance({}, {}) <= {}", col, point_sql, distance_meters)
}
}
}
pub fn contains_sql(geom_col: &str, point: &Point, db_type: DatabaseType) -> String {
let point_sql = point.to_sql(db_type);
match db_type {
DatabaseType::PostgreSQL => format!("ST_Contains({}, {})", geom_col, point_sql),
DatabaseType::MySQL => format!("ST_Contains({}, {})", geom_col, point_sql),
DatabaseType::MSSQL => format!("{}.STContains({})", geom_col, point_sql),
DatabaseType::SQLite => format!("Contains({}, {})", geom_col, point_sql),
}
}
pub fn bbox_sql(
col: &str,
min_lon: f64,
min_lat: f64,
max_lon: f64,
max_lat: f64,
db_type: DatabaseType,
) -> String {
match db_type {
DatabaseType::PostgreSQL => {
format!(
"{} && ST_MakeEnvelope({}, {}, {}, {}, 4326)",
col, min_lon, min_lat, max_lon, max_lat
)
}
DatabaseType::MySQL => {
format!(
"MBRContains(ST_GeomFromText('POLYGON(({} {}, {} {}, {} {}, {} {}, {} {}))'), {})",
min_lon,
min_lat,
max_lon,
min_lat,
max_lon,
max_lat,
min_lon,
max_lat,
min_lon,
min_lat,
col
)
}
_ => "1=1".to_string(),
}
}
}
pub mod uuid {
use super::*;
pub fn generate_v4(db_type: DatabaseType) -> String {
match db_type {
DatabaseType::PostgreSQL => "gen_random_uuid()".to_string(),
DatabaseType::MySQL => "UUID()".to_string(),
DatabaseType::MSSQL => "NEWID()".to_string(),
DatabaseType::SQLite => {
"lower(hex(randomblob(4))) || '-' || lower(hex(randomblob(2))) || '-4' || \
substr(lower(hex(randomblob(2))), 2) || '-' || \
substr('89ab', abs(random()) % 4 + 1, 1) || \
substr(lower(hex(randomblob(2))), 2) || '-' || lower(hex(randomblob(6)))"
.to_string()
}
}
}
pub fn generate_v7_postgres() -> String {
"uuid_generate_v7()".to_string()
}
pub fn from_string(value: &str, db_type: DatabaseType) -> String {
match db_type {
DatabaseType::PostgreSQL => format!("'{}'::uuid", value),
DatabaseType::MySQL => format!("UUID_TO_BIN('{}')", value),
DatabaseType::MSSQL => format!("CONVERT(UNIQUEIDENTIFIER, '{}')", value),
DatabaseType::SQLite => format!("'{}'", value),
}
}
pub fn is_valid_sql(col: &str, db_type: DatabaseType) -> String {
match db_type {
DatabaseType::PostgreSQL => format!(
"{} ~ '^[0-9a-f]{{8}}-[0-9a-f]{{4}}-[0-9a-f]{{4}}-[0-9a-f]{{4}}-[0-9a-f]{{12}}$'",
col
),
DatabaseType::MySQL => format!(
"{} REGEXP '^[0-9a-f]{{8}}-[0-9a-f]{{4}}-[0-9a-f]{{4}}-[0-9a-f]{{4}}-[0-9a-f]{{12}}$'",
col
),
_ => format!("LEN({}) = 36", col),
}
}
}
pub mod crypto {
use super::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum HashAlgorithm {
Md5,
Sha1,
Sha256,
Sha384,
Sha512,
}
impl HashAlgorithm {
pub fn postgres_name(&self) -> &'static str {
match self {
Self::Md5 => "md5",
Self::Sha1 => "sha1",
Self::Sha256 => "sha256",
Self::Sha384 => "sha384",
Self::Sha512 => "sha512",
}
}
}
pub fn hash_sql(expr: &str, algorithm: HashAlgorithm, db_type: DatabaseType) -> String {
match db_type {
DatabaseType::PostgreSQL => {
if algorithm == HashAlgorithm::Md5 {
format!("md5({})", expr)
} else {
format!(
"encode(digest({}, '{}'), 'hex')",
expr,
algorithm.postgres_name()
)
}
}
DatabaseType::MySQL => match algorithm {
HashAlgorithm::Md5 => format!("MD5({})", expr),
HashAlgorithm::Sha1 => format!("SHA1({})", expr),
HashAlgorithm::Sha256 => format!("SHA2({}, 256)", expr),
HashAlgorithm::Sha384 => format!("SHA2({}, 384)", expr),
HashAlgorithm::Sha512 => format!("SHA2({}, 512)", expr),
},
DatabaseType::MSSQL => {
let algo = match algorithm {
HashAlgorithm::Md5 => "MD5",
HashAlgorithm::Sha1 => "SHA1",
HashAlgorithm::Sha256 => "SHA2_256",
HashAlgorithm::Sha384 => "SHA2_384",
HashAlgorithm::Sha512 => "SHA2_512",
};
format!("CONVERT(VARCHAR(MAX), HASHBYTES('{}', {}), 2)", algo, expr)
}
DatabaseType::SQLite => {
format!("-- SQLite requires extension for hashing: {}", expr)
}
}
}
pub fn bcrypt_hash_postgres(password: &str) -> String {
format!("crypt('{}', gen_salt('bf'))", password)
}
pub fn bcrypt_verify_postgres(password: &str, hash_col: &str) -> String {
format!("{} = crypt('{}', {})", hash_col, password, hash_col)
}
pub fn random_bytes_sql(length: usize, db_type: DatabaseType) -> String {
match db_type {
DatabaseType::PostgreSQL => format!("gen_random_bytes({})", length),
DatabaseType::MySQL => format!("RANDOM_BYTES({})", length),
DatabaseType::MSSQL => format!("CRYPT_GEN_RANDOM({})", length),
DatabaseType::SQLite => format!("randomblob({})", length),
}
}
pub fn aes_encrypt_postgres(data: &str, key: &str) -> String {
format!("pgp_sym_encrypt('{}', '{}')", data, key)
}
pub fn aes_decrypt_postgres(encrypted_col: &str, key: &str) -> String {
format!("pgp_sym_decrypt({}, '{}')", encrypted_col, key)
}
}
pub mod vector {
use super::*;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Vector {
pub dimensions: Vec<f32>,
}
impl Vector {
pub fn new(dimensions: Vec<f32>) -> Self {
Self { dimensions }
}
pub fn from_slice(slice: &[f32]) -> Self {
Self {
dimensions: slice.to_vec(),
}
}
pub fn len(&self) -> usize {
self.dimensions.len()
}
pub fn is_empty(&self) -> bool {
self.dimensions.is_empty()
}
pub fn to_pgvector(&self) -> String {
let nums: Vec<String> = self.dimensions.iter().map(|f| f.to_string()).collect();
format!("'[{}]'::vector", nums.join(","))
}
pub fn to_mongodb(&self) -> serde_json::Value {
serde_json::json!(self.dimensions)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SimilarityMetric {
L2,
InnerProduct,
Cosine,
}
impl SimilarityMetric {
pub fn postgres_operator(&self) -> &'static str {
match self {
Self::L2 => "<->",
Self::InnerProduct => "<#>",
Self::Cosine => "<=>",
}
}
pub fn mongodb_name(&self) -> &'static str {
match self {
Self::L2 => "euclidean",
Self::InnerProduct => "dotProduct",
Self::Cosine => "cosine",
}
}
}
pub fn similarity_search_postgres(
col: &str,
query_vector: &Vector,
metric: SimilarityMetric,
limit: usize,
) -> String {
format!(
"SELECT *, {} {} {} AS distance FROM {{table}} ORDER BY distance LIMIT {}",
col,
metric.postgres_operator(),
query_vector.to_pgvector(),
limit
)
}
pub fn distance_sql(col: &str, query_vector: &Vector, metric: SimilarityMetric) -> String {
format!(
"{} {} {}",
col,
metric.postgres_operator(),
query_vector.to_pgvector()
)
}
pub fn create_index_postgres(
index_name: &str,
table: &str,
column: &str,
metric: SimilarityMetric,
lists: Option<usize>,
) -> String {
let ops = match metric {
SimilarityMetric::L2 => "vector_l2_ops",
SimilarityMetric::InnerProduct => "vector_ip_ops",
SimilarityMetric::Cosine => "vector_cosine_ops",
};
let lists_clause = lists
.map(|l| format!(" WITH (lists = {})", l))
.unwrap_or_default();
format!(
"CREATE INDEX {} ON {} USING ivfflat ({} {}){}",
index_name, table, column, ops, lists_clause
)
}
pub fn create_hnsw_index_postgres(
index_name: &str,
table: &str,
column: &str,
metric: SimilarityMetric,
m: Option<usize>,
ef_construction: Option<usize>,
) -> String {
let ops = match metric {
SimilarityMetric::L2 => "vector_l2_ops",
SimilarityMetric::InnerProduct => "vector_ip_ops",
SimilarityMetric::Cosine => "vector_cosine_ops",
};
let mut with_clauses = Vec::new();
if let Some(m_val) = m {
with_clauses.push(format!("m = {}", m_val));
}
if let Some(ef) = ef_construction {
with_clauses.push(format!("ef_construction = {}", ef));
}
let with_clause = if with_clauses.is_empty() {
String::new()
} else {
format!(" WITH ({})", with_clauses.join(", "))
};
format!(
"CREATE INDEX {} ON {} USING hnsw ({} {}){}",
index_name, table, column, ops, with_clause
)
}
}
pub mod mongodb {
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use super::vector::SimilarityMetric;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct VectorSearch {
pub index: String,
pub path: String,
pub query_vector: Vec<f32>,
pub num_candidates: usize,
pub limit: usize,
pub filter: Option<JsonValue>,
}
impl VectorSearch {
pub fn new(
index: impl Into<String>,
path: impl Into<String>,
query: Vec<f32>,
) -> VectorSearchBuilder {
VectorSearchBuilder::new(index, path, query)
}
pub fn to_stage(&self) -> JsonValue {
let mut search = serde_json::json!({
"index": self.index,
"path": self.path,
"queryVector": self.query_vector,
"numCandidates": self.num_candidates,
"limit": self.limit
});
if let Some(ref filter) = self.filter {
search["filter"] = filter.clone();
}
serde_json::json!({ "$vectorSearch": search })
}
}
#[derive(Debug, Clone)]
pub struct VectorSearchBuilder {
index: String,
path: String,
query_vector: Vec<f32>,
num_candidates: usize,
limit: usize,
filter: Option<JsonValue>,
}
impl VectorSearchBuilder {
pub fn new(index: impl Into<String>, path: impl Into<String>, query: Vec<f32>) -> Self {
Self {
index: index.into(),
path: path.into(),
query_vector: query,
num_candidates: 100,
limit: 10,
filter: None,
}
}
pub fn num_candidates(mut self, n: usize) -> Self {
self.num_candidates = n;
self
}
pub fn limit(mut self, n: usize) -> Self {
self.limit = n;
self
}
pub fn filter(mut self, filter: JsonValue) -> Self {
self.filter = Some(filter);
self
}
pub fn build(self) -> VectorSearch {
VectorSearch {
index: self.index,
path: self.path,
query_vector: self.query_vector,
num_candidates: self.num_candidates,
limit: self.limit,
filter: self.filter,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct VectorIndex {
pub name: String,
pub collection: String,
pub fields: Vec<VectorField>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct VectorField {
pub path: String,
pub dimensions: usize,
pub similarity: String,
}
impl VectorIndex {
pub fn new(name: impl Into<String>, collection: impl Into<String>) -> VectorIndexBuilder {
VectorIndexBuilder::new(name, collection)
}
pub fn to_definition(&self) -> JsonValue {
let fields: Vec<JsonValue> = self
.fields
.iter()
.map(|f| {
serde_json::json!({
"type": "vector",
"path": f.path,
"numDimensions": f.dimensions,
"similarity": f.similarity
})
})
.collect();
serde_json::json!({
"name": self.name,
"type": "vectorSearch",
"fields": fields
})
}
}
#[derive(Debug, Clone)]
pub struct VectorIndexBuilder {
name: String,
collection: String,
fields: Vec<VectorField>,
}
impl VectorIndexBuilder {
pub fn new(name: impl Into<String>, collection: impl Into<String>) -> Self {
Self {
name: name.into(),
collection: collection.into(),
fields: Vec::new(),
}
}
pub fn field(
mut self,
path: impl Into<String>,
dimensions: usize,
similarity: SimilarityMetric,
) -> Self {
self.fields.push(VectorField {
path: path.into(),
dimensions,
similarity: similarity.mongodb_name().to_string(),
});
self
}
pub fn build(self) -> VectorIndex {
VectorIndex {
name: self.name,
collection: self.collection,
fields: self.fields,
}
}
}
pub fn vector_search(index: &str, path: &str, query: Vec<f32>) -> VectorSearchBuilder {
VectorSearch::new(index, path, query)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extension_postgres() {
let ext = Extension::new("postgis").schema("public").cascade().build();
let sql = ext.to_postgres_create();
assert!(sql.contains("CREATE EXTENSION IF NOT EXISTS \"postgis\""));
assert!(sql.contains("SCHEMA public"));
assert!(sql.contains("CASCADE"));
}
#[test]
fn test_extension_drop() {
let ext = Extension::postgis();
let sql = ext.to_postgres_drop();
assert!(sql.contains("DROP EXTENSION IF EXISTS \"postgis\""));
}
#[test]
fn test_point_postgis() {
let point = Point::wgs84(-122.4194, 37.7749);
let sql = point.to_postgis();
assert!(sql.contains("ST_SetSRID"));
assert!(sql.contains("-122.4194"));
assert!(sql.contains("37.7749"));
assert!(sql.contains("4326"));
}
#[test]
fn test_point_geojson() {
let point = Point::new(-122.4194, 37.7749);
let geojson = point.to_geojson();
assert_eq!(geojson["type"], "Point");
assert_eq!(geojson["coordinates"][0], -122.4194);
}
#[test]
fn test_polygon_wkt() {
let polygon = Polygon::new(vec![
(0.0, 0.0),
(10.0, 0.0),
(10.0, 10.0),
(0.0, 10.0),
(0.0, 0.0),
]);
let wkt = polygon.to_wkt();
assert!(wkt.starts_with("POLYGON(("));
}
#[test]
fn test_distance_sql() {
let sql = geo::distance_sql("location", "target", DatabaseType::PostgreSQL);
assert!(sql.contains("ST_Distance"));
}
#[test]
fn test_within_distance() {
let point = Point::wgs84(-122.4194, 37.7749);
let sql = geo::within_distance_sql("location", &point, 1000.0, DatabaseType::PostgreSQL);
assert!(sql.contains("ST_DWithin"));
assert!(sql.contains("1000"));
}
#[test]
fn test_uuid_generation() {
let pg = uuid::generate_v4(DatabaseType::PostgreSQL);
assert_eq!(pg, "gen_random_uuid()");
let mysql = uuid::generate_v4(DatabaseType::MySQL);
assert_eq!(mysql, "UUID()");
let mssql = uuid::generate_v4(DatabaseType::MSSQL);
assert_eq!(mssql, "NEWID()");
}
#[test]
fn test_hash_sql() {
let pg = crypto::hash_sql(
"password",
crypto::HashAlgorithm::Sha256,
DatabaseType::PostgreSQL,
);
assert!(pg.contains("digest"));
assert!(pg.contains("sha256"));
let mysql = crypto::hash_sql(
"password",
crypto::HashAlgorithm::Sha256,
DatabaseType::MySQL,
);
assert!(mysql.contains("SHA2"));
assert!(mysql.contains("256"));
}
#[test]
fn test_vector_pgvector() {
let vec = vector::Vector::new(vec![0.1, 0.2, 0.3, 0.4]);
let sql = vec.to_pgvector();
assert!(sql.contains("'[0.1,0.2,0.3,0.4]'::vector"));
}
#[test]
fn test_vector_index() {
let sql = vector::create_index_postgres(
"embeddings_idx",
"documents",
"embedding",
vector::SimilarityMetric::Cosine,
Some(100),
);
assert!(sql.contains("CREATE INDEX embeddings_idx"));
assert!(sql.contains("USING ivfflat"));
assert!(sql.contains("vector_cosine_ops"));
assert!(sql.contains("lists = 100"));
}
#[test]
fn test_hnsw_index() {
let sql = vector::create_hnsw_index_postgres(
"embeddings_hnsw",
"documents",
"embedding",
vector::SimilarityMetric::L2,
Some(16),
Some(64),
);
assert!(sql.contains("USING hnsw"));
assert!(sql.contains("m = 16"));
assert!(sql.contains("ef_construction = 64"));
}
mod mongodb_tests {
use super::super::mongodb::*;
use super::super::vector::SimilarityMetric;
#[test]
fn test_vector_search() {
let search = vector_search("vector_index", "embedding", vec![0.1, 0.2, 0.3])
.num_candidates(200)
.limit(20)
.build();
let stage = search.to_stage();
assert!(stage["$vectorSearch"]["index"].is_string());
assert_eq!(stage["$vectorSearch"]["numCandidates"], 200);
assert_eq!(stage["$vectorSearch"]["limit"], 20);
}
#[test]
fn test_vector_index_definition() {
let index = VectorIndex::new("my_vector_index", "documents")
.field("embedding", 1536, SimilarityMetric::Cosine)
.build();
let def = index.to_definition();
assert_eq!(def["name"], "my_vector_index");
assert_eq!(def["type"], "vectorSearch");
assert!(def["fields"].is_array());
}
}
}