use std::sync::OnceLock;
use torsh_core::Result as TorshResult;
use torsh_tensor::Tensor;
#[derive(Clone)]
pub struct ActivationLookupTable {
pub min_val: f32,
pub max_val: f32,
pub table_size: usize,
pub step: f32,
pub values: Vec<f32>,
}
impl ActivationLookupTable {
pub fn new<F>(min_val: f32, max_val: f32, table_size: usize, f: F) -> Self
where
F: Fn(f32) -> f32,
{
let step = (max_val - min_val) / (table_size - 1) as f32;
let values: Vec<f32> = (0..table_size)
.map(|i| {
let x = min_val + i as f32 * step;
f(x)
})
.collect();
Self {
min_val,
max_val,
table_size,
step,
values,
}
}
#[allow(dead_code)]
pub fn lookup(&self, x: f32) -> f32 {
if x <= self.min_val {
return self.values[0];
}
if x >= self.max_val {
return self.values[self.table_size - 1];
}
let index_f = (x - self.min_val) / self.step;
let index = index_f as usize;
if index >= self.table_size - 1 {
return self.values[self.table_size - 1];
}
let frac = index_f - index as f32;
let v0 = self.values[index];
let v1 = self.values[index + 1];
v0 + frac * (v1 - v0)
}
#[allow(dead_code)]
pub fn lookup_fast(&self, x: f32) -> f32 {
if x <= self.min_val {
return self.values[0];
}
if x >= self.max_val {
return self.values[self.table_size - 1];
}
let index = ((x - self.min_val) / self.step + 0.5) as usize;
let index = index.min(self.table_size - 1);
self.values[index]
}
}
static SIGMOID_TABLE: OnceLock<ActivationLookupTable> = OnceLock::new();
static TANH_TABLE: OnceLock<ActivationLookupTable> = OnceLock::new();
static SOFTPLUS_TABLE: OnceLock<ActivationLookupTable> = OnceLock::new();
#[allow(dead_code)]
static EXP_TABLE: OnceLock<ActivationLookupTable> = OnceLock::new();
pub struct LookupConfig {
pub min_val: f32,
pub max_val: f32,
pub table_size: usize,
pub use_interpolation: bool,
pub size_threshold: usize,
}
impl Default for LookupConfig {
fn default() -> Self {
Self {
min_val: -10.0,
max_val: 10.0,
table_size: 10000,
use_interpolation: true,
size_threshold: 1000,
}
}
}
fn init_sigmoid_table() -> &'static ActivationLookupTable {
SIGMOID_TABLE.get_or_init(|| {
let config = LookupConfig::default();
ActivationLookupTable::new(config.min_val, config.max_val, config.table_size, |x| {
1.0 / (1.0 + (-x).exp())
})
})
}
fn init_tanh_table() -> &'static ActivationLookupTable {
TANH_TABLE.get_or_init(|| {
let config = LookupConfig::default();
ActivationLookupTable::new(config.min_val, config.max_val, config.table_size, |x| {
x.tanh()
})
})
}
fn init_softplus_table() -> &'static ActivationLookupTable {
SOFTPLUS_TABLE.get_or_init(|| {
let config = LookupConfig::default();
ActivationLookupTable::new(config.min_val, config.max_val, config.table_size, |x| {
(1.0 + x.exp()).ln()
})
})
}
#[allow(dead_code)]
fn init_exp_table() -> &'static ActivationLookupTable {
EXP_TABLE.get_or_init(|| {
let config = LookupConfig::default();
ActivationLookupTable::new(config.min_val, config.max_val, config.table_size, |x| {
x.exp()
})
})
}
#[allow(dead_code)]
pub fn sigmoid_lookup(input: &Tensor, config: Option<LookupConfig>) -> TorshResult<Tensor> {
let config = config.unwrap_or_default();
let data = input.data()?;
if data.len() >= config.size_threshold {
let table = init_sigmoid_table();
let result_data: Vec<f32> = if config.use_interpolation {
data.iter().map(|&x| table.lookup(x)).collect()
} else {
data.iter().map(|&x| table.lookup_fast(x)).collect()
};
Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())
} else {
input.sigmoid()
}
}
#[allow(dead_code)]
pub fn tanh_lookup(input: &Tensor, config: Option<LookupConfig>) -> TorshResult<Tensor> {
let config = config.unwrap_or_default();
let data = input.data()?;
if data.len() >= config.size_threshold {
let table = init_tanh_table();
let result_data: Vec<f32> = if config.use_interpolation {
data.iter().map(|&x| table.lookup(x)).collect()
} else {
data.iter().map(|&x| table.lookup_fast(x)).collect()
};
Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())
} else {
input.tanh()
}
}
#[allow(dead_code)]
pub fn softplus_lookup(
input: &Tensor,
beta: f64,
threshold: f64,
config: Option<LookupConfig>,
) -> TorshResult<Tensor> {
let config = config.unwrap_or_default();
let data = input.data()?;
if data.len() >= config.size_threshold && beta == 1.0 {
let table = init_softplus_table();
let result_data: Vec<f32> = data
.iter()
.map(|&x| {
if x > threshold as f32 {
x } else if config.use_interpolation {
table.lookup(x)
} else {
table.lookup_fast(x)
}
})
.collect();
Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())
} else {
crate::activations::softplus(input, beta, threshold)
}
}
#[allow(dead_code)]
pub fn gelu_lookup(input: &Tensor, config: Option<LookupConfig>) -> TorshResult<Tensor> {
let config = config.unwrap_or_default();
let data = input.data()?;
if data.len() >= config.size_threshold {
let sqrt_2_inv = 1.0 / 2.0_f32.sqrt();
let result_data: Vec<f32> = data
.iter()
.map(|&x| {
let x_scaled = x * sqrt_2_inv;
let erf_approx = if x_scaled.abs() < 2.0 {
let x3 = x_scaled * x_scaled * x_scaled;
(1.2 * x_scaled + 0.44 * x3).tanh()
} else {
if x_scaled > 0.0 {
1.0
} else {
-1.0
}
};
0.5 * x * (1.0 + erf_approx)
})
.collect();
Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())
} else {
input.gelu()
}
}
#[allow(dead_code)]
pub fn swish_lookup(input: &Tensor, config: Option<LookupConfig>) -> TorshResult<Tensor> {
let config = config.unwrap_or_default();
let data = input.data()?;
if data.len() >= config.size_threshold {
let table = init_sigmoid_table();
let result_data: Vec<f32> = data
.iter()
.map(|&x| {
let sigmoid_val = if config.use_interpolation {
table.lookup(x)
} else {
table.lookup_fast(x)
};
x * sigmoid_val
})
.collect();
Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())
} else {
let sigmoid_result = input.sigmoid()?;
input.mul_op(&sigmoid_result)
}
}
#[allow(dead_code)]
pub fn adaptive_sigmoid(input: &Tensor) -> TorshResult<Tensor> {
let data_size = input.numel();
if data_size > 10000 {
sigmoid_lookup(input, None)
} else {
input.sigmoid()
}
}
#[allow(dead_code)]
pub fn parallel_activation<F>(
input: &Tensor,
activation_fn: F,
chunk_size: Option<usize>,
) -> TorshResult<Tensor>
where
F: Fn(f32) -> f32 + Send + Sync,
{
use scirs2_core::parallel_ops::*;
let data = input.data()?;
let chunk_size = chunk_size.unwrap_or(1000);
if data.len() >= chunk_size * 4 {
let result_data: Vec<f32> = data
.par_chunks(chunk_size)
.flat_map(|chunk| chunk.iter().map(|&x| activation_fn(x)).collect::<Vec<_>>())
.collect();
Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())
} else {
let result_data: Vec<f32> = data.iter().map(|&x| activation_fn(x)).collect();
Tensor::from_data(result_data, input.shape().dims().to_vec(), input.device())
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::rand;
#[test]
fn test_sigmoid_lookup_table() {
let table = ActivationLookupTable::new(-10.0, 10.0, 1000, |x| 1.0 / (1.0 + (-x).exp()));
let x = 0.0;
let expected = 0.5;
let actual = table.lookup(x);
assert!((actual - expected).abs() < 0.01);
let x = 2.0;
let expected = 1.0 / (1.0 + (-2.0_f32).exp());
let actual = table.lookup(x);
assert!((actual - expected).abs() < 0.01);
}
#[test]
fn test_parallel_activation() {
let input = rand(&[1000]).unwrap();
let result = parallel_activation(&input, |x| x.max(0.0), Some(100)).unwrap();
assert_eq!(input.shape(), result.shape());
}
}