1use serde::{Deserialize, Serialize};
17
18use crate::sql::DatabaseType;
19
20#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
26pub struct Extension {
27 pub name: String,
29 pub schema: Option<String>,
31 pub version: Option<String>,
33 pub cascade: bool,
35}
36
37impl Extension {
38 pub fn new(name: impl Into<String>) -> ExtensionBuilder {
40 ExtensionBuilder::new(name)
41 }
42
43 pub fn postgis() -> Self {
45 Self::new("postgis").build()
46 }
47
48 pub fn pgvector() -> Self {
49 Self::new("vector").build()
50 }
51
52 pub fn uuid_ossp() -> Self {
53 Self::new("uuid-ossp").build()
54 }
55
56 pub fn pgcrypto() -> Self {
57 Self::new("pgcrypto").build()
58 }
59
60 pub fn pg_trgm() -> Self {
61 Self::new("pg_trgm").build()
62 }
63
64 pub fn hstore() -> Self {
65 Self::new("hstore").build()
66 }
67
68 pub fn ltree() -> Self {
69 Self::new("ltree").build()
70 }
71
72 pub fn to_postgres_create(&self) -> String {
74 let mut sql = format!("CREATE EXTENSION IF NOT EXISTS \"{}\"", self.name);
75
76 if let Some(ref schema) = self.schema {
77 sql.push_str(&format!(" SCHEMA {}", schema));
78 }
79
80 if let Some(ref version) = self.version {
81 sql.push_str(&format!(" VERSION '{}'", version));
82 }
83
84 if self.cascade {
85 sql.push_str(" CASCADE");
86 }
87
88 sql
89 }
90
91 pub fn to_postgres_drop(&self) -> String {
93 let mut sql = format!("DROP EXTENSION IF EXISTS \"{}\"", self.name);
94 if self.cascade {
95 sql.push_str(" CASCADE");
96 }
97 sql
98 }
99
100 pub fn to_sqlite_load(&self) -> String {
102 format!("SELECT load_extension('{}')", self.name)
103 }
104}
105
106#[derive(Debug, Clone)]
108pub struct ExtensionBuilder {
109 name: String,
110 schema: Option<String>,
111 version: Option<String>,
112 cascade: bool,
113}
114
115impl ExtensionBuilder {
116 pub fn new(name: impl Into<String>) -> Self {
118 Self {
119 name: name.into(),
120 schema: None,
121 version: None,
122 cascade: false,
123 }
124 }
125
126 pub fn schema(mut self, schema: impl Into<String>) -> Self {
128 self.schema = Some(schema.into());
129 self
130 }
131
132 pub fn version(mut self, version: impl Into<String>) -> Self {
134 self.version = Some(version.into());
135 self
136 }
137
138 pub fn cascade(mut self) -> Self {
140 self.cascade = true;
141 self
142 }
143
144 pub fn build(self) -> Extension {
146 Extension {
147 name: self.name,
148 schema: self.schema,
149 version: self.version,
150 cascade: self.cascade,
151 }
152 }
153}
154
155#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
161pub struct Point {
162 pub longitude: f64,
164 pub latitude: f64,
166 pub srid: Option<i32>,
168}
169
170impl Point {
171 pub fn new(longitude: f64, latitude: f64) -> Self {
173 Self {
174 longitude,
175 latitude,
176 srid: None,
177 }
178 }
179
180 pub fn with_srid(longitude: f64, latitude: f64, srid: i32) -> Self {
182 Self {
183 longitude,
184 latitude,
185 srid: Some(srid),
186 }
187 }
188
189 pub fn wgs84(longitude: f64, latitude: f64) -> Self {
191 Self::with_srid(longitude, latitude, 4326)
192 }
193
194 pub fn to_postgis(&self) -> String {
196 if let Some(srid) = self.srid {
197 format!("ST_SetSRID(ST_MakePoint({}, {}), {})", self.longitude, self.latitude, srid)
198 } else {
199 format!("ST_MakePoint({}, {})", self.longitude, self.latitude)
200 }
201 }
202
203 pub fn to_mysql(&self) -> String {
205 if let Some(srid) = self.srid {
206 format!("ST_GeomFromText('POINT({} {})', {})", self.longitude, self.latitude, srid)
207 } else {
208 format!("ST_GeomFromText('POINT({} {})')", self.longitude, self.latitude)
209 }
210 }
211
212 pub fn to_mssql(&self) -> String {
214 format!(
215 "geography::Point({}, {}, {})",
216 self.latitude, self.longitude, self.srid.unwrap_or(4326)
217 )
218 }
219
220 pub fn to_geojson(&self) -> serde_json::Value {
222 serde_json::json!({
223 "type": "Point",
224 "coordinates": [self.longitude, self.latitude]
225 })
226 }
227
228 pub fn to_sql(&self, db_type: DatabaseType) -> String {
230 match db_type {
231 DatabaseType::PostgreSQL => self.to_postgis(),
232 DatabaseType::MySQL => self.to_mysql(),
233 DatabaseType::MSSQL => self.to_mssql(),
234 DatabaseType::SQLite => format!("MakePoint({}, {})", self.longitude, self.latitude),
235 }
236 }
237}
238
239#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
241pub struct Polygon {
242 pub exterior: Vec<(f64, f64)>,
244 pub interiors: Vec<Vec<(f64, f64)>>,
246 pub srid: Option<i32>,
248}
249
250impl Polygon {
251 pub fn new(exterior: Vec<(f64, f64)>) -> Self {
253 Self {
254 exterior,
255 interiors: Vec::new(),
256 srid: None,
257 }
258 }
259
260 pub fn with_hole(mut self, hole: Vec<(f64, f64)>) -> Self {
262 self.interiors.push(hole);
263 self
264 }
265
266 pub fn with_srid(mut self, srid: i32) -> Self {
268 self.srid = Some(srid);
269 self
270 }
271
272 pub fn to_wkt(&self) -> String {
274 let ext_coords: Vec<String> = self
275 .exterior
276 .iter()
277 .map(|(x, y)| format!("{} {}", x, y))
278 .collect();
279
280 let mut wkt = format!("POLYGON(({})", ext_coords.join(", "));
281
282 for interior in &self.interiors {
283 let int_coords: Vec<String> = interior
284 .iter()
285 .map(|(x, y)| format!("{} {}", x, y))
286 .collect();
287 wkt.push_str(&format!(", ({})", int_coords.join(", ")));
288 }
289
290 wkt.push(')');
291 wkt
292 }
293
294 pub fn to_postgis(&self) -> String {
296 if let Some(srid) = self.srid {
297 format!("ST_GeomFromText('{}', {})", self.to_wkt(), srid)
298 } else {
299 format!("ST_GeomFromText('{}')", self.to_wkt())
300 }
301 }
302
303 pub fn to_geojson(&self) -> serde_json::Value {
305 let mut coordinates = vec![self
306 .exterior
307 .iter()
308 .map(|(x, y)| vec![*x, *y])
309 .collect::<Vec<_>>()];
310
311 for interior in &self.interiors {
312 coordinates.push(interior.iter().map(|(x, y)| vec![*x, *y]).collect());
313 }
314
315 serde_json::json!({
316 "type": "Polygon",
317 "coordinates": coordinates
318 })
319 }
320}
321
322pub mod geo {
324 use super::*;
325
326 #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
328 pub enum DistanceUnit {
329 Meters,
331 Kilometers,
333 Miles,
335 Feet,
337 }
338
339 impl DistanceUnit {
340 pub fn from_meters(&self) -> f64 {
342 match self {
343 Self::Meters => 1.0,
344 Self::Kilometers => 0.001,
345 Self::Miles => 0.000621371,
346 Self::Feet => 3.28084,
347 }
348 }
349 }
350
351 pub fn distance_sql(col1: &str, col2: &str, db_type: DatabaseType) -> String {
353 match db_type {
354 DatabaseType::PostgreSQL => format!("ST_Distance({}::geography, {}::geography)", col1, col2),
355 DatabaseType::MySQL => format!("ST_Distance_Sphere({}, {})", col1, col2),
356 DatabaseType::MSSQL => format!("{}.STDistance({})", col1, col2),
357 DatabaseType::SQLite => format!("Distance({}, {})", col1, col2),
358 }
359 }
360
361 pub fn distance_from_point_sql(col: &str, point: &Point, db_type: DatabaseType) -> String {
363 let point_sql = point.to_sql(db_type);
364 match db_type {
365 DatabaseType::PostgreSQL => {
366 format!("ST_Distance({}::geography, {}::geography)", col, point_sql)
367 }
368 DatabaseType::MySQL => format!("ST_Distance_Sphere({}, {})", col, point_sql),
369 DatabaseType::MSSQL => format!("{}.STDistance({})", col, point_sql),
370 DatabaseType::SQLite => format!("Distance({}, {})", col, point_sql),
371 }
372 }
373
374 pub fn within_distance_sql(
376 col: &str,
377 point: &Point,
378 distance_meters: f64,
379 db_type: DatabaseType,
380 ) -> String {
381 let point_sql = point.to_sql(db_type);
382 match db_type {
383 DatabaseType::PostgreSQL => {
384 format!(
385 "ST_DWithin({}::geography, {}::geography, {})",
386 col, point_sql, distance_meters
387 )
388 }
389 DatabaseType::MySQL => {
390 format!(
391 "ST_Distance_Sphere({}, {}) <= {}",
392 col, point_sql, distance_meters
393 )
394 }
395 DatabaseType::MSSQL => {
396 format!("{}.STDistance({}) <= {}", col, point_sql, distance_meters)
397 }
398 DatabaseType::SQLite => {
399 format!("Distance({}, {}) <= {}", col, point_sql, distance_meters)
400 }
401 }
402 }
403
404 pub fn contains_sql(geom_col: &str, point: &Point, db_type: DatabaseType) -> String {
406 let point_sql = point.to_sql(db_type);
407 match db_type {
408 DatabaseType::PostgreSQL => format!("ST_Contains({}, {})", geom_col, point_sql),
409 DatabaseType::MySQL => format!("ST_Contains({}, {})", geom_col, point_sql),
410 DatabaseType::MSSQL => format!("{}.STContains({})", geom_col, point_sql),
411 DatabaseType::SQLite => format!("Contains({}, {})", geom_col, point_sql),
412 }
413 }
414
415 pub fn bbox_sql(
417 col: &str,
418 min_lon: f64,
419 min_lat: f64,
420 max_lon: f64,
421 max_lat: f64,
422 db_type: DatabaseType,
423 ) -> String {
424 match db_type {
425 DatabaseType::PostgreSQL => {
426 format!(
427 "{} && ST_MakeEnvelope({}, {}, {}, {}, 4326)",
428 col, min_lon, min_lat, max_lon, max_lat
429 )
430 }
431 DatabaseType::MySQL => {
432 format!(
433 "MBRContains(ST_GeomFromText('POLYGON(({} {}, {} {}, {} {}, {} {}, {} {}))'), {})",
434 min_lon, min_lat, max_lon, min_lat, max_lon, max_lat, min_lon, max_lat, min_lon, min_lat, col
435 )
436 }
437 _ => "1=1".to_string(),
438 }
439 }
440}
441
442pub mod uuid {
448 use super::*;
449
450 pub fn generate_v4(db_type: DatabaseType) -> String {
452 match db_type {
453 DatabaseType::PostgreSQL => "gen_random_uuid()".to_string(),
454 DatabaseType::MySQL => "UUID()".to_string(),
455 DatabaseType::MSSQL => "NEWID()".to_string(),
456 DatabaseType::SQLite => {
457 "lower(hex(randomblob(4))) || '-' || lower(hex(randomblob(2))) || '-4' || \
459 substr(lower(hex(randomblob(2))), 2) || '-' || \
460 substr('89ab', abs(random()) % 4 + 1, 1) || \
461 substr(lower(hex(randomblob(2))), 2) || '-' || lower(hex(randomblob(6)))"
462 .to_string()
463 }
464 }
465 }
466
467 pub fn generate_v7_postgres() -> String {
469 "uuid_generate_v7()".to_string()
470 }
471
472 pub fn from_string(value: &str, db_type: DatabaseType) -> String {
474 match db_type {
475 DatabaseType::PostgreSQL => format!("'{}'::uuid", value),
476 DatabaseType::MySQL => format!("UUID_TO_BIN('{}')", value),
477 DatabaseType::MSSQL => format!("CONVERT(UNIQUEIDENTIFIER, '{}')", value),
478 DatabaseType::SQLite => format!("'{}'", value),
479 }
480 }
481
482 pub fn is_valid_sql(col: &str, db_type: DatabaseType) -> String {
484 match db_type {
485 DatabaseType::PostgreSQL => format!(
486 "{} ~ '^[0-9a-f]{{8}}-[0-9a-f]{{4}}-[0-9a-f]{{4}}-[0-9a-f]{{4}}-[0-9a-f]{{12}}$'",
487 col
488 ),
489 DatabaseType::MySQL => format!(
490 "{} REGEXP '^[0-9a-f]{{8}}-[0-9a-f]{{4}}-[0-9a-f]{{4}}-[0-9a-f]{{4}}-[0-9a-f]{{12}}$'",
491 col
492 ),
493 _ => format!("LEN({}) = 36", col),
494 }
495 }
496}
497
498pub mod crypto {
504 use super::*;
505
506 #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
508 pub enum HashAlgorithm {
509 Md5,
510 Sha1,
511 Sha256,
512 Sha384,
513 Sha512,
514 }
515
516 impl HashAlgorithm {
517 pub fn postgres_name(&self) -> &'static str {
519 match self {
520 Self::Md5 => "md5",
521 Self::Sha1 => "sha1",
522 Self::Sha256 => "sha256",
523 Self::Sha384 => "sha384",
524 Self::Sha512 => "sha512",
525 }
526 }
527 }
528
529 pub fn hash_sql(expr: &str, algorithm: HashAlgorithm, db_type: DatabaseType) -> String {
531 match db_type {
532 DatabaseType::PostgreSQL => {
533 if algorithm == HashAlgorithm::Md5 {
534 format!("md5({})", expr)
535 } else {
536 format!("encode(digest({}, '{}'), 'hex')", expr, algorithm.postgres_name())
537 }
538 }
539 DatabaseType::MySQL => match algorithm {
540 HashAlgorithm::Md5 => format!("MD5({})", expr),
541 HashAlgorithm::Sha1 => format!("SHA1({})", expr),
542 HashAlgorithm::Sha256 => format!("SHA2({}, 256)", expr),
543 HashAlgorithm::Sha384 => format!("SHA2({}, 384)", expr),
544 HashAlgorithm::Sha512 => format!("SHA2({}, 512)", expr),
545 },
546 DatabaseType::MSSQL => {
547 let algo = match algorithm {
548 HashAlgorithm::Md5 => "MD5",
549 HashAlgorithm::Sha1 => "SHA1",
550 HashAlgorithm::Sha256 => "SHA2_256",
551 HashAlgorithm::Sha384 => "SHA2_384",
552 HashAlgorithm::Sha512 => "SHA2_512",
553 };
554 format!("CONVERT(VARCHAR(MAX), HASHBYTES('{}', {}), 2)", algo, expr)
555 }
556 DatabaseType::SQLite => {
557 format!("-- SQLite requires extension for hashing: {}", expr)
559 }
560 }
561 }
562
563 pub fn bcrypt_hash_postgres(password: &str) -> String {
565 format!("crypt('{}', gen_salt('bf'))", password)
566 }
567
568 pub fn bcrypt_verify_postgres(password: &str, hash_col: &str) -> String {
570 format!("{} = crypt('{}', {})", hash_col, password, hash_col)
571 }
572
573 pub fn random_bytes_sql(length: usize, db_type: DatabaseType) -> String {
575 match db_type {
576 DatabaseType::PostgreSQL => format!("gen_random_bytes({})", length),
577 DatabaseType::MySQL => format!("RANDOM_BYTES({})", length),
578 DatabaseType::MSSQL => format!("CRYPT_GEN_RANDOM({})", length),
579 DatabaseType::SQLite => format!("randomblob({})", length),
580 }
581 }
582
583 pub fn aes_encrypt_postgres(data: &str, key: &str) -> String {
585 format!("pgp_sym_encrypt('{}', '{}')", data, key)
586 }
587
588 pub fn aes_decrypt_postgres(encrypted_col: &str, key: &str) -> String {
590 format!("pgp_sym_decrypt({}, '{}')", encrypted_col, key)
591 }
592}
593
594pub mod vector {
600 use super::*;
601
602 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
604 pub struct Vector {
605 pub dimensions: Vec<f32>,
607 }
608
609 impl Vector {
610 pub fn new(dimensions: Vec<f32>) -> Self {
612 Self { dimensions }
613 }
614
615 pub fn from_slice(slice: &[f32]) -> Self {
617 Self {
618 dimensions: slice.to_vec(),
619 }
620 }
621
622 pub fn len(&self) -> usize {
624 self.dimensions.len()
625 }
626
627 pub fn is_empty(&self) -> bool {
629 self.dimensions.is_empty()
630 }
631
632 pub fn to_pgvector(&self) -> String {
634 let nums: Vec<String> = self.dimensions.iter().map(|f| f.to_string()).collect();
635 format!("'[{}]'::vector", nums.join(","))
636 }
637
638 pub fn to_mongodb(&self) -> serde_json::Value {
640 serde_json::json!(self.dimensions)
641 }
642 }
643
644 #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
646 pub enum SimilarityMetric {
647 L2,
649 InnerProduct,
651 Cosine,
653 }
654
655 impl SimilarityMetric {
656 pub fn postgres_operator(&self) -> &'static str {
658 match self {
659 Self::L2 => "<->",
660 Self::InnerProduct => "<#>",
661 Self::Cosine => "<=>",
662 }
663 }
664
665 pub fn mongodb_name(&self) -> &'static str {
667 match self {
668 Self::L2 => "euclidean",
669 Self::InnerProduct => "dotProduct",
670 Self::Cosine => "cosine",
671 }
672 }
673 }
674
675 pub fn similarity_search_postgres(
677 col: &str,
678 query_vector: &Vector,
679 metric: SimilarityMetric,
680 limit: usize,
681 ) -> String {
682 format!(
683 "SELECT *, {} {} {} AS distance FROM {{table}} ORDER BY distance LIMIT {}",
684 col,
685 metric.postgres_operator(),
686 query_vector.to_pgvector(),
687 limit
688 )
689 }
690
691 pub fn distance_sql(col: &str, query_vector: &Vector, metric: SimilarityMetric) -> String {
693 format!(
694 "{} {} {}",
695 col,
696 metric.postgres_operator(),
697 query_vector.to_pgvector()
698 )
699 }
700
701 pub fn create_index_postgres(
703 index_name: &str,
704 table: &str,
705 column: &str,
706 metric: SimilarityMetric,
707 lists: Option<usize>,
708 ) -> String {
709 let ops = match metric {
710 SimilarityMetric::L2 => "vector_l2_ops",
711 SimilarityMetric::InnerProduct => "vector_ip_ops",
712 SimilarityMetric::Cosine => "vector_cosine_ops",
713 };
714
715 let lists_clause = lists
716 .map(|l| format!(" WITH (lists = {})", l))
717 .unwrap_or_default();
718
719 format!(
720 "CREATE INDEX {} ON {} USING ivfflat ({} {}){}",
721 index_name, table, column, ops, lists_clause
722 )
723 }
724
725 pub fn create_hnsw_index_postgres(
727 index_name: &str,
728 table: &str,
729 column: &str,
730 metric: SimilarityMetric,
731 m: Option<usize>,
732 ef_construction: Option<usize>,
733 ) -> String {
734 let ops = match metric {
735 SimilarityMetric::L2 => "vector_l2_ops",
736 SimilarityMetric::InnerProduct => "vector_ip_ops",
737 SimilarityMetric::Cosine => "vector_cosine_ops",
738 };
739
740 let mut with_clauses = Vec::new();
741 if let Some(m_val) = m {
742 with_clauses.push(format!("m = {}", m_val));
743 }
744 if let Some(ef) = ef_construction {
745 with_clauses.push(format!("ef_construction = {}", ef));
746 }
747
748 let with_clause = if with_clauses.is_empty() {
749 String::new()
750 } else {
751 format!(" WITH ({})", with_clauses.join(", "))
752 };
753
754 format!(
755 "CREATE INDEX {} ON {} USING hnsw ({} {}){}",
756 index_name, table, column, ops, with_clause
757 )
758 }
759}
760
761pub mod mongodb {
763 use serde::{Deserialize, Serialize};
764 use serde_json::Value as JsonValue;
765
766 use super::vector::SimilarityMetric;
767
768 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
770 pub struct VectorSearch {
771 pub index: String,
773 pub path: String,
775 pub query_vector: Vec<f32>,
777 pub num_candidates: usize,
779 pub limit: usize,
781 pub filter: Option<JsonValue>,
783 }
784
785 impl VectorSearch {
786 pub fn new(index: impl Into<String>, path: impl Into<String>, query: Vec<f32>) -> VectorSearchBuilder {
788 VectorSearchBuilder::new(index, path, query)
789 }
790
791 pub fn to_stage(&self) -> JsonValue {
793 let mut search = serde_json::json!({
794 "index": self.index,
795 "path": self.path,
796 "queryVector": self.query_vector,
797 "numCandidates": self.num_candidates,
798 "limit": self.limit
799 });
800
801 if let Some(ref filter) = self.filter {
802 search["filter"] = filter.clone();
803 }
804
805 serde_json::json!({ "$vectorSearch": search })
806 }
807 }
808
809 #[derive(Debug, Clone)]
811 pub struct VectorSearchBuilder {
812 index: String,
813 path: String,
814 query_vector: Vec<f32>,
815 num_candidates: usize,
816 limit: usize,
817 filter: Option<JsonValue>,
818 }
819
820 impl VectorSearchBuilder {
821 pub fn new(index: impl Into<String>, path: impl Into<String>, query: Vec<f32>) -> Self {
823 Self {
824 index: index.into(),
825 path: path.into(),
826 query_vector: query,
827 num_candidates: 100,
828 limit: 10,
829 filter: None,
830 }
831 }
832
833 pub fn num_candidates(mut self, n: usize) -> Self {
835 self.num_candidates = n;
836 self
837 }
838
839 pub fn limit(mut self, n: usize) -> Self {
841 self.limit = n;
842 self
843 }
844
845 pub fn filter(mut self, filter: JsonValue) -> Self {
847 self.filter = Some(filter);
848 self
849 }
850
851 pub fn build(self) -> VectorSearch {
853 VectorSearch {
854 index: self.index,
855 path: self.path,
856 query_vector: self.query_vector,
857 num_candidates: self.num_candidates,
858 limit: self.limit,
859 filter: self.filter,
860 }
861 }
862 }
863
864 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
866 pub struct VectorIndex {
867 pub name: String,
869 pub collection: String,
871 pub fields: Vec<VectorField>,
873 }
874
875 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
877 pub struct VectorField {
878 pub path: String,
880 pub dimensions: usize,
882 pub similarity: String,
884 }
885
886 impl VectorIndex {
887 pub fn new(name: impl Into<String>, collection: impl Into<String>) -> VectorIndexBuilder {
889 VectorIndexBuilder::new(name, collection)
890 }
891
892 pub fn to_definition(&self) -> JsonValue {
894 let fields: Vec<JsonValue> = self
895 .fields
896 .iter()
897 .map(|f| {
898 serde_json::json!({
899 "type": "vector",
900 "path": f.path,
901 "numDimensions": f.dimensions,
902 "similarity": f.similarity
903 })
904 })
905 .collect();
906
907 serde_json::json!({
908 "name": self.name,
909 "type": "vectorSearch",
910 "fields": fields
911 })
912 }
913 }
914
915 #[derive(Debug, Clone)]
917 pub struct VectorIndexBuilder {
918 name: String,
919 collection: String,
920 fields: Vec<VectorField>,
921 }
922
923 impl VectorIndexBuilder {
924 pub fn new(name: impl Into<String>, collection: impl Into<String>) -> Self {
926 Self {
927 name: name.into(),
928 collection: collection.into(),
929 fields: Vec::new(),
930 }
931 }
932
933 pub fn field(
935 mut self,
936 path: impl Into<String>,
937 dimensions: usize,
938 similarity: SimilarityMetric,
939 ) -> Self {
940 self.fields.push(VectorField {
941 path: path.into(),
942 dimensions,
943 similarity: similarity.mongodb_name().to_string(),
944 });
945 self
946 }
947
948 pub fn build(self) -> VectorIndex {
950 VectorIndex {
951 name: self.name,
952 collection: self.collection,
953 fields: self.fields,
954 }
955 }
956 }
957
958 pub fn vector_search(index: &str, path: &str, query: Vec<f32>) -> VectorSearchBuilder {
960 VectorSearch::new(index, path, query)
961 }
962}
963
964#[cfg(test)]
965mod tests {
966 use super::*;
967
968 #[test]
969 fn test_extension_postgres() {
970 let ext = Extension::new("postgis").schema("public").cascade().build();
971 let sql = ext.to_postgres_create();
972
973 assert!(sql.contains("CREATE EXTENSION IF NOT EXISTS \"postgis\""));
974 assert!(sql.contains("SCHEMA public"));
975 assert!(sql.contains("CASCADE"));
976 }
977
978 #[test]
979 fn test_extension_drop() {
980 let ext = Extension::postgis();
981 let sql = ext.to_postgres_drop();
982
983 assert!(sql.contains("DROP EXTENSION IF EXISTS \"postgis\""));
984 }
985
986 #[test]
987 fn test_point_postgis() {
988 let point = Point::wgs84(-122.4194, 37.7749);
989 let sql = point.to_postgis();
990
991 assert!(sql.contains("ST_SetSRID"));
992 assert!(sql.contains("-122.4194"));
993 assert!(sql.contains("37.7749"));
994 assert!(sql.contains("4326"));
995 }
996
997 #[test]
998 fn test_point_geojson() {
999 let point = Point::new(-122.4194, 37.7749);
1000 let geojson = point.to_geojson();
1001
1002 assert_eq!(geojson["type"], "Point");
1003 assert_eq!(geojson["coordinates"][0], -122.4194);
1004 }
1005
1006 #[test]
1007 fn test_polygon_wkt() {
1008 let polygon = Polygon::new(vec![
1009 (0.0, 0.0),
1010 (10.0, 0.0),
1011 (10.0, 10.0),
1012 (0.0, 10.0),
1013 (0.0, 0.0),
1014 ]);
1015
1016 let wkt = polygon.to_wkt();
1017 assert!(wkt.starts_with("POLYGON(("));
1018 }
1019
1020 #[test]
1021 fn test_distance_sql() {
1022 let sql = geo::distance_sql("location", "target", DatabaseType::PostgreSQL);
1023 assert!(sql.contains("ST_Distance"));
1024 }
1025
1026 #[test]
1027 fn test_within_distance() {
1028 let point = Point::wgs84(-122.4194, 37.7749);
1029 let sql = geo::within_distance_sql("location", &point, 1000.0, DatabaseType::PostgreSQL);
1030
1031 assert!(sql.contains("ST_DWithin"));
1032 assert!(sql.contains("1000"));
1033 }
1034
1035 #[test]
1036 fn test_uuid_generation() {
1037 let pg = uuid::generate_v4(DatabaseType::PostgreSQL);
1038 assert_eq!(pg, "gen_random_uuid()");
1039
1040 let mysql = uuid::generate_v4(DatabaseType::MySQL);
1041 assert_eq!(mysql, "UUID()");
1042
1043 let mssql = uuid::generate_v4(DatabaseType::MSSQL);
1044 assert_eq!(mssql, "NEWID()");
1045 }
1046
1047 #[test]
1048 fn test_hash_sql() {
1049 let pg = crypto::hash_sql("password", crypto::HashAlgorithm::Sha256, DatabaseType::PostgreSQL);
1050 assert!(pg.contains("digest"));
1051 assert!(pg.contains("sha256"));
1052
1053 let mysql = crypto::hash_sql("password", crypto::HashAlgorithm::Sha256, DatabaseType::MySQL);
1054 assert!(mysql.contains("SHA2"));
1055 assert!(mysql.contains("256"));
1056 }
1057
1058 #[test]
1059 fn test_vector_pgvector() {
1060 let vec = vector::Vector::new(vec![0.1, 0.2, 0.3, 0.4]);
1061 let sql = vec.to_pgvector();
1062
1063 assert!(sql.contains("'[0.1,0.2,0.3,0.4]'::vector"));
1064 }
1065
1066 #[test]
1067 fn test_vector_index() {
1068 let sql = vector::create_index_postgres(
1069 "embeddings_idx",
1070 "documents",
1071 "embedding",
1072 vector::SimilarityMetric::Cosine,
1073 Some(100),
1074 );
1075
1076 assert!(sql.contains("CREATE INDEX embeddings_idx"));
1077 assert!(sql.contains("USING ivfflat"));
1078 assert!(sql.contains("vector_cosine_ops"));
1079 assert!(sql.contains("lists = 100"));
1080 }
1081
1082 #[test]
1083 fn test_hnsw_index() {
1084 let sql = vector::create_hnsw_index_postgres(
1085 "embeddings_hnsw",
1086 "documents",
1087 "embedding",
1088 vector::SimilarityMetric::L2,
1089 Some(16),
1090 Some(64),
1091 );
1092
1093 assert!(sql.contains("USING hnsw"));
1094 assert!(sql.contains("m = 16"));
1095 assert!(sql.contains("ef_construction = 64"));
1096 }
1097
1098 mod mongodb_tests {
1099 use super::super::mongodb::*;
1100 use super::super::vector::SimilarityMetric;
1101
1102 #[test]
1103 fn test_vector_search() {
1104 let search = vector_search("vector_index", "embedding", vec![0.1, 0.2, 0.3])
1105 .num_candidates(200)
1106 .limit(20)
1107 .build();
1108
1109 let stage = search.to_stage();
1110 assert!(stage["$vectorSearch"]["index"].is_string());
1111 assert_eq!(stage["$vectorSearch"]["numCandidates"], 200);
1112 assert_eq!(stage["$vectorSearch"]["limit"], 20);
1113 }
1114
1115 #[test]
1116 fn test_vector_index_definition() {
1117 let index = VectorIndex::new("my_vector_index", "documents")
1118 .field("embedding", 1536, SimilarityMetric::Cosine)
1119 .build();
1120
1121 let def = index.to_definition();
1122 assert_eq!(def["name"], "my_vector_index");
1123 assert_eq!(def["type"], "vectorSearch");
1124 assert!(def["fields"].is_array());
1125 }
1126 }
1127}
1128
1129