1use crate::{error::Result, metadata::TableMetadata};
2use tokio_postgres::types::Json;
3
4const INIT_SQL: &str = include_str!("./postgres_schema.sql");
5
6pub async fn init_schema(client: &tokio_postgres::Client) -> Result<()> {
8 client.batch_execute(INIT_SQL).await?;
10 Ok(())
11}
12
13pub async fn upsert(
15 client: &tokio_postgres::Client,
16 table: &str,
17 metadata: &TableMetadata,
18) -> Result<()> {
19 let stmt = client
21 .prepare(
22 r#"
23 INSERT INTO "datasets"
24 ("table_name", "metadata")
25 VALUES ($1, $2)
26 ON CONFLICT ("table_name")
27 DO UPDATE
28 SET "metadata" = EXCLUDED."metadata"
29 "#,
30 )
31 .await?;
32
33 client.execute(&stmt, &[&table, &Json(metadata)]).await?;
35 Ok(())
36}
37
38pub async fn get(
42 client: &tokio_postgres::Client,
43 table: &[&str],
44) -> Result<Vec<(String, TableMetadata)>> {
45 let stmt = client
47 .prepare(
48 r#"
49 SELECT
50 "table_name",
51 "metadata"
52 FROM "datasets"
53 WHERE "table_name" = ANY($1)
54 "#,
55 )
56 .await?;
57
58 let rows = client.query(&stmt, &[&table]).await?;
60
61 let mut result = Vec::new();
63 for row in rows {
64 let table_name: String = row.get(0);
65 let metadata: Json<TableMetadata> = row.get(1);
66 result.push((table_name, metadata.0));
67 }
68 Ok(result)
69}
70
71#[cfg(test)]
72mod tests {
73 use super::*;
74 use std::env;
75 use tokio_postgres::NoTls;
76
77 async fn drop_table(client: &tokio_postgres::Client) -> Result<()> {
78 client
79 .execute(r#"DROP TABLE IF EXISTS "datasets""#, &[])
80 .await?;
81 Ok(())
82 }
83
84 #[tokio::test]
85 async fn test_upsert_and_get() -> Result<()> {
86 let connect_str = env::var("POSTGRES_CONN_STR_TEST").unwrap();
87 let (client, connection) = tokio_postgres::connect(&connect_str, NoTls).await?;
88 tokio::spawn(async move {
89 if let Err(e) = connection.await {
90 eprintln!("Connection error: {}", e);
91 }
92 });
93
94 init_schema(&client).await?;
95
96 let metadata = TableMetadata {
97 name: "example_table".to_string(),
98 desc: Some("An example table".to_string()),
99 source: Some("example_source".to_string()),
100 source_url: None,
101 license: Some("MIT".to_string()),
102 license_url: None,
103 primary_key: Some("id".to_string()),
104 columns: vec![],
105 };
106
107 upsert(&client, "example_table", &metadata).await?;
108
109 let result = get(&client, &["example_table"]).await?;
110 assert_eq!(result.len(), 1);
111 assert_eq!(result[0].0, "example_table");
112 assert_eq!(result[0].1.name, "example_table");
113
114 drop_table(&client).await?;
115
116 Ok(())
117 }
118}