pubky_homeserver/persistence/sql/
connection_string.rs1use std::{fmt::Display, str::FromStr};
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq)]
8pub struct ConnectionString(url::Url);
9
10impl ConnectionString {
11 pub fn new(con_string: &str) -> anyhow::Result<Self> {
14 let con = Self(url::Url::parse(con_string)?);
15 if !con.is_postgres() {
16 anyhow::bail!("Only postgres database urls are supported");
17 }
18 Ok(con)
19 }
20
21 pub fn as_str(&self) -> &str {
23 self.0.as_str()
24 }
25
26 fn is_postgres(&self) -> bool {
27 self.0.scheme() == "postgres" || self.0.scheme() == "postgresql"
28 }
29
30 pub fn database_name(&self) -> &str {
34 self.0.path().trim_start_matches("/")
35 }
36
37 pub fn set_database_name(&mut self, db_name: &str) {
39 self.0.set_path(db_name);
40 }
41}
42
43#[cfg(any(test, feature = "testing"))]
44impl ConnectionString {
45 pub fn default_test_db() -> Self {
49 Self::new("postgres://postgres:postgres@localhost:5432/postgres?pubky-test=true").unwrap()
50 }
51
52 pub fn is_test_db(&self) -> bool {
54 self.0
55 .query_pairs()
56 .any(|(key, value)| key == "pubky-test" && value == "true")
57 }
58}
59
60impl From<url::Url> for ConnectionString {
61 fn from(url: url::Url) -> Self {
62 Self(url)
63 }
64}
65
66impl FromStr for ConnectionString {
67 type Err = url::ParseError;
68
69 fn from_str(s: &str) -> Result<Self, Self::Err> {
70 Ok(Self(url::Url::parse(s)?))
71 }
72}
73
74impl Display for ConnectionString {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 write!(f, "{}", self.0)
77 }
78}
79
80impl Serialize for ConnectionString {
81 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
82 where
83 S: serde::Serializer,
84 {
85 serializer.serialize_str(self.as_str())
86 }
87}
88
89impl<'de> Deserialize<'de> for ConnectionString {
90 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
91 where
92 D: serde::Deserializer<'de>,
93 {
94 let s = String::deserialize(deserializer)?;
95 Self::new(&s).map_err(serde::de::Error::custom)
96 }
97}
98
99impl Default for ConnectionString {
100 fn default() -> Self {
101 Self::new("postgres://localhost:5432/pubky_homeserver").unwrap()
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108
109 #[tokio::test]
110 #[pubky_test_utils::test]
111 async fn test_create_db() {
112 let con_strings = vec![
113 "postgres://localhost:5432/pubky_homeserver",
114 "sqlite:///path/to/sqlite.db",
115 ];
116 for con_string in con_strings {
117 let _: ConnectionString = con_string.parse().unwrap();
118 }
119 }
120}