prax_query/
extension.rs

1//! Database extensions and plugins support.
2//!
3//! This module provides types for managing database extensions and
4//! specialized functionality like geospatial, UUID, cryptography, and vector search.
5//!
6//! # Database Support
7//!
8//! | Feature        | PostgreSQL       | MySQL      | SQLite        | MSSQL      | MongoDB        |
9//! |----------------|------------------|------------|---------------|------------|----------------|
10//! | Extensions     | ✅ CREATE EXT    | ❌         | ✅ load_ext   | ❌         | ❌             |
11//! | Geospatial     | ✅ PostGIS       | ✅ Spatial | ✅ SpatiaLite | ✅         | ✅ GeoJSON     |
12//! | UUID           | ✅ uuid-ossp     | ✅ built-in| ❌            | ✅ NEWID() | ✅ UUID()      |
13//! | Cryptography   | ✅ pgcrypto      | ✅ built-in| ❌            | ✅         | ✅             |
14//! | Vector Search  | ✅ pgvector      | ❌         | ❌            | ❌         | ✅ Atlas Vector|
15
16use serde::{Deserialize, Serialize};
17
18use crate::sql::DatabaseType;
19
20// ============================================================================
21// Extension Management
22// ============================================================================
23
24/// A database extension.
25#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
26pub struct Extension {
27    /// Extension name.
28    pub name: String,
29    /// Schema to install in (PostgreSQL).
30    pub schema: Option<String>,
31    /// Version to install.
32    pub version: Option<String>,
33    /// Whether to cascade dependencies.
34    pub cascade: bool,
35}
36
37impl Extension {
38    /// Create a new extension.
39    pub fn new(name: impl Into<String>) -> ExtensionBuilder {
40        ExtensionBuilder::new(name)
41    }
42
43    /// Common PostgreSQL extensions.
44    pub fn postgis() -> Self {
45        Self::new("postgis").build()
46    }
47
48    pub fn pgvector() -> Self {
49        Self::new("vector").build()
50    }
51
52    pub fn uuid_ossp() -> Self {
53        Self::new("uuid-ossp").build()
54    }
55
56    pub fn pgcrypto() -> Self {
57        Self::new("pgcrypto").build()
58    }
59
60    pub fn pg_trgm() -> Self {
61        Self::new("pg_trgm").build()
62    }
63
64    pub fn hstore() -> Self {
65        Self::new("hstore").build()
66    }
67
68    pub fn ltree() -> Self {
69        Self::new("ltree").build()
70    }
71
72    /// Generate PostgreSQL CREATE EXTENSION SQL.
73    pub fn to_postgres_create(&self) -> String {
74        let mut sql = format!("CREATE EXTENSION IF NOT EXISTS \"{}\"", self.name);
75
76        if let Some(ref schema) = self.schema {
77            sql.push_str(&format!(" SCHEMA {}", schema));
78        }
79
80        if let Some(ref version) = self.version {
81            sql.push_str(&format!(" VERSION '{}'", version));
82        }
83
84        if self.cascade {
85            sql.push_str(" CASCADE");
86        }
87
88        sql
89    }
90
91    /// Generate DROP EXTENSION SQL.
92    pub fn to_postgres_drop(&self) -> String {
93        let mut sql = format!("DROP EXTENSION IF EXISTS \"{}\"", self.name);
94        if self.cascade {
95            sql.push_str(" CASCADE");
96        }
97        sql
98    }
99
100    /// Generate SQLite load extension command.
101    pub fn to_sqlite_load(&self) -> String {
102        format!("SELECT load_extension('{}')", self.name)
103    }
104}
105
106/// Builder for extensions.
107#[derive(Debug, Clone)]
108pub struct ExtensionBuilder {
109    name: String,
110    schema: Option<String>,
111    version: Option<String>,
112    cascade: bool,
113}
114
115impl ExtensionBuilder {
116    /// Create a new builder.
117    pub fn new(name: impl Into<String>) -> Self {
118        Self {
119            name: name.into(),
120            schema: None,
121            version: None,
122            cascade: false,
123        }
124    }
125
126    /// Set the schema.
127    pub fn schema(mut self, schema: impl Into<String>) -> Self {
128        self.schema = Some(schema.into());
129        self
130    }
131
132    /// Set the version.
133    pub fn version(mut self, version: impl Into<String>) -> Self {
134        self.version = Some(version.into());
135        self
136    }
137
138    /// Enable CASCADE.
139    pub fn cascade(mut self) -> Self {
140        self.cascade = true;
141        self
142    }
143
144    /// Build the extension.
145    pub fn build(self) -> Extension {
146        Extension {
147            name: self.name,
148            schema: self.schema,
149            version: self.version,
150            cascade: self.cascade,
151        }
152    }
153}
154
155// ============================================================================
156// Geospatial Types
157// ============================================================================
158
159/// A geographic point (longitude, latitude).
160#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
161pub struct Point {
162    /// Longitude (-180 to 180).
163    pub longitude: f64,
164    /// Latitude (-90 to 90).
165    pub latitude: f64,
166    /// Optional SRID (spatial reference ID).
167    pub srid: Option<i32>,
168}
169
170impl Point {
171    /// Create a new point.
172    pub fn new(longitude: f64, latitude: f64) -> Self {
173        Self {
174            longitude,
175            latitude,
176            srid: None,
177        }
178    }
179
180    /// Create with SRID.
181    pub fn with_srid(longitude: f64, latitude: f64, srid: i32) -> Self {
182        Self {
183            longitude,
184            latitude,
185            srid: Some(srid),
186        }
187    }
188
189    /// WGS84 SRID (standard GPS).
190    pub fn wgs84(longitude: f64, latitude: f64) -> Self {
191        Self::with_srid(longitude, latitude, 4326)
192    }
193
194    /// Generate PostGIS point.
195    pub fn to_postgis(&self) -> String {
196        if let Some(srid) = self.srid {
197            format!("ST_SetSRID(ST_MakePoint({}, {}), {})", self.longitude, self.latitude, srid)
198        } else {
199            format!("ST_MakePoint({}, {})", self.longitude, self.latitude)
200        }
201    }
202
203    /// Generate MySQL spatial point.
204    pub fn to_mysql(&self) -> String {
205        if let Some(srid) = self.srid {
206            format!("ST_GeomFromText('POINT({} {})', {})", self.longitude, self.latitude, srid)
207        } else {
208            format!("ST_GeomFromText('POINT({} {})')", self.longitude, self.latitude)
209        }
210    }
211
212    /// Generate MSSQL geography point.
213    pub fn to_mssql(&self) -> String {
214        format!(
215            "geography::Point({}, {}, {})",
216            self.latitude, self.longitude, self.srid.unwrap_or(4326)
217        )
218    }
219
220    /// Generate GeoJSON.
221    pub fn to_geojson(&self) -> serde_json::Value {
222        serde_json::json!({
223            "type": "Point",
224            "coordinates": [self.longitude, self.latitude]
225        })
226    }
227
228    /// Generate SQL for database type.
229    pub fn to_sql(&self, db_type: DatabaseType) -> String {
230        match db_type {
231            DatabaseType::PostgreSQL => self.to_postgis(),
232            DatabaseType::MySQL => self.to_mysql(),
233            DatabaseType::MSSQL => self.to_mssql(),
234            DatabaseType::SQLite => format!("MakePoint({}, {})", self.longitude, self.latitude),
235        }
236    }
237}
238
239/// A polygon (list of points forming a closed ring).
240#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
241pub struct Polygon {
242    /// Exterior ring coordinates.
243    pub exterior: Vec<(f64, f64)>,
244    /// Interior rings (holes).
245    pub interiors: Vec<Vec<(f64, f64)>>,
246    /// SRID.
247    pub srid: Option<i32>,
248}
249
250impl Polygon {
251    /// Create a new polygon from coordinates.
252    pub fn new(exterior: Vec<(f64, f64)>) -> Self {
253        Self {
254            exterior,
255            interiors: Vec::new(),
256            srid: None,
257        }
258    }
259
260    /// Add an interior ring (hole).
261    pub fn with_hole(mut self, hole: Vec<(f64, f64)>) -> Self {
262        self.interiors.push(hole);
263        self
264    }
265
266    /// Set SRID.
267    pub fn with_srid(mut self, srid: i32) -> Self {
268        self.srid = Some(srid);
269        self
270    }
271
272    /// Generate WKT (Well-Known Text).
273    pub fn to_wkt(&self) -> String {
274        let ext_coords: Vec<String> = self
275            .exterior
276            .iter()
277            .map(|(x, y)| format!("{} {}", x, y))
278            .collect();
279
280        let mut wkt = format!("POLYGON(({})", ext_coords.join(", "));
281
282        for interior in &self.interiors {
283            let int_coords: Vec<String> = interior
284                .iter()
285                .map(|(x, y)| format!("{} {}", x, y))
286                .collect();
287            wkt.push_str(&format!(", ({})", int_coords.join(", ")));
288        }
289
290        wkt.push(')');
291        wkt
292    }
293
294    /// Generate PostGIS polygon.
295    pub fn to_postgis(&self) -> String {
296        if let Some(srid) = self.srid {
297            format!("ST_GeomFromText('{}', {})", self.to_wkt(), srid)
298        } else {
299            format!("ST_GeomFromText('{}')", self.to_wkt())
300        }
301    }
302
303    /// Generate GeoJSON.
304    pub fn to_geojson(&self) -> serde_json::Value {
305        let mut coordinates = vec![self
306            .exterior
307            .iter()
308            .map(|(x, y)| vec![*x, *y])
309            .collect::<Vec<_>>()];
310
311        for interior in &self.interiors {
312            coordinates.push(interior.iter().map(|(x, y)| vec![*x, *y]).collect());
313        }
314
315        serde_json::json!({
316            "type": "Polygon",
317            "coordinates": coordinates
318        })
319    }
320}
321
322/// Geospatial operations.
323pub mod geo {
324    use super::*;
325
326    /// Distance calculation.
327    #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
328    pub enum DistanceUnit {
329        /// Meters.
330        Meters,
331        /// Kilometers.
332        Kilometers,
333        /// Miles.
334        Miles,
335        /// Feet.
336        Feet,
337    }
338
339    impl DistanceUnit {
340        /// Conversion factor from meters.
341        pub fn from_meters(&self) -> f64 {
342            match self {
343                Self::Meters => 1.0,
344                Self::Kilometers => 0.001,
345                Self::Miles => 0.000621371,
346                Self::Feet => 3.28084,
347            }
348        }
349    }
350
351    /// Generate distance SQL between two columns.
352    pub fn distance_sql(col1: &str, col2: &str, db_type: DatabaseType) -> String {
353        match db_type {
354            DatabaseType::PostgreSQL => format!("ST_Distance({}::geography, {}::geography)", col1, col2),
355            DatabaseType::MySQL => format!("ST_Distance_Sphere({}, {})", col1, col2),
356            DatabaseType::MSSQL => format!("{}.STDistance({})", col1, col2),
357            DatabaseType::SQLite => format!("Distance({}, {})", col1, col2),
358        }
359    }
360
361    /// Generate distance from point SQL.
362    pub fn distance_from_point_sql(col: &str, point: &Point, db_type: DatabaseType) -> String {
363        let point_sql = point.to_sql(db_type);
364        match db_type {
365            DatabaseType::PostgreSQL => {
366                format!("ST_Distance({}::geography, {}::geography)", col, point_sql)
367            }
368            DatabaseType::MySQL => format!("ST_Distance_Sphere({}, {})", col, point_sql),
369            DatabaseType::MSSQL => format!("{}.STDistance({})", col, point_sql),
370            DatabaseType::SQLite => format!("Distance({}, {})", col, point_sql),
371        }
372    }
373
374    /// Generate "within distance" filter SQL.
375    pub fn within_distance_sql(
376        col: &str,
377        point: &Point,
378        distance_meters: f64,
379        db_type: DatabaseType,
380    ) -> String {
381        let point_sql = point.to_sql(db_type);
382        match db_type {
383            DatabaseType::PostgreSQL => {
384                format!(
385                    "ST_DWithin({}::geography, {}::geography, {})",
386                    col, point_sql, distance_meters
387                )
388            }
389            DatabaseType::MySQL => {
390                format!(
391                    "ST_Distance_Sphere({}, {}) <= {}",
392                    col, point_sql, distance_meters
393                )
394            }
395            DatabaseType::MSSQL => {
396                format!("{}.STDistance({}) <= {}", col, point_sql, distance_meters)
397            }
398            DatabaseType::SQLite => {
399                format!("Distance({}, {}) <= {}", col, point_sql, distance_meters)
400            }
401        }
402    }
403
404    /// Generate "contains" filter SQL.
405    pub fn contains_sql(geom_col: &str, point: &Point, db_type: DatabaseType) -> String {
406        let point_sql = point.to_sql(db_type);
407        match db_type {
408            DatabaseType::PostgreSQL => format!("ST_Contains({}, {})", geom_col, point_sql),
409            DatabaseType::MySQL => format!("ST_Contains({}, {})", geom_col, point_sql),
410            DatabaseType::MSSQL => format!("{}.STContains({})", geom_col, point_sql),
411            DatabaseType::SQLite => format!("Contains({}, {})", geom_col, point_sql),
412        }
413    }
414
415    /// Generate bounding box filter SQL.
416    pub fn bbox_sql(
417        col: &str,
418        min_lon: f64,
419        min_lat: f64,
420        max_lon: f64,
421        max_lat: f64,
422        db_type: DatabaseType,
423    ) -> String {
424        match db_type {
425            DatabaseType::PostgreSQL => {
426                format!(
427                    "{} && ST_MakeEnvelope({}, {}, {}, {}, 4326)",
428                    col, min_lon, min_lat, max_lon, max_lat
429                )
430            }
431            DatabaseType::MySQL => {
432                format!(
433                    "MBRContains(ST_GeomFromText('POLYGON(({} {}, {} {}, {} {}, {} {}, {} {}))'), {})",
434                    min_lon, min_lat, max_lon, min_lat, max_lon, max_lat, min_lon, max_lat, min_lon, min_lat, col
435                )
436            }
437            _ => "1=1".to_string(),
438        }
439    }
440}
441
442// ============================================================================
443// UUID Support
444// ============================================================================
445
446/// UUID generation helpers.
447pub mod uuid {
448    use super::*;
449
450    /// Generate UUID v4 SQL.
451    pub fn generate_v4(db_type: DatabaseType) -> String {
452        match db_type {
453            DatabaseType::PostgreSQL => "gen_random_uuid()".to_string(),
454            DatabaseType::MySQL => "UUID()".to_string(),
455            DatabaseType::MSSQL => "NEWID()".to_string(),
456            DatabaseType::SQLite => {
457                // SQLite needs custom function or hex/randomblob
458                "lower(hex(randomblob(4))) || '-' || lower(hex(randomblob(2))) || '-4' || \
459                 substr(lower(hex(randomblob(2))), 2) || '-' || \
460                 substr('89ab', abs(random()) % 4 + 1, 1) || \
461                 substr(lower(hex(randomblob(2))), 2) || '-' || lower(hex(randomblob(6)))"
462                    .to_string()
463            }
464        }
465    }
466
467    /// Generate UUID v7 SQL (PostgreSQL with uuid-ossp or pg_uuidv7).
468    pub fn generate_v7_postgres() -> String {
469        "uuid_generate_v7()".to_string()
470    }
471
472    /// Generate UUID from string SQL.
473    pub fn from_string(value: &str, db_type: DatabaseType) -> String {
474        match db_type {
475            DatabaseType::PostgreSQL => format!("'{}'::uuid", value),
476            DatabaseType::MySQL => format!("UUID_TO_BIN('{}')", value),
477            DatabaseType::MSSQL => format!("CONVERT(UNIQUEIDENTIFIER, '{}')", value),
478            DatabaseType::SQLite => format!("'{}'", value),
479        }
480    }
481
482    /// Check if valid UUID SQL.
483    pub fn is_valid_sql(col: &str, db_type: DatabaseType) -> String {
484        match db_type {
485            DatabaseType::PostgreSQL => format!(
486                "{} ~ '^[0-9a-f]{{8}}-[0-9a-f]{{4}}-[0-9a-f]{{4}}-[0-9a-f]{{4}}-[0-9a-f]{{12}}$'",
487                col
488            ),
489            DatabaseType::MySQL => format!(
490                "{} REGEXP '^[0-9a-f]{{8}}-[0-9a-f]{{4}}-[0-9a-f]{{4}}-[0-9a-f]{{4}}-[0-9a-f]{{12}}$'",
491                col
492            ),
493            _ => format!("LEN({}) = 36", col),
494        }
495    }
496}
497
498// ============================================================================
499// Cryptography
500// ============================================================================
501
502/// Cryptographic functions.
503pub mod crypto {
504    use super::*;
505
506    /// Hash algorithms.
507    #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
508    pub enum HashAlgorithm {
509        Md5,
510        Sha1,
511        Sha256,
512        Sha384,
513        Sha512,
514    }
515
516    impl HashAlgorithm {
517        /// PostgreSQL algorithm name.
518        pub fn postgres_name(&self) -> &'static str {
519            match self {
520                Self::Md5 => "md5",
521                Self::Sha1 => "sha1",
522                Self::Sha256 => "sha256",
523                Self::Sha384 => "sha384",
524                Self::Sha512 => "sha512",
525            }
526        }
527    }
528
529    /// Generate hash SQL.
530    pub fn hash_sql(expr: &str, algorithm: HashAlgorithm, db_type: DatabaseType) -> String {
531        match db_type {
532            DatabaseType::PostgreSQL => {
533                if algorithm == HashAlgorithm::Md5 {
534                    format!("md5({})", expr)
535                } else {
536                    format!("encode(digest({}, '{}'), 'hex')", expr, algorithm.postgres_name())
537                }
538            }
539            DatabaseType::MySQL => match algorithm {
540                HashAlgorithm::Md5 => format!("MD5({})", expr),
541                HashAlgorithm::Sha1 => format!("SHA1({})", expr),
542                HashAlgorithm::Sha256 => format!("SHA2({}, 256)", expr),
543                HashAlgorithm::Sha384 => format!("SHA2({}, 384)", expr),
544                HashAlgorithm::Sha512 => format!("SHA2({}, 512)", expr),
545            },
546            DatabaseType::MSSQL => {
547                let algo = match algorithm {
548                    HashAlgorithm::Md5 => "MD5",
549                    HashAlgorithm::Sha1 => "SHA1",
550                    HashAlgorithm::Sha256 => "SHA2_256",
551                    HashAlgorithm::Sha384 => "SHA2_384",
552                    HashAlgorithm::Sha512 => "SHA2_512",
553                };
554                format!("CONVERT(VARCHAR(MAX), HASHBYTES('{}', {}), 2)", algo, expr)
555            }
556            DatabaseType::SQLite => {
557                // SQLite doesn't have built-in hashing
558                format!("-- SQLite requires extension for hashing: {}", expr)
559            }
560        }
561    }
562
563    /// Generate bcrypt hash SQL (PostgreSQL with pgcrypto).
564    pub fn bcrypt_hash_postgres(password: &str) -> String {
565        format!("crypt('{}', gen_salt('bf'))", password)
566    }
567
568    /// Generate bcrypt verify SQL (PostgreSQL).
569    pub fn bcrypt_verify_postgres(password: &str, hash_col: &str) -> String {
570        format!("{} = crypt('{}', {})", hash_col, password, hash_col)
571    }
572
573    /// Generate random bytes SQL.
574    pub fn random_bytes_sql(length: usize, db_type: DatabaseType) -> String {
575        match db_type {
576            DatabaseType::PostgreSQL => format!("gen_random_bytes({})", length),
577            DatabaseType::MySQL => format!("RANDOM_BYTES({})", length),
578            DatabaseType::MSSQL => format!("CRYPT_GEN_RANDOM({})", length),
579            DatabaseType::SQLite => format!("randomblob({})", length),
580        }
581    }
582
583    /// Generate AES encrypt SQL (PostgreSQL with pgcrypto).
584    pub fn aes_encrypt_postgres(data: &str, key: &str) -> String {
585        format!("pgp_sym_encrypt('{}', '{}')", data, key)
586    }
587
588    /// Generate AES decrypt SQL (PostgreSQL with pgcrypto).
589    pub fn aes_decrypt_postgres(encrypted_col: &str, key: &str) -> String {
590        format!("pgp_sym_decrypt({}, '{}')", encrypted_col, key)
591    }
592}
593
594// ============================================================================
595// Vector / Embeddings (pgvector, MongoDB Atlas Vector)
596// ============================================================================
597
598/// Vector operations for AI/ML embeddings.
599pub mod vector {
600    use super::*;
601
602    /// A vector embedding.
603    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
604    pub struct Vector {
605        /// Vector dimensions.
606        pub dimensions: Vec<f32>,
607    }
608
609    impl Vector {
610        /// Create a new vector.
611        pub fn new(dimensions: Vec<f32>) -> Self {
612            Self { dimensions }
613        }
614
615        /// Create from slice.
616        pub fn from_slice(slice: &[f32]) -> Self {
617            Self {
618                dimensions: slice.to_vec(),
619            }
620        }
621
622        /// Get dimension count.
623        pub fn len(&self) -> usize {
624            self.dimensions.len()
625        }
626
627        /// Check if empty.
628        pub fn is_empty(&self) -> bool {
629            self.dimensions.is_empty()
630        }
631
632        /// Generate PostgreSQL pgvector literal.
633        pub fn to_pgvector(&self) -> String {
634            let nums: Vec<String> = self.dimensions.iter().map(|f| f.to_string()).collect();
635            format!("'[{}]'::vector", nums.join(","))
636        }
637
638        /// Generate MongoDB array.
639        pub fn to_mongodb(&self) -> serde_json::Value {
640            serde_json::json!(self.dimensions)
641        }
642    }
643
644    /// Vector similarity metrics.
645    #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
646    pub enum SimilarityMetric {
647        /// Euclidean distance (L2).
648        L2,
649        /// Inner product.
650        InnerProduct,
651        /// Cosine similarity.
652        Cosine,
653    }
654
655    impl SimilarityMetric {
656        /// PostgreSQL operator.
657        pub fn postgres_operator(&self) -> &'static str {
658            match self {
659                Self::L2 => "<->",
660                Self::InnerProduct => "<#>",
661                Self::Cosine => "<=>",
662            }
663        }
664
665        /// MongoDB $vectorSearch similarity.
666        pub fn mongodb_name(&self) -> &'static str {
667            match self {
668                Self::L2 => "euclidean",
669                Self::InnerProduct => "dotProduct",
670                Self::Cosine => "cosine",
671            }
672        }
673    }
674
675    /// Generate vector similarity search SQL (PostgreSQL pgvector).
676    pub fn similarity_search_postgres(
677        col: &str,
678        query_vector: &Vector,
679        metric: SimilarityMetric,
680        limit: usize,
681    ) -> String {
682        format!(
683            "SELECT *, {} {} {} AS distance FROM {{table}} ORDER BY distance LIMIT {}",
684            col,
685            metric.postgres_operator(),
686            query_vector.to_pgvector(),
687            limit
688        )
689    }
690
691    /// Generate vector distance SQL.
692    pub fn distance_sql(col: &str, query_vector: &Vector, metric: SimilarityMetric) -> String {
693        format!(
694            "{} {} {}",
695            col,
696            metric.postgres_operator(),
697            query_vector.to_pgvector()
698        )
699    }
700
701    /// Generate vector index SQL (PostgreSQL).
702    pub fn create_index_postgres(
703        index_name: &str,
704        table: &str,
705        column: &str,
706        metric: SimilarityMetric,
707        lists: Option<usize>,
708    ) -> String {
709        let ops = match metric {
710            SimilarityMetric::L2 => "vector_l2_ops",
711            SimilarityMetric::InnerProduct => "vector_ip_ops",
712            SimilarityMetric::Cosine => "vector_cosine_ops",
713        };
714
715        let lists_clause = lists
716            .map(|l| format!(" WITH (lists = {})", l))
717            .unwrap_or_default();
718
719        format!(
720            "CREATE INDEX {} ON {} USING ivfflat ({} {}){}",
721            index_name, table, column, ops, lists_clause
722        )
723    }
724
725    /// Create HNSW index (PostgreSQL pgvector 0.5+).
726    pub fn create_hnsw_index_postgres(
727        index_name: &str,
728        table: &str,
729        column: &str,
730        metric: SimilarityMetric,
731        m: Option<usize>,
732        ef_construction: Option<usize>,
733    ) -> String {
734        let ops = match metric {
735            SimilarityMetric::L2 => "vector_l2_ops",
736            SimilarityMetric::InnerProduct => "vector_ip_ops",
737            SimilarityMetric::Cosine => "vector_cosine_ops",
738        };
739
740        let mut with_clauses = Vec::new();
741        if let Some(m_val) = m {
742            with_clauses.push(format!("m = {}", m_val));
743        }
744        if let Some(ef) = ef_construction {
745            with_clauses.push(format!("ef_construction = {}", ef));
746        }
747
748        let with_clause = if with_clauses.is_empty() {
749            String::new()
750        } else {
751            format!(" WITH ({})", with_clauses.join(", "))
752        };
753
754        format!(
755            "CREATE INDEX {} ON {} USING hnsw ({} {}){}",
756            index_name, table, column, ops, with_clause
757        )
758    }
759}
760
761/// MongoDB Atlas Vector Search support.
762pub mod mongodb {
763    use serde::{Deserialize, Serialize};
764    use serde_json::Value as JsonValue;
765
766    use super::vector::SimilarityMetric;
767
768    /// MongoDB Atlas Vector Search query.
769    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
770    pub struct VectorSearch {
771        /// Index name.
772        pub index: String,
773        /// Path to vector field.
774        pub path: String,
775        /// Query vector.
776        pub query_vector: Vec<f32>,
777        /// Number of results.
778        pub num_candidates: usize,
779        /// Limit.
780        pub limit: usize,
781        /// Optional filter.
782        pub filter: Option<JsonValue>,
783    }
784
785    impl VectorSearch {
786        /// Create a new vector search.
787        pub fn new(index: impl Into<String>, path: impl Into<String>, query: Vec<f32>) -> VectorSearchBuilder {
788            VectorSearchBuilder::new(index, path, query)
789        }
790
791        /// Convert to $vectorSearch stage.
792        pub fn to_stage(&self) -> JsonValue {
793            let mut search = serde_json::json!({
794                "index": self.index,
795                "path": self.path,
796                "queryVector": self.query_vector,
797                "numCandidates": self.num_candidates,
798                "limit": self.limit
799            });
800
801            if let Some(ref filter) = self.filter {
802                search["filter"] = filter.clone();
803            }
804
805            serde_json::json!({ "$vectorSearch": search })
806        }
807    }
808
809    /// Builder for vector search.
810    #[derive(Debug, Clone)]
811    pub struct VectorSearchBuilder {
812        index: String,
813        path: String,
814        query_vector: Vec<f32>,
815        num_candidates: usize,
816        limit: usize,
817        filter: Option<JsonValue>,
818    }
819
820    impl VectorSearchBuilder {
821        /// Create a new builder.
822        pub fn new(index: impl Into<String>, path: impl Into<String>, query: Vec<f32>) -> Self {
823            Self {
824                index: index.into(),
825                path: path.into(),
826                query_vector: query,
827                num_candidates: 100,
828                limit: 10,
829                filter: None,
830            }
831        }
832
833        /// Set number of candidates.
834        pub fn num_candidates(mut self, n: usize) -> Self {
835            self.num_candidates = n;
836            self
837        }
838
839        /// Set limit.
840        pub fn limit(mut self, n: usize) -> Self {
841            self.limit = n;
842            self
843        }
844
845        /// Add filter.
846        pub fn filter(mut self, filter: JsonValue) -> Self {
847            self.filter = Some(filter);
848            self
849        }
850
851        /// Build the search.
852        pub fn build(self) -> VectorSearch {
853            VectorSearch {
854                index: self.index,
855                path: self.path,
856                query_vector: self.query_vector,
857                num_candidates: self.num_candidates,
858                limit: self.limit,
859                filter: self.filter,
860            }
861        }
862    }
863
864    /// MongoDB Atlas Search index definition for vectors.
865    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
866    pub struct VectorIndex {
867        /// Index name.
868        pub name: String,
869        /// Collection name.
870        pub collection: String,
871        /// Vector field definitions.
872        pub fields: Vec<VectorField>,
873    }
874
875    /// Vector field definition.
876    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
877    pub struct VectorField {
878        /// Field path.
879        pub path: String,
880        /// Number of dimensions.
881        pub dimensions: usize,
882        /// Similarity metric.
883        pub similarity: String,
884    }
885
886    impl VectorIndex {
887        /// Create a new vector index definition.
888        pub fn new(name: impl Into<String>, collection: impl Into<String>) -> VectorIndexBuilder {
889            VectorIndexBuilder::new(name, collection)
890        }
891
892        /// Convert to index definition.
893        pub fn to_definition(&self) -> JsonValue {
894            let fields: Vec<JsonValue> = self
895                .fields
896                .iter()
897                .map(|f| {
898                    serde_json::json!({
899                        "type": "vector",
900                        "path": f.path,
901                        "numDimensions": f.dimensions,
902                        "similarity": f.similarity
903                    })
904                })
905                .collect();
906
907            serde_json::json!({
908                "name": self.name,
909                "type": "vectorSearch",
910                "fields": fields
911            })
912        }
913    }
914
915    /// Builder for vector index.
916    #[derive(Debug, Clone)]
917    pub struct VectorIndexBuilder {
918        name: String,
919        collection: String,
920        fields: Vec<VectorField>,
921    }
922
923    impl VectorIndexBuilder {
924        /// Create a new builder.
925        pub fn new(name: impl Into<String>, collection: impl Into<String>) -> Self {
926            Self {
927                name: name.into(),
928                collection: collection.into(),
929                fields: Vec::new(),
930            }
931        }
932
933        /// Add a vector field.
934        pub fn field(
935            mut self,
936            path: impl Into<String>,
937            dimensions: usize,
938            similarity: SimilarityMetric,
939        ) -> Self {
940            self.fields.push(VectorField {
941                path: path.into(),
942                dimensions,
943                similarity: similarity.mongodb_name().to_string(),
944            });
945            self
946        }
947
948        /// Build the index.
949        pub fn build(self) -> VectorIndex {
950            VectorIndex {
951                name: self.name,
952                collection: self.collection,
953                fields: self.fields,
954            }
955        }
956    }
957
958    /// Helper to create a vector search.
959    pub fn vector_search(index: &str, path: &str, query: Vec<f32>) -> VectorSearchBuilder {
960        VectorSearch::new(index, path, query)
961    }
962}
963
964#[cfg(test)]
965mod tests {
966    use super::*;
967
968    #[test]
969    fn test_extension_postgres() {
970        let ext = Extension::new("postgis").schema("public").cascade().build();
971        let sql = ext.to_postgres_create();
972
973        assert!(sql.contains("CREATE EXTENSION IF NOT EXISTS \"postgis\""));
974        assert!(sql.contains("SCHEMA public"));
975        assert!(sql.contains("CASCADE"));
976    }
977
978    #[test]
979    fn test_extension_drop() {
980        let ext = Extension::postgis();
981        let sql = ext.to_postgres_drop();
982
983        assert!(sql.contains("DROP EXTENSION IF EXISTS \"postgis\""));
984    }
985
986    #[test]
987    fn test_point_postgis() {
988        let point = Point::wgs84(-122.4194, 37.7749);
989        let sql = point.to_postgis();
990
991        assert!(sql.contains("ST_SetSRID"));
992        assert!(sql.contains("-122.4194"));
993        assert!(sql.contains("37.7749"));
994        assert!(sql.contains("4326"));
995    }
996
997    #[test]
998    fn test_point_geojson() {
999        let point = Point::new(-122.4194, 37.7749);
1000        let geojson = point.to_geojson();
1001
1002        assert_eq!(geojson["type"], "Point");
1003        assert_eq!(geojson["coordinates"][0], -122.4194);
1004    }
1005
1006    #[test]
1007    fn test_polygon_wkt() {
1008        let polygon = Polygon::new(vec![
1009            (0.0, 0.0),
1010            (10.0, 0.0),
1011            (10.0, 10.0),
1012            (0.0, 10.0),
1013            (0.0, 0.0),
1014        ]);
1015
1016        let wkt = polygon.to_wkt();
1017        assert!(wkt.starts_with("POLYGON(("));
1018    }
1019
1020    #[test]
1021    fn test_distance_sql() {
1022        let sql = geo::distance_sql("location", "target", DatabaseType::PostgreSQL);
1023        assert!(sql.contains("ST_Distance"));
1024    }
1025
1026    #[test]
1027    fn test_within_distance() {
1028        let point = Point::wgs84(-122.4194, 37.7749);
1029        let sql = geo::within_distance_sql("location", &point, 1000.0, DatabaseType::PostgreSQL);
1030
1031        assert!(sql.contains("ST_DWithin"));
1032        assert!(sql.contains("1000"));
1033    }
1034
1035    #[test]
1036    fn test_uuid_generation() {
1037        let pg = uuid::generate_v4(DatabaseType::PostgreSQL);
1038        assert_eq!(pg, "gen_random_uuid()");
1039
1040        let mysql = uuid::generate_v4(DatabaseType::MySQL);
1041        assert_eq!(mysql, "UUID()");
1042
1043        let mssql = uuid::generate_v4(DatabaseType::MSSQL);
1044        assert_eq!(mssql, "NEWID()");
1045    }
1046
1047    #[test]
1048    fn test_hash_sql() {
1049        let pg = crypto::hash_sql("password", crypto::HashAlgorithm::Sha256, DatabaseType::PostgreSQL);
1050        assert!(pg.contains("digest"));
1051        assert!(pg.contains("sha256"));
1052
1053        let mysql = crypto::hash_sql("password", crypto::HashAlgorithm::Sha256, DatabaseType::MySQL);
1054        assert!(mysql.contains("SHA2"));
1055        assert!(mysql.contains("256"));
1056    }
1057
1058    #[test]
1059    fn test_vector_pgvector() {
1060        let vec = vector::Vector::new(vec![0.1, 0.2, 0.3, 0.4]);
1061        let sql = vec.to_pgvector();
1062
1063        assert!(sql.contains("'[0.1,0.2,0.3,0.4]'::vector"));
1064    }
1065
1066    #[test]
1067    fn test_vector_index() {
1068        let sql = vector::create_index_postgres(
1069            "embeddings_idx",
1070            "documents",
1071            "embedding",
1072            vector::SimilarityMetric::Cosine,
1073            Some(100),
1074        );
1075
1076        assert!(sql.contains("CREATE INDEX embeddings_idx"));
1077        assert!(sql.contains("USING ivfflat"));
1078        assert!(sql.contains("vector_cosine_ops"));
1079        assert!(sql.contains("lists = 100"));
1080    }
1081
1082    #[test]
1083    fn test_hnsw_index() {
1084        let sql = vector::create_hnsw_index_postgres(
1085            "embeddings_hnsw",
1086            "documents",
1087            "embedding",
1088            vector::SimilarityMetric::L2,
1089            Some(16),
1090            Some(64),
1091        );
1092
1093        assert!(sql.contains("USING hnsw"));
1094        assert!(sql.contains("m = 16"));
1095        assert!(sql.contains("ef_construction = 64"));
1096    }
1097
1098    mod mongodb_tests {
1099        use super::super::mongodb::*;
1100        use super::super::vector::SimilarityMetric;
1101
1102        #[test]
1103        fn test_vector_search() {
1104            let search = vector_search("vector_index", "embedding", vec![0.1, 0.2, 0.3])
1105                .num_candidates(200)
1106                .limit(20)
1107                .build();
1108
1109            let stage = search.to_stage();
1110            assert!(stage["$vectorSearch"]["index"].is_string());
1111            assert_eq!(stage["$vectorSearch"]["numCandidates"], 200);
1112            assert_eq!(stage["$vectorSearch"]["limit"], 20);
1113        }
1114
1115        #[test]
1116        fn test_vector_index_definition() {
1117            let index = VectorIndex::new("my_vector_index", "documents")
1118                .field("embedding", 1536, SimilarityMetric::Cosine)
1119                .build();
1120
1121            let def = index.to_definition();
1122            assert_eq!(def["name"], "my_vector_index");
1123            assert_eq!(def["type"], "vectorSearch");
1124            assert!(def["fields"].is_array());
1125        }
1126    }
1127}
1128
1129