use std::str::FromStr;
use snafu::{ResultExt as _, Snafu, ensure};
use url::PathSegmentsMut;
#[derive(
Debug,
Clone,
PartialEq,
Eq,
PartialOrd,
Ord,
Hash,
derive_more::AsRef,
derive_more::Display,
derive_more::Into,
serde::Serialize,
serde::Deserialize,
)]
pub struct Url(url::Url);
#[derive(Debug, Snafu)]
pub enum TryFromUrlError {
#[snafu(display("Failed to parse URL: {source}"))]
Parse { source: url::ParseError },
#[snafu(display("OParl urls must be usable as a base"))]
CannotBeABase,
}
impl TryFrom<url::Url> for Url {
type Error = TryFromUrlError;
fn try_from(value: url::Url) -> Result<Self, Self::Error> {
ensure!(!value.cannot_be_a_base(), CannotBeABaseSnafu);
Ok(Self(value))
}
}
impl FromStr for Url {
type Err = TryFromUrlError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let url: url::Url = s.parse().context(ParseSnafu)?;
url.try_into()
}
}
impl Url {
pub fn as_str(&self) -> &str {
self.0.as_str()
}
pub fn path_segments_mut(&mut self) -> PathSegmentsMut<'_> {
self.0.path_segments_mut().unwrap()
}
pub fn query(&self) -> Option<&str> {
self.0.query()
}
pub fn set_query(&mut self, query: Option<&str>) {
self.0.set_query(query)
}
pub fn set_query_field(&mut self, name: &str, value: &str) {
let mut replaced = false;
let mut query: Vec<(String, String)> = self
.0
.query_pairs()
.map(|(found_name, found_value)| {
let new_value = if found_name == name {
replaced = true;
value
} else {
found_value.as_ref()
};
(found_name.to_string(), new_value.to_string())
})
.collect();
if !replaced {
query.push((name.to_string(), value.to_string()));
}
self.0.query_pairs_mut().clear().extend_pairs(&query);
}
pub fn remove_query_field(&mut self, name: &str) {
let query: Vec<(String, String)> = self
.0
.query_pairs()
.filter(|(found_name, _)| found_name != name)
.map(|(found_name, found_value)| (found_name.to_string(), found_value.to_string()))
.collect();
self.0.query_pairs_mut().clear().extend_pairs(&query);
}
}
impl TryFrom<String> for Url {
type Error = TryFromUrlError;
fn try_from(value: String) -> Result<Self, Self::Error> {
value.parse()
}
}
#[cfg(feature = "sea-orm")]
mod sea_orm_impls {
use sea_orm::{
ActiveValue, ColumnType, DbErr, IntoActiveValue, TryGetError, TryGetable, Value,
sea_query::{ArrayType, Nullable, ValueType, ValueTypeErr},
};
use super::Url;
impl From<Url> for Value {
fn from(value: Url) -> Self {
value.as_str().into()
}
}
impl TryGetable for Url {
fn try_get_by<I: sea_orm::ColIdx>(
res: &sea_orm::QueryResult,
index: I,
) -> Result<Self, TryGetError> {
let s = <String as TryGetable>::try_get_by(res, index)?;
s.try_into().map_err(|e| {
DbErr::TryIntoErr {
from: "String",
into: "Url",
source: Box::new(e),
}
.into()
})
}
}
impl ValueType for Url {
fn try_from(v: Value) -> Result<Self, ValueTypeErr> {
let s = <String as ValueType>::try_from(v)?;
s.try_into().map_err(|_| ValueTypeErr)
}
fn type_name() -> String {
stringify!(String).to_owned()
}
fn array_type() -> ArrayType {
<String as ValueType>::array_type()
}
fn column_type() -> ColumnType {
<String as ValueType>::column_type()
}
}
impl Nullable for Url {
fn null() -> Value {
Value::String(None)
}
}
impl IntoActiveValue<Url> for Url {
fn into_active_value(self) -> ActiveValue<Url> {
ActiveValue::Set(self)
}
}
}
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use super::Url;
#[test]
fn from_str() {
assert_eq!(
Url("https://abc.example.com/qwert".parse().unwrap()),
"https://abc.example.com/qwert"
.parse()
.expect("value must be a parseable url")
);
}
#[test]
fn set_query_field() {
let mut url: Url = "https://abc.example.com/qwert?hello=world".parse().unwrap();
url.set_query_field("per_page", "30");
assert_eq!(
url,
"https://abc.example.com/qwert?hello=world&per_page=30"
.parse()
.unwrap()
);
url.set_query_field("per_page", "300");
assert_eq!(
url,
"https://abc.example.com/qwert?hello=world&per_page=300"
.parse()
.unwrap()
);
url.set_query_field("hello", "everybody");
assert_eq!(
url,
"https://abc.example.com/qwert?hello=everybody&per_page=300"
.parse()
.unwrap()
);
}
#[test]
fn remove_query_field() {
let mut url: Url = "https://abc.example.com/qwert?hello=world&per_page=30&page=5"
.parse()
.unwrap();
url.remove_query_field("per_page");
assert_eq!(
url,
"https://abc.example.com/qwert?hello=world&page=5"
.parse()
.unwrap()
);
}
}
#[cfg(test)]
mod serde_tests {
use pretty_assertions::assert_eq;
use serde_json::json;
use super::Url;
#[test]
fn serialize() {
assert_eq!(
json!("https://abc.example.com/qwert".parse::<Url>().unwrap()),
json!("https://abc.example.com/qwert")
);
}
#[test]
fn deserialize_good() {
let deserialized: Url = serde_json::from_value(json!("https://abc.example.com/qwert"))
.expect("value must be deserializable as Url");
assert_eq!(
deserialized,
"https://abc.example.com/qwert".parse::<Url>().unwrap()
);
}
#[test]
fn deserialize_bad() {
assert!(serde_json::from_value::<Url>(json!([])).is_err());
assert!(serde_json::from_value::<Url>(json!({})).is_err());
assert!(serde_json::from_value::<Url>(json!("hello")).is_err());
assert!(serde_json::from_value::<Url>(json!(true)).is_err());
assert!(serde_json::from_value::<Url>(json!(123)).is_err());
}
}