use crate::{config::Data, error::Error, fetch::fetch_object_http, traits::Object};
use chrono::{DateTime, Duration as ChronoDuration, Utc};
use serde::{Deserialize, Serialize};
use std::{
fmt::{Debug, Display, Formatter},
marker::PhantomData,
str::FromStr,
};
use url::Url;
impl<T> FromStr for ObjectId<T>
where
T: Object + Send + Debug + 'static,
for<'de2> <T as Object>::Kind: Deserialize<'de2>,
{
type Err = url::ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
ObjectId::parse(s)
}
}
#[derive(Serialize, Deserialize)]
#[serde(transparent)]
pub struct ObjectId<Kind>(Box<Url>, PhantomData<Kind>)
where
Kind: Object,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>;
impl<Kind> ObjectId<Kind>
where
Kind: Object + Send + Debug + 'static,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
{
pub fn parse(url: &str) -> Result<Self, url::ParseError> {
Ok(Self(Box::new(Url::parse(url)?), PhantomData::<Kind>))
}
pub fn inner(&self) -> &Url {
&self.0
}
pub fn into_inner(self) -> Url {
*self.0
}
pub async fn dereference(
&self,
data: &Data<<Kind as Object>::DataType>,
) -> Result<Kind, <Kind as Object>::Error>
where
<Kind as Object>::Error: From<Error>,
{
let db_object = self.dereference_from_db(data).await?;
if let Some(object) = db_object {
if let Some(last_refreshed_at) = object.last_refreshed_at() {
let is_local = self.is_local(data);
if !is_local && should_refetch_object(last_refreshed_at) {
return self.dereference_from_http(data, Some(object)).await;
}
}
Ok(object)
}
else {
self.dereference_from_http(data, None).await
}
}
pub async fn dereference_forced(
&self,
data: &Data<<Kind as Object>::DataType>,
) -> Result<Kind, <Kind as Object>::Error>
where
<Kind as Object>::Error: From<Error>,
{
if data.config.is_local_url(&self.0) {
self.dereference_from_db(data)
.await
.map(|o| o.ok_or(Error::NotFound.into()))?
} else {
self.dereference_from_http(data, None).await
}
}
pub async fn dereference_local(
&self,
data: &Data<<Kind as Object>::DataType>,
) -> Result<Kind, <Kind as Object>::Error>
where
<Kind as Object>::Error: From<Error>,
{
let object = self.dereference_from_db(data).await?;
object.ok_or_else(|| Error::NotFound.into())
}
async fn dereference_from_db(
&self,
data: &Data<<Kind as Object>::DataType>,
) -> Result<Option<Kind>, <Kind as Object>::Error> {
let id = self.0.clone();
Object::read_from_id(*id, data).await
}
async fn dereference_from_http(
&self,
data: &Data<<Kind as Object>::DataType>,
db_object: Option<Kind>,
) -> Result<Kind, <Kind as Object>::Error>
where
<Kind as Object>::Error: From<Error>,
{
let res = Box::pin(fetch_object_http(&self.0, data)).await;
if let Err(Error::ObjectDeleted(url)) = res {
if let Some(db_object) = db_object {
db_object.delete(data).await?;
}
return Err(Error::ObjectDeleted(url).into());
}
if let (Err(_), Some(db_object)) = (&res, db_object) {
return Ok(db_object);
}
let res = res?;
let redirect_url = &res.url;
if data.config.is_local_url(redirect_url) {
return self
.dereference_from_db(data)
.await?
.ok_or(Error::NotFound.into());
}
Box::pin(Kind::verify(&res.object, redirect_url, data)).await?;
Box::pin(Kind::from_json(res.object, data)).await
}
pub fn is_local(&self, data: &Data<<Kind as Object>::DataType>) -> bool {
data.config.is_local_url(&self.0)
}
}
impl<Kind> Clone for ObjectId<Kind>
where
Kind: Object,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
{
fn clone(&self) -> Self {
ObjectId(self.0.clone(), self.1)
}
}
static ACTOR_REFETCH_INTERVAL_SECONDS: i64 = 24 * 60 * 60;
static ACTOR_REFETCH_INTERVAL_SECONDS_DEBUG: i64 = 20;
fn should_refetch_object(last_refreshed: DateTime<Utc>) -> bool {
let update_interval = if cfg!(debug_assertions) {
ChronoDuration::try_seconds(ACTOR_REFETCH_INTERVAL_SECONDS_DEBUG).expect("valid duration")
} else {
ChronoDuration::try_seconds(ACTOR_REFETCH_INTERVAL_SECONDS).expect("valid duration")
};
let refresh_limit = Utc::now() - update_interval;
last_refreshed.lt(&refresh_limit)
}
impl<Kind> Display for ObjectId<Kind>
where
Kind: Object,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0.as_str())
}
}
impl<Kind> Debug for ObjectId<Kind>
where
Kind: Object,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0.as_str())
}
}
impl<Kind> From<ObjectId<Kind>> for Url
where
Kind: Object,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
{
fn from(id: ObjectId<Kind>) -> Self {
*id.0
}
}
impl<Kind> From<Url> for ObjectId<Kind>
where
Kind: Object + Send + 'static,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
{
fn from(url: Url) -> Self {
ObjectId(Box::new(url), PhantomData::<Kind>)
}
}
impl<Kind> PartialEq for ObjectId<Kind>
where
Kind: Object,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
{
fn eq(&self, other: &Self) -> bool {
self.0.eq(&other.0) && self.1 == other.1
}
}
#[cfg(feature = "diesel")]
const _IMPL_DIESEL_NEW_TYPE_FOR_OBJECT_ID: () = {
use diesel::{
backend::Backend,
deserialize::{FromSql, FromStaticSqlRow},
expression::AsExpression,
internal::derives::as_expression::Bound,
pg::Pg,
query_builder::QueryId,
serialize,
serialize::{Output, ToSql},
sql_types::{HasSqlType, SingleValue, Text},
Expression,
Queryable,
};
impl<Kind, ST> ToSql<ST, Pg> for ObjectId<Kind>
where
Kind: Object,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
String: ToSql<ST, Pg>,
{
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
let v = self.0.to_string();
<String as ToSql<Text, Pg>>::to_sql(&v, &mut out.reborrow())
}
}
impl<'expr, Kind, ST> AsExpression<ST> for &'expr ObjectId<Kind>
where
Kind: Object,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
Bound<ST, String>: Expression<SqlType = ST>,
ST: SingleValue,
{
type Expression = Bound<ST, &'expr str>;
fn as_expression(self) -> Self::Expression {
Bound::new(self.0.as_str())
}
}
impl<Kind, ST> AsExpression<ST> for ObjectId<Kind>
where
Kind: Object,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
Bound<ST, String>: Expression<SqlType = ST>,
ST: SingleValue,
{
type Expression = Bound<ST, String>;
fn as_expression(self) -> Self::Expression {
Bound::new(self.0.to_string())
}
}
impl<Kind, ST, DB> FromSql<ST, DB> for ObjectId<Kind>
where
Kind: Object + Send + 'static,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
String: FromSql<ST, DB>,
DB: Backend,
DB: HasSqlType<ST>,
{
fn from_sql(
raw: DB::RawValue<'_>,
) -> Result<Self, Box<dyn ::std::error::Error + Send + Sync>> {
let string: String = FromSql::<ST, DB>::from_sql(raw)?;
Ok(ObjectId::parse(&string)?)
}
}
impl<Kind, ST, DB> Queryable<ST, DB> for ObjectId<Kind>
where
Kind: Object + Send + 'static,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
String: FromStaticSqlRow<ST, DB>,
DB: Backend,
DB: HasSqlType<ST>,
{
type Row = String;
fn build(row: Self::Row) -> diesel::deserialize::Result<Self> {
Ok(ObjectId::parse(&row)?)
}
}
impl<Kind> QueryId for ObjectId<Kind>
where
Kind: Object + 'static,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
{
type QueryId = Self;
}
};
#[cfg(test)]
#[allow(clippy::unwrap_used)]
pub mod tests {
use super::*;
use crate::traits::tests::DbUser;
#[test]
fn test_deserialize() {
let id = ObjectId::<DbUser>::parse("http://test.com/").unwrap();
let string = serde_json::to_string(&id).unwrap();
assert_eq!("\"http://test.com/\"", string);
let parsed: ObjectId<DbUser> = serde_json::from_str(&string).unwrap();
assert_eq!(parsed, id);
}
#[test]
fn test_should_refetch_object() {
let one_second_ago = Utc::now() - ChronoDuration::try_seconds(1).unwrap();
assert!(!should_refetch_object(one_second_ago));
let two_days_ago = Utc::now() - ChronoDuration::try_days(2).unwrap();
assert!(should_refetch_object(two_days_ago));
}
}