ciphercore_base/inline/
exponential_inliner.rs

1use crate::broadcast::number_to_index;
2use crate::data_types::{array_type, scalar_type, Type, BIT, UINT64};
3use crate::data_values::Value;
4use crate::errors::Result;
5use crate::graphs::{Graph, Node, SliceElement};
6use crate::inline::data_structures::{log_depth_sum, CombineOp};
7use crate::inline::inline_common::{
8    pick_prefix_sum_algorithm, DepthOptimizationLevel, InlineState,
9};
10use crate::ops::utils::constant_scalar;
11
12const MAX_ALLOWED_STATE_BITS: u64 = 4;
13
14/// This version inlines Iterate assuming the state has low number of bits.
15/// If there are K bits, the additional time complexity multiplier is 2**(K+2),
16/// and the resulting depth is O(log(length of the inputs)).
17/// Important note about the contract: state can be batched, it is assumed that
18/// the last dimension of the state is the actual state consisting of K bits, and
19/// there are no interactions between various states.
20/// More formally:
21/// -- let the input state S have shape (..., K) (rank d);
22/// -- let T = Call(graph, S, some input);
23/// -- let i_1, .., i_{d-1} be a valid combination of indices,
24/// -- then for every S' such that
25///    S[i_1, .., i_{d-1}] == S'[i_1, .., i_{d-1}] (as vectors),
26///    the result T[i_1, .., i_{d-1}] equals T'[i_1, .., i_{d-1}].
27/// (i.e. we can change other "rows" of the input, and the output for this "row"
28/// won't change).
29/// If this property is not true, results might be incorrect.
30///
31/// There is also a special case of K == 1, where the caller can (but not must)
32/// omit the last dimension in the state (e.g. have state of shape (n, m) rather
33/// than (n, m, 1)).
34pub(super) fn inline_iterate_small_state(
35    single_bit: bool,
36    optimization_level: DepthOptimizationLevel,
37    graph: Graph,
38    initial_state: Node,
39    inputs_node: Node,
40    inliner: &mut dyn InlineState,
41) -> Result<(Node, Vec<Node>)> {
42    // TODO(ilyakor): for "batched state" case, support the "all items are equal" optimization.
43    let graph_output_type = graph.get_output_node()?.get_type()?;
44    let output_element_type = match graph_output_type {
45        Type::Tuple(tuple_types) => (*tuple_types[1]).clone(),
46        _ => {
47            panic!("Inconsistency with type checker for Iterate output.");
48        }
49    };
50    let empty_output = match output_element_type {
51        Type::Tuple(tuple_types) => tuple_types.is_empty(),
52        _ => false,
53    };
54
55    let inputs_len = match inputs_node.get_type()? {
56        Type::Vector(len, _) => len,
57        _ => {
58            panic!("Inconsistency with type checker");
59        }
60    };
61    if inputs_len == 0 {
62        return Ok((initial_state, vec![]));
63    }
64
65    let num_bits = get_number_of_bits(initial_state.get_type()?, single_bit)?;
66    if num_bits > MAX_ALLOWED_STATE_BITS {
67        return Err(runtime_error!("Too many bits in the state"));
68    }
69    if num_bits == 0 {
70        return Err(runtime_error!(
71            "This inlining method doesn't support empty state"
72        ));
73    }
74    let num_masks = u64::pow(2, num_bits as u32);
75
76    // Intuitively, the algorithm is as follows.
77    // Let's look at our operation G(state, input) -> state'. In general, it can be
78    // non-associative, however, we can do the following trick. Let's compute
79    // matrices M_input of form M[state1, state2] = (G(state1, input) == state2).
80    // One can see that the resulting state is:
81    //   OneHot(initial_state) * M(input1) * M(input2) * ....
82    // Where "*" is MatMul. But matrix multiplication is associative, so we can use the
83    // same tricks as the associative case to reduce the depth.
84    //
85    // While simple conceptually, the code is quite complicated since we want to handle
86    // states of shape (..., K bits), not just (K bits).
87
88    // Precalculate K-bit masks, they're used in many places.
89    let state_type = initial_state.get_type()?;
90    let mut mask_constants = vec![];
91    for mask in 0..u64::pow(2, num_bits as u32) {
92        let value = mask_to_value(state_type.clone(), num_bits, mask)?;
93        let mask_const = inliner.output_graph().constant(state_type.clone(), value)?;
94        mask_constants.push(mask_const);
95    }
96
97    // First, let's compute our mappings in form of matrices.
98    // If state shape is (..., K), the shape of mappings is (..., K, K).
99    let mappings = create_mappings(
100        initial_state.get_type()?,
101        mask_constants.clone(),
102        num_bits,
103        single_bit,
104        inputs_node.clone(),
105        graph.clone(),
106        inliner,
107    )?;
108
109    // Precompute the transformation of initial_state, which is needed to
110    // extract states from the transformation matrices. See extract_state_from_mapping()
111    // for more detailed explanation.
112    let unused_node = inliner.output_graph().zeros(scalar_type(BIT))?;
113    let initial_state_one_hot = if single_bit {
114        unused_node.clone()
115    } else {
116        let mut initial_state_one_hot = one_hot_encode(
117            initial_state.clone(),
118            num_masks,
119            mask_constants.clone(),
120            inliner.output_graph(),
121            state_type.clone(),
122            single_bit,
123        )?;
124        let mut new_shape = initial_state_one_hot.get_type()?.get_shape();
125        new_shape.insert(0, 1);
126        initial_state_one_hot =
127            initial_state_one_hot.reshape(array_type(new_shape.clone(), BIT))?;
128        let mut permutation: Vec<u64> = (0..new_shape.len()).map(|x| x as u64).collect();
129        permutation.rotate_left(2);
130        initial_state_one_hot = initial_state_one_hot.permute_axes(permutation)?; // ...1i
131        initial_state_one_hot
132    };
133
134    // Precompute the mask array, needed for extracting states. See extract_state_from_mapping()
135    // for more detailed explanation.
136    let masks_arr = if single_bit {
137        unused_node
138    } else {
139        let masks_arr = inliner
140            .output_graph()
141            .create_vector(mask_constants[0].get_type()?, mask_constants)?
142            .vector_to_array()?;
143        let masks_arr_shape = masks_arr.get_type()?.get_shape();
144        let mut masks_arr_permutation: Vec<u64> =
145            (0..masks_arr_shape.len()).map(|x| x as u64).collect();
146        masks_arr_permutation.rotate_left(1);
147        let rank = masks_arr_permutation.len();
148        masks_arr_permutation.swap(rank - 2, rank - 1);
149        masks_arr.permute_axes(masks_arr_permutation)?
150    };
151
152    let mut combiner = MappingCombiner {};
153    let mut bit_combiner = MappingCombiner1Bit {};
154    if empty_output {
155        // Outputs for this case are trivial.
156        let mut outputs = vec![];
157        let empty_tuple = inliner.output_graph().create_tuple(vec![])?;
158        for _ in 0..inputs_len {
159            outputs.push(empty_tuple.clone());
160        }
161
162        let final_mapping = if single_bit {
163            log_depth_sum(&mappings, &mut bit_combiner)?
164        } else {
165            log_depth_sum(&mappings, &mut combiner)?
166        };
167        // We have the final mapping, let's compute and extract the answer.
168
169        let result = extract_state_from_mapping(
170            single_bit,
171            initial_state,
172            initial_state_one_hot,
173            final_mapping,
174            masks_arr,
175            state_type,
176        )?;
177        Ok((result, outputs))
178    } else {
179        let prefix_sums = if single_bit {
180            pick_prefix_sum_algorithm(inputs_len, optimization_level)(&mappings, &mut bit_combiner)?
181        } else {
182            pick_prefix_sum_algorithm(inputs_len, optimization_level)(&mappings, &mut combiner)?
183        };
184        let mut outputs = vec![];
185        for i in 0..inputs_len {
186            let state = if i == 0 {
187                initial_state.clone()
188            } else {
189                extract_state_from_mapping(
190                    single_bit,
191                    initial_state.clone(),
192                    initial_state_one_hot.clone(),
193                    prefix_sums[i as usize - 1].clone(),
194                    masks_arr.clone(),
195                    state_type.clone(),
196                )?
197            };
198            let input =
199                inputs_node.vector_get(constant_scalar(&inliner.output_graph(), i, UINT64)?)?;
200            inliner.assign_input_nodes(graph.clone(), vec![state, input])?;
201            let output = inliner.recursively_inline_graph(graph.clone())?;
202            inliner.unassign_nodes(graph.clone())?;
203            outputs.push(output.tuple_get(1)?);
204        }
205        let result = extract_state_from_mapping(
206            single_bit,
207            initial_state,
208            initial_state_one_hot,
209            prefix_sums[prefix_sums.len() - 1].clone(),
210            masks_arr,
211            state_type,
212        )?;
213        Ok((result, outputs))
214    }
215}
216
217struct MappingCombiner {}
218
219impl CombineOp<Node> for MappingCombiner {
220    fn combine(&mut self, arg1: Node, arg2: Node) -> Result<Node> {
221        arg1.matmul(arg2)
222    }
223}
224
225// Optimized version of mapping combiner for the single-bit case.
226// It produces more operations, but the operations are more lightweight.
227struct MappingCombiner1Bit {}
228
229impl CombineOp<Node> for MappingCombiner1Bit {
230    fn combine(&mut self, arg1: Node, arg2: Node) -> Result<Node> {
231        // Mappings are in the form tuple(bit_0, bit_1), so we extract
232        // both bits from the 2nd mapping, and combine mappings as follows:
233        // output_0 = bit_10 * (bit20 + bit21) + bit_20;
234        // output_1 = bit_11 * (bit20 + bit21) + bit_20.
235        let bit10 = arg1.tuple_get(0)?;
236        let bit11 = arg1.tuple_get(1)?;
237        let bit20 = arg2.tuple_get(0)?;
238        let bit21 = arg2.tuple_get(1)?;
239        let distinct = bit20.add(bit21)?;
240        let bit0 = bit10.multiply(distinct.clone())?.add(bit20.clone())?;
241        let bit1 = bit11.multiply(distinct)?.add(bit20)?;
242        arg1.get_graph().create_tuple(vec![bit0, bit1])
243    }
244}
245
246fn extract_state_from_mapping(
247    single_bit: bool,
248    initial_state: Node,
249    initial_state_one_hot: Node,
250    mapping: Node,
251    masks_arr: Node,
252    state_type: Type,
253) -> Result<Node> {
254    if single_bit {
255        // Optimized 1-bit case. The mapping is a tuple with 2 bits (0->bit_0, 1->bit_1).
256        let g = mapping.get_graph();
257        let out0 = mapping.tuple_get(0)?;
258        let out1 = mapping.tuple_get(1)?;
259        let one = g.ones(scalar_type(BIT))?;
260        let not_initial_state = initial_state.add(one)?;
261        out0.multiply(not_initial_state)?
262            .add(out1.multiply(initial_state)?)
263    } else {
264        // Currently, initial_state_one_hot is "...i1", and final_mapping - "...ij".
265        // We want to multiply them, but for this, we'll need Reshape and PermuteAxes.
266        //
267        // To make it easier to follow, here is what is happening below.
268        // Let M (...ij) be the final mapping, S (...i) be the one-hot-encoded state, C (2**k, ..., k)
269        // be the precomputed mask array.
270        // O := einsum('...ij,...i->...j', M, S), this is the one-hot-encoded output state.
271        // Result := einsum('...j,j...k->...k', M, C)
272        // Then we just reshape the result to the correct shape.
273
274        let output_state_one_hot = initial_state_one_hot.matmul(mapping)?;
275        // Now we have one-hot encoded result, we just need to decode it.
276        let final_state = output_state_one_hot.matmul(masks_arr)?;
277        final_state.reshape(state_type)
278    }
279}
280
281fn get_number_of_bits(state_type: Type, single_bit: bool) -> Result<u64> {
282    match state_type {
283        Type::Scalar(scalar_type) => {
284            if !single_bit {
285                Err(runtime_error!(
286                    "Scalar state is only supported in a single-bit mode"
287                ))
288            } else if scalar_type != BIT {
289                Err(runtime_error!("State must consist of bits"))
290            } else {
291                Ok(1)
292            }
293        }
294        Type::Array(shape, scalar_type) => {
295            if scalar_type != BIT {
296                Err(runtime_error!("State must consist of bits"))
297            } else if single_bit {
298                Ok(1)
299            } else {
300                Ok(shape[shape.len() - 1])
301            }
302        }
303        _ => Err(runtime_error!("Unsupported state type")),
304    }
305}
306
307fn mask_to_value(state_type: Type, num_bits: u64, mask: u64) -> Result<Value> {
308    let data_shape = match state_type.clone() {
309        Type::Scalar(scalar_type) => {
310            return Value::from_scalar(mask, scalar_type);
311        }
312        Type::Array(shape, _) => shape,
313        _ => panic!("Cannot be here"),
314    };
315    let value = Value::zero_of_type(state_type);
316    let mut bytes = value.access_bytes(|ref_bytes| Ok(ref_bytes.to_vec()))?;
317    for i in 0..data_shape.iter().product() {
318        let index = number_to_index(i, &data_shape);
319        let state_index = if num_bits == 1 {
320            0
321        } else {
322            index[index.len() - 1]
323        };
324        let bit = ((mask >> state_index) & 1) as u8;
325        let position = i / 8;
326        let offset = i % 8;
327        bytes[position as usize] &= !(1 << offset);
328        bytes[position as usize] |= bit << offset;
329    }
330    Ok(Value::from_bytes(bytes))
331}
332
333fn one_hot_encode(
334    val: Node,
335    depth: u64,
336    mask_constants: Vec<Node>,
337    output: Graph,
338    state_type: Type,
339    single_bit: bool,
340) -> Result<Node> {
341    let mut result = vec![];
342    // We have a single (batched) value `val`, which we want to one-hot encode into a vector
343    // of size `depth`.
344    // We do this by comparing `val` with every single value in range [0, depth).
345    // Each comparison is done by taking bit_diff := val xor ~target_val, which consists of
346    // ones if there is equality. So we're multiplying bit_diff[..., i] for all i to check this.
347    for mask in 0..depth {
348        // We use ~mask to avoid taking negation within the graph.
349        let column_id = mask_constants[((depth - 1) ^ mask) as usize].clone();
350        let bit_diff = val.add(column_id)?;
351        if single_bit {
352            result.push(bit_diff.clone());
353        } else {
354            let shape = match state_type.clone() {
355                Type::Array(shape, _) => shape,
356                _ => panic!("Cannot be here"),
357            };
358            let mut bit_columns = vec![];
359            for bit_index in 0..shape[shape.len() - 1] {
360                bit_columns.push(bit_diff.get_slice(vec![
361                    SliceElement::Ellipsis,
362                    SliceElement::SingleIndex(bit_index as i64),
363                ])?);
364            }
365            // Note: this can also be done with smaller depth (log(k) instead of k), we can
366            // do it if it even becomes a problem.
367            let mut equality = bit_columns[0].clone();
368            for bit_index in 1..shape[shape.len() - 1] {
369                equality = equality.multiply(bit_columns[bit_index as usize].clone())?;
370            }
371            result.push(equality.clone());
372        }
373    }
374
375    output.vector_to_array(output.create_vector(result[0].get_type()?, result)?)
376}
377
378fn create_mapping_matrix(
379    mapping: Vec<Node>,
380    output: Graph,
381    mask_constants: Vec<Node>,
382    state_type: Type,
383    single_bit: bool,
384) -> Result<Node> {
385    if single_bit {
386        // Single-bit optimization: we don't need to one-hot encode the mapping in this case.
387        return output.create_tuple(mapping);
388    }
389    // We're given 2 ** K mappings, where each mapping is a state. We want to produce
390    // the transition matrix of shape (2 ** K, 2 ** K), by one-hot-encoding every state.
391    // We want the transition matrix to be in the last two dimensions, so we do
392    // PermuteAxes at the end of the process.
393    let mut result = vec![];
394    let depth = mapping.len() as u64;
395    for node_to_map in mapping {
396        result.push(one_hot_encode(
397            node_to_map,
398            depth,
399            mask_constants.clone(),
400            output.clone(),
401            state_type.clone(),
402            single_bit,
403        )?);
404    }
405    let matrix = output.vector_to_array(output.create_vector(result[0].get_type()?, result)?)?;
406    Ok(matrix)
407}
408
409/// Creates mappings and one-hot-encoded initial state.  
410fn create_mappings(
411    state_type: Type,
412    mask_constants: Vec<Node>,
413    num_bits: u64,
414    single_bit: bool,
415    inputs_node: Node,
416    graph: Graph,
417    inliner: &mut dyn InlineState,
418) -> Result<Vec<Node>> {
419    let inputs_len = match inputs_node.get_type()? {
420        Type::Vector(len, _) => len,
421        _ => {
422            panic!("Inconsistency with type checker");
423        }
424    };
425    let mut mappings = vec![];
426    for i in 0..inputs_len {
427        let current_input = inputs_node.vector_get(
428            inliner
429                .output_graph()
430                .constant(scalar_type(UINT64), Value::from_scalar(i, UINT64)?)?,
431        )?;
432        let mut mapping_table = vec![];
433        for mask in 0..u64::pow(2, num_bits as u32) {
434            let current_state = mask_constants[mask as usize].clone();
435            inliner.assign_input_nodes(
436                graph.clone(),
437                vec![current_state.clone(), current_input.clone()],
438            )?;
439            let output = inliner.recursively_inline_graph(graph.clone())?;
440            inliner.unassign_nodes(graph.clone())?;
441            mapping_table.push(inliner.output_graph().tuple_get(output, 0)?);
442        }
443        // Note: from the perspective of graph size, it might be a good idea to batch
444        // matrix creation outsize of the loop.
445        mappings.push(create_mapping_matrix(
446            mapping_table,
447            inliner.output_graph().clone(),
448            mask_constants.clone(),
449            state_type.clone(),
450            single_bit,
451        )?);
452    }
453
454    if single_bit {
455        return Ok(mappings);
456    }
457    let mut mappings_arr = inliner
458        .output_graph()
459        .create_vector(mappings[0].get_type()?, mappings)?
460        .vector_to_array()?;
461    let shape_len = mappings_arr.get_type()?.get_dimensions().len();
462    let mut permutation: Vec<u64> = (1..shape_len).map(|x| x as u64).collect();
463    permutation.rotate_left(2);
464    permutation.insert(0, 0);
465    mappings_arr = mappings_arr.permute_axes(permutation)?;
466    let mut final_mappings = vec![];
467    for i in 0..inputs_len {
468        final_mappings.push(mappings_arr.get(vec![i])?);
469    }
470    Ok(final_mappings)
471}
472
473#[cfg(test)]
474mod tests {
475    // Note: we test basic behavior here, and rely on the general inliner tests
476    // for the end-to-end behavior testing.
477    use super::*;
478    use crate::data_values::Value;
479    use crate::graphs::create_context;
480    use crate::inline::inline_test_utils::{build_test_data, MockInlineState};
481
482    #[test]
483    fn test_small_state_iterate_too_many_bits() {
484        || -> Result<()> {
485            let c = create_context()?;
486            let g = c.create_graph()?;
487            let initial_state = g.constant(
488                array_type(vec![10], BIT),
489                Value::from_flattened_array(&vec![0; 10], BIT)?,
490            )?;
491            let input_vals = vec![1; 5];
492            let mut inputs = vec![];
493            for i in input_vals {
494                let val = g.constant(scalar_type(BIT), Value::from_scalar(i, BIT)?)?;
495                inputs.push(val.clone());
496            }
497            let inputs_node = g.create_vector(scalar_type(BIT), inputs.clone())?;
498            let mut inliner = MockInlineState {
499                fake_graph: g.clone(),
500                inputs: vec![],
501                inline_graph_calls: vec![],
502                returned_nodes: vec![],
503            };
504            let g_inline = c.create_graph()?;
505            let empty = g_inline.create_tuple(vec![])?;
506            g_inline.set_output_node(g_inline.create_tuple(vec![empty.clone(), empty.clone()])?)?;
507            let res = inline_iterate_small_state(
508                false,
509                DepthOptimizationLevel::Extreme,
510                g_inline.clone(),
511                initial_state.clone(),
512                inputs_node.clone(),
513                &mut inliner,
514            );
515            assert!(res.is_err());
516            Ok(())
517        }()
518        .unwrap();
519    }
520
521    #[test]
522    fn test_small_state_iterate_nonempty_output() {
523        || -> Result<()> {
524            let c = create_context()?;
525            let (g, initial_state, inputs_node, _input_vals) = build_test_data(c.clone(), BIT)?;
526            let mut inliner = MockInlineState {
527                fake_graph: g.clone(),
528                inputs: vec![],
529                inline_graph_calls: vec![],
530                returned_nodes: vec![],
531            };
532            let g_inline = c.create_graph()?;
533            let one_bit = g_inline.input(scalar_type(BIT))?;
534            g_inline
535                .set_output_node(g_inline.create_tuple(vec![one_bit.clone(), one_bit.clone()])?)?;
536            inline_iterate_small_state(
537                true,
538                DepthOptimizationLevel::Extreme,
539                g_inline.clone(),
540                initial_state.clone(),
541                inputs_node.clone(),
542                &mut inliner,
543            )?;
544            assert_eq!(inliner.inputs.len(), 15);
545            Ok(())
546        }()
547        .unwrap();
548    }
549
550    #[test]
551    fn test_small_state_iterate_valid_case() {
552        || -> Result<()> {
553            let c = create_context()?;
554            let (g, initial_state, inputs_node, _input_vals) = build_test_data(c.clone(), BIT)?;
555            let mut inliner = MockInlineState {
556                fake_graph: g.clone(),
557                inputs: vec![],
558                inline_graph_calls: vec![],
559                returned_nodes: vec![],
560            };
561            let g_inline = c.create_graph()?;
562            let one_bit = g_inline.input(scalar_type(BIT))?;
563            let empty = g_inline.create_tuple(vec![])?;
564            g_inline
565                .set_output_node(g_inline.create_tuple(vec![one_bit.clone(), empty.clone()])?)?;
566            inline_iterate_small_state(
567                true,
568                DepthOptimizationLevel::Extreme,
569                g_inline.clone(),
570                initial_state.clone(),
571                inputs_node.clone(),
572                &mut inliner,
573            )?;
574            assert_eq!(inliner.inline_graph_calls.len(), 5 * 2);
575            Ok(())
576        }()
577        .unwrap();
578    }
579}