1use serde::{Deserialize, Serialize};
17
18use crate::sql::DatabaseType;
19
20#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
26pub struct Extension {
27 pub name: String,
29 pub schema: Option<String>,
31 pub version: Option<String>,
33 pub cascade: bool,
35}
36
37impl Extension {
38 pub fn new(name: impl Into<String>) -> ExtensionBuilder {
40 ExtensionBuilder::new(name)
41 }
42
43 pub fn postgis() -> Self {
45 Self::new("postgis").build()
46 }
47
48 pub fn pgvector() -> Self {
49 Self::new("vector").build()
50 }
51
52 pub fn uuid_ossp() -> Self {
53 Self::new("uuid-ossp").build()
54 }
55
56 pub fn pgcrypto() -> Self {
57 Self::new("pgcrypto").build()
58 }
59
60 pub fn pg_trgm() -> Self {
61 Self::new("pg_trgm").build()
62 }
63
64 pub fn hstore() -> Self {
65 Self::new("hstore").build()
66 }
67
68 pub fn ltree() -> Self {
69 Self::new("ltree").build()
70 }
71
72 pub fn to_postgres_create(&self) -> String {
74 let mut sql = format!("CREATE EXTENSION IF NOT EXISTS \"{}\"", self.name);
75
76 if let Some(ref schema) = self.schema {
77 sql.push_str(&format!(" SCHEMA {}", schema));
78 }
79
80 if let Some(ref version) = self.version {
81 sql.push_str(&format!(" VERSION '{}'", version));
82 }
83
84 if self.cascade {
85 sql.push_str(" CASCADE");
86 }
87
88 sql
89 }
90
91 pub fn to_postgres_drop(&self) -> String {
93 let mut sql = format!("DROP EXTENSION IF EXISTS \"{}\"", self.name);
94 if self.cascade {
95 sql.push_str(" CASCADE");
96 }
97 sql
98 }
99
100 pub fn to_sqlite_load(&self) -> String {
102 format!("SELECT load_extension('{}')", self.name)
103 }
104}
105
106#[derive(Debug, Clone)]
108pub struct ExtensionBuilder {
109 name: String,
110 schema: Option<String>,
111 version: Option<String>,
112 cascade: bool,
113}
114
115impl ExtensionBuilder {
116 pub fn new(name: impl Into<String>) -> Self {
118 Self {
119 name: name.into(),
120 schema: None,
121 version: None,
122 cascade: false,
123 }
124 }
125
126 pub fn schema(mut self, schema: impl Into<String>) -> Self {
128 self.schema = Some(schema.into());
129 self
130 }
131
132 pub fn version(mut self, version: impl Into<String>) -> Self {
134 self.version = Some(version.into());
135 self
136 }
137
138 pub fn cascade(mut self) -> Self {
140 self.cascade = true;
141 self
142 }
143
144 pub fn build(self) -> Extension {
146 Extension {
147 name: self.name,
148 schema: self.schema,
149 version: self.version,
150 cascade: self.cascade,
151 }
152 }
153}
154
155#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
161pub struct Point {
162 pub longitude: f64,
164 pub latitude: f64,
166 pub srid: Option<i32>,
168}
169
170impl Point {
171 pub fn new(longitude: f64, latitude: f64) -> Self {
173 Self {
174 longitude,
175 latitude,
176 srid: None,
177 }
178 }
179
180 pub fn with_srid(longitude: f64, latitude: f64, srid: i32) -> Self {
182 Self {
183 longitude,
184 latitude,
185 srid: Some(srid),
186 }
187 }
188
189 pub fn wgs84(longitude: f64, latitude: f64) -> Self {
191 Self::with_srid(longitude, latitude, 4326)
192 }
193
194 pub fn to_postgis(&self) -> String {
196 if let Some(srid) = self.srid {
197 format!(
198 "ST_SetSRID(ST_MakePoint({}, {}), {})",
199 self.longitude, self.latitude, srid
200 )
201 } else {
202 format!("ST_MakePoint({}, {})", self.longitude, self.latitude)
203 }
204 }
205
206 pub fn to_mysql(&self) -> String {
208 if let Some(srid) = self.srid {
209 format!(
210 "ST_GeomFromText('POINT({} {})', {})",
211 self.longitude, self.latitude, srid
212 )
213 } else {
214 format!(
215 "ST_GeomFromText('POINT({} {})')",
216 self.longitude, self.latitude
217 )
218 }
219 }
220
221 pub fn to_mssql(&self) -> String {
223 format!(
224 "geography::Point({}, {}, {})",
225 self.latitude,
226 self.longitude,
227 self.srid.unwrap_or(4326)
228 )
229 }
230
231 pub fn to_geojson(&self) -> serde_json::Value {
233 serde_json::json!({
234 "type": "Point",
235 "coordinates": [self.longitude, self.latitude]
236 })
237 }
238
239 pub fn to_sql(&self, db_type: DatabaseType) -> String {
241 match db_type {
242 DatabaseType::PostgreSQL => self.to_postgis(),
243 DatabaseType::MySQL => self.to_mysql(),
244 DatabaseType::MSSQL => self.to_mssql(),
245 DatabaseType::SQLite => format!("MakePoint({}, {})", self.longitude, self.latitude),
246 }
247 }
248}
249
250#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
252pub struct Polygon {
253 pub exterior: Vec<(f64, f64)>,
255 pub interiors: Vec<Vec<(f64, f64)>>,
257 pub srid: Option<i32>,
259}
260
261impl Polygon {
262 pub fn new(exterior: Vec<(f64, f64)>) -> Self {
264 Self {
265 exterior,
266 interiors: Vec::new(),
267 srid: None,
268 }
269 }
270
271 pub fn with_hole(mut self, hole: Vec<(f64, f64)>) -> Self {
273 self.interiors.push(hole);
274 self
275 }
276
277 pub fn with_srid(mut self, srid: i32) -> Self {
279 self.srid = Some(srid);
280 self
281 }
282
283 pub fn to_wkt(&self) -> String {
285 let ext_coords: Vec<String> = self
286 .exterior
287 .iter()
288 .map(|(x, y)| format!("{} {}", x, y))
289 .collect();
290
291 let mut wkt = format!("POLYGON(({})", ext_coords.join(", "));
292
293 for interior in &self.interiors {
294 let int_coords: Vec<String> = interior
295 .iter()
296 .map(|(x, y)| format!("{} {}", x, y))
297 .collect();
298 wkt.push_str(&format!(", ({})", int_coords.join(", ")));
299 }
300
301 wkt.push(')');
302 wkt
303 }
304
305 pub fn to_postgis(&self) -> String {
307 if let Some(srid) = self.srid {
308 format!("ST_GeomFromText('{}', {})", self.to_wkt(), srid)
309 } else {
310 format!("ST_GeomFromText('{}')", self.to_wkt())
311 }
312 }
313
314 pub fn to_geojson(&self) -> serde_json::Value {
316 let mut coordinates = vec![
317 self.exterior
318 .iter()
319 .map(|(x, y)| vec![*x, *y])
320 .collect::<Vec<_>>(),
321 ];
322
323 for interior in &self.interiors {
324 coordinates.push(interior.iter().map(|(x, y)| vec![*x, *y]).collect());
325 }
326
327 serde_json::json!({
328 "type": "Polygon",
329 "coordinates": coordinates
330 })
331 }
332}
333
334pub mod geo {
336 use super::*;
337
338 #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
340 pub enum DistanceUnit {
341 Meters,
343 Kilometers,
345 Miles,
347 Feet,
349 }
350
351 impl DistanceUnit {
352 pub fn from_meters(&self) -> f64 {
354 match self {
355 Self::Meters => 1.0,
356 Self::Kilometers => 0.001,
357 Self::Miles => 0.000621371,
358 Self::Feet => 3.28084,
359 }
360 }
361 }
362
363 pub fn distance_sql(col1: &str, col2: &str, db_type: DatabaseType) -> String {
365 match db_type {
366 DatabaseType::PostgreSQL => {
367 format!("ST_Distance({}::geography, {}::geography)", col1, col2)
368 }
369 DatabaseType::MySQL => format!("ST_Distance_Sphere({}, {})", col1, col2),
370 DatabaseType::MSSQL => format!("{}.STDistance({})", col1, col2),
371 DatabaseType::SQLite => format!("Distance({}, {})", col1, col2),
372 }
373 }
374
375 pub fn distance_from_point_sql(col: &str, point: &Point, db_type: DatabaseType) -> String {
377 let point_sql = point.to_sql(db_type);
378 match db_type {
379 DatabaseType::PostgreSQL => {
380 format!("ST_Distance({}::geography, {}::geography)", col, point_sql)
381 }
382 DatabaseType::MySQL => format!("ST_Distance_Sphere({}, {})", col, point_sql),
383 DatabaseType::MSSQL => format!("{}.STDistance({})", col, point_sql),
384 DatabaseType::SQLite => format!("Distance({}, {})", col, point_sql),
385 }
386 }
387
388 pub fn within_distance_sql(
390 col: &str,
391 point: &Point,
392 distance_meters: f64,
393 db_type: DatabaseType,
394 ) -> String {
395 let point_sql = point.to_sql(db_type);
396 match db_type {
397 DatabaseType::PostgreSQL => {
398 format!(
399 "ST_DWithin({}::geography, {}::geography, {})",
400 col, point_sql, distance_meters
401 )
402 }
403 DatabaseType::MySQL => {
404 format!(
405 "ST_Distance_Sphere({}, {}) <= {}",
406 col, point_sql, distance_meters
407 )
408 }
409 DatabaseType::MSSQL => {
410 format!("{}.STDistance({}) <= {}", col, point_sql, distance_meters)
411 }
412 DatabaseType::SQLite => {
413 format!("Distance({}, {}) <= {}", col, point_sql, distance_meters)
414 }
415 }
416 }
417
418 pub fn contains_sql(geom_col: &str, point: &Point, db_type: DatabaseType) -> String {
420 let point_sql = point.to_sql(db_type);
421 match db_type {
422 DatabaseType::PostgreSQL => format!("ST_Contains({}, {})", geom_col, point_sql),
423 DatabaseType::MySQL => format!("ST_Contains({}, {})", geom_col, point_sql),
424 DatabaseType::MSSQL => format!("{}.STContains({})", geom_col, point_sql),
425 DatabaseType::SQLite => format!("Contains({}, {})", geom_col, point_sql),
426 }
427 }
428
429 pub fn bbox_sql(
431 col: &str,
432 min_lon: f64,
433 min_lat: f64,
434 max_lon: f64,
435 max_lat: f64,
436 db_type: DatabaseType,
437 ) -> String {
438 match db_type {
439 DatabaseType::PostgreSQL => {
440 format!(
441 "{} && ST_MakeEnvelope({}, {}, {}, {}, 4326)",
442 col, min_lon, min_lat, max_lon, max_lat
443 )
444 }
445 DatabaseType::MySQL => {
446 format!(
447 "MBRContains(ST_GeomFromText('POLYGON(({} {}, {} {}, {} {}, {} {}, {} {}))'), {})",
448 min_lon,
449 min_lat,
450 max_lon,
451 min_lat,
452 max_lon,
453 max_lat,
454 min_lon,
455 max_lat,
456 min_lon,
457 min_lat,
458 col
459 )
460 }
461 _ => "1=1".to_string(),
462 }
463 }
464}
465
466pub mod uuid {
472 use super::*;
473
474 pub fn generate_v4(db_type: DatabaseType) -> String {
476 match db_type {
477 DatabaseType::PostgreSQL => "gen_random_uuid()".to_string(),
478 DatabaseType::MySQL => "UUID()".to_string(),
479 DatabaseType::MSSQL => "NEWID()".to_string(),
480 DatabaseType::SQLite => {
481 "lower(hex(randomblob(4))) || '-' || lower(hex(randomblob(2))) || '-4' || \
483 substr(lower(hex(randomblob(2))), 2) || '-' || \
484 substr('89ab', abs(random()) % 4 + 1, 1) || \
485 substr(lower(hex(randomblob(2))), 2) || '-' || lower(hex(randomblob(6)))"
486 .to_string()
487 }
488 }
489 }
490
491 pub fn generate_v7_postgres() -> String {
493 "uuid_generate_v7()".to_string()
494 }
495
496 pub fn from_string(value: &str, db_type: DatabaseType) -> String {
498 match db_type {
499 DatabaseType::PostgreSQL => format!("'{}'::uuid", value),
500 DatabaseType::MySQL => format!("UUID_TO_BIN('{}')", value),
501 DatabaseType::MSSQL => format!("CONVERT(UNIQUEIDENTIFIER, '{}')", value),
502 DatabaseType::SQLite => format!("'{}'", value),
503 }
504 }
505
506 pub fn is_valid_sql(col: &str, db_type: DatabaseType) -> String {
508 match db_type {
509 DatabaseType::PostgreSQL => format!(
510 "{} ~ '^[0-9a-f]{{8}}-[0-9a-f]{{4}}-[0-9a-f]{{4}}-[0-9a-f]{{4}}-[0-9a-f]{{12}}$'",
511 col
512 ),
513 DatabaseType::MySQL => format!(
514 "{} REGEXP '^[0-9a-f]{{8}}-[0-9a-f]{{4}}-[0-9a-f]{{4}}-[0-9a-f]{{4}}-[0-9a-f]{{12}}$'",
515 col
516 ),
517 _ => format!("LEN({}) = 36", col),
518 }
519 }
520}
521
522pub mod crypto {
528 use super::*;
529
530 #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
532 pub enum HashAlgorithm {
533 Md5,
534 Sha1,
535 Sha256,
536 Sha384,
537 Sha512,
538 }
539
540 impl HashAlgorithm {
541 pub fn postgres_name(&self) -> &'static str {
543 match self {
544 Self::Md5 => "md5",
545 Self::Sha1 => "sha1",
546 Self::Sha256 => "sha256",
547 Self::Sha384 => "sha384",
548 Self::Sha512 => "sha512",
549 }
550 }
551 }
552
553 pub fn hash_sql(expr: &str, algorithm: HashAlgorithm, db_type: DatabaseType) -> String {
555 match db_type {
556 DatabaseType::PostgreSQL => {
557 if algorithm == HashAlgorithm::Md5 {
558 format!("md5({})", expr)
559 } else {
560 format!(
561 "encode(digest({}, '{}'), 'hex')",
562 expr,
563 algorithm.postgres_name()
564 )
565 }
566 }
567 DatabaseType::MySQL => match algorithm {
568 HashAlgorithm::Md5 => format!("MD5({})", expr),
569 HashAlgorithm::Sha1 => format!("SHA1({})", expr),
570 HashAlgorithm::Sha256 => format!("SHA2({}, 256)", expr),
571 HashAlgorithm::Sha384 => format!("SHA2({}, 384)", expr),
572 HashAlgorithm::Sha512 => format!("SHA2({}, 512)", expr),
573 },
574 DatabaseType::MSSQL => {
575 let algo = match algorithm {
576 HashAlgorithm::Md5 => "MD5",
577 HashAlgorithm::Sha1 => "SHA1",
578 HashAlgorithm::Sha256 => "SHA2_256",
579 HashAlgorithm::Sha384 => "SHA2_384",
580 HashAlgorithm::Sha512 => "SHA2_512",
581 };
582 format!("CONVERT(VARCHAR(MAX), HASHBYTES('{}', {}), 2)", algo, expr)
583 }
584 DatabaseType::SQLite => {
585 format!("-- SQLite requires extension for hashing: {}", expr)
587 }
588 }
589 }
590
591 pub fn bcrypt_hash_postgres(password: &str) -> String {
593 format!("crypt('{}', gen_salt('bf'))", password)
594 }
595
596 pub fn bcrypt_verify_postgres(password: &str, hash_col: &str) -> String {
598 format!("{} = crypt('{}', {})", hash_col, password, hash_col)
599 }
600
601 pub fn random_bytes_sql(length: usize, db_type: DatabaseType) -> String {
603 match db_type {
604 DatabaseType::PostgreSQL => format!("gen_random_bytes({})", length),
605 DatabaseType::MySQL => format!("RANDOM_BYTES({})", length),
606 DatabaseType::MSSQL => format!("CRYPT_GEN_RANDOM({})", length),
607 DatabaseType::SQLite => format!("randomblob({})", length),
608 }
609 }
610
611 pub fn aes_encrypt_postgres(data: &str, key: &str) -> String {
613 format!("pgp_sym_encrypt('{}', '{}')", data, key)
614 }
615
616 pub fn aes_decrypt_postgres(encrypted_col: &str, key: &str) -> String {
618 format!("pgp_sym_decrypt({}, '{}')", encrypted_col, key)
619 }
620}
621
622pub mod vector {
628 use super::*;
629
630 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
632 pub struct Vector {
633 pub dimensions: Vec<f32>,
635 }
636
637 impl Vector {
638 pub fn new(dimensions: Vec<f32>) -> Self {
640 Self { dimensions }
641 }
642
643 pub fn from_slice(slice: &[f32]) -> Self {
645 Self {
646 dimensions: slice.to_vec(),
647 }
648 }
649
650 pub fn len(&self) -> usize {
652 self.dimensions.len()
653 }
654
655 pub fn is_empty(&self) -> bool {
657 self.dimensions.is_empty()
658 }
659
660 pub fn to_pgvector(&self) -> String {
662 let nums: Vec<String> = self.dimensions.iter().map(|f| f.to_string()).collect();
663 format!("'[{}]'::vector", nums.join(","))
664 }
665
666 pub fn to_mongodb(&self) -> serde_json::Value {
668 serde_json::json!(self.dimensions)
669 }
670 }
671
672 #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
674 pub enum SimilarityMetric {
675 L2,
677 InnerProduct,
679 Cosine,
681 }
682
683 impl SimilarityMetric {
684 pub fn postgres_operator(&self) -> &'static str {
686 match self {
687 Self::L2 => "<->",
688 Self::InnerProduct => "<#>",
689 Self::Cosine => "<=>",
690 }
691 }
692
693 pub fn mongodb_name(&self) -> &'static str {
695 match self {
696 Self::L2 => "euclidean",
697 Self::InnerProduct => "dotProduct",
698 Self::Cosine => "cosine",
699 }
700 }
701 }
702
703 pub fn similarity_search_postgres(
705 col: &str,
706 query_vector: &Vector,
707 metric: SimilarityMetric,
708 limit: usize,
709 ) -> String {
710 format!(
711 "SELECT *, {} {} {} AS distance FROM {{table}} ORDER BY distance LIMIT {}",
712 col,
713 metric.postgres_operator(),
714 query_vector.to_pgvector(),
715 limit
716 )
717 }
718
719 pub fn distance_sql(col: &str, query_vector: &Vector, metric: SimilarityMetric) -> String {
721 format!(
722 "{} {} {}",
723 col,
724 metric.postgres_operator(),
725 query_vector.to_pgvector()
726 )
727 }
728
729 pub fn create_index_postgres(
731 index_name: &str,
732 table: &str,
733 column: &str,
734 metric: SimilarityMetric,
735 lists: Option<usize>,
736 ) -> String {
737 let ops = match metric {
738 SimilarityMetric::L2 => "vector_l2_ops",
739 SimilarityMetric::InnerProduct => "vector_ip_ops",
740 SimilarityMetric::Cosine => "vector_cosine_ops",
741 };
742
743 let lists_clause = lists
744 .map(|l| format!(" WITH (lists = {})", l))
745 .unwrap_or_default();
746
747 format!(
748 "CREATE INDEX {} ON {} USING ivfflat ({} {}){}",
749 index_name, table, column, ops, lists_clause
750 )
751 }
752
753 pub fn create_hnsw_index_postgres(
755 index_name: &str,
756 table: &str,
757 column: &str,
758 metric: SimilarityMetric,
759 m: Option<usize>,
760 ef_construction: Option<usize>,
761 ) -> String {
762 let ops = match metric {
763 SimilarityMetric::L2 => "vector_l2_ops",
764 SimilarityMetric::InnerProduct => "vector_ip_ops",
765 SimilarityMetric::Cosine => "vector_cosine_ops",
766 };
767
768 let mut with_clauses = Vec::new();
769 if let Some(m_val) = m {
770 with_clauses.push(format!("m = {}", m_val));
771 }
772 if let Some(ef) = ef_construction {
773 with_clauses.push(format!("ef_construction = {}", ef));
774 }
775
776 let with_clause = if with_clauses.is_empty() {
777 String::new()
778 } else {
779 format!(" WITH ({})", with_clauses.join(", "))
780 };
781
782 format!(
783 "CREATE INDEX {} ON {} USING hnsw ({} {}){}",
784 index_name, table, column, ops, with_clause
785 )
786 }
787}
788
789pub mod mongodb {
791 use serde::{Deserialize, Serialize};
792 use serde_json::Value as JsonValue;
793
794 use super::vector::SimilarityMetric;
795
796 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
798 pub struct VectorSearch {
799 pub index: String,
801 pub path: String,
803 pub query_vector: Vec<f32>,
805 pub num_candidates: usize,
807 pub limit: usize,
809 pub filter: Option<JsonValue>,
811 }
812
813 impl VectorSearch {
814 pub fn new(
816 index: impl Into<String>,
817 path: impl Into<String>,
818 query: Vec<f32>,
819 ) -> VectorSearchBuilder {
820 VectorSearchBuilder::new(index, path, query)
821 }
822
823 pub fn to_stage(&self) -> JsonValue {
825 let mut search = serde_json::json!({
826 "index": self.index,
827 "path": self.path,
828 "queryVector": self.query_vector,
829 "numCandidates": self.num_candidates,
830 "limit": self.limit
831 });
832
833 if let Some(ref filter) = self.filter {
834 search["filter"] = filter.clone();
835 }
836
837 serde_json::json!({ "$vectorSearch": search })
838 }
839 }
840
841 #[derive(Debug, Clone)]
843 pub struct VectorSearchBuilder {
844 index: String,
845 path: String,
846 query_vector: Vec<f32>,
847 num_candidates: usize,
848 limit: usize,
849 filter: Option<JsonValue>,
850 }
851
852 impl VectorSearchBuilder {
853 pub fn new(index: impl Into<String>, path: impl Into<String>, query: Vec<f32>) -> Self {
855 Self {
856 index: index.into(),
857 path: path.into(),
858 query_vector: query,
859 num_candidates: 100,
860 limit: 10,
861 filter: None,
862 }
863 }
864
865 pub fn num_candidates(mut self, n: usize) -> Self {
867 self.num_candidates = n;
868 self
869 }
870
871 pub fn limit(mut self, n: usize) -> Self {
873 self.limit = n;
874 self
875 }
876
877 pub fn filter(mut self, filter: JsonValue) -> Self {
879 self.filter = Some(filter);
880 self
881 }
882
883 pub fn build(self) -> VectorSearch {
885 VectorSearch {
886 index: self.index,
887 path: self.path,
888 query_vector: self.query_vector,
889 num_candidates: self.num_candidates,
890 limit: self.limit,
891 filter: self.filter,
892 }
893 }
894 }
895
896 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
898 pub struct VectorIndex {
899 pub name: String,
901 pub collection: String,
903 pub fields: Vec<VectorField>,
905 }
906
907 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
909 pub struct VectorField {
910 pub path: String,
912 pub dimensions: usize,
914 pub similarity: String,
916 }
917
918 impl VectorIndex {
919 pub fn new(name: impl Into<String>, collection: impl Into<String>) -> VectorIndexBuilder {
921 VectorIndexBuilder::new(name, collection)
922 }
923
924 pub fn to_definition(&self) -> JsonValue {
926 let fields: Vec<JsonValue> = self
927 .fields
928 .iter()
929 .map(|f| {
930 serde_json::json!({
931 "type": "vector",
932 "path": f.path,
933 "numDimensions": f.dimensions,
934 "similarity": f.similarity
935 })
936 })
937 .collect();
938
939 serde_json::json!({
940 "name": self.name,
941 "type": "vectorSearch",
942 "fields": fields
943 })
944 }
945 }
946
947 #[derive(Debug, Clone)]
949 pub struct VectorIndexBuilder {
950 name: String,
951 collection: String,
952 fields: Vec<VectorField>,
953 }
954
955 impl VectorIndexBuilder {
956 pub fn new(name: impl Into<String>, collection: impl Into<String>) -> Self {
958 Self {
959 name: name.into(),
960 collection: collection.into(),
961 fields: Vec::new(),
962 }
963 }
964
965 pub fn field(
967 mut self,
968 path: impl Into<String>,
969 dimensions: usize,
970 similarity: SimilarityMetric,
971 ) -> Self {
972 self.fields.push(VectorField {
973 path: path.into(),
974 dimensions,
975 similarity: similarity.mongodb_name().to_string(),
976 });
977 self
978 }
979
980 pub fn build(self) -> VectorIndex {
982 VectorIndex {
983 name: self.name,
984 collection: self.collection,
985 fields: self.fields,
986 }
987 }
988 }
989
990 pub fn vector_search(index: &str, path: &str, query: Vec<f32>) -> VectorSearchBuilder {
992 VectorSearch::new(index, path, query)
993 }
994}
995
996#[cfg(test)]
997mod tests {
998 use super::*;
999
1000 #[test]
1001 fn test_extension_postgres() {
1002 let ext = Extension::new("postgis").schema("public").cascade().build();
1003 let sql = ext.to_postgres_create();
1004
1005 assert!(sql.contains("CREATE EXTENSION IF NOT EXISTS \"postgis\""));
1006 assert!(sql.contains("SCHEMA public"));
1007 assert!(sql.contains("CASCADE"));
1008 }
1009
1010 #[test]
1011 fn test_extension_drop() {
1012 let ext = Extension::postgis();
1013 let sql = ext.to_postgres_drop();
1014
1015 assert!(sql.contains("DROP EXTENSION IF EXISTS \"postgis\""));
1016 }
1017
1018 #[test]
1019 fn test_point_postgis() {
1020 let point = Point::wgs84(-122.4194, 37.7749);
1021 let sql = point.to_postgis();
1022
1023 assert!(sql.contains("ST_SetSRID"));
1024 assert!(sql.contains("-122.4194"));
1025 assert!(sql.contains("37.7749"));
1026 assert!(sql.contains("4326"));
1027 }
1028
1029 #[test]
1030 fn test_point_geojson() {
1031 let point = Point::new(-122.4194, 37.7749);
1032 let geojson = point.to_geojson();
1033
1034 assert_eq!(geojson["type"], "Point");
1035 assert_eq!(geojson["coordinates"][0], -122.4194);
1036 }
1037
1038 #[test]
1039 fn test_polygon_wkt() {
1040 let polygon = Polygon::new(vec![
1041 (0.0, 0.0),
1042 (10.0, 0.0),
1043 (10.0, 10.0),
1044 (0.0, 10.0),
1045 (0.0, 0.0),
1046 ]);
1047
1048 let wkt = polygon.to_wkt();
1049 assert!(wkt.starts_with("POLYGON(("));
1050 }
1051
1052 #[test]
1053 fn test_distance_sql() {
1054 let sql = geo::distance_sql("location", "target", DatabaseType::PostgreSQL);
1055 assert!(sql.contains("ST_Distance"));
1056 }
1057
1058 #[test]
1059 fn test_within_distance() {
1060 let point = Point::wgs84(-122.4194, 37.7749);
1061 let sql = geo::within_distance_sql("location", &point, 1000.0, DatabaseType::PostgreSQL);
1062
1063 assert!(sql.contains("ST_DWithin"));
1064 assert!(sql.contains("1000"));
1065 }
1066
1067 #[test]
1068 fn test_uuid_generation() {
1069 let pg = uuid::generate_v4(DatabaseType::PostgreSQL);
1070 assert_eq!(pg, "gen_random_uuid()");
1071
1072 let mysql = uuid::generate_v4(DatabaseType::MySQL);
1073 assert_eq!(mysql, "UUID()");
1074
1075 let mssql = uuid::generate_v4(DatabaseType::MSSQL);
1076 assert_eq!(mssql, "NEWID()");
1077 }
1078
1079 #[test]
1080 fn test_hash_sql() {
1081 let pg = crypto::hash_sql(
1082 "password",
1083 crypto::HashAlgorithm::Sha256,
1084 DatabaseType::PostgreSQL,
1085 );
1086 assert!(pg.contains("digest"));
1087 assert!(pg.contains("sha256"));
1088
1089 let mysql = crypto::hash_sql(
1090 "password",
1091 crypto::HashAlgorithm::Sha256,
1092 DatabaseType::MySQL,
1093 );
1094 assert!(mysql.contains("SHA2"));
1095 assert!(mysql.contains("256"));
1096 }
1097
1098 #[test]
1099 fn test_vector_pgvector() {
1100 let vec = vector::Vector::new(vec![0.1, 0.2, 0.3, 0.4]);
1101 let sql = vec.to_pgvector();
1102
1103 assert!(sql.contains("'[0.1,0.2,0.3,0.4]'::vector"));
1104 }
1105
1106 #[test]
1107 fn test_vector_index() {
1108 let sql = vector::create_index_postgres(
1109 "embeddings_idx",
1110 "documents",
1111 "embedding",
1112 vector::SimilarityMetric::Cosine,
1113 Some(100),
1114 );
1115
1116 assert!(sql.contains("CREATE INDEX embeddings_idx"));
1117 assert!(sql.contains("USING ivfflat"));
1118 assert!(sql.contains("vector_cosine_ops"));
1119 assert!(sql.contains("lists = 100"));
1120 }
1121
1122 #[test]
1123 fn test_hnsw_index() {
1124 let sql = vector::create_hnsw_index_postgres(
1125 "embeddings_hnsw",
1126 "documents",
1127 "embedding",
1128 vector::SimilarityMetric::L2,
1129 Some(16),
1130 Some(64),
1131 );
1132
1133 assert!(sql.contains("USING hnsw"));
1134 assert!(sql.contains("m = 16"));
1135 assert!(sql.contains("ef_construction = 64"));
1136 }
1137
1138 mod mongodb_tests {
1139 use super::super::mongodb::*;
1140 use super::super::vector::SimilarityMetric;
1141
1142 #[test]
1143 fn test_vector_search() {
1144 let search = vector_search("vector_index", "embedding", vec![0.1, 0.2, 0.3])
1145 .num_candidates(200)
1146 .limit(20)
1147 .build();
1148
1149 let stage = search.to_stage();
1150 assert!(stage["$vectorSearch"]["index"].is_string());
1151 assert_eq!(stage["$vectorSearch"]["numCandidates"], 200);
1152 assert_eq!(stage["$vectorSearch"]["limit"], 20);
1153 }
1154
1155 #[test]
1156 fn test_vector_index_definition() {
1157 let index = VectorIndex::new("my_vector_index", "documents")
1158 .field("embedding", 1536, SimilarityMetric::Cosine)
1159 .build();
1160
1161 let def = index.to_definition();
1162 assert_eq!(def["name"], "my_vector_index");
1163 assert_eq!(def["type"], "vectorSearch");
1164 assert!(def["fields"].is_array());
1165 }
1166 }
1167}