1use serde::{Deserialize, Serialize};
4use smol_str::SmolStr;
5
6use super::Span;
7
8#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
10pub enum DatabaseProvider {
11 PostgreSQL,
13 MySQL,
15 SQLite,
17 MongoDB,
19}
20
21impl DatabaseProvider {
22 #[allow(clippy::should_implement_trait)]
24 pub fn from_str(s: &str) -> Option<Self> {
25 match s.to_lowercase().as_str() {
26 "postgresql" | "postgres" => Some(Self::PostgreSQL),
27 "mysql" => Some(Self::MySQL),
28 "sqlite" => Some(Self::SQLite),
29 "mongodb" => Some(Self::MongoDB),
30 _ => None,
31 }
32 }
33
34 pub fn as_str(&self) -> &'static str {
36 match self {
37 Self::PostgreSQL => "postgresql",
38 Self::MySQL => "mysql",
39 Self::SQLite => "sqlite",
40 Self::MongoDB => "mongodb",
41 }
42 }
43
44 pub fn supports_extensions(&self) -> bool {
46 matches!(self, Self::PostgreSQL)
47 }
48}
49
50impl std::fmt::Display for DatabaseProvider {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 write!(f, "{}", self.as_str())
53 }
54}
55
56#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
58pub struct PostgresExtension {
59 pub name: SmolStr,
61 pub schema: Option<SmolStr>,
63 pub version: Option<SmolStr>,
65 pub span: Span,
67}
68
69impl PostgresExtension {
70 pub fn new(name: impl Into<SmolStr>, span: Span) -> Self {
72 Self {
73 name: name.into(),
74 schema: None,
75 version: None,
76 span,
77 }
78 }
79
80 pub fn with_schema(mut self, schema: impl Into<SmolStr>) -> Self {
82 self.schema = Some(schema.into());
83 self
84 }
85
86 pub fn with_version(mut self, version: impl Into<SmolStr>) -> Self {
88 self.version = Some(version.into());
89 self
90 }
91
92 pub fn name(&self) -> &str {
94 &self.name
95 }
96
97 pub fn to_create_sql(&self) -> String {
99 let mut sql = format!("CREATE EXTENSION IF NOT EXISTS \"{}\"", self.name);
100 if let Some(schema) = &self.schema {
101 sql.push_str(&format!(" SCHEMA \"{}\"", schema));
102 }
103 if let Some(version) = &self.version {
104 sql.push_str(&format!(" VERSION '{}'", version));
105 }
106 sql.push(';');
107 sql
108 }
109
110 pub fn to_drop_sql(&self) -> String {
112 format!("DROP EXTENSION IF EXISTS \"{}\" CASCADE;", self.name)
113 }
114
115 pub fn provides_custom_types(&self) -> bool {
117 matches!(
118 self.name.as_str(),
119 "vector" | "pgvector" | "postgis" | "hstore" | "ltree" | "cube" | "citext"
120 )
121 }
122}
123
124impl std::fmt::Display for PostgresExtension {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 write!(f, "{}", self.name)
127 }
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
132pub enum WellKnownExtension {
133 PgTrgm,
135 Vector,
137 UuidOssp,
139 PgCrypto,
141 PostGIS,
143 HStore,
145 LTree,
147 Citext,
149 Cube,
151 PgStatStatements,
153 AwsLambda,
155 AwsS3,
157 PlPgSQL,
159}
160
161impl WellKnownExtension {
162 pub fn extension_name(&self) -> &'static str {
164 match self {
165 Self::PgTrgm => "pg_trgm",
166 Self::Vector => "vector",
167 Self::UuidOssp => "uuid-ossp",
168 Self::PgCrypto => "pgcrypto",
169 Self::PostGIS => "postgis",
170 Self::HStore => "hstore",
171 Self::LTree => "ltree",
172 Self::Citext => "citext",
173 Self::Cube => "cube",
174 Self::PgStatStatements => "pg_stat_statements",
175 Self::AwsLambda => "aws_lambda",
176 Self::AwsS3 => "aws_s3",
177 Self::PlPgSQL => "plpgsql",
178 }
179 }
180
181 #[allow(clippy::should_implement_trait)]
183 pub fn from_str(s: &str) -> Option<Self> {
184 match s {
185 "pg_trgm" => Some(Self::PgTrgm),
186 "vector" | "pgvector" => Some(Self::Vector),
187 "uuid-ossp" | "uuid_ossp" => Some(Self::UuidOssp),
188 "pgcrypto" => Some(Self::PgCrypto),
189 "postgis" => Some(Self::PostGIS),
190 "hstore" => Some(Self::HStore),
191 "ltree" => Some(Self::LTree),
192 "citext" => Some(Self::Citext),
193 "cube" => Some(Self::Cube),
194 "pg_stat_statements" => Some(Self::PgStatStatements),
195 "aws_lambda" => Some(Self::AwsLambda),
196 "aws_s3" => Some(Self::AwsS3),
197 "plpgsql" => Some(Self::PlPgSQL),
198 _ => None,
199 }
200 }
201
202 pub fn description(&self) -> &'static str {
204 match self {
205 Self::PgTrgm => "Trigram-based text similarity search",
206 Self::Vector => "Vector similarity search for AI/ML embeddings",
207 Self::UuidOssp => "UUID generation functions",
208 Self::PgCrypto => "Cryptographic functions",
209 Self::PostGIS => "Geographic objects and spatial queries",
210 Self::HStore => "Key-value store type",
211 Self::LTree => "Hierarchical tree-like data",
212 Self::Citext => "Case-insensitive text type",
213 Self::Cube => "Multi-dimensional cube data type",
214 Self::PgStatStatements => "Query execution statistics",
215 Self::AwsLambda => "AWS Lambda function invocation",
216 Self::AwsS3 => "AWS S3 storage integration",
217 Self::PlPgSQL => "PL/pgSQL procedural language",
218 }
219 }
220}
221
222#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
224pub struct Datasource {
225 pub name: SmolStr,
227 pub provider: DatabaseProvider,
229 pub url: Option<SmolStr>,
231 pub url_env: Option<SmolStr>,
233 pub extensions: Vec<PostgresExtension>,
235 pub properties: Vec<(SmolStr, SmolStr)>,
237 pub span: Span,
239 #[serde(default, skip_serializing_if = "Option::is_none")]
241 pub source_id: Option<crate::loader::SourceId>,
242}
243
244impl Datasource {
245 pub fn new(name: impl Into<SmolStr>, provider: DatabaseProvider, span: Span) -> Self {
247 Self {
248 name: name.into(),
249 provider,
250 url: None,
251 url_env: None,
252 extensions: Vec::new(),
253 properties: Vec::new(),
254 span,
255 source_id: None,
256 }
257 }
258
259 pub fn with_url(mut self, url: impl Into<SmolStr>) -> Self {
261 self.url = Some(url.into());
262 self
263 }
264
265 pub fn with_url_env(mut self, env_var: impl Into<SmolStr>) -> Self {
267 self.url_env = Some(env_var.into());
268 self
269 }
270
271 pub fn add_extension(&mut self, ext: PostgresExtension) {
273 self.extensions.push(ext);
274 }
275
276 pub fn add_property(&mut self, key: impl Into<SmolStr>, value: impl Into<SmolStr>) {
278 self.properties.push((key.into(), value.into()));
279 }
280
281 pub fn has_extension(&self, name: &str) -> bool {
283 self.extensions.iter().any(|e| e.name == name)
284 }
285
286 pub fn get_extension(&self, name: &str) -> Option<&PostgresExtension> {
288 self.extensions.iter().find(|e| e.name == name)
289 }
290
291 pub fn has_vector_support(&self) -> bool {
293 self.has_extension("vector") || self.has_extension("pgvector")
294 }
295
296 pub fn extensions_create_sql(&self) -> Vec<String> {
298 self.extensions.iter().map(|e| e.to_create_sql()).collect()
299 }
300}
301
302impl Default for Datasource {
303 fn default() -> Self {
304 Self {
305 name: SmolStr::new("db"),
306 provider: DatabaseProvider::PostgreSQL,
307 url: None,
308 url_env: None,
309 extensions: Vec::new(),
310 properties: Vec::new(),
311 span: Span::new(0, 0),
312 source_id: None,
313 }
314 }
315}
316
317impl std::fmt::Display for Datasource {
318 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
319 write!(
320 f,
321 "datasource {} {{ provider = {} }}",
322 self.name, self.provider
323 )
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330
331 fn make_span() -> Span {
332 Span::new(0, 10)
333 }
334
335 #[test]
338 fn test_database_provider_from_str() {
339 assert_eq!(
340 DatabaseProvider::from_str("postgresql"),
341 Some(DatabaseProvider::PostgreSQL)
342 );
343 assert_eq!(
344 DatabaseProvider::from_str("postgres"),
345 Some(DatabaseProvider::PostgreSQL)
346 );
347 assert_eq!(
348 DatabaseProvider::from_str("PostgreSQL"),
349 Some(DatabaseProvider::PostgreSQL)
350 );
351 assert_eq!(
352 DatabaseProvider::from_str("mysql"),
353 Some(DatabaseProvider::MySQL)
354 );
355 assert_eq!(
356 DatabaseProvider::from_str("sqlite"),
357 Some(DatabaseProvider::SQLite)
358 );
359 assert_eq!(
360 DatabaseProvider::from_str("mongodb"),
361 Some(DatabaseProvider::MongoDB)
362 );
363 assert_eq!(DatabaseProvider::from_str("unknown"), None);
364 }
365
366 #[test]
367 fn test_database_provider_as_str() {
368 assert_eq!(DatabaseProvider::PostgreSQL.as_str(), "postgresql");
369 assert_eq!(DatabaseProvider::MySQL.as_str(), "mysql");
370 assert_eq!(DatabaseProvider::SQLite.as_str(), "sqlite");
371 assert_eq!(DatabaseProvider::MongoDB.as_str(), "mongodb");
372 }
373
374 #[test]
375 fn test_database_provider_supports_extensions() {
376 assert!(DatabaseProvider::PostgreSQL.supports_extensions());
377 assert!(!DatabaseProvider::MySQL.supports_extensions());
378 assert!(!DatabaseProvider::SQLite.supports_extensions());
379 assert!(!DatabaseProvider::MongoDB.supports_extensions());
380 }
381
382 #[test]
385 fn test_postgres_extension_new() {
386 let ext = PostgresExtension::new("vector", make_span());
387 assert_eq!(ext.name(), "vector");
388 assert!(ext.schema.is_none());
389 assert!(ext.version.is_none());
390 }
391
392 #[test]
393 fn test_postgres_extension_with_schema() {
394 let ext = PostgresExtension::new("postgis", make_span()).with_schema("public");
395 assert_eq!(ext.schema, Some(SmolStr::new("public")));
396 }
397
398 #[test]
399 fn test_postgres_extension_with_version() {
400 let ext = PostgresExtension::new("vector", make_span()).with_version("0.5.0");
401 assert_eq!(ext.version, Some(SmolStr::new("0.5.0")));
402 }
403
404 #[test]
405 fn test_postgres_extension_to_create_sql() {
406 let ext = PostgresExtension::new("pg_trgm", make_span());
407 assert_eq!(
408 ext.to_create_sql(),
409 "CREATE EXTENSION IF NOT EXISTS \"pg_trgm\";"
410 );
411
412 let ext_with_schema =
413 PostgresExtension::new("postgis", make_span()).with_schema("extensions");
414 assert_eq!(
415 ext_with_schema.to_create_sql(),
416 "CREATE EXTENSION IF NOT EXISTS \"postgis\" SCHEMA \"extensions\";"
417 );
418
419 let ext_with_version = PostgresExtension::new("vector", make_span()).with_version("0.5.0");
420 assert_eq!(
421 ext_with_version.to_create_sql(),
422 "CREATE EXTENSION IF NOT EXISTS \"vector\" VERSION '0.5.0';"
423 );
424 }
425
426 #[test]
427 fn test_postgres_extension_to_drop_sql() {
428 let ext = PostgresExtension::new("vector", make_span());
429 assert_eq!(
430 ext.to_drop_sql(),
431 "DROP EXTENSION IF EXISTS \"vector\" CASCADE;"
432 );
433 }
434
435 #[test]
436 fn test_postgres_extension_provides_custom_types() {
437 assert!(PostgresExtension::new("vector", make_span()).provides_custom_types());
438 assert!(PostgresExtension::new("postgis", make_span()).provides_custom_types());
439 assert!(PostgresExtension::new("hstore", make_span()).provides_custom_types());
440 assert!(!PostgresExtension::new("pg_trgm", make_span()).provides_custom_types());
441 }
442
443 #[test]
446 fn test_well_known_extension_from_str() {
447 assert_eq!(
448 WellKnownExtension::from_str("vector"),
449 Some(WellKnownExtension::Vector)
450 );
451 assert_eq!(
452 WellKnownExtension::from_str("pgvector"),
453 Some(WellKnownExtension::Vector)
454 );
455 assert_eq!(
456 WellKnownExtension::from_str("pg_trgm"),
457 Some(WellKnownExtension::PgTrgm)
458 );
459 assert_eq!(
460 WellKnownExtension::from_str("uuid-ossp"),
461 Some(WellKnownExtension::UuidOssp)
462 );
463 assert_eq!(WellKnownExtension::from_str("unknown"), None);
464 }
465
466 #[test]
467 fn test_well_known_extension_name() {
468 assert_eq!(WellKnownExtension::Vector.extension_name(), "vector");
469 assert_eq!(WellKnownExtension::PgTrgm.extension_name(), "pg_trgm");
470 assert_eq!(WellKnownExtension::UuidOssp.extension_name(), "uuid-ossp");
471 }
472
473 #[test]
476 fn test_datasource_new() {
477 let ds = Datasource::new("db", DatabaseProvider::PostgreSQL, make_span());
478 assert_eq!(ds.name.as_str(), "db");
479 assert_eq!(ds.provider, DatabaseProvider::PostgreSQL);
480 assert!(ds.extensions.is_empty());
481 }
482
483 #[test]
484 fn test_datasource_with_url() {
485 let ds = Datasource::new("db", DatabaseProvider::PostgreSQL, make_span())
486 .with_url("postgresql://localhost/mydb");
487 assert_eq!(ds.url, Some(SmolStr::new("postgresql://localhost/mydb")));
488 }
489
490 #[test]
491 fn test_datasource_with_url_env() {
492 let ds = Datasource::new("db", DatabaseProvider::PostgreSQL, make_span())
493 .with_url_env("DATABASE_URL");
494 assert_eq!(ds.url_env, Some(SmolStr::new("DATABASE_URL")));
495 }
496
497 #[test]
498 fn test_datasource_add_extension() {
499 let mut ds = Datasource::new("db", DatabaseProvider::PostgreSQL, make_span());
500 ds.add_extension(PostgresExtension::new("vector", make_span()));
501 ds.add_extension(PostgresExtension::new("pg_trgm", make_span()));
502
503 assert_eq!(ds.extensions.len(), 2);
504 assert!(ds.has_extension("vector"));
505 assert!(ds.has_extension("pg_trgm"));
506 assert!(!ds.has_extension("postgis"));
507 }
508
509 #[test]
510 fn test_datasource_has_vector_support() {
511 let mut ds = Datasource::new("db", DatabaseProvider::PostgreSQL, make_span());
512 assert!(!ds.has_vector_support());
513
514 ds.add_extension(PostgresExtension::new("vector", make_span()));
515 assert!(ds.has_vector_support());
516 }
517
518 #[test]
519 fn test_datasource_extensions_create_sql() {
520 let mut ds = Datasource::new("db", DatabaseProvider::PostgreSQL, make_span());
521 ds.add_extension(PostgresExtension::new("vector", make_span()));
522 ds.add_extension(PostgresExtension::new("pg_trgm", make_span()));
523
524 let sqls = ds.extensions_create_sql();
525 assert_eq!(sqls.len(), 2);
526 assert!(sqls[0].contains("vector"));
527 assert!(sqls[1].contains("pg_trgm"));
528 }
529
530 #[test]
531 fn test_datasource_default() {
532 let ds = Datasource::default();
533 assert_eq!(ds.name.as_str(), "db");
534 assert_eq!(ds.provider, DatabaseProvider::PostgreSQL);
535 }
536}