use std::rc::Rc;
use crate::autograd::Variable;
use crate::nn::{Linear, Module};
use crate::tensor::{Device, Result, Tensor};
pub struct ThresholdHalt {
threshold: f32,
}
impl ThresholdHalt {
pub fn new(threshold: f32) -> Self {
ThresholdHalt { threshold }
}
}
impl Module for ThresholdHalt {
fn name(&self) -> &str { "threshold_halt" }
fn forward(&self, input: &Variable) -> Result<Variable> {
let data = input.data().to_f32_vec()?;
let max_val = data
.iter()
.copied()
.fold(f32::NEG_INFINITY, f32::max);
let val = max_val - self.threshold; Ok(Variable::new(
Tensor::from_f32(&[val], &[1], input.device())?,
false,
))
}
}
pub struct LearnedHalt {
proj: Rc<Linear>,
}
impl LearnedHalt {
pub fn new(input_dim: i64) -> Result<Self> {
Self::on_device(input_dim, Device::CPU)
}
pub fn on_device(input_dim: i64, device: Device) -> Result<Self> {
Ok(LearnedHalt {
proj: Rc::new(Linear::on_device(input_dim, 1, device)?),
})
}
}
impl Module for LearnedHalt {
fn name(&self) -> &str { "learned_halt" }
fn forward(&self, input: &Variable) -> Result<Variable> {
self.proj.forward(input)
}
fn sub_modules(&self) -> Vec<Rc<dyn Module>> {
vec![self.proj.clone()]
}
}