use std::collections::HashMap;
use std::sync::Arc;
use crate::dual::Dual;
use crate::float::Float;
use crate::opcode::{self, OpCode, UNUSED};
mod forward;
mod jacobian;
mod optimize;
mod reverse;
mod sparse;
mod tangent;
#[cfg(feature = "parallel")]
mod parallel;
#[cfg(feature = "serde")]
mod serde_support;
#[cfg(feature = "taylor")]
mod taylor;
mod thread_local;
pub use self::thread_local::{with_active_btape, BtapeGuard, BtapeThreadLocal};
pub const CONSTANT: u32 = u32::MAX;
pub trait CustomOp<F: Float>: Send + Sync {
fn eval(&self, a: F, b: F) -> F;
fn partials(&self, a: F, b: F, result: F) -> (F, F);
fn eval_dual(&self, a: Dual<F>, b: Dual<F>) -> Dual<F> {
let result = self.eval(a.re, b.re);
let (da, db) = self.partials(a.re, b.re, result);
Dual::new(result, da * a.eps + db * b.eps)
}
fn partials_dual(&self, a: Dual<F>, b: Dual<F>, result: Dual<F>) -> (Dual<F>, Dual<F>) {
let (da, db) = self.partials(a.re, b.re, result.re);
(Dual::constant(da), Dual::constant(db))
}
}
#[derive(Clone, Copy, Debug)]
pub struct CustomOpHandle(pub(crate) u16);
pub struct BytecodeTape<F: Float> {
pub(crate) opcodes: Vec<OpCode>,
pub(crate) arg_indices: Vec<[u32; 2]>,
pub(crate) values: Vec<F>,
pub(crate) num_inputs: u32,
pub(crate) num_variables: u32,
pub(crate) output_index: u32,
pub(crate) output_indices: Vec<u32>,
pub(crate) custom_ops: Vec<Arc<dyn CustomOp<F>>>,
pub(crate) custom_second_args: HashMap<u32, u32>,
}
impl<F: Float> BytecodeTape<F> {
#[must_use]
pub fn new() -> Self {
BytecodeTape {
opcodes: Vec::new(),
arg_indices: Vec::new(),
values: Vec::new(),
num_inputs: 0,
num_variables: 0,
output_index: 0,
output_indices: Vec::new(),
custom_ops: Vec::new(),
custom_second_args: HashMap::new(),
}
}
#[must_use]
pub fn with_capacity(est_ops: usize) -> Self {
BytecodeTape {
opcodes: Vec::with_capacity(est_ops),
arg_indices: Vec::with_capacity(est_ops),
values: Vec::with_capacity(est_ops),
num_inputs: 0,
num_variables: 0,
output_index: 0,
output_indices: Vec::new(),
custom_ops: Vec::new(),
custom_second_args: HashMap::new(),
}
}
#[inline]
pub fn new_input(&mut self, value: F) -> u32 {
debug_assert!(
self.num_variables < u32::MAX,
"tape variable count overflow"
);
let idx = self.num_variables;
self.num_variables += 1;
self.num_inputs += 1;
self.opcodes.push(OpCode::Input);
self.arg_indices.push([UNUSED, UNUSED]);
self.values.push(value);
idx
}
#[inline]
pub fn push_const(&mut self, value: F) -> u32 {
debug_assert!(
self.num_variables < u32::MAX,
"tape variable count overflow"
);
let idx = self.num_variables;
self.num_variables += 1;
self.opcodes.push(OpCode::Const);
self.arg_indices.push([UNUSED, UNUSED]);
self.values.push(value);
idx
}
#[inline]
pub fn push_op(&mut self, op: OpCode, arg0: u32, arg1: u32, value: F) -> u32 {
let arg0_const = self.opcodes[arg0 as usize] == OpCode::Const;
let arg1_const = arg1 == UNUSED || self.opcodes[arg1 as usize] == OpCode::Const;
if arg0_const && arg1_const {
return self.push_const(value);
}
if (arg0_const || arg1_const) && arg1 != UNUSED {
if let Some(idx) =
self.try_algebraic_simplify(op, arg0, arg1, arg0_const, arg1_const, value)
{
return idx;
}
}
if arg0 == arg1 && arg1 != UNUSED {
if let Some(idx) = self.try_same_index_simplify(op, value) {
return idx;
}
}
debug_assert!(
self.num_variables < u32::MAX,
"tape variable count overflow"
);
let idx = self.num_variables;
self.num_variables += 1;
self.opcodes.push(op);
self.arg_indices.push([arg0, arg1]);
self.values.push(value);
idx
}
#[inline(never)]
fn try_algebraic_simplify(
&mut self,
op: OpCode,
arg0: u32,
arg1: u32,
arg0_const: bool,
arg1_const: bool,
value: F,
) -> Option<u32> {
let zero = F::zero();
let one = F::one();
match op {
OpCode::Add => {
if arg1_const && self.values[arg1 as usize] == zero {
return Some(arg0);
}
if arg0_const && self.values[arg0 as usize] == zero {
return Some(arg1);
}
}
OpCode::Sub if arg1_const && self.values[arg1 as usize] == zero => {
return Some(arg0);
}
OpCode::Mul => {
if arg1_const && self.values[arg1 as usize] == one {
return Some(arg0);
}
if arg0_const && self.values[arg0 as usize] == one {
return Some(arg1);
}
if arg1_const && self.values[arg1 as usize] == zero && value == zero {
return Some(self.push_const(value));
}
if arg0_const && self.values[arg0 as usize] == zero && value == zero {
return Some(self.push_const(value));
}
}
OpCode::Div if arg1_const && self.values[arg1 as usize] == one => {
return Some(arg0);
}
_ => {}
}
None
}
#[inline(never)]
fn try_same_index_simplify(&mut self, op: OpCode, value: F) -> Option<u32> {
match op {
OpCode::Sub if value == F::zero() => Some(self.push_const(value)),
OpCode::Div if value == F::one() => Some(self.push_const(value)),
_ => None,
}
}
#[inline]
pub fn push_powi(&mut self, arg0: u32, exp: i32, value: F) -> u32 {
if self.opcodes[arg0 as usize] == OpCode::Const {
return self.push_const(value);
}
if exp == 0 && value == F::one() {
return self.push_const(F::one());
}
if exp == 1 {
return arg0;
}
if exp == -1 {
return self.push_op(OpCode::Recip, arg0, UNUSED, value);
}
let idx = self.num_variables;
self.num_variables += 1;
self.opcodes.push(OpCode::Powi);
self.arg_indices.push([arg0, opcode::powi_exp_encode(exp)]);
self.values.push(value);
idx
}
pub fn register_custom(&mut self, op: Arc<dyn CustomOp<F>>) -> CustomOpHandle {
let idx = self.custom_ops.len();
assert!(idx <= u16::MAX as usize, "too many custom ops");
self.custom_ops.push(op);
CustomOpHandle(idx as u16)
}
#[inline]
pub fn push_custom_unary(&mut self, arg0: u32, handle: CustomOpHandle, value: F) -> u32 {
let idx = self.num_variables;
self.num_variables += 1;
self.opcodes.push(OpCode::Custom);
self.arg_indices.push([arg0, handle.0 as u32]);
self.values.push(value);
idx
}
#[inline]
pub fn push_custom_binary(
&mut self,
arg0: u32,
arg1: u32,
handle: CustomOpHandle,
value: F,
) -> u32 {
let idx = self.num_variables;
self.num_variables += 1;
self.opcodes.push(OpCode::Custom);
self.arg_indices.push([arg0, handle.0 as u32]);
self.custom_second_args.insert(idx, arg1);
self.values.push(value);
idx
}
#[inline]
pub fn set_output(&mut self, index: u32) {
self.output_index = index;
}
#[inline]
#[must_use]
pub fn output_value(&self) -> F {
self.values[self.output_index as usize]
}
#[inline]
#[must_use]
pub fn output_index(&self) -> usize {
self.output_index as usize
}
#[inline]
#[must_use]
pub fn num_inputs(&self) -> usize {
self.num_inputs as usize
}
#[inline]
#[must_use]
pub fn num_ops(&self) -> usize {
self.opcodes.len()
}
pub fn set_outputs(&mut self, indices: &[u32]) {
let n = self.values.len();
for (i, &idx) in indices.iter().enumerate() {
assert!(
(idx as usize) < n,
"set_outputs: indices[{}] = {} is out of range (tape has \
{} values). Indices must point to tape variables created \
via new_input/push_op/push_const.",
i,
idx,
n
);
}
self.output_indices = indices.to_vec();
if let Some(&first) = indices.first() {
self.output_index = first;
}
}
#[must_use]
pub fn num_outputs(&self) -> usize {
if self.output_indices.is_empty() {
1
} else {
self.output_indices.len()
}
}
#[must_use]
pub fn output_values(&self) -> Vec<F> {
if self.output_indices.is_empty() {
vec![self.values[self.output_index as usize]]
} else {
self.output_indices
.iter()
.map(|&idx| self.values[idx as usize])
.collect()
}
}
#[must_use]
pub fn all_output_indices(&self) -> &[u32] {
if self.output_indices.is_empty() {
std::slice::from_ref(&self.output_index)
} else {
&self.output_indices
}
}
#[inline]
#[must_use]
pub fn opcodes_slice(&self) -> &[OpCode] {
&self.opcodes
}
#[inline]
#[must_use]
pub fn arg_indices_slice(&self) -> &[[u32; 2]] {
&self.arg_indices
}
#[inline]
#[must_use]
pub fn values_slice(&self) -> &[F] {
&self.values
}
#[inline]
#[must_use]
pub fn num_variables_count(&self) -> usize {
self.num_variables as usize
}
#[inline]
#[must_use]
pub fn has_custom_ops(&self) -> bool {
!self.custom_ops.is_empty()
}
}
impl<F: Float> Default for BytecodeTape<F> {
fn default() -> Self {
Self::new()
}
}