use crate::install::install_err;
use crate::PgmqError;
use regex::Regex;
use sqlx::encode::IsNull;
use sqlx::error::BoxDynError;
use sqlx::{Database, Decode, Encode, Type};
use std::cmp::Ordering;
use std::fmt::{Display, Formatter};
use std::str::FromStr;
use std::sync::OnceLock;
static VERSION_REGEX: OnceLock<Result<Regex, regex::Error>> = OnceLock::new();
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct Version {
pub major: u32,
pub minor: u32,
pub patch: u32,
}
impl<'r, DB: Database> Decode<'r, DB> for Version
where
&'r str: Decode<'r, DB>,
{
fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
let value = <&str as Decode<DB>>::decode(value)?;
let value = Self::from_str(value)?;
Ok(value)
}
}
impl<'q, DB: Database> Encode<'q, DB> for Version
where
String: Encode<'q, DB>,
{
fn encode_by_ref(
&self,
buf: &mut <DB as Database>::ArgumentBuffer<'q>,
) -> Result<IsNull, BoxDynError> {
let value = self.to_string();
<String as sqlx::Encode<'_, DB>>::encode_by_ref(&value, buf)
}
}
impl<DB: Database> Type<DB> for Version
where
String: Type<DB>,
{
fn type_info() -> DB::TypeInfo {
<String as sqlx::Type<DB>>::type_info()
}
}
impl Version {
pub fn get_version_from_extension_config(extension_config: &str) -> Result<Self, PgmqError> {
let version_line = extension_config
.lines()
.find(|line| line.trim_start().starts_with("default_version"))
.ok_or_else(|| install_err("Version is not present in extension config"))?;
let re =
Regex::new(r"^\s*default_version\s*=\s*'(?<version>.*)'\s*$").map_err(install_err)?;
let captures = re.captures(version_line).ok_or_else(|| {
install_err(format!(
"Unable to extract version from extension config: {}",
version_line
))
})?;
Self::from_str(&captures["version"])
}
}
impl FromStr for Version {
type Err = PgmqError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let captures = VERSION_REGEX
.get_or_init(|| Regex::new(r"^v?(?<major>\d+)\.(?<minor>\d+)\.(?<patch>\d+)$"))
.as_ref()
.map_err(install_err)?
.captures(s)
.ok_or_else(|| install_err(format!("Invalid version: '{}'", s)))?;
Ok(Self {
major: u32::from_str(&captures["major"]).map_err(install_err)?,
minor: u32::from_str(&captures["minor"]).map_err(install_err)?,
patch: u32::from_str(&captures["patch"]).map_err(install_err)?,
})
}
}
impl Display for Version {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}.{}.{}", self.major, self.minor, self.patch)
}
}
impl Ord for Version {
fn cmp(&self, other: &Self) -> Ordering {
let cmp = self.major.cmp(&other.major);
match cmp {
Ordering::Less | Ordering::Greater => {
return cmp;
}
Ordering::Equal => {}
}
let cmp = self.minor.cmp(&other.minor);
match cmp {
Ordering::Less | Ordering::Greater => {
return cmp;
}
Ordering::Equal => {}
}
self.patch.cmp(&other.patch)
}
}
impl PartialOrd for Version {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[cfg(test)]
mod tests {
use super::Version;
use insta::assert_debug_snapshot;
use itertools::Itertools;
use std::str::FromStr;
#[test]
fn get_pgmq_version() {
let extension_config = "default_version = '1.11.0'";
let pgmq_version = Version::get_version_from_extension_config(extension_config).unwrap();
assert_debug_snapshot!(pgmq_version);
}
#[test]
fn get_pgmq_version_extra_whitespace() {
let extension_config = " default_version = '1.11.0' ";
let pgmq_version = Version::get_version_from_extension_config(extension_config).unwrap();
assert_debug_snapshot!(pgmq_version);
}
#[test]
fn get_pgmq_version_err_invalid_version() {
let extension_config = "default_version = 'a.b.c'";
let pgmq_version = Version::get_version_from_extension_config(extension_config);
assert_debug_snapshot!(pgmq_version);
}
#[test]
fn get_pgmq_version_err_version_not_present() {
let extension_config = "";
let pgmq_version = Version::get_version_from_extension_config(extension_config);
assert_debug_snapshot!(pgmq_version);
}
#[test]
fn get_pgmq_version_err_missing_quotes() {
let extension_config = "default_version = 1.11.0";
let pgmq_version = Version::get_version_from_extension_config(extension_config);
assert_debug_snapshot!(pgmq_version);
}
#[test]
fn get_pgmq_version_no_whitespace() {
let extension_config = "default_version='1.11.0'";
let pgmq_version = Version::get_version_from_extension_config(extension_config).unwrap();
assert_debug_snapshot!(pgmq_version);
}
#[test]
fn from_str() {
let version = Version::from_str("1.11.0").unwrap();
assert_eq!(
version,
Version {
major: 1,
minor: 11,
patch: 0
}
);
}
#[test]
fn from_str_v_prefix() {
let version = Version::from_str("v1.11.0").unwrap();
assert_eq!(
version,
Version {
major: 1,
minor: 11,
patch: 0
}
);
}
#[test]
fn from_str_err_not_enough_segments() {
let version = Version::from_str("1.2");
assert_debug_snapshot!(version);
}
#[test]
fn from_str_err_invalid_segment_values() {
let version = Version::from_str("a.b.c");
assert_debug_snapshot!(version);
}
#[test]
fn sort_and_unique() {
let versions = [
"0.1.0", "1.1.1", "2.0.1", "2.0.0", "2.0.0", "1.11.1", "1.0.1",
]
.iter()
.map(|version| Version::from_str(version).unwrap())
.sorted()
.unique()
.collect::<Vec<Version>>();
assert_debug_snapshot!(versions);
}
}