Skip to main content

prax_query/
extension.rs

1//! Database extensions and plugins support.
2//!
3//! This module provides types for managing database extensions and
4//! specialized functionality like geospatial, UUID, cryptography, and vector search.
5//!
6//! # Database Support
7//!
8//! | Feature        | PostgreSQL       | MySQL      | SQLite        | MSSQL      | MongoDB        |
9//! |----------------|------------------|------------|---------------|------------|----------------|
10//! | Extensions     | ✅ CREATE EXT    | ❌         | ✅ load_ext   | ❌         | ❌             |
11//! | Geospatial     | ✅ PostGIS       | ✅ Spatial | ✅ SpatiaLite | ✅         | ✅ GeoJSON     |
12//! | UUID           | ✅ uuid-ossp     | ✅ built-in| ❌            | ✅ NEWID() | ✅ UUID()      |
13//! | Cryptography   | ✅ pgcrypto      | ✅ built-in| ❌            | ✅         | ✅             |
14//! | Vector Search  | ✅ pgvector      | ❌         | ❌            | ❌         | ✅ Atlas Vector|
15
16use serde::{Deserialize, Serialize};
17
18use crate::sql::DatabaseType;
19
20// ============================================================================
21// Extension Management
22// ============================================================================
23
24/// A database extension.
25#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
26pub struct Extension {
27    /// Extension name.
28    pub name: String,
29    /// Schema to install in (PostgreSQL).
30    pub schema: Option<String>,
31    /// Version to install.
32    pub version: Option<String>,
33    /// Whether to cascade dependencies.
34    pub cascade: bool,
35}
36
37impl Extension {
38    /// Create a new extension.
39    pub fn new(name: impl Into<String>) -> ExtensionBuilder {
40        ExtensionBuilder::new(name)
41    }
42
43    /// Common PostgreSQL extensions.
44    pub fn postgis() -> Self {
45        Self::new("postgis").build()
46    }
47
48    pub fn pgvector() -> Self {
49        Self::new("vector").build()
50    }
51
52    pub fn uuid_ossp() -> Self {
53        Self::new("uuid-ossp").build()
54    }
55
56    pub fn pgcrypto() -> Self {
57        Self::new("pgcrypto").build()
58    }
59
60    pub fn pg_trgm() -> Self {
61        Self::new("pg_trgm").build()
62    }
63
64    pub fn hstore() -> Self {
65        Self::new("hstore").build()
66    }
67
68    pub fn ltree() -> Self {
69        Self::new("ltree").build()
70    }
71
72    /// Generate PostgreSQL CREATE EXTENSION SQL.
73    pub fn to_postgres_create(&self) -> String {
74        let mut sql = format!("CREATE EXTENSION IF NOT EXISTS \"{}\"", self.name);
75
76        if let Some(ref schema) = self.schema {
77            sql.push_str(&format!(" SCHEMA {}", schema));
78        }
79
80        if let Some(ref version) = self.version {
81            sql.push_str(&format!(" VERSION '{}'", version));
82        }
83
84        if self.cascade {
85            sql.push_str(" CASCADE");
86        }
87
88        sql
89    }
90
91    /// Generate DROP EXTENSION SQL.
92    pub fn to_postgres_drop(&self) -> String {
93        let mut sql = format!("DROP EXTENSION IF EXISTS \"{}\"", self.name);
94        if self.cascade {
95            sql.push_str(" CASCADE");
96        }
97        sql
98    }
99
100    /// Generate SQLite load extension command.
101    pub fn to_sqlite_load(&self) -> String {
102        format!("SELECT load_extension('{}')", self.name)
103    }
104}
105
106/// Builder for extensions.
107#[derive(Debug, Clone)]
108pub struct ExtensionBuilder {
109    name: String,
110    schema: Option<String>,
111    version: Option<String>,
112    cascade: bool,
113}
114
115impl ExtensionBuilder {
116    /// Create a new builder.
117    pub fn new(name: impl Into<String>) -> Self {
118        Self {
119            name: name.into(),
120            schema: None,
121            version: None,
122            cascade: false,
123        }
124    }
125
126    /// Set the schema.
127    pub fn schema(mut self, schema: impl Into<String>) -> Self {
128        self.schema = Some(schema.into());
129        self
130    }
131
132    /// Set the version.
133    pub fn version(mut self, version: impl Into<String>) -> Self {
134        self.version = Some(version.into());
135        self
136    }
137
138    /// Enable CASCADE.
139    pub fn cascade(mut self) -> Self {
140        self.cascade = true;
141        self
142    }
143
144    /// Build the extension.
145    pub fn build(self) -> Extension {
146        Extension {
147            name: self.name,
148            schema: self.schema,
149            version: self.version,
150            cascade: self.cascade,
151        }
152    }
153}
154
155// ============================================================================
156// Geospatial Types
157// ============================================================================
158
159/// A geographic point (longitude, latitude).
160#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
161pub struct Point {
162    /// Longitude (-180 to 180).
163    pub longitude: f64,
164    /// Latitude (-90 to 90).
165    pub latitude: f64,
166    /// Optional SRID (spatial reference ID).
167    pub srid: Option<i32>,
168}
169
170impl Point {
171    /// Create a new point.
172    pub fn new(longitude: f64, latitude: f64) -> Self {
173        Self {
174            longitude,
175            latitude,
176            srid: None,
177        }
178    }
179
180    /// Create with SRID.
181    pub fn with_srid(longitude: f64, latitude: f64, srid: i32) -> Self {
182        Self {
183            longitude,
184            latitude,
185            srid: Some(srid),
186        }
187    }
188
189    /// WGS84 SRID (standard GPS).
190    pub fn wgs84(longitude: f64, latitude: f64) -> Self {
191        Self::with_srid(longitude, latitude, 4326)
192    }
193
194    /// Generate PostGIS point.
195    pub fn to_postgis(&self) -> String {
196        if let Some(srid) = self.srid {
197            format!(
198                "ST_SetSRID(ST_MakePoint({}, {}), {})",
199                self.longitude, self.latitude, srid
200            )
201        } else {
202            format!("ST_MakePoint({}, {})", self.longitude, self.latitude)
203        }
204    }
205
206    /// Generate MySQL spatial point.
207    pub fn to_mysql(&self) -> String {
208        if let Some(srid) = self.srid {
209            format!(
210                "ST_GeomFromText('POINT({} {})', {})",
211                self.longitude, self.latitude, srid
212            )
213        } else {
214            format!(
215                "ST_GeomFromText('POINT({} {})')",
216                self.longitude, self.latitude
217            )
218        }
219    }
220
221    /// Generate MSSQL geography point.
222    pub fn to_mssql(&self) -> String {
223        format!(
224            "geography::Point({}, {}, {})",
225            self.latitude,
226            self.longitude,
227            self.srid.unwrap_or(4326)
228        )
229    }
230
231    /// Generate GeoJSON.
232    pub fn to_geojson(&self) -> serde_json::Value {
233        serde_json::json!({
234            "type": "Point",
235            "coordinates": [self.longitude, self.latitude]
236        })
237    }
238
239    /// Generate SQL for database type.
240    pub fn to_sql(&self, db_type: DatabaseType) -> String {
241        match db_type {
242            DatabaseType::PostgreSQL => self.to_postgis(),
243            DatabaseType::MySQL => self.to_mysql(),
244            DatabaseType::MSSQL => self.to_mssql(),
245            DatabaseType::SQLite => format!("MakePoint({}, {})", self.longitude, self.latitude),
246        }
247    }
248}
249
250/// A polygon (list of points forming a closed ring).
251#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
252pub struct Polygon {
253    /// Exterior ring coordinates.
254    pub exterior: Vec<(f64, f64)>,
255    /// Interior rings (holes).
256    pub interiors: Vec<Vec<(f64, f64)>>,
257    /// SRID.
258    pub srid: Option<i32>,
259}
260
261impl Polygon {
262    /// Create a new polygon from coordinates.
263    pub fn new(exterior: Vec<(f64, f64)>) -> Self {
264        Self {
265            exterior,
266            interiors: Vec::new(),
267            srid: None,
268        }
269    }
270
271    /// Add an interior ring (hole).
272    pub fn with_hole(mut self, hole: Vec<(f64, f64)>) -> Self {
273        self.interiors.push(hole);
274        self
275    }
276
277    /// Set SRID.
278    pub fn with_srid(mut self, srid: i32) -> Self {
279        self.srid = Some(srid);
280        self
281    }
282
283    /// Generate WKT (Well-Known Text).
284    pub fn to_wkt(&self) -> String {
285        let ext_coords: Vec<String> = self
286            .exterior
287            .iter()
288            .map(|(x, y)| format!("{} {}", x, y))
289            .collect();
290
291        let mut wkt = format!("POLYGON(({})", ext_coords.join(", "));
292
293        for interior in &self.interiors {
294            let int_coords: Vec<String> = interior
295                .iter()
296                .map(|(x, y)| format!("{} {}", x, y))
297                .collect();
298            wkt.push_str(&format!(", ({})", int_coords.join(", ")));
299        }
300
301        wkt.push(')');
302        wkt
303    }
304
305    /// Generate PostGIS polygon.
306    pub fn to_postgis(&self) -> String {
307        if let Some(srid) = self.srid {
308            format!("ST_GeomFromText('{}', {})", self.to_wkt(), srid)
309        } else {
310            format!("ST_GeomFromText('{}')", self.to_wkt())
311        }
312    }
313
314    /// Generate GeoJSON.
315    pub fn to_geojson(&self) -> serde_json::Value {
316        let mut coordinates = vec![
317            self.exterior
318                .iter()
319                .map(|(x, y)| vec![*x, *y])
320                .collect::<Vec<_>>(),
321        ];
322
323        for interior in &self.interiors {
324            coordinates.push(interior.iter().map(|(x, y)| vec![*x, *y]).collect());
325        }
326
327        serde_json::json!({
328            "type": "Polygon",
329            "coordinates": coordinates
330        })
331    }
332}
333
334/// Geospatial operations.
335pub mod geo {
336    use super::*;
337
338    /// Distance calculation.
339    #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
340    pub enum DistanceUnit {
341        /// Meters.
342        Meters,
343        /// Kilometers.
344        Kilometers,
345        /// Miles.
346        Miles,
347        /// Feet.
348        Feet,
349    }
350
351    impl DistanceUnit {
352        /// Conversion factor from meters.
353        pub fn from_meters(&self) -> f64 {
354            match self {
355                Self::Meters => 1.0,
356                Self::Kilometers => 0.001,
357                Self::Miles => 0.000621371,
358                Self::Feet => 3.28084,
359            }
360        }
361    }
362
363    /// Generate distance SQL between two columns.
364    pub fn distance_sql(col1: &str, col2: &str, db_type: DatabaseType) -> String {
365        match db_type {
366            DatabaseType::PostgreSQL => {
367                format!("ST_Distance({}::geography, {}::geography)", col1, col2)
368            }
369            DatabaseType::MySQL => format!("ST_Distance_Sphere({}, {})", col1, col2),
370            DatabaseType::MSSQL => format!("{}.STDistance({})", col1, col2),
371            DatabaseType::SQLite => format!("Distance({}, {})", col1, col2),
372        }
373    }
374
375    /// Generate distance from point SQL.
376    pub fn distance_from_point_sql(col: &str, point: &Point, db_type: DatabaseType) -> String {
377        let point_sql = point.to_sql(db_type);
378        match db_type {
379            DatabaseType::PostgreSQL => {
380                format!("ST_Distance({}::geography, {}::geography)", col, point_sql)
381            }
382            DatabaseType::MySQL => format!("ST_Distance_Sphere({}, {})", col, point_sql),
383            DatabaseType::MSSQL => format!("{}.STDistance({})", col, point_sql),
384            DatabaseType::SQLite => format!("Distance({}, {})", col, point_sql),
385        }
386    }
387
388    /// Generate "within distance" filter SQL.
389    pub fn within_distance_sql(
390        col: &str,
391        point: &Point,
392        distance_meters: f64,
393        db_type: DatabaseType,
394    ) -> String {
395        let point_sql = point.to_sql(db_type);
396        match db_type {
397            DatabaseType::PostgreSQL => {
398                format!(
399                    "ST_DWithin({}::geography, {}::geography, {})",
400                    col, point_sql, distance_meters
401                )
402            }
403            DatabaseType::MySQL => {
404                format!(
405                    "ST_Distance_Sphere({}, {}) <= {}",
406                    col, point_sql, distance_meters
407                )
408            }
409            DatabaseType::MSSQL => {
410                format!("{}.STDistance({}) <= {}", col, point_sql, distance_meters)
411            }
412            DatabaseType::SQLite => {
413                format!("Distance({}, {}) <= {}", col, point_sql, distance_meters)
414            }
415        }
416    }
417
418    /// Generate "contains" filter SQL.
419    pub fn contains_sql(geom_col: &str, point: &Point, db_type: DatabaseType) -> String {
420        let point_sql = point.to_sql(db_type);
421        match db_type {
422            DatabaseType::PostgreSQL => format!("ST_Contains({}, {})", geom_col, point_sql),
423            DatabaseType::MySQL => format!("ST_Contains({}, {})", geom_col, point_sql),
424            DatabaseType::MSSQL => format!("{}.STContains({})", geom_col, point_sql),
425            DatabaseType::SQLite => format!("Contains({}, {})", geom_col, point_sql),
426        }
427    }
428
429    /// Generate bounding box filter SQL.
430    pub fn bbox_sql(
431        col: &str,
432        min_lon: f64,
433        min_lat: f64,
434        max_lon: f64,
435        max_lat: f64,
436        db_type: DatabaseType,
437    ) -> String {
438        match db_type {
439            DatabaseType::PostgreSQL => {
440                format!(
441                    "{} && ST_MakeEnvelope({}, {}, {}, {}, 4326)",
442                    col, min_lon, min_lat, max_lon, max_lat
443                )
444            }
445            DatabaseType::MySQL => {
446                format!(
447                    "MBRContains(ST_GeomFromText('POLYGON(({} {}, {} {}, {} {}, {} {}, {} {}))'), {})",
448                    min_lon,
449                    min_lat,
450                    max_lon,
451                    min_lat,
452                    max_lon,
453                    max_lat,
454                    min_lon,
455                    max_lat,
456                    min_lon,
457                    min_lat,
458                    col
459                )
460            }
461            _ => "1=1".to_string(),
462        }
463    }
464}
465
466// ============================================================================
467// UUID Support
468// ============================================================================
469
470/// UUID generation helpers.
471pub mod uuid {
472    use super::*;
473
474    /// Generate UUID v4 SQL.
475    pub fn generate_v4(db_type: DatabaseType) -> String {
476        match db_type {
477            DatabaseType::PostgreSQL => "gen_random_uuid()".to_string(),
478            DatabaseType::MySQL => "UUID()".to_string(),
479            DatabaseType::MSSQL => "NEWID()".to_string(),
480            DatabaseType::SQLite => {
481                // SQLite needs custom function or hex/randomblob
482                "lower(hex(randomblob(4))) || '-' || lower(hex(randomblob(2))) || '-4' || \
483                 substr(lower(hex(randomblob(2))), 2) || '-' || \
484                 substr('89ab', abs(random()) % 4 + 1, 1) || \
485                 substr(lower(hex(randomblob(2))), 2) || '-' || lower(hex(randomblob(6)))"
486                    .to_string()
487            }
488        }
489    }
490
491    /// Generate UUID v7 SQL (PostgreSQL with uuid-ossp or pg_uuidv7).
492    pub fn generate_v7_postgres() -> String {
493        "uuid_generate_v7()".to_string()
494    }
495
496    /// Generate UUID from string SQL.
497    pub fn from_string(value: &str, db_type: DatabaseType) -> String {
498        match db_type {
499            DatabaseType::PostgreSQL => format!("'{}'::uuid", value),
500            DatabaseType::MySQL => format!("UUID_TO_BIN('{}')", value),
501            DatabaseType::MSSQL => format!("CONVERT(UNIQUEIDENTIFIER, '{}')", value),
502            DatabaseType::SQLite => format!("'{}'", value),
503        }
504    }
505
506    /// Check if valid UUID SQL.
507    pub fn is_valid_sql(col: &str, db_type: DatabaseType) -> String {
508        match db_type {
509            DatabaseType::PostgreSQL => format!(
510                "{} ~ '^[0-9a-f]{{8}}-[0-9a-f]{{4}}-[0-9a-f]{{4}}-[0-9a-f]{{4}}-[0-9a-f]{{12}}$'",
511                col
512            ),
513            DatabaseType::MySQL => format!(
514                "{} REGEXP '^[0-9a-f]{{8}}-[0-9a-f]{{4}}-[0-9a-f]{{4}}-[0-9a-f]{{4}}-[0-9a-f]{{12}}$'",
515                col
516            ),
517            _ => format!("LEN({}) = 36", col),
518        }
519    }
520}
521
522// ============================================================================
523// Cryptography
524// ============================================================================
525
526/// Cryptographic functions.
527pub mod crypto {
528    use super::*;
529
530    /// Hash algorithms.
531    #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
532    pub enum HashAlgorithm {
533        Md5,
534        Sha1,
535        Sha256,
536        Sha384,
537        Sha512,
538    }
539
540    impl HashAlgorithm {
541        /// PostgreSQL algorithm name.
542        pub fn postgres_name(&self) -> &'static str {
543            match self {
544                Self::Md5 => "md5",
545                Self::Sha1 => "sha1",
546                Self::Sha256 => "sha256",
547                Self::Sha384 => "sha384",
548                Self::Sha512 => "sha512",
549            }
550        }
551    }
552
553    /// Generate hash SQL.
554    pub fn hash_sql(expr: &str, algorithm: HashAlgorithm, db_type: DatabaseType) -> String {
555        match db_type {
556            DatabaseType::PostgreSQL => {
557                if algorithm == HashAlgorithm::Md5 {
558                    format!("md5({})", expr)
559                } else {
560                    format!(
561                        "encode(digest({}, '{}'), 'hex')",
562                        expr,
563                        algorithm.postgres_name()
564                    )
565                }
566            }
567            DatabaseType::MySQL => match algorithm {
568                HashAlgorithm::Md5 => format!("MD5({})", expr),
569                HashAlgorithm::Sha1 => format!("SHA1({})", expr),
570                HashAlgorithm::Sha256 => format!("SHA2({}, 256)", expr),
571                HashAlgorithm::Sha384 => format!("SHA2({}, 384)", expr),
572                HashAlgorithm::Sha512 => format!("SHA2({}, 512)", expr),
573            },
574            DatabaseType::MSSQL => {
575                let algo = match algorithm {
576                    HashAlgorithm::Md5 => "MD5",
577                    HashAlgorithm::Sha1 => "SHA1",
578                    HashAlgorithm::Sha256 => "SHA2_256",
579                    HashAlgorithm::Sha384 => "SHA2_384",
580                    HashAlgorithm::Sha512 => "SHA2_512",
581                };
582                format!("CONVERT(VARCHAR(MAX), HASHBYTES('{}', {}), 2)", algo, expr)
583            }
584            DatabaseType::SQLite => {
585                // SQLite doesn't have built-in hashing
586                format!("-- SQLite requires extension for hashing: {}", expr)
587            }
588        }
589    }
590
591    /// Generate bcrypt hash SQL (PostgreSQL with pgcrypto).
592    pub fn bcrypt_hash_postgres(password: &str) -> String {
593        format!("crypt('{}', gen_salt('bf'))", password)
594    }
595
596    /// Generate bcrypt verify SQL (PostgreSQL).
597    pub fn bcrypt_verify_postgres(password: &str, hash_col: &str) -> String {
598        format!("{} = crypt('{}', {})", hash_col, password, hash_col)
599    }
600
601    /// Generate random bytes SQL.
602    pub fn random_bytes_sql(length: usize, db_type: DatabaseType) -> String {
603        match db_type {
604            DatabaseType::PostgreSQL => format!("gen_random_bytes({})", length),
605            DatabaseType::MySQL => format!("RANDOM_BYTES({})", length),
606            DatabaseType::MSSQL => format!("CRYPT_GEN_RANDOM({})", length),
607            DatabaseType::SQLite => format!("randomblob({})", length),
608        }
609    }
610
611    /// Generate AES encrypt SQL (PostgreSQL with pgcrypto).
612    pub fn aes_encrypt_postgres(data: &str, key: &str) -> String {
613        format!("pgp_sym_encrypt('{}', '{}')", data, key)
614    }
615
616    /// Generate AES decrypt SQL (PostgreSQL with pgcrypto).
617    pub fn aes_decrypt_postgres(encrypted_col: &str, key: &str) -> String {
618        format!("pgp_sym_decrypt({}, '{}')", encrypted_col, key)
619    }
620}
621
622// ============================================================================
623// Vector / Embeddings (pgvector, MongoDB Atlas Vector)
624// ============================================================================
625
626/// Vector operations for AI/ML embeddings.
627pub mod vector {
628    use super::*;
629
630    /// A vector embedding.
631    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
632    pub struct Vector {
633        /// Vector dimensions.
634        pub dimensions: Vec<f32>,
635    }
636
637    impl Vector {
638        /// Create a new vector.
639        pub fn new(dimensions: Vec<f32>) -> Self {
640            Self { dimensions }
641        }
642
643        /// Create from slice.
644        pub fn from_slice(slice: &[f32]) -> Self {
645            Self {
646                dimensions: slice.to_vec(),
647            }
648        }
649
650        /// Get dimension count.
651        pub fn len(&self) -> usize {
652            self.dimensions.len()
653        }
654
655        /// Check if empty.
656        pub fn is_empty(&self) -> bool {
657            self.dimensions.is_empty()
658        }
659
660        /// Generate PostgreSQL pgvector literal.
661        pub fn to_pgvector(&self) -> String {
662            let nums: Vec<String> = self.dimensions.iter().map(|f| f.to_string()).collect();
663            format!("'[{}]'::vector", nums.join(","))
664        }
665
666        /// Generate MongoDB array.
667        pub fn to_mongodb(&self) -> serde_json::Value {
668            serde_json::json!(self.dimensions)
669        }
670    }
671
672    /// Vector similarity metrics.
673    #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
674    pub enum SimilarityMetric {
675        /// Euclidean distance (L2).
676        L2,
677        /// Inner product.
678        InnerProduct,
679        /// Cosine similarity.
680        Cosine,
681    }
682
683    impl SimilarityMetric {
684        /// PostgreSQL operator.
685        pub fn postgres_operator(&self) -> &'static str {
686            match self {
687                Self::L2 => "<->",
688                Self::InnerProduct => "<#>",
689                Self::Cosine => "<=>",
690            }
691        }
692
693        /// MongoDB $vectorSearch similarity.
694        pub fn mongodb_name(&self) -> &'static str {
695            match self {
696                Self::L2 => "euclidean",
697                Self::InnerProduct => "dotProduct",
698                Self::Cosine => "cosine",
699            }
700        }
701    }
702
703    /// Generate vector similarity search SQL (PostgreSQL pgvector).
704    pub fn similarity_search_postgres(
705        col: &str,
706        query_vector: &Vector,
707        metric: SimilarityMetric,
708        limit: usize,
709    ) -> String {
710        format!(
711            "SELECT *, {} {} {} AS distance FROM {{table}} ORDER BY distance LIMIT {}",
712            col,
713            metric.postgres_operator(),
714            query_vector.to_pgvector(),
715            limit
716        )
717    }
718
719    /// Generate vector distance SQL.
720    pub fn distance_sql(col: &str, query_vector: &Vector, metric: SimilarityMetric) -> String {
721        format!(
722            "{} {} {}",
723            col,
724            metric.postgres_operator(),
725            query_vector.to_pgvector()
726        )
727    }
728
729    /// Generate vector index SQL (PostgreSQL).
730    pub fn create_index_postgres(
731        index_name: &str,
732        table: &str,
733        column: &str,
734        metric: SimilarityMetric,
735        lists: Option<usize>,
736    ) -> String {
737        let ops = match metric {
738            SimilarityMetric::L2 => "vector_l2_ops",
739            SimilarityMetric::InnerProduct => "vector_ip_ops",
740            SimilarityMetric::Cosine => "vector_cosine_ops",
741        };
742
743        let lists_clause = lists
744            .map(|l| format!(" WITH (lists = {})", l))
745            .unwrap_or_default();
746
747        format!(
748            "CREATE INDEX {} ON {} USING ivfflat ({} {}){}",
749            index_name, table, column, ops, lists_clause
750        )
751    }
752
753    /// Create HNSW index (PostgreSQL pgvector 0.5+).
754    pub fn create_hnsw_index_postgres(
755        index_name: &str,
756        table: &str,
757        column: &str,
758        metric: SimilarityMetric,
759        m: Option<usize>,
760        ef_construction: Option<usize>,
761    ) -> String {
762        let ops = match metric {
763            SimilarityMetric::L2 => "vector_l2_ops",
764            SimilarityMetric::InnerProduct => "vector_ip_ops",
765            SimilarityMetric::Cosine => "vector_cosine_ops",
766        };
767
768        let mut with_clauses = Vec::new();
769        if let Some(m_val) = m {
770            with_clauses.push(format!("m = {}", m_val));
771        }
772        if let Some(ef) = ef_construction {
773            with_clauses.push(format!("ef_construction = {}", ef));
774        }
775
776        let with_clause = if with_clauses.is_empty() {
777            String::new()
778        } else {
779            format!(" WITH ({})", with_clauses.join(", "))
780        };
781
782        format!(
783            "CREATE INDEX {} ON {} USING hnsw ({} {}){}",
784            index_name, table, column, ops, with_clause
785        )
786    }
787}
788
789/// MongoDB Atlas Vector Search support.
790pub mod mongodb {
791    use serde::{Deserialize, Serialize};
792    use serde_json::Value as JsonValue;
793
794    use super::vector::SimilarityMetric;
795
796    /// MongoDB Atlas Vector Search query.
797    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
798    pub struct VectorSearch {
799        /// Index name.
800        pub index: String,
801        /// Path to vector field.
802        pub path: String,
803        /// Query vector.
804        pub query_vector: Vec<f32>,
805        /// Number of results.
806        pub num_candidates: usize,
807        /// Limit.
808        pub limit: usize,
809        /// Optional filter.
810        pub filter: Option<JsonValue>,
811    }
812
813    impl VectorSearch {
814        /// Create a new vector search.
815        pub fn new(
816            index: impl Into<String>,
817            path: impl Into<String>,
818            query: Vec<f32>,
819        ) -> VectorSearchBuilder {
820            VectorSearchBuilder::new(index, path, query)
821        }
822
823        /// Convert to $vectorSearch stage.
824        pub fn to_stage(&self) -> JsonValue {
825            let mut search = serde_json::json!({
826                "index": self.index,
827                "path": self.path,
828                "queryVector": self.query_vector,
829                "numCandidates": self.num_candidates,
830                "limit": self.limit
831            });
832
833            if let Some(ref filter) = self.filter {
834                search["filter"] = filter.clone();
835            }
836
837            serde_json::json!({ "$vectorSearch": search })
838        }
839    }
840
841    /// Builder for vector search.
842    #[derive(Debug, Clone)]
843    pub struct VectorSearchBuilder {
844        index: String,
845        path: String,
846        query_vector: Vec<f32>,
847        num_candidates: usize,
848        limit: usize,
849        filter: Option<JsonValue>,
850    }
851
852    impl VectorSearchBuilder {
853        /// Create a new builder.
854        pub fn new(index: impl Into<String>, path: impl Into<String>, query: Vec<f32>) -> Self {
855            Self {
856                index: index.into(),
857                path: path.into(),
858                query_vector: query,
859                num_candidates: 100,
860                limit: 10,
861                filter: None,
862            }
863        }
864
865        /// Set number of candidates.
866        pub fn num_candidates(mut self, n: usize) -> Self {
867            self.num_candidates = n;
868            self
869        }
870
871        /// Set limit.
872        pub fn limit(mut self, n: usize) -> Self {
873            self.limit = n;
874            self
875        }
876
877        /// Add filter.
878        pub fn filter(mut self, filter: JsonValue) -> Self {
879            self.filter = Some(filter);
880            self
881        }
882
883        /// Build the search.
884        pub fn build(self) -> VectorSearch {
885            VectorSearch {
886                index: self.index,
887                path: self.path,
888                query_vector: self.query_vector,
889                num_candidates: self.num_candidates,
890                limit: self.limit,
891                filter: self.filter,
892            }
893        }
894    }
895
896    /// MongoDB Atlas Search index definition for vectors.
897    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
898    pub struct VectorIndex {
899        /// Index name.
900        pub name: String,
901        /// Collection name.
902        pub collection: String,
903        /// Vector field definitions.
904        pub fields: Vec<VectorField>,
905    }
906
907    /// Vector field definition.
908    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
909    pub struct VectorField {
910        /// Field path.
911        pub path: String,
912        /// Number of dimensions.
913        pub dimensions: usize,
914        /// Similarity metric.
915        pub similarity: String,
916    }
917
918    impl VectorIndex {
919        /// Create a new vector index definition.
920        pub fn new(name: impl Into<String>, collection: impl Into<String>) -> VectorIndexBuilder {
921            VectorIndexBuilder::new(name, collection)
922        }
923
924        /// Convert to index definition.
925        pub fn to_definition(&self) -> JsonValue {
926            let fields: Vec<JsonValue> = self
927                .fields
928                .iter()
929                .map(|f| {
930                    serde_json::json!({
931                        "type": "vector",
932                        "path": f.path,
933                        "numDimensions": f.dimensions,
934                        "similarity": f.similarity
935                    })
936                })
937                .collect();
938
939            serde_json::json!({
940                "name": self.name,
941                "type": "vectorSearch",
942                "fields": fields
943            })
944        }
945    }
946
947    /// Builder for vector index.
948    #[derive(Debug, Clone)]
949    pub struct VectorIndexBuilder {
950        name: String,
951        collection: String,
952        fields: Vec<VectorField>,
953    }
954
955    impl VectorIndexBuilder {
956        /// Create a new builder.
957        pub fn new(name: impl Into<String>, collection: impl Into<String>) -> Self {
958            Self {
959                name: name.into(),
960                collection: collection.into(),
961                fields: Vec::new(),
962            }
963        }
964
965        /// Add a vector field.
966        pub fn field(
967            mut self,
968            path: impl Into<String>,
969            dimensions: usize,
970            similarity: SimilarityMetric,
971        ) -> Self {
972            self.fields.push(VectorField {
973                path: path.into(),
974                dimensions,
975                similarity: similarity.mongodb_name().to_string(),
976            });
977            self
978        }
979
980        /// Build the index.
981        pub fn build(self) -> VectorIndex {
982            VectorIndex {
983                name: self.name,
984                collection: self.collection,
985                fields: self.fields,
986            }
987        }
988    }
989
990    /// Helper to create a vector search.
991    pub fn vector_search(index: &str, path: &str, query: Vec<f32>) -> VectorSearchBuilder {
992        VectorSearch::new(index, path, query)
993    }
994}
995
996#[cfg(test)]
997mod tests {
998    use super::*;
999
1000    #[test]
1001    fn test_extension_postgres() {
1002        let ext = Extension::new("postgis").schema("public").cascade().build();
1003        let sql = ext.to_postgres_create();
1004
1005        assert!(sql.contains("CREATE EXTENSION IF NOT EXISTS \"postgis\""));
1006        assert!(sql.contains("SCHEMA public"));
1007        assert!(sql.contains("CASCADE"));
1008    }
1009
1010    #[test]
1011    fn test_extension_drop() {
1012        let ext = Extension::postgis();
1013        let sql = ext.to_postgres_drop();
1014
1015        assert!(sql.contains("DROP EXTENSION IF EXISTS \"postgis\""));
1016    }
1017
1018    #[test]
1019    fn test_point_postgis() {
1020        let point = Point::wgs84(-122.4194, 37.7749);
1021        let sql = point.to_postgis();
1022
1023        assert!(sql.contains("ST_SetSRID"));
1024        assert!(sql.contains("-122.4194"));
1025        assert!(sql.contains("37.7749"));
1026        assert!(sql.contains("4326"));
1027    }
1028
1029    #[test]
1030    fn test_point_geojson() {
1031        let point = Point::new(-122.4194, 37.7749);
1032        let geojson = point.to_geojson();
1033
1034        assert_eq!(geojson["type"], "Point");
1035        assert_eq!(geojson["coordinates"][0], -122.4194);
1036    }
1037
1038    #[test]
1039    fn test_polygon_wkt() {
1040        let polygon = Polygon::new(vec![
1041            (0.0, 0.0),
1042            (10.0, 0.0),
1043            (10.0, 10.0),
1044            (0.0, 10.0),
1045            (0.0, 0.0),
1046        ]);
1047
1048        let wkt = polygon.to_wkt();
1049        assert!(wkt.starts_with("POLYGON(("));
1050    }
1051
1052    #[test]
1053    fn test_distance_sql() {
1054        let sql = geo::distance_sql("location", "target", DatabaseType::PostgreSQL);
1055        assert!(sql.contains("ST_Distance"));
1056    }
1057
1058    #[test]
1059    fn test_within_distance() {
1060        let point = Point::wgs84(-122.4194, 37.7749);
1061        let sql = geo::within_distance_sql("location", &point, 1000.0, DatabaseType::PostgreSQL);
1062
1063        assert!(sql.contains("ST_DWithin"));
1064        assert!(sql.contains("1000"));
1065    }
1066
1067    #[test]
1068    fn test_uuid_generation() {
1069        let pg = uuid::generate_v4(DatabaseType::PostgreSQL);
1070        assert_eq!(pg, "gen_random_uuid()");
1071
1072        let mysql = uuid::generate_v4(DatabaseType::MySQL);
1073        assert_eq!(mysql, "UUID()");
1074
1075        let mssql = uuid::generate_v4(DatabaseType::MSSQL);
1076        assert_eq!(mssql, "NEWID()");
1077    }
1078
1079    #[test]
1080    fn test_hash_sql() {
1081        let pg = crypto::hash_sql(
1082            "password",
1083            crypto::HashAlgorithm::Sha256,
1084            DatabaseType::PostgreSQL,
1085        );
1086        assert!(pg.contains("digest"));
1087        assert!(pg.contains("sha256"));
1088
1089        let mysql = crypto::hash_sql(
1090            "password",
1091            crypto::HashAlgorithm::Sha256,
1092            DatabaseType::MySQL,
1093        );
1094        assert!(mysql.contains("SHA2"));
1095        assert!(mysql.contains("256"));
1096    }
1097
1098    #[test]
1099    fn test_vector_pgvector() {
1100        let vec = vector::Vector::new(vec![0.1, 0.2, 0.3, 0.4]);
1101        let sql = vec.to_pgvector();
1102
1103        assert!(sql.contains("'[0.1,0.2,0.3,0.4]'::vector"));
1104    }
1105
1106    #[test]
1107    fn test_vector_index() {
1108        let sql = vector::create_index_postgres(
1109            "embeddings_idx",
1110            "documents",
1111            "embedding",
1112            vector::SimilarityMetric::Cosine,
1113            Some(100),
1114        );
1115
1116        assert!(sql.contains("CREATE INDEX embeddings_idx"));
1117        assert!(sql.contains("USING ivfflat"));
1118        assert!(sql.contains("vector_cosine_ops"));
1119        assert!(sql.contains("lists = 100"));
1120    }
1121
1122    #[test]
1123    fn test_hnsw_index() {
1124        let sql = vector::create_hnsw_index_postgres(
1125            "embeddings_hnsw",
1126            "documents",
1127            "embedding",
1128            vector::SimilarityMetric::L2,
1129            Some(16),
1130            Some(64),
1131        );
1132
1133        assert!(sql.contains("USING hnsw"));
1134        assert!(sql.contains("m = 16"));
1135        assert!(sql.contains("ef_construction = 64"));
1136    }
1137
1138    mod mongodb_tests {
1139        use super::super::mongodb::*;
1140        use super::super::vector::SimilarityMetric;
1141
1142        #[test]
1143        fn test_vector_search() {
1144            let search = vector_search("vector_index", "embedding", vec![0.1, 0.2, 0.3])
1145                .num_candidates(200)
1146                .limit(20)
1147                .build();
1148
1149            let stage = search.to_stage();
1150            assert!(stage["$vectorSearch"]["index"].is_string());
1151            assert_eq!(stage["$vectorSearch"]["numCandidates"], 200);
1152            assert_eq!(stage["$vectorSearch"]["limit"], 20);
1153        }
1154
1155        #[test]
1156        fn test_vector_index_definition() {
1157            let index = VectorIndex::new("my_vector_index", "documents")
1158                .field("embedding", 1536, SimilarityMetric::Cosine)
1159                .build();
1160
1161            let def = index.to_definition();
1162            assert_eq!(def["name"], "my_vector_index");
1163            assert_eq!(def["type"], "vectorSearch");
1164            assert!(def["fields"].is_array());
1165        }
1166    }
1167}