1use serde::{Deserialize, Serialize};
4use smol_str::SmolStr;
5
6use super::Span;
7
8#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
10pub enum DatabaseProvider {
11 PostgreSQL,
13 MySQL,
15 SQLite,
17 MongoDB,
19}
20
21impl DatabaseProvider {
22 #[allow(clippy::should_implement_trait)]
24 pub fn from_str(s: &str) -> Option<Self> {
25 match s.to_lowercase().as_str() {
26 "postgresql" | "postgres" => Some(Self::PostgreSQL),
27 "mysql" => Some(Self::MySQL),
28 "sqlite" => Some(Self::SQLite),
29 "mongodb" => Some(Self::MongoDB),
30 _ => None,
31 }
32 }
33
34 pub fn as_str(&self) -> &'static str {
36 match self {
37 Self::PostgreSQL => "postgresql",
38 Self::MySQL => "mysql",
39 Self::SQLite => "sqlite",
40 Self::MongoDB => "mongodb",
41 }
42 }
43
44 pub fn supports_extensions(&self) -> bool {
46 matches!(self, Self::PostgreSQL)
47 }
48}
49
50impl std::fmt::Display for DatabaseProvider {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 write!(f, "{}", self.as_str())
53 }
54}
55
56#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
58pub struct PostgresExtension {
59 pub name: SmolStr,
61 pub schema: Option<SmolStr>,
63 pub version: Option<SmolStr>,
65 pub span: Span,
67}
68
69impl PostgresExtension {
70 pub fn new(name: impl Into<SmolStr>, span: Span) -> Self {
72 Self {
73 name: name.into(),
74 schema: None,
75 version: None,
76 span,
77 }
78 }
79
80 pub fn with_schema(mut self, schema: impl Into<SmolStr>) -> Self {
82 self.schema = Some(schema.into());
83 self
84 }
85
86 pub fn with_version(mut self, version: impl Into<SmolStr>) -> Self {
88 self.version = Some(version.into());
89 self
90 }
91
92 pub fn name(&self) -> &str {
94 &self.name
95 }
96
97 pub fn to_create_sql(&self) -> String {
99 let mut sql = format!("CREATE EXTENSION IF NOT EXISTS \"{}\"", self.name);
100 if let Some(schema) = &self.schema {
101 sql.push_str(&format!(" SCHEMA \"{}\"", schema));
102 }
103 if let Some(version) = &self.version {
104 sql.push_str(&format!(" VERSION '{}'", version));
105 }
106 sql.push(';');
107 sql
108 }
109
110 pub fn to_drop_sql(&self) -> String {
112 format!("DROP EXTENSION IF EXISTS \"{}\" CASCADE;", self.name)
113 }
114
115 pub fn provides_custom_types(&self) -> bool {
117 matches!(
118 self.name.as_str(),
119 "vector" | "pgvector" | "postgis" | "hstore" | "ltree" | "cube" | "citext"
120 )
121 }
122}
123
124impl std::fmt::Display for PostgresExtension {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 write!(f, "{}", self.name)
127 }
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
132pub enum WellKnownExtension {
133 PgTrgm,
135 Vector,
137 UuidOssp,
139 PgCrypto,
141 PostGIS,
143 HStore,
145 LTree,
147 Citext,
149 Cube,
151 PgStatStatements,
153 AwsLambda,
155 AwsS3,
157 PlPgSQL,
159}
160
161impl WellKnownExtension {
162 pub fn extension_name(&self) -> &'static str {
164 match self {
165 Self::PgTrgm => "pg_trgm",
166 Self::Vector => "vector",
167 Self::UuidOssp => "uuid-ossp",
168 Self::PgCrypto => "pgcrypto",
169 Self::PostGIS => "postgis",
170 Self::HStore => "hstore",
171 Self::LTree => "ltree",
172 Self::Citext => "citext",
173 Self::Cube => "cube",
174 Self::PgStatStatements => "pg_stat_statements",
175 Self::AwsLambda => "aws_lambda",
176 Self::AwsS3 => "aws_s3",
177 Self::PlPgSQL => "plpgsql",
178 }
179 }
180
181 #[allow(clippy::should_implement_trait)]
183 pub fn from_str(s: &str) -> Option<Self> {
184 match s {
185 "pg_trgm" => Some(Self::PgTrgm),
186 "vector" | "pgvector" => Some(Self::Vector),
187 "uuid-ossp" | "uuid_ossp" => Some(Self::UuidOssp),
188 "pgcrypto" => Some(Self::PgCrypto),
189 "postgis" => Some(Self::PostGIS),
190 "hstore" => Some(Self::HStore),
191 "ltree" => Some(Self::LTree),
192 "citext" => Some(Self::Citext),
193 "cube" => Some(Self::Cube),
194 "pg_stat_statements" => Some(Self::PgStatStatements),
195 "aws_lambda" => Some(Self::AwsLambda),
196 "aws_s3" => Some(Self::AwsS3),
197 "plpgsql" => Some(Self::PlPgSQL),
198 _ => None,
199 }
200 }
201
202 pub fn description(&self) -> &'static str {
204 match self {
205 Self::PgTrgm => "Trigram-based text similarity search",
206 Self::Vector => "Vector similarity search for AI/ML embeddings",
207 Self::UuidOssp => "UUID generation functions",
208 Self::PgCrypto => "Cryptographic functions",
209 Self::PostGIS => "Geographic objects and spatial queries",
210 Self::HStore => "Key-value store type",
211 Self::LTree => "Hierarchical tree-like data",
212 Self::Citext => "Case-insensitive text type",
213 Self::Cube => "Multi-dimensional cube data type",
214 Self::PgStatStatements => "Query execution statistics",
215 Self::AwsLambda => "AWS Lambda function invocation",
216 Self::AwsS3 => "AWS S3 storage integration",
217 Self::PlPgSQL => "PL/pgSQL procedural language",
218 }
219 }
220}
221
222#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
224pub struct Datasource {
225 pub name: SmolStr,
227 pub provider: DatabaseProvider,
229 pub url: Option<SmolStr>,
231 pub url_env: Option<SmolStr>,
233 pub extensions: Vec<PostgresExtension>,
235 pub properties: Vec<(SmolStr, SmolStr)>,
237 pub span: Span,
239}
240
241impl Datasource {
242 pub fn new(name: impl Into<SmolStr>, provider: DatabaseProvider, span: Span) -> Self {
244 Self {
245 name: name.into(),
246 provider,
247 url: None,
248 url_env: None,
249 extensions: Vec::new(),
250 properties: Vec::new(),
251 span,
252 }
253 }
254
255 pub fn with_url(mut self, url: impl Into<SmolStr>) -> Self {
257 self.url = Some(url.into());
258 self
259 }
260
261 pub fn with_url_env(mut self, env_var: impl Into<SmolStr>) -> Self {
263 self.url_env = Some(env_var.into());
264 self
265 }
266
267 pub fn add_extension(&mut self, ext: PostgresExtension) {
269 self.extensions.push(ext);
270 }
271
272 pub fn add_property(&mut self, key: impl Into<SmolStr>, value: impl Into<SmolStr>) {
274 self.properties.push((key.into(), value.into()));
275 }
276
277 pub fn has_extension(&self, name: &str) -> bool {
279 self.extensions.iter().any(|e| e.name == name)
280 }
281
282 pub fn get_extension(&self, name: &str) -> Option<&PostgresExtension> {
284 self.extensions.iter().find(|e| e.name == name)
285 }
286
287 pub fn has_vector_support(&self) -> bool {
289 self.has_extension("vector") || self.has_extension("pgvector")
290 }
291
292 pub fn extensions_create_sql(&self) -> Vec<String> {
294 self.extensions.iter().map(|e| e.to_create_sql()).collect()
295 }
296}
297
298impl Default for Datasource {
299 fn default() -> Self {
300 Self {
301 name: SmolStr::new("db"),
302 provider: DatabaseProvider::PostgreSQL,
303 url: None,
304 url_env: None,
305 extensions: Vec::new(),
306 properties: Vec::new(),
307 span: Span::new(0, 0),
308 }
309 }
310}
311
312impl std::fmt::Display for Datasource {
313 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314 write!(
315 f,
316 "datasource {} {{ provider = {} }}",
317 self.name, self.provider
318 )
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 fn make_span() -> Span {
327 Span::new(0, 10)
328 }
329
330 #[test]
333 fn test_database_provider_from_str() {
334 assert_eq!(
335 DatabaseProvider::from_str("postgresql"),
336 Some(DatabaseProvider::PostgreSQL)
337 );
338 assert_eq!(
339 DatabaseProvider::from_str("postgres"),
340 Some(DatabaseProvider::PostgreSQL)
341 );
342 assert_eq!(
343 DatabaseProvider::from_str("PostgreSQL"),
344 Some(DatabaseProvider::PostgreSQL)
345 );
346 assert_eq!(
347 DatabaseProvider::from_str("mysql"),
348 Some(DatabaseProvider::MySQL)
349 );
350 assert_eq!(
351 DatabaseProvider::from_str("sqlite"),
352 Some(DatabaseProvider::SQLite)
353 );
354 assert_eq!(
355 DatabaseProvider::from_str("mongodb"),
356 Some(DatabaseProvider::MongoDB)
357 );
358 assert_eq!(DatabaseProvider::from_str("unknown"), None);
359 }
360
361 #[test]
362 fn test_database_provider_as_str() {
363 assert_eq!(DatabaseProvider::PostgreSQL.as_str(), "postgresql");
364 assert_eq!(DatabaseProvider::MySQL.as_str(), "mysql");
365 assert_eq!(DatabaseProvider::SQLite.as_str(), "sqlite");
366 assert_eq!(DatabaseProvider::MongoDB.as_str(), "mongodb");
367 }
368
369 #[test]
370 fn test_database_provider_supports_extensions() {
371 assert!(DatabaseProvider::PostgreSQL.supports_extensions());
372 assert!(!DatabaseProvider::MySQL.supports_extensions());
373 assert!(!DatabaseProvider::SQLite.supports_extensions());
374 assert!(!DatabaseProvider::MongoDB.supports_extensions());
375 }
376
377 #[test]
380 fn test_postgres_extension_new() {
381 let ext = PostgresExtension::new("vector", make_span());
382 assert_eq!(ext.name(), "vector");
383 assert!(ext.schema.is_none());
384 assert!(ext.version.is_none());
385 }
386
387 #[test]
388 fn test_postgres_extension_with_schema() {
389 let ext = PostgresExtension::new("postgis", make_span()).with_schema("public");
390 assert_eq!(ext.schema, Some(SmolStr::new("public")));
391 }
392
393 #[test]
394 fn test_postgres_extension_with_version() {
395 let ext = PostgresExtension::new("vector", make_span()).with_version("0.5.0");
396 assert_eq!(ext.version, Some(SmolStr::new("0.5.0")));
397 }
398
399 #[test]
400 fn test_postgres_extension_to_create_sql() {
401 let ext = PostgresExtension::new("pg_trgm", make_span());
402 assert_eq!(
403 ext.to_create_sql(),
404 "CREATE EXTENSION IF NOT EXISTS \"pg_trgm\";"
405 );
406
407 let ext_with_schema =
408 PostgresExtension::new("postgis", make_span()).with_schema("extensions");
409 assert_eq!(
410 ext_with_schema.to_create_sql(),
411 "CREATE EXTENSION IF NOT EXISTS \"postgis\" SCHEMA \"extensions\";"
412 );
413
414 let ext_with_version = PostgresExtension::new("vector", make_span()).with_version("0.5.0");
415 assert_eq!(
416 ext_with_version.to_create_sql(),
417 "CREATE EXTENSION IF NOT EXISTS \"vector\" VERSION '0.5.0';"
418 );
419 }
420
421 #[test]
422 fn test_postgres_extension_to_drop_sql() {
423 let ext = PostgresExtension::new("vector", make_span());
424 assert_eq!(
425 ext.to_drop_sql(),
426 "DROP EXTENSION IF EXISTS \"vector\" CASCADE;"
427 );
428 }
429
430 #[test]
431 fn test_postgres_extension_provides_custom_types() {
432 assert!(PostgresExtension::new("vector", make_span()).provides_custom_types());
433 assert!(PostgresExtension::new("postgis", make_span()).provides_custom_types());
434 assert!(PostgresExtension::new("hstore", make_span()).provides_custom_types());
435 assert!(!PostgresExtension::new("pg_trgm", make_span()).provides_custom_types());
436 }
437
438 #[test]
441 fn test_well_known_extension_from_str() {
442 assert_eq!(
443 WellKnownExtension::from_str("vector"),
444 Some(WellKnownExtension::Vector)
445 );
446 assert_eq!(
447 WellKnownExtension::from_str("pgvector"),
448 Some(WellKnownExtension::Vector)
449 );
450 assert_eq!(
451 WellKnownExtension::from_str("pg_trgm"),
452 Some(WellKnownExtension::PgTrgm)
453 );
454 assert_eq!(
455 WellKnownExtension::from_str("uuid-ossp"),
456 Some(WellKnownExtension::UuidOssp)
457 );
458 assert_eq!(WellKnownExtension::from_str("unknown"), None);
459 }
460
461 #[test]
462 fn test_well_known_extension_name() {
463 assert_eq!(WellKnownExtension::Vector.extension_name(), "vector");
464 assert_eq!(WellKnownExtension::PgTrgm.extension_name(), "pg_trgm");
465 assert_eq!(WellKnownExtension::UuidOssp.extension_name(), "uuid-ossp");
466 }
467
468 #[test]
471 fn test_datasource_new() {
472 let ds = Datasource::new("db", DatabaseProvider::PostgreSQL, make_span());
473 assert_eq!(ds.name.as_str(), "db");
474 assert_eq!(ds.provider, DatabaseProvider::PostgreSQL);
475 assert!(ds.extensions.is_empty());
476 }
477
478 #[test]
479 fn test_datasource_with_url() {
480 let ds = Datasource::new("db", DatabaseProvider::PostgreSQL, make_span())
481 .with_url("postgresql://localhost/mydb");
482 assert_eq!(ds.url, Some(SmolStr::new("postgresql://localhost/mydb")));
483 }
484
485 #[test]
486 fn test_datasource_with_url_env() {
487 let ds = Datasource::new("db", DatabaseProvider::PostgreSQL, make_span())
488 .with_url_env("DATABASE_URL");
489 assert_eq!(ds.url_env, Some(SmolStr::new("DATABASE_URL")));
490 }
491
492 #[test]
493 fn test_datasource_add_extension() {
494 let mut ds = Datasource::new("db", DatabaseProvider::PostgreSQL, make_span());
495 ds.add_extension(PostgresExtension::new("vector", make_span()));
496 ds.add_extension(PostgresExtension::new("pg_trgm", make_span()));
497
498 assert_eq!(ds.extensions.len(), 2);
499 assert!(ds.has_extension("vector"));
500 assert!(ds.has_extension("pg_trgm"));
501 assert!(!ds.has_extension("postgis"));
502 }
503
504 #[test]
505 fn test_datasource_has_vector_support() {
506 let mut ds = Datasource::new("db", DatabaseProvider::PostgreSQL, make_span());
507 assert!(!ds.has_vector_support());
508
509 ds.add_extension(PostgresExtension::new("vector", make_span()));
510 assert!(ds.has_vector_support());
511 }
512
513 #[test]
514 fn test_datasource_extensions_create_sql() {
515 let mut ds = Datasource::new("db", DatabaseProvider::PostgreSQL, make_span());
516 ds.add_extension(PostgresExtension::new("vector", make_span()));
517 ds.add_extension(PostgresExtension::new("pg_trgm", make_span()));
518
519 let sqls = ds.extensions_create_sql();
520 assert_eq!(sqls.len(), 2);
521 assert!(sqls[0].contains("vector"));
522 assert!(sqls[1].contains("pg_trgm"));
523 }
524
525 #[test]
526 fn test_datasource_default() {
527 let ds = Datasource::default();
528 assert_eq!(ds.name.as_str(), "db");
529 assert_eq!(ds.provider, DatabaseProvider::PostgreSQL);
530 }
531}