Skip to main content

prax_schema/ast/
datasource.rs

1//! Datasource and PostgreSQL extension definitions.
2
3use serde::{Deserialize, Serialize};
4use smol_str::SmolStr;
5
6use super::Span;
7
8/// Database provider type.
9#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
10pub enum DatabaseProvider {
11    /// PostgreSQL database.
12    PostgreSQL,
13    /// MySQL database.
14    MySQL,
15    /// SQLite database.
16    SQLite,
17    /// MongoDB database.
18    MongoDB,
19}
20
21impl DatabaseProvider {
22    /// Parse a provider from a string.
23    #[allow(clippy::should_implement_trait)]
24    pub fn from_str(s: &str) -> Option<Self> {
25        match s.to_lowercase().as_str() {
26            "postgresql" | "postgres" => Some(Self::PostgreSQL),
27            "mysql" => Some(Self::MySQL),
28            "sqlite" => Some(Self::SQLite),
29            "mongodb" => Some(Self::MongoDB),
30            _ => None,
31        }
32    }
33
34    /// Get the provider as a string.
35    pub fn as_str(&self) -> &'static str {
36        match self {
37            Self::PostgreSQL => "postgresql",
38            Self::MySQL => "mysql",
39            Self::SQLite => "sqlite",
40            Self::MongoDB => "mongodb",
41        }
42    }
43
44    /// Check if this provider supports extensions.
45    pub fn supports_extensions(&self) -> bool {
46        matches!(self, Self::PostgreSQL)
47    }
48}
49
50impl std::fmt::Display for DatabaseProvider {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        write!(f, "{}", self.as_str())
53    }
54}
55
56/// A PostgreSQL extension.
57#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
58pub struct PostgresExtension {
59    /// Extension name (e.g., "pg_trgm", "vector", "uuid-ossp").
60    pub name: SmolStr,
61    /// Optional schema to install the extension into.
62    pub schema: Option<SmolStr>,
63    /// Optional version constraint.
64    pub version: Option<SmolStr>,
65    /// Source span for error reporting.
66    pub span: Span,
67}
68
69impl PostgresExtension {
70    /// Create a new extension.
71    pub fn new(name: impl Into<SmolStr>, span: Span) -> Self {
72        Self {
73            name: name.into(),
74            schema: None,
75            version: None,
76            span,
77        }
78    }
79
80    /// Set the schema for this extension.
81    pub fn with_schema(mut self, schema: impl Into<SmolStr>) -> Self {
82        self.schema = Some(schema.into());
83        self
84    }
85
86    /// Set the version for this extension.
87    pub fn with_version(mut self, version: impl Into<SmolStr>) -> Self {
88        self.version = Some(version.into());
89        self
90    }
91
92    /// Get the extension name.
93    pub fn name(&self) -> &str {
94        &self.name
95    }
96
97    /// Generate the CREATE EXTENSION SQL.
98    pub fn to_create_sql(&self) -> String {
99        let mut sql = format!("CREATE EXTENSION IF NOT EXISTS \"{}\"", self.name);
100        if let Some(schema) = &self.schema {
101            sql.push_str(&format!(" SCHEMA \"{}\"", schema));
102        }
103        if let Some(version) = &self.version {
104            sql.push_str(&format!(" VERSION '{}'", version));
105        }
106        sql.push(';');
107        sql
108    }
109
110    /// Generate the DROP EXTENSION SQL.
111    pub fn to_drop_sql(&self) -> String {
112        format!("DROP EXTENSION IF EXISTS \"{}\" CASCADE;", self.name)
113    }
114
115    /// Check if this is a known extension that provides custom types.
116    pub fn provides_custom_types(&self) -> bool {
117        matches!(
118            self.name.as_str(),
119            "vector" | "pgvector" | "postgis" | "hstore" | "ltree" | "cube" | "citext"
120        )
121    }
122}
123
124impl std::fmt::Display for PostgresExtension {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        write!(f, "{}", self.name)
127    }
128}
129
130/// Well-known PostgreSQL extensions with their capabilities.
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
132pub enum WellKnownExtension {
133    /// pg_trgm - Trigram text similarity search.
134    PgTrgm,
135    /// vector/pgvector - Vector similarity search for AI/ML embeddings.
136    Vector,
137    /// uuid-ossp - UUID generation functions.
138    UuidOssp,
139    /// pgcrypto - Cryptographic functions.
140    PgCrypto,
141    /// postgis - Geographic objects and spatial queries.
142    PostGIS,
143    /// hstore - Key-value store.
144    HStore,
145    /// ltree - Hierarchical tree-like data.
146    LTree,
147    /// citext - Case-insensitive text.
148    Citext,
149    /// cube - Multi-dimensional cubes.
150    Cube,
151    /// pg_stat_statements - Query statistics.
152    PgStatStatements,
153    /// aws_lambda - AWS Lambda integration.
154    AwsLambda,
155    /// aws_s3 - AWS S3 integration.
156    AwsS3,
157    /// plpgsql - PL/pgSQL procedural language.
158    PlPgSQL,
159}
160
161impl WellKnownExtension {
162    /// Get the extension name as used in CREATE EXTENSION.
163    pub fn extension_name(&self) -> &'static str {
164        match self {
165            Self::PgTrgm => "pg_trgm",
166            Self::Vector => "vector",
167            Self::UuidOssp => "uuid-ossp",
168            Self::PgCrypto => "pgcrypto",
169            Self::PostGIS => "postgis",
170            Self::HStore => "hstore",
171            Self::LTree => "ltree",
172            Self::Citext => "citext",
173            Self::Cube => "cube",
174            Self::PgStatStatements => "pg_stat_statements",
175            Self::AwsLambda => "aws_lambda",
176            Self::AwsS3 => "aws_s3",
177            Self::PlPgSQL => "plpgsql",
178        }
179    }
180
181    /// Parse a well-known extension from a string.
182    #[allow(clippy::should_implement_trait)]
183    pub fn from_str(s: &str) -> Option<Self> {
184        match s {
185            "pg_trgm" => Some(Self::PgTrgm),
186            "vector" | "pgvector" => Some(Self::Vector),
187            "uuid-ossp" | "uuid_ossp" => Some(Self::UuidOssp),
188            "pgcrypto" => Some(Self::PgCrypto),
189            "postgis" => Some(Self::PostGIS),
190            "hstore" => Some(Self::HStore),
191            "ltree" => Some(Self::LTree),
192            "citext" => Some(Self::Citext),
193            "cube" => Some(Self::Cube),
194            "pg_stat_statements" => Some(Self::PgStatStatements),
195            "aws_lambda" => Some(Self::AwsLambda),
196            "aws_s3" => Some(Self::AwsS3),
197            "plpgsql" => Some(Self::PlPgSQL),
198            _ => None,
199        }
200    }
201
202    /// Get a description of what this extension provides.
203    pub fn description(&self) -> &'static str {
204        match self {
205            Self::PgTrgm => "Trigram-based text similarity search",
206            Self::Vector => "Vector similarity search for AI/ML embeddings",
207            Self::UuidOssp => "UUID generation functions",
208            Self::PgCrypto => "Cryptographic functions",
209            Self::PostGIS => "Geographic objects and spatial queries",
210            Self::HStore => "Key-value store type",
211            Self::LTree => "Hierarchical tree-like data",
212            Self::Citext => "Case-insensitive text type",
213            Self::Cube => "Multi-dimensional cube data type",
214            Self::PgStatStatements => "Query execution statistics",
215            Self::AwsLambda => "AWS Lambda function invocation",
216            Self::AwsS3 => "AWS S3 storage integration",
217            Self::PlPgSQL => "PL/pgSQL procedural language",
218        }
219    }
220}
221
222/// Datasource configuration block.
223#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
224pub struct Datasource {
225    /// Datasource name (usually "db").
226    pub name: SmolStr,
227    /// Database provider.
228    pub provider: DatabaseProvider,
229    /// Connection URL (can be an env var reference).
230    pub url: Option<SmolStr>,
231    /// Environment variable name for the URL.
232    pub url_env: Option<SmolStr>,
233    /// PostgreSQL extensions to enable.
234    pub extensions: Vec<PostgresExtension>,
235    /// Additional provider-specific properties.
236    pub properties: Vec<(SmolStr, SmolStr)>,
237    /// Source span for error reporting.
238    pub span: Span,
239    /// Source file this datasource was parsed from (None for single-file path).
240    #[serde(default, skip_serializing_if = "Option::is_none")]
241    pub source_id: Option<crate::loader::SourceId>,
242}
243
244impl Datasource {
245    /// Create a new datasource.
246    pub fn new(name: impl Into<SmolStr>, provider: DatabaseProvider, span: Span) -> Self {
247        Self {
248            name: name.into(),
249            provider,
250            url: None,
251            url_env: None,
252            extensions: Vec::new(),
253            properties: Vec::new(),
254            span,
255            source_id: None,
256        }
257    }
258
259    /// Set the URL.
260    pub fn with_url(mut self, url: impl Into<SmolStr>) -> Self {
261        self.url = Some(url.into());
262        self
263    }
264
265    /// Set the URL from an environment variable.
266    pub fn with_url_env(mut self, env_var: impl Into<SmolStr>) -> Self {
267        self.url_env = Some(env_var.into());
268        self
269    }
270
271    /// Add an extension.
272    pub fn add_extension(&mut self, ext: PostgresExtension) {
273        self.extensions.push(ext);
274    }
275
276    /// Add a property.
277    pub fn add_property(&mut self, key: impl Into<SmolStr>, value: impl Into<SmolStr>) {
278        self.properties.push((key.into(), value.into()));
279    }
280
281    /// Check if this datasource has a specific extension.
282    pub fn has_extension(&self, name: &str) -> bool {
283        self.extensions.iter().any(|e| e.name == name)
284    }
285
286    /// Get extension by name.
287    pub fn get_extension(&self, name: &str) -> Option<&PostgresExtension> {
288        self.extensions.iter().find(|e| e.name == name)
289    }
290
291    /// Check if vector extension is enabled.
292    pub fn has_vector_support(&self) -> bool {
293        self.has_extension("vector") || self.has_extension("pgvector")
294    }
295
296    /// Generate SQL to create all extensions.
297    pub fn extensions_create_sql(&self) -> Vec<String> {
298        self.extensions.iter().map(|e| e.to_create_sql()).collect()
299    }
300}
301
302impl Default for Datasource {
303    fn default() -> Self {
304        Self {
305            name: SmolStr::new("db"),
306            provider: DatabaseProvider::PostgreSQL,
307            url: None,
308            url_env: None,
309            extensions: Vec::new(),
310            properties: Vec::new(),
311            span: Span::new(0, 0),
312            source_id: None,
313        }
314    }
315}
316
317impl std::fmt::Display for Datasource {
318    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
319        write!(
320            f,
321            "datasource {} {{ provider = {} }}",
322            self.name, self.provider
323        )
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330
331    fn make_span() -> Span {
332        Span::new(0, 10)
333    }
334
335    // ==================== DatabaseProvider Tests ====================
336
337    #[test]
338    fn test_database_provider_from_str() {
339        assert_eq!(
340            DatabaseProvider::from_str("postgresql"),
341            Some(DatabaseProvider::PostgreSQL)
342        );
343        assert_eq!(
344            DatabaseProvider::from_str("postgres"),
345            Some(DatabaseProvider::PostgreSQL)
346        );
347        assert_eq!(
348            DatabaseProvider::from_str("PostgreSQL"),
349            Some(DatabaseProvider::PostgreSQL)
350        );
351        assert_eq!(
352            DatabaseProvider::from_str("mysql"),
353            Some(DatabaseProvider::MySQL)
354        );
355        assert_eq!(
356            DatabaseProvider::from_str("sqlite"),
357            Some(DatabaseProvider::SQLite)
358        );
359        assert_eq!(
360            DatabaseProvider::from_str("mongodb"),
361            Some(DatabaseProvider::MongoDB)
362        );
363        assert_eq!(DatabaseProvider::from_str("unknown"), None);
364    }
365
366    #[test]
367    fn test_database_provider_as_str() {
368        assert_eq!(DatabaseProvider::PostgreSQL.as_str(), "postgresql");
369        assert_eq!(DatabaseProvider::MySQL.as_str(), "mysql");
370        assert_eq!(DatabaseProvider::SQLite.as_str(), "sqlite");
371        assert_eq!(DatabaseProvider::MongoDB.as_str(), "mongodb");
372    }
373
374    #[test]
375    fn test_database_provider_supports_extensions() {
376        assert!(DatabaseProvider::PostgreSQL.supports_extensions());
377        assert!(!DatabaseProvider::MySQL.supports_extensions());
378        assert!(!DatabaseProvider::SQLite.supports_extensions());
379        assert!(!DatabaseProvider::MongoDB.supports_extensions());
380    }
381
382    // ==================== PostgresExtension Tests ====================
383
384    #[test]
385    fn test_postgres_extension_new() {
386        let ext = PostgresExtension::new("vector", make_span());
387        assert_eq!(ext.name(), "vector");
388        assert!(ext.schema.is_none());
389        assert!(ext.version.is_none());
390    }
391
392    #[test]
393    fn test_postgres_extension_with_schema() {
394        let ext = PostgresExtension::new("postgis", make_span()).with_schema("public");
395        assert_eq!(ext.schema, Some(SmolStr::new("public")));
396    }
397
398    #[test]
399    fn test_postgres_extension_with_version() {
400        let ext = PostgresExtension::new("vector", make_span()).with_version("0.5.0");
401        assert_eq!(ext.version, Some(SmolStr::new("0.5.0")));
402    }
403
404    #[test]
405    fn test_postgres_extension_to_create_sql() {
406        let ext = PostgresExtension::new("pg_trgm", make_span());
407        assert_eq!(
408            ext.to_create_sql(),
409            "CREATE EXTENSION IF NOT EXISTS \"pg_trgm\";"
410        );
411
412        let ext_with_schema =
413            PostgresExtension::new("postgis", make_span()).with_schema("extensions");
414        assert_eq!(
415            ext_with_schema.to_create_sql(),
416            "CREATE EXTENSION IF NOT EXISTS \"postgis\" SCHEMA \"extensions\";"
417        );
418
419        let ext_with_version = PostgresExtension::new("vector", make_span()).with_version("0.5.0");
420        assert_eq!(
421            ext_with_version.to_create_sql(),
422            "CREATE EXTENSION IF NOT EXISTS \"vector\" VERSION '0.5.0';"
423        );
424    }
425
426    #[test]
427    fn test_postgres_extension_to_drop_sql() {
428        let ext = PostgresExtension::new("vector", make_span());
429        assert_eq!(
430            ext.to_drop_sql(),
431            "DROP EXTENSION IF EXISTS \"vector\" CASCADE;"
432        );
433    }
434
435    #[test]
436    fn test_postgres_extension_provides_custom_types() {
437        assert!(PostgresExtension::new("vector", make_span()).provides_custom_types());
438        assert!(PostgresExtension::new("postgis", make_span()).provides_custom_types());
439        assert!(PostgresExtension::new("hstore", make_span()).provides_custom_types());
440        assert!(!PostgresExtension::new("pg_trgm", make_span()).provides_custom_types());
441    }
442
443    // ==================== WellKnownExtension Tests ====================
444
445    #[test]
446    fn test_well_known_extension_from_str() {
447        assert_eq!(
448            WellKnownExtension::from_str("vector"),
449            Some(WellKnownExtension::Vector)
450        );
451        assert_eq!(
452            WellKnownExtension::from_str("pgvector"),
453            Some(WellKnownExtension::Vector)
454        );
455        assert_eq!(
456            WellKnownExtension::from_str("pg_trgm"),
457            Some(WellKnownExtension::PgTrgm)
458        );
459        assert_eq!(
460            WellKnownExtension::from_str("uuid-ossp"),
461            Some(WellKnownExtension::UuidOssp)
462        );
463        assert_eq!(WellKnownExtension::from_str("unknown"), None);
464    }
465
466    #[test]
467    fn test_well_known_extension_name() {
468        assert_eq!(WellKnownExtension::Vector.extension_name(), "vector");
469        assert_eq!(WellKnownExtension::PgTrgm.extension_name(), "pg_trgm");
470        assert_eq!(WellKnownExtension::UuidOssp.extension_name(), "uuid-ossp");
471    }
472
473    // ==================== Datasource Tests ====================
474
475    #[test]
476    fn test_datasource_new() {
477        let ds = Datasource::new("db", DatabaseProvider::PostgreSQL, make_span());
478        assert_eq!(ds.name.as_str(), "db");
479        assert_eq!(ds.provider, DatabaseProvider::PostgreSQL);
480        assert!(ds.extensions.is_empty());
481    }
482
483    #[test]
484    fn test_datasource_with_url() {
485        let ds = Datasource::new("db", DatabaseProvider::PostgreSQL, make_span())
486            .with_url("postgresql://localhost/mydb");
487        assert_eq!(ds.url, Some(SmolStr::new("postgresql://localhost/mydb")));
488    }
489
490    #[test]
491    fn test_datasource_with_url_env() {
492        let ds = Datasource::new("db", DatabaseProvider::PostgreSQL, make_span())
493            .with_url_env("DATABASE_URL");
494        assert_eq!(ds.url_env, Some(SmolStr::new("DATABASE_URL")));
495    }
496
497    #[test]
498    fn test_datasource_add_extension() {
499        let mut ds = Datasource::new("db", DatabaseProvider::PostgreSQL, make_span());
500        ds.add_extension(PostgresExtension::new("vector", make_span()));
501        ds.add_extension(PostgresExtension::new("pg_trgm", make_span()));
502
503        assert_eq!(ds.extensions.len(), 2);
504        assert!(ds.has_extension("vector"));
505        assert!(ds.has_extension("pg_trgm"));
506        assert!(!ds.has_extension("postgis"));
507    }
508
509    #[test]
510    fn test_datasource_has_vector_support() {
511        let mut ds = Datasource::new("db", DatabaseProvider::PostgreSQL, make_span());
512        assert!(!ds.has_vector_support());
513
514        ds.add_extension(PostgresExtension::new("vector", make_span()));
515        assert!(ds.has_vector_support());
516    }
517
518    #[test]
519    fn test_datasource_extensions_create_sql() {
520        let mut ds = Datasource::new("db", DatabaseProvider::PostgreSQL, make_span());
521        ds.add_extension(PostgresExtension::new("vector", make_span()));
522        ds.add_extension(PostgresExtension::new("pg_trgm", make_span()));
523
524        let sqls = ds.extensions_create_sql();
525        assert_eq!(sqls.len(), 2);
526        assert!(sqls[0].contains("vector"));
527        assert!(sqls[1].contains("pg_trgm"));
528    }
529
530    #[test]
531    fn test_datasource_default() {
532        let ds = Datasource::default();
533        assert_eq!(ds.name.as_str(), "db");
534        assert_eq!(ds.provider, DatabaseProvider::PostgreSQL);
535    }
536}