use std::cmp::PartialEq;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::{Mutex, OwnedMutexGuard, RwLock};
use crate::movement::{Movement, MovementId, MovementStatus, MovementSubsystem};
use crate::movement::error::MovementError;
use crate::movement::update::MovementUpdate;
use crate::notification::NotificationDispatch;
use crate::persist::BarkPersister;
use crate::subsystem::Subsystem;
pub struct MovementManager {
db: Arc<dyn BarkPersister>,
subsystem_ids: RwLock<HashSet<Subsystem>>,
active_movements: RwLock<HashMap<MovementId, Arc<Mutex<Movement>>>>,
notifications: NotificationDispatch,
}
impl MovementManager {
pub(crate) fn new(
db: Arc<dyn BarkPersister>,
notifications: NotificationDispatch,
) -> Self {
Self {
db, notifications,
subsystem_ids: RwLock::new(HashSet::new()),
active_movements: RwLock::new(HashMap::new()),
}
}
pub async fn register_subsystem(&self, id: Subsystem) -> anyhow::Result<(), MovementError> {
let mut guard = self.subsystem_ids.write().await;
if guard.contains(&id) {
Err(MovementError::SubsystemError {
id, error: "Subsystem already registered".into(),
})
} else {
guard.insert(id);
Ok(())
}
}
async fn persist_new_movement(
&self,
subsystem_id: Subsystem,
movement_kind: impl Into<String>,
) -> anyhow::Result<MovementId, MovementError> {
self.db.create_new_movement(
MovementStatus::Pending,
&MovementSubsystem {
name: subsystem_id.as_name().to_string(),
kind: movement_kind.into(),
},
chrono::Local::now(),
).await.map_err(|e| MovementError::CreationError { e })
}
pub async fn new_movement(
&self,
subsystem_id: Subsystem,
movement_kind: impl Into<String>,
) -> anyhow::Result<MovementId, MovementError> {
let id = self.persist_new_movement(subsystem_id, movement_kind).await?;
let movement = self.db.get_movement_by_id(id).await
.map_err(|e| MovementError::LoadError { id, e })?;
self.notifications.dispatch_movement_created(movement);
Ok(id)
}
pub async fn new_guarded_movement(
self: &Arc<Self>,
subsystem_id: Subsystem,
movement_kind: impl Into<String>,
on_drop: OnDropStatus,
) -> anyhow::Result<MovementGuard, MovementError> {
Ok(MovementGuard::new(
self.new_movement(subsystem_id, movement_kind).await?, self.clone(), on_drop,
))
}
pub async fn new_movement_with_update(
&self,
subsystem_id: Subsystem,
movement_kind: impl Into<String>,
update: MovementUpdate,
) -> anyhow::Result<MovementId, MovementError> {
let id = self.persist_new_movement(subsystem_id, movement_kind).await?;
self.update_movement(id, update).await?;
let movement = self.db.get_movement_by_id(id).await
.map_err(|e| MovementError::LoadError { id, e })?;
self.notifications.dispatch_movement_created(movement);
Ok(id)
}
pub async fn new_guarded_movement_with_update(
self: &Arc<Self>,
subsystem_id: Subsystem,
movement_kind: impl Into<String>,
on_drop: OnDropStatus,
update: MovementUpdate,
) -> anyhow::Result<MovementGuard, MovementError> {
Ok(MovementGuard::new(
self.new_movement_with_update(subsystem_id, movement_kind, update).await?,
self.clone(),
on_drop,
))
}
pub async fn new_finished_movement(
&self,
subsystem_id: Subsystem,
movement_kind: impl Into<String>,
status: MovementStatus,
details: MovementUpdate,
) -> anyhow::Result<MovementId, MovementError> {
if status == MovementStatus::Pending {
return Err(MovementError::IncorrectPendingStatus);
}
let id = self.persist_new_movement(subsystem_id, movement_kind).await?;
let mut movement = self.db.get_movement_by_id(id).await
.map_err(|e| MovementError::LoadError { id, e })?;
let at = chrono::Local::now();
details.apply_to(&mut movement, at);
movement.status = status;
movement.time.completed_at = Some(at);
self.db.update_movement(&movement).await
.map_err(|e| MovementError::PersisterError { id, e })?;
self.notifications.dispatch_movement_created(movement);
Ok(id)
}
pub async fn update_movement(
&self,
id: MovementId,
update: MovementUpdate,
) -> anyhow::Result<(), MovementError> {
let mut guard = self.get_cached_movement(id).await?;
update.apply_to(&mut *guard, chrono::Local::now());
self.db.update_movement(&guard).await
.map_err(|e| MovementError::PersisterError { id, e })?;
self.notifications.dispatch_movement_updated(guard.clone());
if guard.status != MovementStatus::Pending {
drop(guard);
self.unload_movement_from_cache(id).await?;
}
Ok(())
}
pub async fn finish_movement(
&self,
id: MovementId,
new_status: MovementStatus,
) -> anyhow::Result<(), MovementError> {
if new_status == MovementStatus::Pending {
return Err(MovementError::IncorrectPendingStatus);
}
let mut guard = self.get_cached_movement(id).await?;
guard.status = new_status;
guard.time.completed_at = Some(chrono::Local::now());
self.db.update_movement(&*guard).await
.map_err(|e| MovementError::PersisterError { id, e })?;
self.notifications.dispatch_movement_updated(guard.clone());
drop(guard);
self.unload_movement_from_cache(id).await
}
pub async fn finish_movement_with_update(
&self,
id: MovementId,
new_status: MovementStatus,
update: MovementUpdate,
) -> anyhow::Result<(), MovementError> {
if new_status == MovementStatus::Pending {
return Err(MovementError::IncorrectPendingStatus);
}
let mut guard = self.get_cached_movement(id).await?;
update.apply_to(&mut *guard, chrono::Local::now());
guard.status = new_status;
guard.time.completed_at = Some(chrono::Local::now());
self.db.update_movement(&*guard).await
.map_err(|e| MovementError::PersisterError { id, e })?;
self.notifications.dispatch_movement_updated(guard.clone());
drop(guard);
self.unload_movement_from_cache(id).await
}
async fn get_cached_movement(
&self,
id: MovementId,
) -> anyhow::Result<OwnedMutexGuard<Movement>, MovementError> {
if let Some(lock) = self.active_movements.read().await.get(&id).cloned() {
return Ok(lock.lock_owned().await);
}
let movement_lock = {
let active_guard = self.active_movements.write().await;
if let Some(lock) = active_guard.get(&id).cloned() {
lock
} else {
Arc::new(Mutex::new(
self.db.get_movement_by_id(id).await
.map_err(|e| MovementError::LoadError { id, e })?
))
}
};
Ok(movement_lock.lock_owned().await)
}
async fn unload_movement_from_cache(&self, id: MovementId) -> anyhow::Result<(), MovementError> {
let mut lock = self.active_movements.write().await;
lock.remove(&id);
Ok(())
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum OnDropStatus {
Canceled,
Failed,
}
impl From<OnDropStatus> for MovementStatus {
fn from(status: OnDropStatus) -> Self {
match status {
OnDropStatus::Canceled => MovementStatus::Canceled,
OnDropStatus::Failed => MovementStatus::Failed,
}
}
}
pub struct MovementGuard {
id: MovementId,
manager: Arc<MovementManager>,
on_drop: OnDropStatus,
has_finished: bool,
}
impl<'a> MovementGuard {
pub fn new(
id: MovementId,
manager: Arc<MovementManager>,
on_drop: OnDropStatus,
) -> Self {
Self {
id,
manager,
on_drop,
has_finished: false,
}
}
pub fn id(&self) -> MovementId {
self.id
}
pub fn set_on_drop_status(&mut self, status: OnDropStatus) {
self.on_drop = status;
}
pub async fn apply_update(
&self,
update: MovementUpdate,
) -> anyhow::Result<(), MovementError> {
self.manager.update_movement(self.id, update).await
}
pub async fn cancel(&mut self) -> anyhow::Result<(), MovementError> {
self.stop();
self.manager.finish_movement(self.id, MovementStatus::Canceled).await
}
pub async fn fail(&mut self) -> anyhow::Result<(), MovementError> {
self.stop();
self.manager.finish_movement(self.id, MovementStatus::Failed).await
}
pub async fn success(
&mut self,
) -> anyhow::Result<(), MovementError> {
self.stop();
self.manager.finish_movement(self.id, MovementStatus::Successful).await
}
pub fn stop(&mut self) {
self.has_finished = true;
}
}
impl Drop for MovementGuard {
fn drop(&mut self) {
if !self.has_finished {
let manager = self.manager.clone();
let id = self.id;
let on_drop = self.on_drop;
crate::utils::spawn(async move {
if let Err(e) = manager.finish_movement(id, on_drop.into()).await {
log::error!("An error occurred in MovementGuard::drop(): {:#}", e);
}
});
}
}
}