use crate::broadcast::number_to_index;
use crate::data_types::{array_type, scalar_type, Type, BIT, UINT64};
use crate::data_values::Value;
use crate::errors::Result;
use crate::graphs::{Graph, Node, SliceElement};
use crate::inline::data_structures::{log_depth_sum, CombineOp};
use crate::inline::inline_common::{
pick_prefix_sum_algorithm, DepthOptimizationLevel, InlineState,
};
use crate::ops::utils::constant_scalar;
const MAX_ALLOWED_STATE_BITS: u64 = 4;
pub(super) fn inline_iterate_small_state(
single_bit: bool,
optimization_level: DepthOptimizationLevel,
graph: Graph,
initial_state: Node,
inputs_node: Node,
inliner: &mut dyn InlineState,
) -> Result<(Node, Vec<Node>)> {
let graph_output_type = graph.get_output_node()?.get_type()?;
let output_element_type = match graph_output_type {
Type::Tuple(tuple_types) => (*tuple_types[1]).clone(),
_ => {
panic!("Inconsistency with type checker for Iterate output.");
}
};
let empty_output = match output_element_type {
Type::Tuple(tuple_types) => tuple_types.is_empty(),
_ => false,
};
let inputs_len = match inputs_node.get_type()? {
Type::Vector(len, _) => len,
_ => {
panic!("Inconsistency with type checker");
}
};
if inputs_len == 0 {
return Ok((initial_state, vec![]));
}
let num_bits = get_number_of_bits(initial_state.get_type()?, single_bit)?;
if num_bits > MAX_ALLOWED_STATE_BITS {
return Err(runtime_error!("Too many bits in the state"));
}
if num_bits == 0 {
return Err(runtime_error!(
"This inlining method doesn't support empty state"
));
}
let num_masks = u64::pow(2, num_bits as u32);
let state_type = initial_state.get_type()?;
let mut mask_constants = vec![];
for mask in 0..u64::pow(2, num_bits as u32) {
let value = mask_to_value(state_type.clone(), num_bits, mask)?;
let mask_const = inliner.output_graph().constant(state_type.clone(), value)?;
mask_constants.push(mask_const);
}
let mappings = create_mappings(
initial_state.get_type()?,
mask_constants.clone(),
num_bits,
single_bit,
inputs_node.clone(),
graph.clone(),
inliner,
)?;
let unused_node = inliner.output_graph().zeros(scalar_type(BIT))?;
let initial_state_one_hot = if single_bit {
unused_node.clone()
} else {
let mut initial_state_one_hot = one_hot_encode(
initial_state.clone(),
num_masks,
mask_constants.clone(),
inliner.output_graph(),
state_type.clone(),
single_bit,
)?;
let mut new_shape = initial_state_one_hot.get_type()?.get_shape();
new_shape.insert(0, 1);
initial_state_one_hot =
initial_state_one_hot.reshape(array_type(new_shape.clone(), BIT))?;
let mut permutation: Vec<u64> = (0..new_shape.len()).map(|x| x as u64).collect();
permutation.rotate_left(2);
initial_state_one_hot = initial_state_one_hot.permute_axes(permutation)?; initial_state_one_hot
};
let masks_arr = if single_bit {
unused_node
} else {
let masks_arr = inliner
.output_graph()
.create_vector(mask_constants[0].get_type()?, mask_constants)?
.vector_to_array()?;
let masks_arr_shape = masks_arr.get_type()?.get_shape();
let mut masks_arr_permutation: Vec<u64> =
(0..masks_arr_shape.len()).map(|x| x as u64).collect();
masks_arr_permutation.rotate_left(1);
let rank = masks_arr_permutation.len();
masks_arr_permutation.swap(rank - 2, rank - 1);
masks_arr.permute_axes(masks_arr_permutation)?
};
let mut combiner = MappingCombiner {};
let mut bit_combiner = MappingCombiner1Bit {};
if empty_output {
let mut outputs = vec![];
let empty_tuple = inliner.output_graph().create_tuple(vec![])?;
for _ in 0..inputs_len {
outputs.push(empty_tuple.clone());
}
let final_mapping = if single_bit {
log_depth_sum(&mappings, &mut bit_combiner)?
} else {
log_depth_sum(&mappings, &mut combiner)?
};
let result = extract_state_from_mapping(
single_bit,
initial_state,
initial_state_one_hot,
final_mapping,
masks_arr,
state_type,
)?;
Ok((result, outputs))
} else {
let prefix_sums = if single_bit {
pick_prefix_sum_algorithm(inputs_len, optimization_level)(&mappings, &mut bit_combiner)?
} else {
pick_prefix_sum_algorithm(inputs_len, optimization_level)(&mappings, &mut combiner)?
};
let mut outputs = vec![];
for i in 0..inputs_len {
let state = if i == 0 {
initial_state.clone()
} else {
extract_state_from_mapping(
single_bit,
initial_state.clone(),
initial_state_one_hot.clone(),
prefix_sums[i as usize - 1].clone(),
masks_arr.clone(),
state_type.clone(),
)?
};
let input =
inputs_node.vector_get(constant_scalar(&inliner.output_graph(), i, UINT64)?)?;
inliner.assign_input_nodes(graph.clone(), vec![state, input])?;
let output = inliner.recursively_inline_graph(graph.clone())?;
inliner.unassign_nodes(graph.clone())?;
outputs.push(output.tuple_get(1)?);
}
let result = extract_state_from_mapping(
single_bit,
initial_state,
initial_state_one_hot,
prefix_sums[prefix_sums.len() - 1].clone(),
masks_arr,
state_type,
)?;
Ok((result, outputs))
}
}
struct MappingCombiner {}
impl CombineOp<Node> for MappingCombiner {
fn combine(&mut self, arg1: Node, arg2: Node) -> Result<Node> {
arg1.matmul(arg2)
}
}
struct MappingCombiner1Bit {}
impl CombineOp<Node> for MappingCombiner1Bit {
fn combine(&mut self, arg1: Node, arg2: Node) -> Result<Node> {
let bit10 = arg1.tuple_get(0)?;
let bit11 = arg1.tuple_get(1)?;
let bit20 = arg2.tuple_get(0)?;
let bit21 = arg2.tuple_get(1)?;
let distinct = bit20.add(bit21)?;
let bit0 = bit10.multiply(distinct.clone())?.add(bit20.clone())?;
let bit1 = bit11.multiply(distinct)?.add(bit20)?;
arg1.get_graph().create_tuple(vec![bit0, bit1])
}
}
fn extract_state_from_mapping(
single_bit: bool,
initial_state: Node,
initial_state_one_hot: Node,
mapping: Node,
masks_arr: Node,
state_type: Type,
) -> Result<Node> {
if single_bit {
let g = mapping.get_graph();
let out0 = mapping.tuple_get(0)?;
let out1 = mapping.tuple_get(1)?;
let one = g.ones(scalar_type(BIT))?;
let not_initial_state = initial_state.add(one)?;
out0.multiply(not_initial_state)?
.add(out1.multiply(initial_state)?)
} else {
let output_state_one_hot = initial_state_one_hot.matmul(mapping)?;
let final_state = output_state_one_hot.matmul(masks_arr)?;
final_state.reshape(state_type)
}
}
fn get_number_of_bits(state_type: Type, single_bit: bool) -> Result<u64> {
match state_type {
Type::Scalar(scalar_type) => {
if !single_bit {
Err(runtime_error!(
"Scalar state is only supported in a single-bit mode"
))
} else if scalar_type != BIT {
Err(runtime_error!("State must consist of bits"))
} else {
Ok(1)
}
}
Type::Array(shape, scalar_type) => {
if scalar_type != BIT {
Err(runtime_error!("State must consist of bits"))
} else if single_bit {
Ok(1)
} else {
Ok(shape[shape.len() - 1])
}
}
_ => Err(runtime_error!("Unsupported state type")),
}
}
fn mask_to_value(state_type: Type, num_bits: u64, mask: u64) -> Result<Value> {
let data_shape = match state_type.clone() {
Type::Scalar(scalar_type) => {
return Value::from_scalar(mask, scalar_type);
}
Type::Array(shape, _) => shape,
_ => panic!("Cannot be here"),
};
let value = Value::zero_of_type(state_type);
let mut bytes = value.access_bytes(|ref_bytes| Ok(ref_bytes.to_vec()))?;
for i in 0..data_shape.iter().product() {
let index = number_to_index(i, &data_shape);
let state_index = if num_bits == 1 {
0
} else {
index[index.len() - 1]
};
let bit = ((mask >> state_index) & 1) as u8;
let position = i / 8;
let offset = i % 8;
bytes[position as usize] &= !(1 << offset);
bytes[position as usize] |= bit << offset;
}
Ok(Value::from_bytes(bytes))
}
fn one_hot_encode(
val: Node,
depth: u64,
mask_constants: Vec<Node>,
output: Graph,
state_type: Type,
single_bit: bool,
) -> Result<Node> {
let mut result = vec![];
for mask in 0..depth {
let column_id = mask_constants[((depth - 1) ^ mask) as usize].clone();
let bit_diff = val.add(column_id)?;
if single_bit {
result.push(bit_diff.clone());
} else {
let shape = match state_type.clone() {
Type::Array(shape, _) => shape,
_ => panic!("Cannot be here"),
};
let mut bit_columns = vec![];
for bit_index in 0..shape[shape.len() - 1] {
bit_columns.push(bit_diff.get_slice(vec![
SliceElement::Ellipsis,
SliceElement::SingleIndex(bit_index as i64),
])?);
}
let mut equality = bit_columns[0].clone();
for bit_index in 1..shape[shape.len() - 1] {
equality = equality.multiply(bit_columns[bit_index as usize].clone())?;
}
result.push(equality.clone());
}
}
output.vector_to_array(output.create_vector(result[0].get_type()?, result)?)
}
fn create_mapping_matrix(
mapping: Vec<Node>,
output: Graph,
mask_constants: Vec<Node>,
state_type: Type,
single_bit: bool,
) -> Result<Node> {
if single_bit {
return output.create_tuple(mapping);
}
let mut result = vec![];
let depth = mapping.len() as u64;
for node_to_map in mapping {
result.push(one_hot_encode(
node_to_map,
depth,
mask_constants.clone(),
output.clone(),
state_type.clone(),
single_bit,
)?);
}
let matrix = output.vector_to_array(output.create_vector(result[0].get_type()?, result)?)?;
Ok(matrix)
}
fn create_mappings(
state_type: Type,
mask_constants: Vec<Node>,
num_bits: u64,
single_bit: bool,
inputs_node: Node,
graph: Graph,
inliner: &mut dyn InlineState,
) -> Result<Vec<Node>> {
let inputs_len = match inputs_node.get_type()? {
Type::Vector(len, _) => len,
_ => {
panic!("Inconsistency with type checker");
}
};
let mut mappings = vec![];
for i in 0..inputs_len {
let current_input = inputs_node.vector_get(
inliner
.output_graph()
.constant(scalar_type(UINT64), Value::from_scalar(i, UINT64)?)?,
)?;
let mut mapping_table = vec![];
for mask in 0..u64::pow(2, num_bits as u32) {
let current_state = mask_constants[mask as usize].clone();
inliner.assign_input_nodes(
graph.clone(),
vec![current_state.clone(), current_input.clone()],
)?;
let output = inliner.recursively_inline_graph(graph.clone())?;
inliner.unassign_nodes(graph.clone())?;
mapping_table.push(inliner.output_graph().tuple_get(output, 0)?);
}
mappings.push(create_mapping_matrix(
mapping_table,
inliner.output_graph().clone(),
mask_constants.clone(),
state_type.clone(),
single_bit,
)?);
}
if single_bit {
return Ok(mappings);
}
let mut mappings_arr = inliner
.output_graph()
.create_vector(mappings[0].get_type()?, mappings)?
.vector_to_array()?;
let shape_len = mappings_arr.get_type()?.get_dimensions().len();
let mut permutation: Vec<u64> = (1..shape_len).map(|x| x as u64).collect();
permutation.rotate_left(2);
permutation.insert(0, 0);
mappings_arr = mappings_arr.permute_axes(permutation)?;
let mut final_mappings = vec![];
for i in 0..inputs_len {
final_mappings.push(mappings_arr.get(vec![i])?);
}
Ok(final_mappings)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data_values::Value;
use crate::graphs::create_context;
use crate::inline::inline_test_utils::{build_test_data, MockInlineState};
#[test]
fn test_small_state_iterate_too_many_bits() {
|| -> Result<()> {
let c = create_context()?;
let g = c.create_graph()?;
let initial_state = g.constant(
array_type(vec![10], BIT),
Value::from_flattened_array(&vec![0; 10], BIT)?,
)?;
let input_vals = vec![1; 5];
let mut inputs = vec![];
for i in input_vals {
let val = g.constant(scalar_type(BIT), Value::from_scalar(i, BIT)?)?;
inputs.push(val.clone());
}
let inputs_node = g.create_vector(scalar_type(BIT), inputs.clone())?;
let mut inliner = MockInlineState {
fake_graph: g.clone(),
inputs: vec![],
inline_graph_calls: vec![],
returned_nodes: vec![],
};
let g_inline = c.create_graph()?;
let empty = g_inline.create_tuple(vec![])?;
g_inline.set_output_node(g_inline.create_tuple(vec![empty.clone(), empty.clone()])?)?;
let res = inline_iterate_small_state(
false,
DepthOptimizationLevel::Extreme,
g_inline.clone(),
initial_state.clone(),
inputs_node.clone(),
&mut inliner,
);
assert!(res.is_err());
Ok(())
}()
.unwrap();
}
#[test]
fn test_small_state_iterate_nonempty_output() {
|| -> Result<()> {
let c = create_context()?;
let (g, initial_state, inputs_node, _input_vals) = build_test_data(c.clone(), BIT)?;
let mut inliner = MockInlineState {
fake_graph: g.clone(),
inputs: vec![],
inline_graph_calls: vec![],
returned_nodes: vec![],
};
let g_inline = c.create_graph()?;
let one_bit = g_inline.input(scalar_type(BIT))?;
g_inline
.set_output_node(g_inline.create_tuple(vec![one_bit.clone(), one_bit.clone()])?)?;
inline_iterate_small_state(
true,
DepthOptimizationLevel::Extreme,
g_inline.clone(),
initial_state.clone(),
inputs_node.clone(),
&mut inliner,
)?;
assert_eq!(inliner.inputs.len(), 15);
Ok(())
}()
.unwrap();
}
#[test]
fn test_small_state_iterate_valid_case() {
|| -> Result<()> {
let c = create_context()?;
let (g, initial_state, inputs_node, _input_vals) = build_test_data(c.clone(), BIT)?;
let mut inliner = MockInlineState {
fake_graph: g.clone(),
inputs: vec![],
inline_graph_calls: vec![],
returned_nodes: vec![],
};
let g_inline = c.create_graph()?;
let one_bit = g_inline.input(scalar_type(BIT))?;
let empty = g_inline.create_tuple(vec![])?;
g_inline
.set_output_node(g_inline.create_tuple(vec![one_bit.clone(), empty.clone()])?)?;
inline_iterate_small_state(
true,
DepthOptimizationLevel::Extreme,
g_inline.clone(),
initial_state.clone(),
inputs_node.clone(),
&mut inliner,
)?;
assert_eq!(inliner.inline_graph_calls.len(), 5 * 2);
Ok(())
}()
.unwrap();
}
}