use crate::{
UnifiedTokenVocab,
WCError,
WCResult,
alloc::sync::Arc,
prelude::*,
pretrained::{
LabeledVocab,
VocabDescription,
VocabQuery,
factory::VocabProviderInventoryHook,
},
support::resources::ResourceLoader,
};
pub trait VocabProvider: Sync + Send {
fn name(&self) -> String;
fn description(&self) -> String;
fn list_vocabs(&self) -> Vec<VocabDescription>;
fn resolve_vocab(
&self,
query: &VocabQuery,
) -> WCResult<VocabDescription> {
for desc in self.list_vocabs() {
if desc.id().fuzzy_match(query) {
return Ok(desc);
}
}
Err(WCError::ResourceNotFound(query.to_string()))
}
fn load_vocab(
&self,
query: &VocabQuery,
loader: &mut dyn ResourceLoader,
) -> WCResult<LabeledVocab<u32>>;
}
pub struct BuiltinPretrainedVocabHook {
id: &'static str,
descr_fn: fn(&str) -> VocabDescription,
#[allow(clippy::type_complexity)]
vocab_fn: fn(&VocabDescription, &mut dyn ResourceLoader) -> WCResult<UnifiedTokenVocab<u32>>,
}
inventory::collect!(BuiltinPretrainedVocabHook);
impl BuiltinPretrainedVocabHook {
#[allow(clippy::type_complexity)]
pub const fn new(
id: &'static str,
descr_fn: fn(&str) -> VocabDescription,
vocab_fn: fn(
&VocabDescription,
&mut dyn ResourceLoader,
) -> WCResult<UnifiedTokenVocab<u32>>,
) -> Self {
Self {
id,
descr_fn,
vocab_fn,
}
}
pub fn id(&self) -> &str {
self.id
}
pub fn description(&self) -> VocabDescription {
(self.descr_fn)(self.id)
}
pub fn vocab_fn(
&self
) -> &fn(&VocabDescription, &mut dyn ResourceLoader) -> WCResult<UnifiedTokenVocab<u32>> {
&self.vocab_fn
}
}
pub struct BuiltinVocabProvider {}
inventory::submit! {
VocabProviderInventoryHook::new(|| Arc::new(BuiltinVocabProvider{}))
}
impl VocabProvider for BuiltinVocabProvider {
fn name(&self) -> String {
"builtin".to_string()
}
fn description(&self) -> String {
"Link-registered vocabularies".to_string()
}
fn list_vocabs(&self) -> Vec<VocabDescription> {
let mut res = Vec::new();
for hook in inventory::iter::<BuiltinPretrainedVocabHook> {
res.push(hook.description());
}
res
}
fn load_vocab(
&self,
query: &VocabQuery,
loader: &mut dyn ResourceLoader,
) -> WCResult<LabeledVocab<u32>> {
for hook in inventory::iter::<BuiltinPretrainedVocabHook> {
let description = hook.description();
if description.id().fuzzy_match(query) {
let vocab = (hook.vocab_fn)(&description, loader)?;
return Ok(LabeledVocab::new(description, vocab.into()));
}
}
Err(WCError::ResourceNotFound(query.to_string()))
}
}