#![cfg(feature = "tensorrt-int8")]
use std::sync::Arc;
use crate::error::TrtError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CalibrationAlgo {
EntropyV2,
Entropy,
MinMax,
Legacy,
}
pub trait Calibrator: Send + Sync {
fn algorithm(&self) -> CalibrationAlgo;
fn next_batch(&mut self) -> Option<Vec<CalibrationBinding>>;
fn read_cache(&self) -> Option<Vec<u8>> {
None
}
fn write_cache(&mut self, _blob: &[u8]) {}
}
#[derive(Debug, Clone)]
pub struct CalibrationBinding {
pub name: String,
pub device_ptr: u64,
pub bytes: usize,
}
pub struct MinMaxCalibrator {
batches: Vec<Vec<CalibrationBinding>>,
cursor: usize,
cache: Option<Vec<u8>>,
}
impl MinMaxCalibrator {
pub fn new(batches: Vec<Vec<CalibrationBinding>>) -> Self {
Self {
batches,
cursor: 0,
cache: None,
}
}
pub fn with_cache(mut self, blob: Vec<u8>) -> Self {
self.cache = Some(blob);
self
}
pub fn into_arc(self) -> Arc<parking_lot::Mutex<dyn Calibrator>> {
Arc::new(parking_lot::Mutex::new(self))
}
}
impl Calibrator for MinMaxCalibrator {
fn algorithm(&self) -> CalibrationAlgo {
CalibrationAlgo::MinMax
}
fn next_batch(&mut self) -> Option<Vec<CalibrationBinding>> {
if self.cursor >= self.batches.len() {
None
} else {
let b = self.batches[self.cursor].clone();
self.cursor += 1;
Some(b)
}
}
fn read_cache(&self) -> Option<Vec<u8>> {
self.cache.clone()
}
fn write_cache(&mut self, blob: &[u8]) {
self.cache = Some(blob.to_vec());
}
}
pub struct EntropyCalibrator {
batches: Vec<Vec<CalibrationBinding>>,
cursor: usize,
cache: Option<Vec<u8>>,
legacy: bool,
}
impl EntropyCalibrator {
pub fn new(batches: Vec<Vec<CalibrationBinding>>) -> Self {
Self {
batches,
cursor: 0,
cache: None,
legacy: false,
}
}
pub fn legacy(mut self) -> Self {
self.legacy = true;
self
}
}
impl Calibrator for EntropyCalibrator {
fn algorithm(&self) -> CalibrationAlgo {
if self.legacy {
CalibrationAlgo::Entropy
} else {
CalibrationAlgo::EntropyV2
}
}
fn next_batch(&mut self) -> Option<Vec<CalibrationBinding>> {
if self.cursor >= self.batches.len() {
None
} else {
let b = self.batches[self.cursor].clone();
self.cursor += 1;
Some(b)
}
}
fn read_cache(&self) -> Option<Vec<u8>> {
self.cache.clone()
}
fn write_cache(&mut self, blob: &[u8]) {
self.cache = Some(blob.to_vec());
}
}
#[cfg(feature = "tensorrt-fp8")]
pub struct Fp8Calibrator {
batches: Vec<Vec<CalibrationBinding>>,
cursor: usize,
}
#[cfg(feature = "tensorrt-fp8")]
impl Fp8Calibrator {
pub fn new(batches: Vec<Vec<CalibrationBinding>>) -> Self {
Self { batches, cursor: 0 }
}
}
#[cfg(feature = "tensorrt-fp8")]
impl Calibrator for Fp8Calibrator {
fn algorithm(&self) -> CalibrationAlgo {
CalibrationAlgo::EntropyV2
}
fn next_batch(&mut self) -> Option<Vec<CalibrationBinding>> {
if self.cursor >= self.batches.len() {
None
} else {
let b = self.batches[self.cursor].clone();
self.cursor += 1;
Some(b)
}
}
}
pub fn drain<C: Calibrator>(c: &mut C) -> Result<usize, TrtError> {
let mut count = 0usize;
while let Some(batch) = c.next_batch() {
count += batch.len();
}
Ok(count)
}
#[cfg(test)]
mod tests {
use super::*;
fn fixture(n: usize) -> Vec<Vec<CalibrationBinding>> {
(0..n)
.map(|i| {
vec![CalibrationBinding {
name: "input".into(),
device_ptr: 0xCAFE_0000 + i as u64,
bytes: 1024 * (i + 1),
}]
})
.collect()
}
#[test]
fn int8_minmax_calibrator_constructs() {
let mut c = MinMaxCalibrator::new(fixture(3));
assert_eq!(c.algorithm(), CalibrationAlgo::MinMax);
assert_eq!(drain(&mut c).unwrap(), 3);
assert!(c.next_batch().is_none());
c.write_cache(&[0xAB, 0xCD]);
assert_eq!(c.read_cache().as_deref(), Some(&[0xAB, 0xCD][..]));
}
#[test]
fn entropy_v2_default_and_legacy() {
let c = EntropyCalibrator::new(fixture(1));
assert_eq!(c.algorithm(), CalibrationAlgo::EntropyV2);
let c = EntropyCalibrator::new(fixture(1)).legacy();
assert_eq!(c.algorithm(), CalibrationAlgo::Entropy);
}
#[cfg(feature = "tensorrt-fp8")]
#[test]
fn fp8_calibrator_uses_entropy_v2() {
let c = Fp8Calibrator::new(fixture(2));
assert_eq!(c.algorithm(), CalibrationAlgo::EntropyV2);
}
}