use alloc::{
collections::BTreeMap,
format,
string::{
String,
ToString,
},
sync::Arc,
};
use core::fmt::Debug;
use anyhow::bail;
use burn::config::Config;
use crate::cache::weights::{
PretrainedWeightsDescriptor,
PretrainedWeightsMap,
StaticPretrainedWeightsMap,
};
pub struct StaticPreFabConfig<C>
where
C: 'static + Config + Debug + Clone,
{
pub name: &'static str,
pub description: &'static str,
pub builder: fn() -> C,
pub weights: Option<&'static StaticPretrainedWeightsMap<'static>>,
}
impl<C> StaticPreFabConfig<C>
where
C: 'static + Config + Debug + Clone,
{
pub fn to_prefab(&self) -> PreFabConfig<C> {
let builder = self.builder;
PreFabConfig {
name: self.name.to_string(),
description: self.description.to_string(),
builder: Arc::new(builder),
weights: self.weights.map(|w| w.to_directory()),
}
}
pub fn to_config(&self) -> C {
(self.builder)()
}
}
impl<C> From<&StaticPreFabConfig<C>> for PreFabConfig<C>
where
C: 'static + Config + Debug + Clone,
{
fn from(config: &StaticPreFabConfig<C>) -> Self {
config.to_prefab()
}
}
impl<C> Debug for StaticPreFabConfig<C>
where
C: 'static + Config + Debug + Clone,
{
fn fmt(
&self,
f: &mut core::fmt::Formatter<'_>,
) -> core::fmt::Result {
self.to_prefab().fmt(f)
}
}
#[derive(Clone)]
pub struct PreFabConfig<C>
where
C: 'static + Config + Debug + Clone,
{
pub name: String,
pub description: String,
pub builder: Arc<dyn Fn() -> C + Send + Sync>,
pub weights: Option<PretrainedWeightsMap>,
}
impl<C> Debug for PreFabConfig<C>
where
C: 'static + Config + Debug + Clone,
{
fn fmt(
&self,
f: &mut core::fmt::Formatter<'_>,
) -> core::fmt::Result {
let pretty = f.alternate();
let type_name = core::any::type_name::<C>();
let mut handle = f.debug_struct(&format!("PreFabConfig<{}>", type_name));
handle
.field("name", &self.name)
.field("description", &self.description);
if pretty {
handle.field("config", &self.to_config());
}
handle.finish()
}
}
impl<C> PreFabConfig<C>
where
C: 'static + Config + Debug + Clone,
{
pub fn to_config(&self) -> C {
(self.builder)()
}
pub fn lookup_pretrained_weights(
&self,
name: &str,
) -> Option<PretrainedWeightsDescriptor> {
match &self.weights {
None => None,
Some(m) => m.lookup_by_name(name),
}
}
pub fn try_lookup_pretrained_weights(
&self,
name: &str,
) -> anyhow::Result<PretrainedWeightsDescriptor> {
match self.lookup_pretrained_weights(name) {
Some(d) => Ok(d),
None => bail!("Descriptor not found: {}", name),
}
}
pub fn expect_lookup_pretrained_weights(
&self,
name: &str,
) -> PretrainedWeightsDescriptor {
match self.try_lookup_pretrained_weights(name) {
Ok(p) => p,
Err(e) => panic!("{}", e),
}
}
}
#[derive(Debug)]
pub struct StaticPreFabMap<C>
where
C: 'static + Config + Debug + Clone,
{
pub name: &'static str,
pub description: &'static str,
pub items: &'static [&'static StaticPreFabConfig<C>],
}
impl<C> StaticPreFabMap<C>
where
C: 'static + Config + Debug + Clone,
{
pub fn to_prefab_map(&self) -> PreFabMap<C> {
PreFabMap {
name: self.name.to_string(),
description: self.description.to_string(),
items: self
.items
.iter()
.map(|c| (c.name.to_string(), c.to_prefab()))
.collect(),
}
}
pub fn lookup_prefab(
&self,
name: &str,
) -> Option<PreFabConfig<C>> {
self.items
.iter()
.find(|c| c.name == name)
.map(|c| c.to_prefab())
}
pub fn try_lookup_prefab(
&self,
name: &str,
) -> anyhow::Result<PreFabConfig<C>> {
match self.lookup_prefab(name) {
Some(d) => Ok(d),
None => bail!("PreFab not found: {}", name),
}
}
pub fn expect_lookup_prefab(
&self,
name: &str,
) -> PreFabConfig<C> {
match self.try_lookup_prefab(name) {
Ok(p) => p,
Err(e) => panic!("{}", e),
}
}
}
#[derive(Debug, Clone)]
pub struct PreFabMap<C>
where
C: 'static + Config + Debug + Clone,
{
pub name: String,
pub description: String,
pub items: BTreeMap<String, PreFabConfig<C>>,
}
impl<C> PreFabMap<C>
where
C: 'static + Config + Debug + Clone,
{
pub fn lookup_prefab(
&self,
name: &str,
) -> Option<PreFabConfig<C>> {
self.items.get(name).cloned()
}
pub fn try_lookup_prefab(
&self,
name: &str,
) -> anyhow::Result<PreFabConfig<C>> {
match self.lookup_prefab(name) {
Some(d) => Ok(d),
None => bail!("PreFab not found: {}", name),
}
}
pub fn expect_lookup_prefab(
&self,
name: &str,
) -> PreFabConfig<C> {
match self.try_lookup_prefab(name) {
Ok(p) => p,
Err(e) => panic!("{}", e),
}
}
}