use std::borrow::Cow;
use crate::internal::*;
#[cfg(feature = "blas")]
use crate::ops::einsum::as_blas::AsBlas;
use crate::ops::matmul::de_block_quant::BlockQuantTransform;
use std::fmt::Debug;
use tract_data::TractResult;
use crate::floats::FloatPrecisionTranslator;
use crate::ops::nn::{Softmax, SoftmaxExp, SoftmaxKind, TypedModel};
#[macro_export]
macro_rules! rule_if {
($cond:expr) => {
if !$cond {
return Ok(None);
}
};
}
#[macro_export]
macro_rules! rule_if_let {
($pat:pat = $expr:expr) => {
let $pat = $expr else {
return Ok(None);
};
};
}
#[macro_export]
macro_rules! rule_if_some {
($pat:pat = $expr:expr) => {
let Some($pat) = $expr else {
return Ok(None);
};
};
}
#[derive(Debug, Clone, Default)]
pub struct NodeFilter {
pub include: Option<Vec<String>>,
pub exclude: Option<Vec<String>>,
}
impl NodeFilter {
pub fn matches(&self, name: &str) -> bool {
let dominated = match &self.include {
Some(patterns) => patterns.iter().any(|p| name.contains(p)),
None => true,
};
if !dominated {
return false;
}
match &self.exclude {
Some(patterns) => !patterns.iter().any(|p| name.contains(p)),
None => true,
}
}
pub fn is_pass_through(&self) -> bool {
self.include.is_none() && self.exclude.is_none()
}
}
pub fn parse_legacy_filter(filter: Option<&str>) -> TractResult<NodeFilter> {
let Some(filter) = filter.filter(|f| !f.is_empty()) else {
return Ok(NodeFilter::default());
};
if let Some(patterns) = filter.strip_prefix("!=") {
let patterns = patterns.split(',').map(|it| it.trim().to_string()).collect();
Ok(NodeFilter { exclude: Some(patterns), ..Default::default() })
} else if let Some(patterns) = filter.strip_prefix("==") {
let patterns = patterns.split(',').map(|it| it.trim().to_string()).collect();
Ok(NodeFilter { include: Some(patterns), ..Default::default() })
} else {
Ok(NodeFilter::default())
}
}
pub fn build_float_translator(
from_dt: DatumType,
to_dt: DatumType,
filter: NodeFilter,
) -> Box<dyn ModelTransform> {
if filter.is_pass_through() {
return Box::new(FloatPrecisionTranslator::new(from_dt, to_dt));
}
Box::new(FloatPrecisionTranslator::with_filter(from_dt, to_dt, move |node| {
filter.matches(&node.name)
}))
}
pub trait ModelTransform: Debug {
fn name(&self) -> StaticName;
fn transform(&self, model: &mut TypedModel) -> TractResult<()>;
fn transform_into(&self, mut model: TypedModel) -> TractResult<TypedModel> {
self.transform(&mut model)?;
Ok(model)
}
}
#[derive(Debug)]
struct SoftmaxFastCompact;
impl ModelTransform for SoftmaxFastCompact {
fn name(&self) -> StaticName {
"softmax_fast_compact".into()
}
fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
for node in &mut model.nodes {
if let Some(softmax) = node.op_as_mut::<Softmax>()
&& let SoftmaxKind::Softmax(kind) = &mut softmax.kind
{
*kind = SoftmaxExp::FastCompact
}
}
Ok(())
}
}
#[derive(Debug, Default, serde::Deserialize)]
pub struct FloatTranslatorConfig {
#[serde(default)]
pub filter: Option<String>,
#[serde(default)]
pub include: Option<Vec<String>>,
#[serde(default)]
pub exclude: Option<Vec<String>>,
}
impl FloatTranslatorConfig {
pub fn into_node_filter(self) -> TractResult<NodeFilter> {
if self.include.is_some() || self.exclude.is_some() {
Ok(NodeFilter { include: self.include, exclude: self.exclude })
} else {
parse_legacy_filter(self.filter.as_deref())
}
}
}
#[derive(Debug, serde::Deserialize)]
pub struct FloatPrecisionConfig {
pub from: String,
pub to: String,
#[serde(default)]
pub include: Option<Vec<String>>,
#[serde(default)]
pub exclude: Option<Vec<String>>,
}
pub struct ModelTransformFactory {
pub name: &'static str,
pub build_default: fn() -> TractResult<Box<dyn ModelTransform>>,
pub build: fn(&mut dyn erased_serde::Deserializer) -> TractResult<Box<dyn ModelTransform>>,
}
inventory::collect!(ModelTransformFactory);
#[macro_export]
macro_rules! register_simple_model_transform {
($name: expr, $type: expr) => {
$crate::internal::inventory::submit! {
$crate::transform::ModelTransformFactory {
name: $name,
build_default: || Ok(Box::new($type)),
build: |_de| Ok(Box::new($type)),
}
}
};
}
#[macro_export]
macro_rules! register_model_transform {
($name:expr, $config:ty, $builder:expr) => {
$crate::internal::inventory::submit! {
$crate::transform::ModelTransformFactory {
name: $name,
build_default: || {
let config = <$config>::default();
let builder: fn($config) -> $crate::prelude::TractResult<Box<dyn $crate::transform::ModelTransform>> = $builder;
builder(config)
},
build: |de: &mut dyn erased_serde::Deserializer| {
let config: $config = erased_serde::deserialize(de)
.map_err(|e| $crate::internal::anyhow!("deserializing transform config: {e}"))?;
let builder: fn($config) -> $crate::prelude::TractResult<Box<dyn $crate::transform::ModelTransform>> = $builder;
builder(config)
},
}
}
};
}
pub fn split_spec(spec: &str) -> (Cow<'_, str>, &str) {
if let Some(pos) = spec.find('(') {
(Cow::Borrowed(&spec[..pos]), &spec[pos..])
} else if spec.contains('-') {
(Cow::Owned(spec.replace('-', "_")), "")
} else {
(Cow::Borrowed(spec), "")
}
}
pub fn get_transform(name: &str) -> TractResult<Option<Box<dyn ModelTransform>>> {
let (name, _) = split_spec(name);
for factory in inventory::iter::<ModelTransformFactory>() {
if factory.name == &*name {
return Ok(Some((factory.build_default)()?));
}
}
Ok(None)
}
pub fn get_transform_with_params(
name: &str,
de: &mut dyn erased_serde::Deserializer,
) -> TractResult<Option<Box<dyn ModelTransform>>> {
for factory in inventory::iter::<ModelTransformFactory>() {
if factory.name == name {
return Ok(Some((factory.build)(de)?));
}
}
Ok(None)
}
#[derive(Debug, Default, serde::Deserialize)]
pub struct ConcretizeSymbolsConfig {
pub values: std::collections::HashMap<String, i64>,
}
#[derive(Debug)]
struct ConcretizeSymbolsTransform(ConcretizeSymbolsConfig);
impl ModelTransform for ConcretizeSymbolsTransform {
fn name(&self) -> StaticName {
"concretize_symbols".into()
}
fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
let mut table = SymbolValues::default();
for (k, v) in &self.0.values {
table = table.with(&model.symbols.sym(k), *v);
}
*model = model.concretize_dims(&table)?;
Ok(())
}
}
register_model_transform!("concretize_symbols", ConcretizeSymbolsConfig, |config| Ok(Box::new(
ConcretizeSymbolsTransform(config)
)));
register_simple_model_transform!("softmax_fast_compact", SoftmaxFastCompact);
#[cfg(feature = "blas")]
register_simple_model_transform!("as_blas", AsBlas);
register_simple_model_transform!("block_quant", BlockQuantTransform);
#[derive(Debug, serde::Deserialize, Default)]
pub struct SelectOutputsConfig {
pub outputs: Vec<String>,
}
#[derive(Debug)]
struct SelectOutputsTransform(SelectOutputsConfig);
impl ModelTransform for SelectOutputsTransform {
fn name(&self) -> StaticName {
"select_outputs".into()
}
fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
model.select_outputs_by_name(self.0.outputs.iter())
}
}
register_model_transform!("select_outputs", SelectOutputsConfig, |config| Ok(Box::new(
SelectOutputsTransform(config)
)));
inventory::submit! {
ModelTransformFactory {
name: "f32_to_f16",
build_default: || Ok(build_float_translator(DatumType::F32, DatumType::F16, NodeFilter::default())),
build: |de| {
let config: FloatTranslatorConfig = erased_serde::deserialize(de)
.map_err(|e| anyhow::anyhow!("deserializing f32_to_f16 config: {e}"))?;
Ok(build_float_translator(DatumType::F32, DatumType::F16, config.into_node_filter()?))
},
}
}
inventory::submit! {
ModelTransformFactory {
name: "f16_to_f32",
build_default: || Ok(build_float_translator(DatumType::F16, DatumType::F32, NodeFilter::default())),
build: |de| {
let config: FloatTranslatorConfig = erased_serde::deserialize(de)
.map_err(|e| anyhow::anyhow!("deserializing f16_to_f32 config: {e}"))?;
Ok(build_float_translator(DatumType::F16, DatumType::F32, config.into_node_filter()?))
},
}
}
inventory::submit! {
ModelTransformFactory {
name: "float_precision",
build_default: || {
anyhow::bail!("float_precision transform requires 'from' and 'to' parameters")
},
build: |de| {
let config: FloatPrecisionConfig = erased_serde::deserialize(de)
.map_err(|e| anyhow::anyhow!("deserializing float_precision config: {e}"))?;
let from_dt: DatumType = config.from.parse()
.map_err(|e| anyhow::anyhow!("parsing 'from' datum type: {e}"))?;
let to_dt: DatumType = config.to.parse()
.map_err(|e| anyhow::anyhow!("parsing 'to' datum type: {e}"))?;
let filter = NodeFilter { include: config.include, exclude: config.exclude };
Ok(build_float_translator(from_dt, to_dt, filter))
},
}
}