Skip to main content

pgmq/
util.rs

1use std::fmt::Display;
2
3use crate::{errors::PgmqError, types::Message};
4
5use log::LevelFilter;
6use serde::Deserialize;
7use sqlx::error::Error;
8use sqlx::postgres::PgRow;
9use sqlx::postgres::{PgConnectOptions, PgPoolOptions};
10use sqlx::ConnectOptions;
11use sqlx::Row;
12use sqlx::{Pool, Postgres};
13use url::{ParseError, Url};
14
15#[cfg(feature = "cli")]
16use futures_util::stream::StreamExt;
17#[cfg(feature = "cli")]
18use sqlx::Executor;
19// Configure connection options
20pub fn conn_options(url: &str) -> Result<PgConnectOptions, ParseError> {
21    // Parse url
22    let parsed = Url::parse(url)?;
23    let options = PgConnectOptions::new()
24        .host(parsed.host_str().ok_or(ParseError::EmptyHost)?)
25        .port(parsed.port().ok_or(ParseError::InvalidPort)?)
26        .username(parsed.username())
27        .password(parsed.password().ok_or(ParseError::IdnaError)?)
28        .database(parsed.path().trim_start_matches('/'))
29        .log_statements(LevelFilter::Debug);
30    Ok(options)
31}
32
33/// Connect to the database
34pub async fn connect(url: &str, max_connections: u32) -> Result<Pool<Postgres>, PgmqError> {
35    let options = conn_options(url)?;
36    let pgp = PgPoolOptions::new()
37        .acquire_timeout(std::time::Duration::from_secs(10))
38        .max_connections(max_connections)
39        .connect_with(options)
40        .await?;
41    Ok(pgp)
42}
43
44// Executes a query and returns a single row
45// If the query returns no rows, None is returned
46// This function is intended for internal use.
47pub async fn fetch_one_message<T: for<'de> Deserialize<'de>>(
48    query: &str,
49    connection: &Pool<Postgres>,
50) -> Result<Option<Message<T>>, PgmqError> {
51    // explore: .fetch_optional()
52    let row: Result<PgRow, Error> = sqlx::query(query).fetch_one(connection).await;
53    match row {
54        Ok(row) => {
55            // happy path - successfully read a message
56            let raw_msg = row.get("message");
57            let parsed_msg = serde_json::from_value::<T>(raw_msg);
58            match parsed_msg {
59                Ok(parsed_msg) => Ok(Some(Message {
60                    msg_id: row.get("msg_id"),
61                    vt: row.get("vt"),
62                    read_ct: row.get("read_ct"),
63                    enqueued_at: row.get("enqueued_at"),
64                    message: parsed_msg,
65                })),
66                Err(e) => Err(PgmqError::JsonParsingError(e)),
67            }
68        }
69        Err(sqlx::error::Error::RowNotFound) => Ok(None),
70        Err(e) => Err(e)?,
71    }
72}
73
74/// A string that is known to be formed of only ASCII alphanumeric or an underscore;
75#[derive(Clone, Copy)]
76pub struct CheckedName<'a>(&'a str);
77
78impl<'a> CheckedName<'a> {
79    /// Accepts `input` as a CheckedName if it is a valid queue identifier
80    pub fn new(input: &'a str) -> Result<Self, PgmqError> {
81        check_input(input)?;
82
83        Ok(Self(input))
84    }
85}
86
87impl AsRef<str> for CheckedName<'_> {
88    fn as_ref(&self) -> &str {
89        self.0
90    }
91}
92
93impl Display for CheckedName<'_> {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        f.write_str(self.0)
96    }
97}
98
99/// panics if input is invalid. otherwise does nothing.
100pub fn check_input(input: &str) -> Result<(), PgmqError> {
101    // Docs:
102    // https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
103
104    // Default value of `NAMEDATALEN`, set in `src/include/pg_config_manual.h`
105    const NAMEDATALEN: usize = 64;
106
107    // The maximum length of an identifier.
108    // Longer names can be used in commands, but they'll be truncated
109    const MAX_IDENTIFIER_LEN: usize = NAMEDATALEN - 1;
110    const BIGGEST_CONCAT: &str = "archived_at_idx_";
111
112    // The max length of the name of a PGMQ queue, considering that the biggest
113    // postgres identifier created by PGMQ is the index on archived_at
114    const MAX_PGMQ_QUEUE_LEN: usize = MAX_IDENTIFIER_LEN - BIGGEST_CONCAT.len();
115
116    let is_short_enough = input.len() <= MAX_PGMQ_QUEUE_LEN;
117    let has_valid_characters = input
118        .as_bytes()
119        .iter()
120        .all(|&c| c.is_ascii_alphanumeric() || c == b'_');
121    let valid = is_short_enough && has_valid_characters;
122    match valid {
123        true => Ok(()),
124        false => Err(PgmqError::InvalidQueueName {
125            name: input.to_owned(),
126        }),
127    }
128}
129
130#[cfg(feature = "cli")]
131async fn get_latest_release_tag() -> Result<String, PgmqError> {
132    log::info!("Getting latest PGMQ release...");
133
134    let client = reqwest::Client::new();
135    let response = client
136        .get("https://api.github.com/repos/pgmq/pgmq/releases/latest")
137        .header("User-Agent", "pgmq-cli")
138        .send()
139        .await?;
140
141    if !response.status().is_success() {
142        return Err(format!("Failed to fetch latest release: HTTP {}", response.status()).into());
143    }
144
145    let release: GitHubRelease = response.json().await?;
146    log::info!("Latest release: {}", release.tag_name);
147
148    Ok(release.tag_name)
149}
150
151#[cfg(feature = "cli")]
152async fn get_install_sql(version: Option<&String>) -> Result<String, PgmqError> {
153    let version_to_use = match version {
154        Some(v) => v.clone(),
155        None => get_latest_release_tag().await?,
156    };
157
158    // Determine if it's a git hash by checking if it's a hex string
159    let is_git_hash = version_to_use.len() >= 7 && // minimum abbreviated hash
160        version_to_use.len() <= 64 && // maximum full hash
161        version_to_use.chars().all(|c| c.is_ascii_hexdigit());
162
163    let sql_url = if is_git_hash {
164        format!(
165            "https://raw.githubusercontent.com/pgmq/pgmq/{version_to_use}/pgmq-extension/sql/pgmq.sql",
166        )
167    } else {
168        let version_tag = if version_to_use.starts_with('v') {
169            version_to_use.clone()
170        } else {
171            format!("v{version_to_use}")
172        };
173        format!(
174            "https://raw.githubusercontent.com/pgmq/pgmq/refs/tags/{version_tag}/pgmq-extension/sql/pgmq.sql",
175        )
176    };
177
178    log::info!("Fetching SQL from: {sql_url}");
179
180    let client = reqwest::Client::new();
181    let response = client.get(&sql_url).send().await?;
182
183    if !response.status().is_success() {
184        return Err(format!("Failed to download SQL file: HTTP {}", response.status()).into());
185    }
186    let sql_content = response.text().await?;
187    Ok(sql_content)
188}
189
190#[cfg(feature = "cli")]
191pub async fn install_pgmq(
192    pool: &Pool<Postgres>,
193    version: Option<&String>,
194) -> Result<(), PgmqError> {
195    log::info!("Installing PGMQ...");
196
197    let sql_content = get_install_sql(version).await?;
198    // Execute the SQL file
199    log::info!("Executing PGMQ installation SQL...");
200    execute_sql_statements(pool, &sql_content).await?;
201
202    log::info!("PGMQ installation completed successfully!");
203    Ok(())
204}
205
206#[cfg(feature = "cli")]
207async fn execute_sql_statements(pool: &Pool<Postgres>, multi_query: &str) -> Result<(), Error> {
208    let mut tx = pool.begin().await?;
209
210    {
211        let mut stream = tx.fetch_many(multi_query);
212        // Consume the stream, ignore results, but propagate errors
213        while let Some(step) = stream.next().await {
214            // Only check for error
215            step?; // If any query fails, this will return the error immediately
216        }
217    }
218
219    tx.commit().await?;
220    Ok(())
221}
222
223#[cfg(feature = "cli")]
224use serde::Serialize;
225#[cfg(feature = "cli")]
226#[derive(Serialize, Deserialize)]
227struct GitHubRelease {
228    tag_name: String,
229    name: String,
230}