km_to_sql/
postgres.rs

1use crate::{error::Result, metadata::TableMetadata};
2use tokio_postgres::types::Json;
3
4const INIT_SQL: &str = include_str!("./postgres_schema.sql");
5
6/// Initialize the schema in the database.
7pub async fn init_schema(client: &tokio_postgres::Client) -> Result<()> {
8    // Execute the SQL commands to create the schema
9    client.batch_execute(INIT_SQL).await?;
10    Ok(())
11}
12
13/// Insert or update metadata for a table.
14pub async fn upsert(
15    client: &tokio_postgres::Client,
16    table: &str,
17    metadata: &TableMetadata,
18) -> Result<()> {
19    // Prepare the SQL statement for upsert
20    let stmt = client
21        .prepare(
22            r#"
23            INSERT INTO "datasets"
24                ("table_name", "metadata")
25            VALUES ($1, $2)
26            ON CONFLICT ("table_name")
27                DO UPDATE
28                SET "metadata" = EXCLUDED."metadata"
29            "#,
30        )
31        .await?;
32
33    // Execute the upsert
34    client.execute(&stmt, &[&table, &Json(metadata)]).await?;
35    Ok(())
36}
37
38/// Get metadata for a list of tables.
39///
40/// Returns a vector of tuples containing the table name and its metadata.
41pub async fn get(
42    client: &tokio_postgres::Client,
43    table: &[&str],
44) -> Result<Vec<(String, TableMetadata)>> {
45    // Prepare the SQL statement for getting metadata
46    let stmt = client
47        .prepare(
48            r#"
49            SELECT
50                "table_name",
51                "metadata"
52            FROM "datasets"
53            WHERE "table_name" = ANY($1)
54            "#,
55        )
56        .await?;
57
58    // Execute the query
59    let rows = client.query(&stmt, &[&table]).await?;
60
61    // Map the rows to TableMetadata and return the Vec
62    let mut result = Vec::new();
63    for row in rows {
64        let table_name: String = row.get(0);
65        let metadata: Json<TableMetadata> = row.get(1);
66        result.push((table_name, metadata.0));
67    }
68    Ok(result)
69}
70
71#[cfg(test)]
72mod tests {
73    use super::*;
74    use std::env;
75    use tokio_postgres::NoTls;
76
77    async fn drop_table(client: &tokio_postgres::Client) -> Result<()> {
78        client
79            .execute(r#"DROP TABLE IF EXISTS "datasets""#, &[])
80            .await?;
81        Ok(())
82    }
83
84    #[tokio::test]
85    async fn test_upsert_and_get() -> Result<()> {
86        let connect_str = env::var("POSTGRES_CONN_STR_TEST").unwrap();
87        let (client, connection) = tokio_postgres::connect(&connect_str, NoTls).await?;
88        tokio::spawn(async move {
89            if let Err(e) = connection.await {
90                eprintln!("Connection error: {}", e);
91            }
92        });
93
94        init_schema(&client).await?;
95
96        let metadata = TableMetadata {
97            name: "example_table".to_string(),
98            desc: Some("An example table".to_string()),
99            source: Some("example_source".to_string()),
100            source_url: None,
101            license: Some("MIT".to_string()),
102            license_url: None,
103            primary_key: Some("id".to_string()),
104            columns: vec![],
105        };
106
107        upsert(&client, "example_table", &metadata).await?;
108
109        let result = get(&client, &["example_table"]).await?;
110        assert_eq!(result.len(), 1);
111        assert_eq!(result[0].0, "example_table");
112        assert_eq!(result[0].1.name, "example_table");
113
114        drop_table(&client).await?;
115
116        Ok(())
117    }
118}