Skip to main content

pubky_homeserver/persistence/sql/
connection_string.rs

1use std::{fmt::Display, str::FromStr};
2
3use serde::{Deserialize, Serialize};
4
5/// A connection string for a  postgres database.
6/// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING-URIS
7#[derive(Debug, Clone, PartialEq)]
8pub struct ConnectionString(url::Url);
9
10impl ConnectionString {
11    /// Create a new connection string from a string.
12    /// This function validates that the connection string is a postgres connection string.
13    pub fn new(con_string: &str) -> anyhow::Result<Self> {
14        let con = Self(url::Url::parse(con_string)?);
15        if !con.is_postgres() {
16            anyhow::bail!("Only postgres database urls are supported");
17        }
18        Ok(con)
19    }
20
21    /// Get the connection string as a str.
22    pub fn as_str(&self) -> &str {
23        self.0.as_str()
24    }
25
26    fn is_postgres(&self) -> bool {
27        self.0.scheme() == "postgres" || self.0.scheme() == "postgresql"
28    }
29
30    /// Get the database name
31    /// For postgres, this is the database name directly
32    /// For sqlite, this is the path to the database file
33    pub fn database_name(&self) -> &str {
34        self.0.path().trim_start_matches("/")
35    }
36
37    /// Set the database name
38    pub fn set_database_name(&mut self, db_name: &str) {
39        self.0.set_path(db_name);
40    }
41}
42
43#[cfg(any(test, feature = "testing"))]
44impl ConnectionString {
45    /// Returns a connection string for a test database.
46    /// This is a postgres database that is not real.
47    /// It is used as an indicator for a empheral test database.
48    pub fn default_test_db() -> Self {
49        Self::new("postgres://postgres:postgres@localhost:5432/postgres?pubky-test=true").unwrap()
50    }
51
52    /// Returns true if the connection string is for a test database.
53    pub fn is_test_db(&self) -> bool {
54        self.0
55            .query_pairs()
56            .any(|(key, value)| key == "pubky-test" && value == "true")
57    }
58}
59
60impl From<url::Url> for ConnectionString {
61    fn from(url: url::Url) -> Self {
62        Self(url)
63    }
64}
65
66impl FromStr for ConnectionString {
67    type Err = url::ParseError;
68
69    fn from_str(s: &str) -> Result<Self, Self::Err> {
70        Ok(Self(url::Url::parse(s)?))
71    }
72}
73
74impl Display for ConnectionString {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        write!(f, "{}", self.0)
77    }
78}
79
80impl Serialize for ConnectionString {
81    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
82    where
83        S: serde::Serializer,
84    {
85        serializer.serialize_str(self.as_str())
86    }
87}
88
89impl<'de> Deserialize<'de> for ConnectionString {
90    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
91    where
92        D: serde::Deserializer<'de>,
93    {
94        let s = String::deserialize(deserializer)?;
95        Self::new(&s).map_err(serde::de::Error::custom)
96    }
97}
98
99impl Default for ConnectionString {
100    fn default() -> Self {
101        Self::new("postgres://localhost:5432/pubky_homeserver").unwrap()
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108
109    #[tokio::test]
110    #[pubky_test_utils::test]
111    async fn test_create_db() {
112        let con_strings = vec![
113            "postgres://localhost:5432/pubky_homeserver",
114            "sqlite:///path/to/sqlite.db",
115        ];
116        for con_string in con_strings {
117            let _: ConnectionString = con_string.parse().unwrap();
118        }
119    }
120}