pjrt 0.2.0

A safe PJRT C API bindings for Rust
Documentation
use std::collections::HashMap;
use std::sync::{Mutex, OnceLock};

use bon::builder;
use libloading::Library;
use pjrt_sys::PJRT_Api;

use crate::{Api, Error, Result};

type GetPjrtApi = unsafe extern "C" fn() -> *const PJRT_Api;

struct PluginManager {
    plugins: Mutex<HashMap<String, (Library, Api)>>,
    aliases: Mutex<HashMap<String, Api>>,
}

impl PluginManager {
    fn new() -> Self {
        PluginManager {
            plugins: Mutex::new(HashMap::new()),
            aliases: Mutex::new(HashMap::new()),
        }
    }

    pub fn load_plugin(&self, library: String, alias: Option<String>) -> Result<Api> {
        let mut libraries = self
            .plugins
            .lock()
            .map_err(|err| Error::PoisonError(err.to_string()))?;
        if let Some((_, api)) = libraries.get(library.as_str()) {
            return Ok(api.clone());
        }
        let lib = unsafe { Library::new(library.as_str())? };
        let get_api_func: libloading::Symbol<GetPjrtApi> = unsafe { lib.get(b"GetPjrtApi")? };
        let ptr = unsafe { get_api_func() };
        let api = Api::wrap(ptr);
        libraries.insert(library, (lib, api.clone()));
        if let Some(alias) = alias {
            let mut aliases = self
                .aliases
                .lock()
                .map_err(|err| Error::PoisonError(err.to_string()))?;
            aliases.insert(alias, api.clone());
        }
        Ok(api)
    }

    pub fn get_plugin(&self, alias: &str) -> Option<Api> {
        let aliases = self
            .aliases
            .lock()
            .map_err(|err| Error::PoisonError(err.to_string()))
            .ok()?;
        aliases.get(alias).cloned()
    }
}

static PLUGIN_MANAGER: OnceLock<PluginManager> = OnceLock::new();

#[builder(finish_fn = "load")]
pub fn plugin(
    #[builder(start_fn, into)] library: String,
    #[builder(into)] alias: Option<String>,
) -> Result<Api> {
    let manager = PLUGIN_MANAGER.get_or_init(PluginManager::new);
    manager.load_plugin(library, alias)
}

#[allow(dead_code)]
pub fn get_plugin(alias: &str) -> Result<Api> {
    let manager = PLUGIN_MANAGER.get_or_init(PluginManager::new);
    manager
        .get_plugin(alias)
        .ok_or_else(|| Error::PluginNotFound(alias.to_string()))
}