1use std::fmt::Display;
2
3use crate::{errors::PgmqError, types::Message};
4
5use log::LevelFilter;
6use serde::Deserialize;
7use sqlx::error::Error;
8use sqlx::postgres::PgRow;
9use sqlx::postgres::{PgConnectOptions, PgPoolOptions};
10use sqlx::ConnectOptions;
11use sqlx::Row;
12use sqlx::{Pool, Postgres};
13use url::{ParseError, Url};
14
15#[cfg(feature = "cli")]
16use futures_util::stream::StreamExt;
17#[cfg(feature = "cli")]
18use sqlx::Executor;
19pub fn conn_options(url: &str) -> Result<PgConnectOptions, ParseError> {
21 let parsed = Url::parse(url)?;
23 let options = PgConnectOptions::new()
24 .host(parsed.host_str().ok_or(ParseError::EmptyHost)?)
25 .port(parsed.port().ok_or(ParseError::InvalidPort)?)
26 .username(parsed.username())
27 .password(parsed.password().ok_or(ParseError::IdnaError)?)
28 .database(parsed.path().trim_start_matches('/'))
29 .log_statements(LevelFilter::Debug);
30 Ok(options)
31}
32
33pub async fn connect(url: &str, max_connections: u32) -> Result<Pool<Postgres>, PgmqError> {
35 let options = conn_options(url)?;
36 let pgp = PgPoolOptions::new()
37 .acquire_timeout(std::time::Duration::from_secs(10))
38 .max_connections(max_connections)
39 .connect_with(options)
40 .await?;
41 Ok(pgp)
42}
43
44pub async fn fetch_one_message<T: for<'de> Deserialize<'de>>(
48 query: &str,
49 connection: &Pool<Postgres>,
50) -> Result<Option<Message<T>>, PgmqError> {
51 let row: Result<PgRow, Error> = sqlx::query(query).fetch_one(connection).await;
53 match row {
54 Ok(row) => {
55 let raw_msg = row.get("message");
57 let parsed_msg = serde_json::from_value::<T>(raw_msg);
58 match parsed_msg {
59 Ok(parsed_msg) => Ok(Some(Message {
60 msg_id: row.get("msg_id"),
61 vt: row.get("vt"),
62 read_ct: row.get("read_ct"),
63 enqueued_at: row.get("enqueued_at"),
64 message: parsed_msg,
65 })),
66 Err(e) => Err(PgmqError::JsonParsingError(e)),
67 }
68 }
69 Err(sqlx::error::Error::RowNotFound) => Ok(None),
70 Err(e) => Err(e)?,
71 }
72}
73
74#[derive(Clone, Copy)]
76pub struct CheckedName<'a>(&'a str);
77
78impl<'a> CheckedName<'a> {
79 pub fn new(input: &'a str) -> Result<Self, PgmqError> {
81 check_input(input)?;
82
83 Ok(Self(input))
84 }
85}
86
87impl AsRef<str> for CheckedName<'_> {
88 fn as_ref(&self) -> &str {
89 self.0
90 }
91}
92
93impl Display for CheckedName<'_> {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 f.write_str(self.0)
96 }
97}
98
99pub fn check_input(input: &str) -> Result<(), PgmqError> {
101 const NAMEDATALEN: usize = 64;
106
107 const MAX_IDENTIFIER_LEN: usize = NAMEDATALEN - 1;
110 const BIGGEST_CONCAT: &str = "archived_at_idx_";
111
112 const MAX_PGMQ_QUEUE_LEN: usize = MAX_IDENTIFIER_LEN - BIGGEST_CONCAT.len();
115
116 let is_short_enough = input.len() <= MAX_PGMQ_QUEUE_LEN;
117 let has_valid_characters = input
118 .as_bytes()
119 .iter()
120 .all(|&c| c.is_ascii_alphanumeric() || c == b'_');
121 let valid = is_short_enough && has_valid_characters;
122 match valid {
123 true => Ok(()),
124 false => Err(PgmqError::InvalidQueueName {
125 name: input.to_owned(),
126 }),
127 }
128}
129
130#[cfg(feature = "cli")]
131async fn get_latest_release_tag() -> Result<String, PgmqError> {
132 log::info!("Getting latest PGMQ release...");
133
134 let client = reqwest::Client::new();
135 let response = client
136 .get("https://api.github.com/repos/pgmq/pgmq/releases/latest")
137 .header("User-Agent", "pgmq-cli")
138 .send()
139 .await?;
140
141 if !response.status().is_success() {
142 return Err(format!("Failed to fetch latest release: HTTP {}", response.status()).into());
143 }
144
145 let release: GitHubRelease = response.json().await?;
146 log::info!("Latest release: {}", release.tag_name);
147
148 Ok(release.tag_name)
149}
150
151#[cfg(feature = "cli")]
152async fn get_install_sql(version: Option<&String>) -> Result<String, PgmqError> {
153 let version_to_use = match version {
154 Some(v) => v.clone(),
155 None => get_latest_release_tag().await?,
156 };
157
158 let is_git_hash = version_to_use.len() >= 7 && version_to_use.len() <= 64 && version_to_use.chars().all(|c| c.is_ascii_hexdigit());
162
163 let sql_url = if is_git_hash {
164 format!(
165 "https://raw.githubusercontent.com/pgmq/pgmq/{version_to_use}/pgmq-extension/sql/pgmq.sql",
166 )
167 } else {
168 let version_tag = if version_to_use.starts_with('v') {
169 version_to_use.clone()
170 } else {
171 format!("v{version_to_use}")
172 };
173 format!(
174 "https://raw.githubusercontent.com/pgmq/pgmq/refs/tags/{version_tag}/pgmq-extension/sql/pgmq.sql",
175 )
176 };
177
178 log::info!("Fetching SQL from: {sql_url}");
179
180 let client = reqwest::Client::new();
181 let response = client.get(&sql_url).send().await?;
182
183 if !response.status().is_success() {
184 return Err(format!("Failed to download SQL file: HTTP {}", response.status()).into());
185 }
186 let sql_content = response.text().await?;
187 Ok(sql_content)
188}
189
190#[cfg(feature = "cli")]
191pub async fn install_pgmq(
192 pool: &Pool<Postgres>,
193 version: Option<&String>,
194) -> Result<(), PgmqError> {
195 log::info!("Installing PGMQ...");
196
197 let sql_content = get_install_sql(version).await?;
198 log::info!("Executing PGMQ installation SQL...");
200 execute_sql_statements(pool, &sql_content).await?;
201
202 log::info!("PGMQ installation completed successfully!");
203 Ok(())
204}
205
206#[cfg(feature = "cli")]
207async fn execute_sql_statements(pool: &Pool<Postgres>, multi_query: &str) -> Result<(), Error> {
208 let mut tx = pool.begin().await?;
209
210 {
211 let mut stream = tx.fetch_many(multi_query);
212 while let Some(step) = stream.next().await {
214 step?; }
217 }
218
219 tx.commit().await?;
220 Ok(())
221}
222
223#[cfg(feature = "cli")]
224use serde::Serialize;
225#[cfg(feature = "cli")]
226#[derive(Serialize, Deserialize)]
227struct GitHubRelease {
228 tag_name: String,
229 name: String,
230}