use once_cell::sync::OnceCell;
use spin::RwLock;
use crate::{
WCError,
WCResult,
alloc::{
format,
string::String,
sync::Arc,
},
prelude::*,
pretrained::factory::{
vocab_description::{
LabeledVocab,
VocabDescription,
VocabListing,
},
vocab_provider::VocabProvider,
vocab_query::VocabQuery,
},
support::resources::ResourceLoader,
};
static FACTORY: OnceCell<RwLock<VocabFactory>> = OnceCell::new();
pub struct VocabProviderInventoryHook {
pub builder: fn() -> Arc<dyn VocabProvider>,
}
inventory::collect!(VocabProviderInventoryHook);
impl VocabProviderInventoryHook {
pub const fn new(builder: fn() -> Arc<dyn VocabProvider>) -> Self {
Self { builder }
}
}
pub fn get_vocab_factory() -> &'static RwLock<VocabFactory> {
FACTORY.get_or_init(|| RwLock::new(init_factory()))
}
fn init_factory() -> VocabFactory {
let mut factory = VocabFactory::default();
for hook in inventory::iter::<VocabProviderInventoryHook> {
factory.register_provider((hook.builder)()).unwrap();
}
factory
}
pub fn with_vocab_factory_mut<F, V>(func: &mut F) -> V
where
F: FnMut(&mut VocabFactory) -> V,
{
let mut guard = get_vocab_factory().write();
let factory = &mut *guard;
func(factory)
}
pub fn with_vocab_factory<F, V>(func: &mut F) -> V
where
F: FnMut(&VocabFactory) -> V,
{
let guard = get_vocab_factory().read();
let factory = &*guard;
func(factory)
}
pub fn list_vocabs() -> Vec<VocabListing> {
with_vocab_factory(&mut |f: &VocabFactory| f.list_vocabs())
}
pub fn resolve_vocab(name: &str) -> WCResult<VocabDescription> {
with_vocab_factory(&mut |f: &VocabFactory| f.resolve_vocab(name))
}
pub fn load_vocab(
name: &str,
loader: &mut dyn ResourceLoader,
) -> WCResult<LabeledVocab<u32>> {
with_vocab_factory(&mut move |f: &VocabFactory| f.load_vocab(name, loader))
}
pub fn list_models() -> Vec<String> {
let mut res = Vec::new();
for listing in list_vocabs() {
for descr in listing.vocabs() {
res.push(descr.id().to_string());
}
}
res
}
#[derive(Default)]
pub struct VocabFactory {
providers: Vec<Arc<dyn VocabProvider>>,
}
impl VocabFactory {
pub fn providers(&self) -> &[Arc<dyn VocabProvider>] {
&self.providers
}
pub fn find_provider(
&self,
id: &str,
) -> Option<&Arc<dyn VocabProvider>> {
self.providers
.iter()
.find(|p| p.name().to_lowercase() == id.to_lowercase())
}
pub fn register_provider(
&mut self,
provider: Arc<dyn VocabProvider>,
) -> WCResult<()> {
let id = provider.name().to_lowercase();
for existing in &self.providers {
if id == existing.name().to_lowercase() {
return Err(WCError::DuplicatedResource(format!(
"Vocabulary provider with id '{id}' already exists",
)));
}
}
self.providers.push(provider);
Ok(())
}
pub fn remove_provider(
&mut self,
id: &str,
) -> Option<Arc<dyn VocabProvider>> {
self.providers
.iter()
.position(|p| p.name() == id)
.map(|i| self.providers.remove(i))
}
pub fn list_vocabs(&self) -> Vec<VocabListing> {
let mut res = Vec::new();
for provider in &self.providers {
let listing = VocabListing::new(
&provider.name(),
&provider.description(),
provider.list_vocabs(),
);
res.push(listing);
}
res
}
pub fn resolve_vocab<Q>(
&self,
query: Q,
) -> WCResult<VocabDescription>
where
Q: Into<VocabQuery>,
{
let query = query.into();
for provider in &self.providers {
match provider.resolve_vocab(&query) {
Ok(vocab) => return Ok(vocab),
Err(WCError::ResourceNotFound(_)) => (),
Err(err) => return Err(err),
}
}
Err(WCError::ResourceNotFound(query.to_string()))
}
pub fn load_vocab<Q>(
&self,
query: Q,
loader: &mut dyn ResourceLoader,
) -> WCResult<LabeledVocab<u32>>
where
Q: Into<VocabQuery>,
{
let query = query.into();
for provider in &self.providers {
match provider.load_vocab(&query, loader) {
Ok(vocab) => return Ok(vocab),
Err(WCError::ResourceNotFound(_)) => (),
Err(err) => return Err(err),
}
}
Err(WCError::ResourceNotFound(query.to_string()))
}
}