prax_schema/ast/
datasource.rs

1//! Datasource and PostgreSQL extension definitions.
2
3use serde::{Deserialize, Serialize};
4use smol_str::SmolStr;
5
6use super::Span;
7
8/// Database provider type.
9#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
10pub enum DatabaseProvider {
11    /// PostgreSQL database.
12    PostgreSQL,
13    /// MySQL database.
14    MySQL,
15    /// SQLite database.
16    SQLite,
17    /// MongoDB database.
18    MongoDB,
19}
20
21impl DatabaseProvider {
22    /// Parse a provider from a string.
23    #[allow(clippy::should_implement_trait)]
24    pub fn from_str(s: &str) -> Option<Self> {
25        match s.to_lowercase().as_str() {
26            "postgresql" | "postgres" => Some(Self::PostgreSQL),
27            "mysql" => Some(Self::MySQL),
28            "sqlite" => Some(Self::SQLite),
29            "mongodb" => Some(Self::MongoDB),
30            _ => None,
31        }
32    }
33
34    /// Get the provider as a string.
35    pub fn as_str(&self) -> &'static str {
36        match self {
37            Self::PostgreSQL => "postgresql",
38            Self::MySQL => "mysql",
39            Self::SQLite => "sqlite",
40            Self::MongoDB => "mongodb",
41        }
42    }
43
44    /// Check if this provider supports extensions.
45    pub fn supports_extensions(&self) -> bool {
46        matches!(self, Self::PostgreSQL)
47    }
48}
49
50impl std::fmt::Display for DatabaseProvider {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        write!(f, "{}", self.as_str())
53    }
54}
55
56/// A PostgreSQL extension.
57#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
58pub struct PostgresExtension {
59    /// Extension name (e.g., "pg_trgm", "vector", "uuid-ossp").
60    pub name: SmolStr,
61    /// Optional schema to install the extension into.
62    pub schema: Option<SmolStr>,
63    /// Optional version constraint.
64    pub version: Option<SmolStr>,
65    /// Source span for error reporting.
66    pub span: Span,
67}
68
69impl PostgresExtension {
70    /// Create a new extension.
71    pub fn new(name: impl Into<SmolStr>, span: Span) -> Self {
72        Self {
73            name: name.into(),
74            schema: None,
75            version: None,
76            span,
77        }
78    }
79
80    /// Set the schema for this extension.
81    pub fn with_schema(mut self, schema: impl Into<SmolStr>) -> Self {
82        self.schema = Some(schema.into());
83        self
84    }
85
86    /// Set the version for this extension.
87    pub fn with_version(mut self, version: impl Into<SmolStr>) -> Self {
88        self.version = Some(version.into());
89        self
90    }
91
92    /// Get the extension name.
93    pub fn name(&self) -> &str {
94        &self.name
95    }
96
97    /// Generate the CREATE EXTENSION SQL.
98    pub fn to_create_sql(&self) -> String {
99        let mut sql = format!("CREATE EXTENSION IF NOT EXISTS \"{}\"", self.name);
100        if let Some(schema) = &self.schema {
101            sql.push_str(&format!(" SCHEMA \"{}\"", schema));
102        }
103        if let Some(version) = &self.version {
104            sql.push_str(&format!(" VERSION '{}'", version));
105        }
106        sql.push(';');
107        sql
108    }
109
110    /// Generate the DROP EXTENSION SQL.
111    pub fn to_drop_sql(&self) -> String {
112        format!("DROP EXTENSION IF EXISTS \"{}\" CASCADE;", self.name)
113    }
114
115    /// Check if this is a known extension that provides custom types.
116    pub fn provides_custom_types(&self) -> bool {
117        matches!(
118            self.name.as_str(),
119            "vector" | "pgvector" | "postgis" | "hstore" | "ltree" | "cube" | "citext"
120        )
121    }
122}
123
124impl std::fmt::Display for PostgresExtension {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        write!(f, "{}", self.name)
127    }
128}
129
130/// Well-known PostgreSQL extensions with their capabilities.
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
132pub enum WellKnownExtension {
133    /// pg_trgm - Trigram text similarity search.
134    PgTrgm,
135    /// vector/pgvector - Vector similarity search for AI/ML embeddings.
136    Vector,
137    /// uuid-ossp - UUID generation functions.
138    UuidOssp,
139    /// pgcrypto - Cryptographic functions.
140    PgCrypto,
141    /// postgis - Geographic objects and spatial queries.
142    PostGIS,
143    /// hstore - Key-value store.
144    HStore,
145    /// ltree - Hierarchical tree-like data.
146    LTree,
147    /// citext - Case-insensitive text.
148    Citext,
149    /// cube - Multi-dimensional cubes.
150    Cube,
151    /// pg_stat_statements - Query statistics.
152    PgStatStatements,
153    /// aws_lambda - AWS Lambda integration.
154    AwsLambda,
155    /// aws_s3 - AWS S3 integration.
156    AwsS3,
157    /// plpgsql - PL/pgSQL procedural language.
158    PlPgSQL,
159}
160
161impl WellKnownExtension {
162    /// Get the extension name as used in CREATE EXTENSION.
163    pub fn extension_name(&self) -> &'static str {
164        match self {
165            Self::PgTrgm => "pg_trgm",
166            Self::Vector => "vector",
167            Self::UuidOssp => "uuid-ossp",
168            Self::PgCrypto => "pgcrypto",
169            Self::PostGIS => "postgis",
170            Self::HStore => "hstore",
171            Self::LTree => "ltree",
172            Self::Citext => "citext",
173            Self::Cube => "cube",
174            Self::PgStatStatements => "pg_stat_statements",
175            Self::AwsLambda => "aws_lambda",
176            Self::AwsS3 => "aws_s3",
177            Self::PlPgSQL => "plpgsql",
178        }
179    }
180
181    /// Parse a well-known extension from a string.
182    #[allow(clippy::should_implement_trait)]
183    pub fn from_str(s: &str) -> Option<Self> {
184        match s {
185            "pg_trgm" => Some(Self::PgTrgm),
186            "vector" | "pgvector" => Some(Self::Vector),
187            "uuid-ossp" | "uuid_ossp" => Some(Self::UuidOssp),
188            "pgcrypto" => Some(Self::PgCrypto),
189            "postgis" => Some(Self::PostGIS),
190            "hstore" => Some(Self::HStore),
191            "ltree" => Some(Self::LTree),
192            "citext" => Some(Self::Citext),
193            "cube" => Some(Self::Cube),
194            "pg_stat_statements" => Some(Self::PgStatStatements),
195            "aws_lambda" => Some(Self::AwsLambda),
196            "aws_s3" => Some(Self::AwsS3),
197            "plpgsql" => Some(Self::PlPgSQL),
198            _ => None,
199        }
200    }
201
202    /// Get a description of what this extension provides.
203    pub fn description(&self) -> &'static str {
204        match self {
205            Self::PgTrgm => "Trigram-based text similarity search",
206            Self::Vector => "Vector similarity search for AI/ML embeddings",
207            Self::UuidOssp => "UUID generation functions",
208            Self::PgCrypto => "Cryptographic functions",
209            Self::PostGIS => "Geographic objects and spatial queries",
210            Self::HStore => "Key-value store type",
211            Self::LTree => "Hierarchical tree-like data",
212            Self::Citext => "Case-insensitive text type",
213            Self::Cube => "Multi-dimensional cube data type",
214            Self::PgStatStatements => "Query execution statistics",
215            Self::AwsLambda => "AWS Lambda function invocation",
216            Self::AwsS3 => "AWS S3 storage integration",
217            Self::PlPgSQL => "PL/pgSQL procedural language",
218        }
219    }
220}
221
222/// Datasource configuration block.
223#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
224pub struct Datasource {
225    /// Datasource name (usually "db").
226    pub name: SmolStr,
227    /// Database provider.
228    pub provider: DatabaseProvider,
229    /// Connection URL (can be an env var reference).
230    pub url: Option<SmolStr>,
231    /// Environment variable name for the URL.
232    pub url_env: Option<SmolStr>,
233    /// PostgreSQL extensions to enable.
234    pub extensions: Vec<PostgresExtension>,
235    /// Additional provider-specific properties.
236    pub properties: Vec<(SmolStr, SmolStr)>,
237    /// Source span for error reporting.
238    pub span: Span,
239}
240
241impl Datasource {
242    /// Create a new datasource.
243    pub fn new(name: impl Into<SmolStr>, provider: DatabaseProvider, span: Span) -> Self {
244        Self {
245            name: name.into(),
246            provider,
247            url: None,
248            url_env: None,
249            extensions: Vec::new(),
250            properties: Vec::new(),
251            span,
252        }
253    }
254
255    /// Set the URL.
256    pub fn with_url(mut self, url: impl Into<SmolStr>) -> Self {
257        self.url = Some(url.into());
258        self
259    }
260
261    /// Set the URL from an environment variable.
262    pub fn with_url_env(mut self, env_var: impl Into<SmolStr>) -> Self {
263        self.url_env = Some(env_var.into());
264        self
265    }
266
267    /// Add an extension.
268    pub fn add_extension(&mut self, ext: PostgresExtension) {
269        self.extensions.push(ext);
270    }
271
272    /// Add a property.
273    pub fn add_property(&mut self, key: impl Into<SmolStr>, value: impl Into<SmolStr>) {
274        self.properties.push((key.into(), value.into()));
275    }
276
277    /// Check if this datasource has a specific extension.
278    pub fn has_extension(&self, name: &str) -> bool {
279        self.extensions.iter().any(|e| e.name == name)
280    }
281
282    /// Get extension by name.
283    pub fn get_extension(&self, name: &str) -> Option<&PostgresExtension> {
284        self.extensions.iter().find(|e| e.name == name)
285    }
286
287    /// Check if vector extension is enabled.
288    pub fn has_vector_support(&self) -> bool {
289        self.has_extension("vector") || self.has_extension("pgvector")
290    }
291
292    /// Generate SQL to create all extensions.
293    pub fn extensions_create_sql(&self) -> Vec<String> {
294        self.extensions.iter().map(|e| e.to_create_sql()).collect()
295    }
296}
297
298impl Default for Datasource {
299    fn default() -> Self {
300        Self {
301            name: SmolStr::new("db"),
302            provider: DatabaseProvider::PostgreSQL,
303            url: None,
304            url_env: None,
305            extensions: Vec::new(),
306            properties: Vec::new(),
307            span: Span::new(0, 0),
308        }
309    }
310}
311
312impl std::fmt::Display for Datasource {
313    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314        write!(
315            f,
316            "datasource {} {{ provider = {} }}",
317            self.name, self.provider
318        )
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    fn make_span() -> Span {
327        Span::new(0, 10)
328    }
329
330    // ==================== DatabaseProvider Tests ====================
331
332    #[test]
333    fn test_database_provider_from_str() {
334        assert_eq!(
335            DatabaseProvider::from_str("postgresql"),
336            Some(DatabaseProvider::PostgreSQL)
337        );
338        assert_eq!(
339            DatabaseProvider::from_str("postgres"),
340            Some(DatabaseProvider::PostgreSQL)
341        );
342        assert_eq!(
343            DatabaseProvider::from_str("PostgreSQL"),
344            Some(DatabaseProvider::PostgreSQL)
345        );
346        assert_eq!(
347            DatabaseProvider::from_str("mysql"),
348            Some(DatabaseProvider::MySQL)
349        );
350        assert_eq!(
351            DatabaseProvider::from_str("sqlite"),
352            Some(DatabaseProvider::SQLite)
353        );
354        assert_eq!(
355            DatabaseProvider::from_str("mongodb"),
356            Some(DatabaseProvider::MongoDB)
357        );
358        assert_eq!(DatabaseProvider::from_str("unknown"), None);
359    }
360
361    #[test]
362    fn test_database_provider_as_str() {
363        assert_eq!(DatabaseProvider::PostgreSQL.as_str(), "postgresql");
364        assert_eq!(DatabaseProvider::MySQL.as_str(), "mysql");
365        assert_eq!(DatabaseProvider::SQLite.as_str(), "sqlite");
366        assert_eq!(DatabaseProvider::MongoDB.as_str(), "mongodb");
367    }
368
369    #[test]
370    fn test_database_provider_supports_extensions() {
371        assert!(DatabaseProvider::PostgreSQL.supports_extensions());
372        assert!(!DatabaseProvider::MySQL.supports_extensions());
373        assert!(!DatabaseProvider::SQLite.supports_extensions());
374        assert!(!DatabaseProvider::MongoDB.supports_extensions());
375    }
376
377    // ==================== PostgresExtension Tests ====================
378
379    #[test]
380    fn test_postgres_extension_new() {
381        let ext = PostgresExtension::new("vector", make_span());
382        assert_eq!(ext.name(), "vector");
383        assert!(ext.schema.is_none());
384        assert!(ext.version.is_none());
385    }
386
387    #[test]
388    fn test_postgres_extension_with_schema() {
389        let ext = PostgresExtension::new("postgis", make_span()).with_schema("public");
390        assert_eq!(ext.schema, Some(SmolStr::new("public")));
391    }
392
393    #[test]
394    fn test_postgres_extension_with_version() {
395        let ext = PostgresExtension::new("vector", make_span()).with_version("0.5.0");
396        assert_eq!(ext.version, Some(SmolStr::new("0.5.0")));
397    }
398
399    #[test]
400    fn test_postgres_extension_to_create_sql() {
401        let ext = PostgresExtension::new("pg_trgm", make_span());
402        assert_eq!(
403            ext.to_create_sql(),
404            "CREATE EXTENSION IF NOT EXISTS \"pg_trgm\";"
405        );
406
407        let ext_with_schema =
408            PostgresExtension::new("postgis", make_span()).with_schema("extensions");
409        assert_eq!(
410            ext_with_schema.to_create_sql(),
411            "CREATE EXTENSION IF NOT EXISTS \"postgis\" SCHEMA \"extensions\";"
412        );
413
414        let ext_with_version = PostgresExtension::new("vector", make_span()).with_version("0.5.0");
415        assert_eq!(
416            ext_with_version.to_create_sql(),
417            "CREATE EXTENSION IF NOT EXISTS \"vector\" VERSION '0.5.0';"
418        );
419    }
420
421    #[test]
422    fn test_postgres_extension_to_drop_sql() {
423        let ext = PostgresExtension::new("vector", make_span());
424        assert_eq!(
425            ext.to_drop_sql(),
426            "DROP EXTENSION IF EXISTS \"vector\" CASCADE;"
427        );
428    }
429
430    #[test]
431    fn test_postgres_extension_provides_custom_types() {
432        assert!(PostgresExtension::new("vector", make_span()).provides_custom_types());
433        assert!(PostgresExtension::new("postgis", make_span()).provides_custom_types());
434        assert!(PostgresExtension::new("hstore", make_span()).provides_custom_types());
435        assert!(!PostgresExtension::new("pg_trgm", make_span()).provides_custom_types());
436    }
437
438    // ==================== WellKnownExtension Tests ====================
439
440    #[test]
441    fn test_well_known_extension_from_str() {
442        assert_eq!(
443            WellKnownExtension::from_str("vector"),
444            Some(WellKnownExtension::Vector)
445        );
446        assert_eq!(
447            WellKnownExtension::from_str("pgvector"),
448            Some(WellKnownExtension::Vector)
449        );
450        assert_eq!(
451            WellKnownExtension::from_str("pg_trgm"),
452            Some(WellKnownExtension::PgTrgm)
453        );
454        assert_eq!(
455            WellKnownExtension::from_str("uuid-ossp"),
456            Some(WellKnownExtension::UuidOssp)
457        );
458        assert_eq!(WellKnownExtension::from_str("unknown"), None);
459    }
460
461    #[test]
462    fn test_well_known_extension_name() {
463        assert_eq!(WellKnownExtension::Vector.extension_name(), "vector");
464        assert_eq!(WellKnownExtension::PgTrgm.extension_name(), "pg_trgm");
465        assert_eq!(WellKnownExtension::UuidOssp.extension_name(), "uuid-ossp");
466    }
467
468    // ==================== Datasource Tests ====================
469
470    #[test]
471    fn test_datasource_new() {
472        let ds = Datasource::new("db", DatabaseProvider::PostgreSQL, make_span());
473        assert_eq!(ds.name.as_str(), "db");
474        assert_eq!(ds.provider, DatabaseProvider::PostgreSQL);
475        assert!(ds.extensions.is_empty());
476    }
477
478    #[test]
479    fn test_datasource_with_url() {
480        let ds = Datasource::new("db", DatabaseProvider::PostgreSQL, make_span())
481            .with_url("postgresql://localhost/mydb");
482        assert_eq!(ds.url, Some(SmolStr::new("postgresql://localhost/mydb")));
483    }
484
485    #[test]
486    fn test_datasource_with_url_env() {
487        let ds = Datasource::new("db", DatabaseProvider::PostgreSQL, make_span())
488            .with_url_env("DATABASE_URL");
489        assert_eq!(ds.url_env, Some(SmolStr::new("DATABASE_URL")));
490    }
491
492    #[test]
493    fn test_datasource_add_extension() {
494        let mut ds = Datasource::new("db", DatabaseProvider::PostgreSQL, make_span());
495        ds.add_extension(PostgresExtension::new("vector", make_span()));
496        ds.add_extension(PostgresExtension::new("pg_trgm", make_span()));
497
498        assert_eq!(ds.extensions.len(), 2);
499        assert!(ds.has_extension("vector"));
500        assert!(ds.has_extension("pg_trgm"));
501        assert!(!ds.has_extension("postgis"));
502    }
503
504    #[test]
505    fn test_datasource_has_vector_support() {
506        let mut ds = Datasource::new("db", DatabaseProvider::PostgreSQL, make_span());
507        assert!(!ds.has_vector_support());
508
509        ds.add_extension(PostgresExtension::new("vector", make_span()));
510        assert!(ds.has_vector_support());
511    }
512
513    #[test]
514    fn test_datasource_extensions_create_sql() {
515        let mut ds = Datasource::new("db", DatabaseProvider::PostgreSQL, make_span());
516        ds.add_extension(PostgresExtension::new("vector", make_span()));
517        ds.add_extension(PostgresExtension::new("pg_trgm", make_span()));
518
519        let sqls = ds.extensions_create_sql();
520        assert_eq!(sqls.len(), 2);
521        assert!(sqls[0].contains("vector"));
522        assert!(sqls[1].contains("pg_trgm"));
523    }
524
525    #[test]
526    fn test_datasource_default() {
527        let ds = Datasource::default();
528        assert_eq!(ds.name.as_str(), "db");
529        assert_eq!(ds.provider, DatabaseProvider::PostgreSQL);
530    }
531}