use std::{
any::{Any, TypeId},
borrow::Borrow,
collections::HashMap,
rc::Rc,
};
use lunar_lib::trace;
use sled::{
CompareAndSwapError, Db, IVec, Transactional, Tree,
transaction::{ConflictableTransactionError, TransactionError, TransactionalTree},
};
use crate::{
database::{
DatabaseEntry, DatabaseError, DbKey, EntryId, Patchable, deserialize_from_ivec, library_db,
serialize_to_ivec, sled_get_raw, validator::DatabaseReferenceError,
},
library::{
album::{Album, AlbumId, TrackReference},
artist::{ArtistGroup, ArtistId},
track::TrackId,
},
};
#[derive(Debug)]
pub struct CompareAndSwapValue<T: DatabaseEntry> {
pub old: Option<IVec>,
pub new: Option<T>,
}
impl<T: DatabaseEntry> CompareAndSwapValue<T> {
#[must_use]
pub fn new(old: Option<IVec>, new: Option<T>) -> Self {
Self { old, new }
}
}
#[derive(Debug)]
pub struct TreeCompareAndSwap<T: DatabaseEntry> {
tree: Tree,
swaps: HashMap<DbKey, CompareAndSwapValue<T>>,
}
impl<T: DatabaseEntry> TreeCompareAndSwap<T> {
fn new(db: &Db) -> Self {
Self {
tree: T::tree(db),
swaps: HashMap::new(),
}
}
#[must_use]
pub fn tree(&self) -> &Tree {
&self.tree
}
}
pub trait GenericCompareAndSwap: Any + std::fmt::Debug {
fn tree(&self) -> &Tree;
fn as_any(&self) -> &dyn Any;
fn as_any_mut(&mut self) -> &mut dyn Any;
fn apply(
&self,
tx_tree: &TransactionalTree,
) -> Result<(), ConflictableTransactionError<CompareAndSwapError>>;
}
impl<T: DatabaseEntry> GenericCompareAndSwap for TreeCompareAndSwap<T> {
fn tree(&self) -> &Tree {
&self.tree
}
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn apply(
&self,
tx_tree: &TransactionalTree,
) -> Result<(), ConflictableTransactionError<CompareAndSwapError>> {
for (k, v) in &self.swaps {
let ivec = tx_tree.get(k)?;
if ivec == v.old {
if let Some(new) = &v.new {
tx_tree.insert(k, serialize_to_ivec(&new))?;
} else {
tx_tree.remove(k)?;
}
} else {
return Err(ConflictableTransactionError::Abort(CompareAndSwapError {
current: ivec,
proposed: v.new.as_ref().map(serialize_to_ivec),
}));
}
}
Ok(())
}
}
#[derive(Debug)]
pub struct CompareAndSwapTransaction {
swaps: HashMap<TypeId, Box<dyn GenericCompareAndSwap>>,
database: Rc<Db>,
}
impl CompareAndSwapTransaction {
#[must_use]
pub(crate) fn new() -> Self {
Self {
swaps: HashMap::new(),
database: Rc::new(library_db()),
}
}
#[must_use]
pub(crate) fn with_db(database: Rc<Db>) -> Self {
Self {
swaps: HashMap::new(),
database,
}
}
pub(crate) fn tx_patch<T: Patchable<T> + DatabaseEntry + 'static>(
&mut self,
item: T,
) -> Result<(), DatabaseError> {
if let Some(mut old_item) = self.tx_get(item.id())? {
let item_id = item.id();
old_item.patch(item);
self.tx_upsert(item_id, Some(old_item))?;
} else {
self.tx_upsert(item.id(), Some(item))?;
}
Ok(())
}
pub fn tx_get<Id: EntryId>(&self, id: Id) -> Result<Option<Id::Entry>, DatabaseError> {
if let Some(boxed) = self.swaps.get(&TypeId::of::<Id::Entry>()) {
let cas_tree = boxed
.as_any()
.downcast_ref::<TreeCompareAndSwap<Id::Entry>>()
.unwrap();
if let Some(get) = cas_tree.swaps.get(id.as_bytes()) {
return Ok(get.new.clone());
}
}
let tree = Id::Entry::tree(&self.database);
let raw = sled_get_raw(&tree, id.as_bytes())?;
Ok(raw.map(deserialize_from_ivec))
}
pub(crate) fn tx_get_batch<I, A, Entry: DatabaseEntry>(
&self,
items: I,
) -> Result<Vec<Entry>, DatabaseError>
where
I: IntoIterator<Item = A>,
A: Borrow<Entry::Id>,
{
items
.into_iter()
.map(|id| {
self.tx_get(*id.borrow())?
.ok_or(DatabaseError::MissingEntry)
})
.collect()
}
pub fn tx_remove<Id: EntryId>(&mut self, key: Id) -> Result<(), DatabaseError> {
let db = self.database.clone();
let request = self.get_or_new_request::<Id::Entry>();
let key = *key.as_bytes();
if let Some(get_mut) = request.swaps.get_mut(&key) {
get_mut.new = None;
} else {
let old = sled_get_raw(&Id::Entry::tree(&db), &key)?;
request
.swaps
.insert(key, CompareAndSwapValue { old, new: None });
}
Ok(())
}
pub fn tx_upsert<T: DatabaseEntry>(
&mut self,
key: T::Id,
mut new: Option<T>,
) -> Result<(), DatabaseError> {
let db = self.database.clone();
if let Some(new) = &mut new {
new.pre_upsert(self)?;
}
let request = self.get_or_new_request::<T>();
let key = *key.as_bytes();
if let Some(get_mut) = request.swaps.get_mut(&key) {
get_mut.new = new;
} else {
let old = sled_get_raw(&T::tree(&db), &key)?;
request.swaps.insert(key, CompareAndSwapValue { old, new });
}
Ok(())
}
pub fn tx_insert<T: DatabaseEntry>(&mut self, item: T) -> Result<(), DatabaseError> {
if self.tx_get(item.id())?.is_some() {
return Err(DatabaseError::AlreadyInDatabase);
}
self.tx_upsert(item.id(), Some(item))?;
Ok(())
}
pub fn get_or_new_request<T: DatabaseEntry>(&mut self) -> &mut TreeCompareAndSwap<T> {
self.swaps
.entry(TypeId::of::<T>())
.or_insert_with(|| Box::new(TreeCompareAndSwap::<T>::new(&self.database)))
.as_any_mut()
.downcast_mut::<TreeCompareAndSwap<T>>()
.unwrap()
}
#[must_use]
pub fn trees(&self) -> Vec<&Tree> {
self.swaps.values().map(|a| a.tree()).collect()
}
}
pub fn apply_cas_tx(
tx: CompareAndSwapTransaction,
flush: bool,
) -> Result<(), TransactionError<CompareAndSwapError>> {
tx.trees().transaction(|tx_trees| {
for (tree, cas) in tx_trees.iter().zip(tx.swaps.values()) {
cas.apply(tree)?;
if flush {
tree.flush();
}
}
Ok(())
})
}
pub fn db_transaction<F, E>(mut f: F, db: Option<Db>, flush: bool) -> Result<(), E>
where
F: FnMut(&mut CompareAndSwapTransaction) -> Result<(), E>,
E: From<TransactionError<CompareAndSwapError>>,
{
let db = db.map(Rc::new);
loop {
let mut cas_tx = if let Some(db) = db.clone() {
CompareAndSwapTransaction::with_db(db)
} else {
CompareAndSwapTransaction::new()
};
f(&mut cas_tx)?;
match apply_cas_tx(cas_tx, flush) {
Ok(()) => return Ok(()),
Err(TransactionError::Abort(CompareAndSwapError {
current: _,
proposed: _,
})) => {
trace!("Transaction (Not sync) ran into a CAS error and is retrying.");
}
Err(err) => return Err(err.into()),
}
}
}
impl CompareAndSwapTransaction {
pub fn relink_track_to_album(
&mut self,
track_id: TrackId,
album: Option<AlbumId>,
) -> Result<bool, DatabaseError> {
let Some(mut track) = self.tx_get(track_id)? else {
return Ok(false);
};
if track.metadata.album == album {
return Ok(false);
}
let old_album_id = track.metadata.album;
track.metadata.album = album;
self.tx_upsert(track.id(), Some(track.clone()))?;
if let Some(old_album_id) = old_album_id {
let mut old_album = self.tx_get(old_album_id)?.ok_or({
DatabaseReferenceError::TrackDanglingAlbumRef {
track: track_id,
album: old_album_id,
}
})?;
old_album.tracks.retain(|t| t.id != track_id);
self.tx_upsert(old_album.id(), Some(old_album))?;
}
if let Some(new_album_id) = album {
let mut new_album = self
.tx_get(new_album_id)?
.ok_or(DatabaseError::MissingEntry)?;
new_album.tracks.push(TrackReference {
id: track_id,
track_num: None,
disc_num: None,
});
self.tx_upsert(new_album.id(), Some(new_album))?;
}
Ok(true)
}
pub fn album_set_and_relink_artists(
&mut self,
album_id: AlbumId,
artists: &[ArtistId],
) -> Result<bool, DatabaseError> {
let mut album = self.tx_get(album_id)?.ok_or(DatabaseError::MissingEntry)?;
let old_artists: Vec<ArtistId> = album.artist_group.artist_ids().to_vec();
album.artist_group = ArtistGroup::from_artist_ids(artists.iter().cloned());
let removed_artists: Vec<ArtistId> = old_artists
.into_iter()
.filter(|old_artist| !artists.contains(old_artist))
.collect();
self.artists_add_album(album_id, artists)?;
self.artists_remove_album(album_id, &removed_artists)?;
self.tx_upsert(album_id, Some(album))?;
Ok(true)
}
pub fn album_set_and_relink_tracks(
&mut self,
album_id: AlbumId,
tracks: &[TrackId],
) -> Result<bool, DatabaseError> {
let album = self.tx_get(album_id)?.ok_or(DatabaseError::MissingEntry)?;
let old_tracks: Vec<TrackId> = album.tracks.iter().map(|t| t.id).collect();
let removed_tracks: Vec<TrackId> = old_tracks
.iter()
.filter(|old_track| !tracks.contains(old_track))
.cloned()
.collect();
self.album_set_tracks(album, tracks)?;
self.tracks_set_album(Some(album_id), tracks)?;
self.tracks_set_album(None, &removed_tracks)?;
Ok(true)
}
}
impl CompareAndSwapTransaction {
pub(crate) fn album_set_tracks(
&mut self,
mut album: Album,
tracks: &[TrackId],
) -> Result<(), DatabaseError> {
album.tracks = tracks
.iter()
.map(|t| {
album
.tracks
.iter()
.find(|old| old.id == *t)
.cloned()
.unwrap_or(TrackReference {
id: *t,
track_num: None,
disc_num: None,
})
})
.collect();
self.tx_upsert(album.id(), Some(album))?;
Ok(())
}
pub(crate) fn tracks_set_album<'a>(
&mut self,
album_id: Option<AlbumId>,
tracks: impl IntoIterator<Item = &'a TrackId>,
) -> Result<(), DatabaseError> {
for track_id in tracks {
let Some(mut track) = self.tx_get(*track_id)? else {
return Err(DatabaseError::MissingEntry);
};
track.metadata.album = album_id;
for artist_id in track.metadata.artists.artist_ids() {
if let Some(album_id) = album_id {
let Some(mut artist) = self.tx_get(*artist_id)? else {
return Err(DatabaseError::MissingEntry);
};
if artist.albums.contains(&album_id) {
artist.tracks.retain(|t| t != track_id);
self.tx_upsert(*artist_id, Some(artist))?;
} else {
self.artist_add_tracks(*artist_id, &[*track_id])?;
}
} else {
self.artist_add_tracks(*artist_id, &[*track_id])?;
}
}
self.tx_upsert(*track_id, Some(track))?;
}
Ok(())
}
pub(crate) fn artists_remove_album(
&mut self,
album_id: AlbumId,
artists: &[ArtistId],
) -> Result<(), DatabaseError> {
for artist_id in artists {
let Some(mut artist) = self.tx_get(*artist_id)? else {
return Err(DatabaseError::MissingEntry);
};
artist.albums.retain(|a| *a != album_id);
self.tx_upsert(*artist_id, Some(artist))?;
}
Ok(())
}
pub(crate) fn artists_add_album(
&mut self,
album_id: AlbumId,
artists: &[ArtistId],
) -> Result<(), DatabaseError> {
for artist_id in artists {
let Some(mut artist) = self.tx_get(*artist_id)? else {
return Err(DatabaseError::MissingEntry);
};
if !artist.albums.contains(&album_id) {
artist.albums.push(album_id);
}
self.tx_upsert(*artist_id, Some(artist))?;
}
Ok(())
}
pub(crate) fn artist_add_tracks(
&mut self,
artist_id: ArtistId,
tracks: &[TrackId],
) -> Result<(), DatabaseError> {
let Some(mut artist) = self.tx_get(artist_id)? else {
return Err(DatabaseError::MissingEntry);
};
for track_id in tracks {
let Some(track) = self.tx_get(*track_id)? else {
return Err(DatabaseError::MissingEntry);
};
if let Some(album_id) = track.metadata.album
&& artist.albums.contains(&album_id)
{
continue;
}
if !artist.tracks.contains(track_id) {
artist.tracks.push(*track_id)
}
}
self.tx_upsert(artist_id, Some(artist))?;
Ok(())
}
}