librashader-cache 0.11.1

RetroArch shaders for all.
Documentation
use crate::cacheable::Cacheable;
use crate::key::CacheKey;
use std::panic::{catch_unwind, AssertUnwindSafe};

pub(crate) mod internal {
    #[derive(Debug, Error)]
    enum CatchPanicError {
        #[error("a panic ocurred when loading the database")]
        Panic(Box<dyn Any + Send + 'static>),
    }

    use parking_lot::Mutex;
    use persy::{ByteVec, Config, Persy, ValueMode};
    use platform_dirs::AppDirs;
    use std::any::Any;
    use std::error::Error;
    use std::panic::catch_unwind;
    use std::path::PathBuf;
    use std::sync::OnceLock;
    use thiserror::Error;

    pub(crate) fn get_cache_dir() -> Result<PathBuf, Box<dyn Error>> {
        let cache_dir = if let Some(cache_dir) =
            AppDirs::new(Some("librashader"), false).map(|a| a.cache_dir)
        {
            cache_dir
        } else {
            let mut current_dir = std::env::current_dir()?;
            current_dir.push("librashader");
            current_dir
        };

        std::fs::create_dir_all(&cache_dir)?;

        Ok(cache_dir)
    }

    // pub(crate) fn get_cache() -> Result<Connection, Box<dyn Error>> {
    //     let cache_dir = get_cache_dir()?;
    //     let mut conn = Connection::open(&cache_dir.join("librashader.db"))?;
    //
    //     let tx = conn.transaction()?;
    //     tx.pragma_update(Some(DatabaseName::Main), "journal_mode", "wal2")?;
    //     tx.execute(
    //         r#"create table if not exists cache (
    //     type text not null,
    //     id blob not null,
    //     value blob not null unique,
    //     primary key (id, type)
    // )"#,
    //         [],
    //     )?;
    //     tx.commit()?;
    //     Ok(conn)
    // }

    pub(crate) fn remove_cache() {
        let Ok(cache_dir) = get_cache_dir() else {
            return;
        };
        let path = &cache_dir.join("librashader.db.1");
        let _ = std::fs::remove_file(path).ok();
    }

    pub(crate) fn get_cache() -> Result<Persy, Box<dyn Error>> {
        let cache_dir = get_cache_dir()?;
        static CACHE: OnceLock<Persy> = OnceLock::new();

        if let Some(persy) = CACHE.get() {
            return Ok(persy.clone());
        }

        let persy = match catch_unwind(|| {
            Persy::open_or_create_with(
                &cache_dir.join("librashader.db.1"),
                Config::new(),
                |persy| {
                    let tx = persy.begin()?;
                    tx.commit()?;
                    Ok(())
                },
            )
        }) {
            Ok(Ok(conn)) => Ok::<_, Box<dyn Error>>(conn),
            Ok(Err(e)) => {
                remove_cache();
                Err(e)?
            }
            Err(e) => {
                remove_cache();
                Err(CatchPanicError::Panic(e))?
            }
        }?;

        Ok(CACHE.get_or_init(move || persy).clone())
    }

    pub(crate) fn get_blob(
        conn: &Persy,
        index: &str,
        key: &[u8],
    ) -> Result<Option<Vec<u8>>, Box<dyn Error>> {
        if !conn.exists_index(index)? {
            return Ok(None);
        }

        let value = conn.get::<_, ByteVec>(index, &ByteVec::from(key))?.next();
        Ok(value.map(|v| v.to_vec()))
    }

    pub(crate) fn set_blob(
        conn: &Persy,
        index: &str,
        key: &[u8],
        value: &[u8],
    ) -> Result<(), Box<dyn Error>> {
        static WRITE_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
        let write_lock = WRITE_LOCK.get_or_init(|| Mutex::new(()));
        let _guard = write_lock.lock();

        let mut tx = conn.begin()?;
        if !tx.exists_index(index)? {
            tx.create_index::<ByteVec, ByteVec>(index, ValueMode::Replace)?;
        }

        tx.put(index, ByteVec::from(key), ByteVec::from(value))?;
        tx.commit()?;

        Ok(())
    }
}

/// Cache a shader object (usually bytecode) created by the keyed objects.
///
/// - `factory` is the function that compiles the values passed as keys to a shader object.
/// - `load` tries to load a compiled shader object to a driver-specialized result.
pub fn cache_shader_object<E, T, R, H, const KEY_SIZE: usize>(
    index: &str,
    keys: &[H; KEY_SIZE],
    factory: impl Fn(&[H; KEY_SIZE]) -> Result<T, E> + std::panic::RefUnwindSafe,
    load: impl Fn(T) -> Result<R, E> + std::panic::RefUnwindSafe,
    bypass_cache: bool,
) -> Result<R, E>
where
    H: CacheKey + std::panic::RefUnwindSafe,
    T: Cacheable,
{
    if bypass_cache {
        return Ok(load(factory(keys)?)?);
    }

    catch_unwind(|| {
        let cache = internal::get_cache();

        let Ok(cache) = cache else {
            return Ok(load(factory(keys)?)?);
        };

        let hashkey = {
            let mut hasher = blake3::Hasher::new();
            for subkeys in keys {
                hasher.update(subkeys.hash_bytes());
            }
            let hash = hasher.finalize();
            hash
        };

        'attempt: {
            if let Ok(Some(blob)) = internal::get_blob(&cache, index, hashkey.as_bytes()) {
                let cached = T::from_bytes(&blob).map(&load);

                match cached {
                    None => break 'attempt,
                    Some(Err(_)) => break 'attempt,
                    Some(Ok(res)) => return Ok(res),
                }
            }
        };

        let blob = factory(keys)?;

        if let Some(slice) = T::to_bytes(&blob) {
            let _ = internal::set_blob(&cache, index, hashkey.as_bytes(), &slice);
        }
        Ok(load(blob)?)
    })
    .unwrap_or_else(|_| {
        internal::remove_cache();
        Ok(load(factory(keys)?)?)
    })
}

/// Cache a shader object (usually bytecode) created by the keyed objects, deferring the load step.
///
/// This behaves like [`cache_shader_object`], except that `load` is not executed immediately.
/// Instead, the compiled (or cached) shader object is curried into the returned closure, which can
/// be invoked later to produce the driver-specialized result. This allows shader objects to be
/// compiled in parallel via `factory`, then have their driver resources created sequentially by the
/// returned closures, for drivers whose object creation is not safe to call concurrently.
///
/// Because `load` is deferred, a cached shader object can not be validated by `load` at fetch time:
/// unlike [`cache_shader_object`], a cached object that fails to load will not be transparently
/// recompiled.
pub fn cache_shader_object_deferred<'a, E, T, R, H, const KEY_SIZE: usize>(
    index: &str,
    keys: &[H; KEY_SIZE],
    factory: impl Fn(&[H; KEY_SIZE]) -> Result<T, E> + std::panic::RefUnwindSafe,
    load: impl FnOnce(T) -> Result<R, E> + Send + 'a,
    bypass_cache: bool,
) -> Result<Box<dyn FnOnce() -> Result<R, E> + Send + 'a>, E>
where
    H: CacheKey + std::panic::RefUnwindSafe,
    T: Cacheable + Send + 'a,
{
    let object = if bypass_cache {
        factory(keys)
    } else {
        catch_unwind(|| {
            let Ok(cache) = internal::get_cache() else {
                return factory(keys);
            };

            let hashkey = {
                let mut hasher = blake3::Hasher::new();
                for subkeys in keys {
                    hasher.update(subkeys.hash_bytes());
                }
                hasher.finalize()
            };

            if let Ok(Some(blob)) = internal::get_blob(&cache, index, hashkey.as_bytes()) {
                if let Some(cached) = T::from_bytes(&blob) {
                    return Ok(cached);
                }
            }

            let blob = factory(keys)?;

            if let Some(slice) = T::to_bytes(&blob) {
                let _ = internal::set_blob(&cache, index, hashkey.as_bytes(), &slice);
            }
            Ok(blob)
        })
        .unwrap_or_else(|_| {
            internal::remove_cache();
            factory(keys)
        })
    }?;

    Ok(Box::new(move || load(object)))
}

/// Cache a pipeline state object.
///
/// Keys are not used to create the object and are only used to uniquely identify the pipeline state.
///
/// - `restore_pipeline` tries to restore the pipeline with either a cached binary pipeline state
///    cache, or create a new pipeline if no cached value is available.
/// - `fetch_pipeline_state` fetches the new pipeline state cache after the pipeline was created.
pub fn cache_pipeline<E, T, R, const KEY_SIZE: usize>(
    index: &str,
    keys: &[&dyn CacheKey; KEY_SIZE],
    restore_pipeline: impl Fn(Option<Vec<u8>>) -> Result<R, E>,
    fetch_pipeline_state: impl Fn(&R) -> Result<T, E>,
    bypass_cache: bool,
) -> Result<R, E>
where
    T: Cacheable,
{
    if bypass_cache {
        return Ok(restore_pipeline(None)?);
    }

    catch_unwind(AssertUnwindSafe(|| {
        let cache = internal::get_cache();

        let Ok(cache) = cache else {
            return Ok(restore_pipeline(None)?);
        };

        let hashkey = {
            let mut hasher = blake3::Hasher::new();
            for subkeys in keys {
                hasher.update(subkeys.hash_bytes());
            }
            let hash = hasher.finalize();
            hash
        };

        let pipeline = 'attempt: {
            if let Ok(Some(blob)) = internal::get_blob(&cache, index, hashkey.as_bytes()) {
                let cached = restore_pipeline(Some(blob));
                match cached {
                    Ok(res) => {
                        break 'attempt res;
                    }
                    _ => (),
                }
            }

            restore_pipeline(None)?
        };

        // update the pso every time just in case.
        if let Ok(state) = fetch_pipeline_state(&pipeline) {
            if let Some(slice) = T::to_bytes(&state) {
                // We don't really care if the transaction fails, just try again next time.
                let _ = internal::set_blob(&cache, index, hashkey.as_bytes(), &slice);
            }
        }

        Ok(pipeline)
    }))
    .unwrap_or_else(|_| {
        internal::remove_cache();
        Ok(restore_pipeline(None)?)
    })
}