1use crate::broadcast::number_to_index;
2use crate::data_types::{array_type, scalar_type, Type, BIT, UINT64};
3use crate::data_values::Value;
4use crate::errors::Result;
5use crate::graphs::{Graph, Node, SliceElement};
6use crate::inline::data_structures::{log_depth_sum, CombineOp};
7use crate::inline::inline_common::{
8 pick_prefix_sum_algorithm, DepthOptimizationLevel, InlineState,
9};
10use crate::ops::utils::constant_scalar;
11
12const MAX_ALLOWED_STATE_BITS: u64 = 4;
13
14pub(super) fn inline_iterate_small_state(
35 single_bit: bool,
36 optimization_level: DepthOptimizationLevel,
37 graph: Graph,
38 initial_state: Node,
39 inputs_node: Node,
40 inliner: &mut dyn InlineState,
41) -> Result<(Node, Vec<Node>)> {
42 let graph_output_type = graph.get_output_node()?.get_type()?;
44 let output_element_type = match graph_output_type {
45 Type::Tuple(tuple_types) => (*tuple_types[1]).clone(),
46 _ => {
47 panic!("Inconsistency with type checker for Iterate output.");
48 }
49 };
50 let empty_output = match output_element_type {
51 Type::Tuple(tuple_types) => tuple_types.is_empty(),
52 _ => false,
53 };
54
55 let inputs_len = match inputs_node.get_type()? {
56 Type::Vector(len, _) => len,
57 _ => {
58 panic!("Inconsistency with type checker");
59 }
60 };
61 if inputs_len == 0 {
62 return Ok((initial_state, vec![]));
63 }
64
65 let num_bits = get_number_of_bits(initial_state.get_type()?, single_bit)?;
66 if num_bits > MAX_ALLOWED_STATE_BITS {
67 return Err(runtime_error!("Too many bits in the state"));
68 }
69 if num_bits == 0 {
70 return Err(runtime_error!(
71 "This inlining method doesn't support empty state"
72 ));
73 }
74 let num_masks = u64::pow(2, num_bits as u32);
75
76 let state_type = initial_state.get_type()?;
90 let mut mask_constants = vec![];
91 for mask in 0..u64::pow(2, num_bits as u32) {
92 let value = mask_to_value(state_type.clone(), num_bits, mask)?;
93 let mask_const = inliner.output_graph().constant(state_type.clone(), value)?;
94 mask_constants.push(mask_const);
95 }
96
97 let mappings = create_mappings(
100 initial_state.get_type()?,
101 mask_constants.clone(),
102 num_bits,
103 single_bit,
104 inputs_node.clone(),
105 graph.clone(),
106 inliner,
107 )?;
108
109 let unused_node = inliner.output_graph().zeros(scalar_type(BIT))?;
113 let initial_state_one_hot = if single_bit {
114 unused_node.clone()
115 } else {
116 let mut initial_state_one_hot = one_hot_encode(
117 initial_state.clone(),
118 num_masks,
119 mask_constants.clone(),
120 inliner.output_graph(),
121 state_type.clone(),
122 single_bit,
123 )?;
124 let mut new_shape = initial_state_one_hot.get_type()?.get_shape();
125 new_shape.insert(0, 1);
126 initial_state_one_hot =
127 initial_state_one_hot.reshape(array_type(new_shape.clone(), BIT))?;
128 let mut permutation: Vec<u64> = (0..new_shape.len()).map(|x| x as u64).collect();
129 permutation.rotate_left(2);
130 initial_state_one_hot = initial_state_one_hot.permute_axes(permutation)?; initial_state_one_hot
132 };
133
134 let masks_arr = if single_bit {
137 unused_node
138 } else {
139 let masks_arr = inliner
140 .output_graph()
141 .create_vector(mask_constants[0].get_type()?, mask_constants)?
142 .vector_to_array()?;
143 let masks_arr_shape = masks_arr.get_type()?.get_shape();
144 let mut masks_arr_permutation: Vec<u64> =
145 (0..masks_arr_shape.len()).map(|x| x as u64).collect();
146 masks_arr_permutation.rotate_left(1);
147 let rank = masks_arr_permutation.len();
148 masks_arr_permutation.swap(rank - 2, rank - 1);
149 masks_arr.permute_axes(masks_arr_permutation)?
150 };
151
152 let mut combiner = MappingCombiner {};
153 let mut bit_combiner = MappingCombiner1Bit {};
154 if empty_output {
155 let mut outputs = vec![];
157 let empty_tuple = inliner.output_graph().create_tuple(vec![])?;
158 for _ in 0..inputs_len {
159 outputs.push(empty_tuple.clone());
160 }
161
162 let final_mapping = if single_bit {
163 log_depth_sum(&mappings, &mut bit_combiner)?
164 } else {
165 log_depth_sum(&mappings, &mut combiner)?
166 };
167 let result = extract_state_from_mapping(
170 single_bit,
171 initial_state,
172 initial_state_one_hot,
173 final_mapping,
174 masks_arr,
175 state_type,
176 )?;
177 Ok((result, outputs))
178 } else {
179 let prefix_sums = if single_bit {
180 pick_prefix_sum_algorithm(inputs_len, optimization_level)(&mappings, &mut bit_combiner)?
181 } else {
182 pick_prefix_sum_algorithm(inputs_len, optimization_level)(&mappings, &mut combiner)?
183 };
184 let mut outputs = vec![];
185 for i in 0..inputs_len {
186 let state = if i == 0 {
187 initial_state.clone()
188 } else {
189 extract_state_from_mapping(
190 single_bit,
191 initial_state.clone(),
192 initial_state_one_hot.clone(),
193 prefix_sums[i as usize - 1].clone(),
194 masks_arr.clone(),
195 state_type.clone(),
196 )?
197 };
198 let input =
199 inputs_node.vector_get(constant_scalar(&inliner.output_graph(), i, UINT64)?)?;
200 inliner.assign_input_nodes(graph.clone(), vec![state, input])?;
201 let output = inliner.recursively_inline_graph(graph.clone())?;
202 inliner.unassign_nodes(graph.clone())?;
203 outputs.push(output.tuple_get(1)?);
204 }
205 let result = extract_state_from_mapping(
206 single_bit,
207 initial_state,
208 initial_state_one_hot,
209 prefix_sums[prefix_sums.len() - 1].clone(),
210 masks_arr,
211 state_type,
212 )?;
213 Ok((result, outputs))
214 }
215}
216
217struct MappingCombiner {}
218
219impl CombineOp<Node> for MappingCombiner {
220 fn combine(&mut self, arg1: Node, arg2: Node) -> Result<Node> {
221 arg1.matmul(arg2)
222 }
223}
224
225struct MappingCombiner1Bit {}
228
229impl CombineOp<Node> for MappingCombiner1Bit {
230 fn combine(&mut self, arg1: Node, arg2: Node) -> Result<Node> {
231 let bit10 = arg1.tuple_get(0)?;
236 let bit11 = arg1.tuple_get(1)?;
237 let bit20 = arg2.tuple_get(0)?;
238 let bit21 = arg2.tuple_get(1)?;
239 let distinct = bit20.add(bit21)?;
240 let bit0 = bit10.multiply(distinct.clone())?.add(bit20.clone())?;
241 let bit1 = bit11.multiply(distinct)?.add(bit20)?;
242 arg1.get_graph().create_tuple(vec![bit0, bit1])
243 }
244}
245
246fn extract_state_from_mapping(
247 single_bit: bool,
248 initial_state: Node,
249 initial_state_one_hot: Node,
250 mapping: Node,
251 masks_arr: Node,
252 state_type: Type,
253) -> Result<Node> {
254 if single_bit {
255 let g = mapping.get_graph();
257 let out0 = mapping.tuple_get(0)?;
258 let out1 = mapping.tuple_get(1)?;
259 let one = g.ones(scalar_type(BIT))?;
260 let not_initial_state = initial_state.add(one)?;
261 out0.multiply(not_initial_state)?
262 .add(out1.multiply(initial_state)?)
263 } else {
264 let output_state_one_hot = initial_state_one_hot.matmul(mapping)?;
275 let final_state = output_state_one_hot.matmul(masks_arr)?;
277 final_state.reshape(state_type)
278 }
279}
280
281fn get_number_of_bits(state_type: Type, single_bit: bool) -> Result<u64> {
282 match state_type {
283 Type::Scalar(scalar_type) => {
284 if !single_bit {
285 Err(runtime_error!(
286 "Scalar state is only supported in a single-bit mode"
287 ))
288 } else if scalar_type != BIT {
289 Err(runtime_error!("State must consist of bits"))
290 } else {
291 Ok(1)
292 }
293 }
294 Type::Array(shape, scalar_type) => {
295 if scalar_type != BIT {
296 Err(runtime_error!("State must consist of bits"))
297 } else if single_bit {
298 Ok(1)
299 } else {
300 Ok(shape[shape.len() - 1])
301 }
302 }
303 _ => Err(runtime_error!("Unsupported state type")),
304 }
305}
306
307fn mask_to_value(state_type: Type, num_bits: u64, mask: u64) -> Result<Value> {
308 let data_shape = match state_type.clone() {
309 Type::Scalar(scalar_type) => {
310 return Value::from_scalar(mask, scalar_type);
311 }
312 Type::Array(shape, _) => shape,
313 _ => panic!("Cannot be here"),
314 };
315 let value = Value::zero_of_type(state_type);
316 let mut bytes = value.access_bytes(|ref_bytes| Ok(ref_bytes.to_vec()))?;
317 for i in 0..data_shape.iter().product() {
318 let index = number_to_index(i, &data_shape);
319 let state_index = if num_bits == 1 {
320 0
321 } else {
322 index[index.len() - 1]
323 };
324 let bit = ((mask >> state_index) & 1) as u8;
325 let position = i / 8;
326 let offset = i % 8;
327 bytes[position as usize] &= !(1 << offset);
328 bytes[position as usize] |= bit << offset;
329 }
330 Ok(Value::from_bytes(bytes))
331}
332
333fn one_hot_encode(
334 val: Node,
335 depth: u64,
336 mask_constants: Vec<Node>,
337 output: Graph,
338 state_type: Type,
339 single_bit: bool,
340) -> Result<Node> {
341 let mut result = vec![];
342 for mask in 0..depth {
348 let column_id = mask_constants[((depth - 1) ^ mask) as usize].clone();
350 let bit_diff = val.add(column_id)?;
351 if single_bit {
352 result.push(bit_diff.clone());
353 } else {
354 let shape = match state_type.clone() {
355 Type::Array(shape, _) => shape,
356 _ => panic!("Cannot be here"),
357 };
358 let mut bit_columns = vec![];
359 for bit_index in 0..shape[shape.len() - 1] {
360 bit_columns.push(bit_diff.get_slice(vec![
361 SliceElement::Ellipsis,
362 SliceElement::SingleIndex(bit_index as i64),
363 ])?);
364 }
365 let mut equality = bit_columns[0].clone();
368 for bit_index in 1..shape[shape.len() - 1] {
369 equality = equality.multiply(bit_columns[bit_index as usize].clone())?;
370 }
371 result.push(equality.clone());
372 }
373 }
374
375 output.vector_to_array(output.create_vector(result[0].get_type()?, result)?)
376}
377
378fn create_mapping_matrix(
379 mapping: Vec<Node>,
380 output: Graph,
381 mask_constants: Vec<Node>,
382 state_type: Type,
383 single_bit: bool,
384) -> Result<Node> {
385 if single_bit {
386 return output.create_tuple(mapping);
388 }
389 let mut result = vec![];
394 let depth = mapping.len() as u64;
395 for node_to_map in mapping {
396 result.push(one_hot_encode(
397 node_to_map,
398 depth,
399 mask_constants.clone(),
400 output.clone(),
401 state_type.clone(),
402 single_bit,
403 )?);
404 }
405 let matrix = output.vector_to_array(output.create_vector(result[0].get_type()?, result)?)?;
406 Ok(matrix)
407}
408
409fn create_mappings(
411 state_type: Type,
412 mask_constants: Vec<Node>,
413 num_bits: u64,
414 single_bit: bool,
415 inputs_node: Node,
416 graph: Graph,
417 inliner: &mut dyn InlineState,
418) -> Result<Vec<Node>> {
419 let inputs_len = match inputs_node.get_type()? {
420 Type::Vector(len, _) => len,
421 _ => {
422 panic!("Inconsistency with type checker");
423 }
424 };
425 let mut mappings = vec![];
426 for i in 0..inputs_len {
427 let current_input = inputs_node.vector_get(
428 inliner
429 .output_graph()
430 .constant(scalar_type(UINT64), Value::from_scalar(i, UINT64)?)?,
431 )?;
432 let mut mapping_table = vec![];
433 for mask in 0..u64::pow(2, num_bits as u32) {
434 let current_state = mask_constants[mask as usize].clone();
435 inliner.assign_input_nodes(
436 graph.clone(),
437 vec![current_state.clone(), current_input.clone()],
438 )?;
439 let output = inliner.recursively_inline_graph(graph.clone())?;
440 inliner.unassign_nodes(graph.clone())?;
441 mapping_table.push(inliner.output_graph().tuple_get(output, 0)?);
442 }
443 mappings.push(create_mapping_matrix(
446 mapping_table,
447 inliner.output_graph().clone(),
448 mask_constants.clone(),
449 state_type.clone(),
450 single_bit,
451 )?);
452 }
453
454 if single_bit {
455 return Ok(mappings);
456 }
457 let mut mappings_arr = inliner
458 .output_graph()
459 .create_vector(mappings[0].get_type()?, mappings)?
460 .vector_to_array()?;
461 let shape_len = mappings_arr.get_type()?.get_dimensions().len();
462 let mut permutation: Vec<u64> = (1..shape_len).map(|x| x as u64).collect();
463 permutation.rotate_left(2);
464 permutation.insert(0, 0);
465 mappings_arr = mappings_arr.permute_axes(permutation)?;
466 let mut final_mappings = vec![];
467 for i in 0..inputs_len {
468 final_mappings.push(mappings_arr.get(vec![i])?);
469 }
470 Ok(final_mappings)
471}
472
473#[cfg(test)]
474mod tests {
475 use super::*;
478 use crate::data_values::Value;
479 use crate::graphs::create_context;
480 use crate::inline::inline_test_utils::{build_test_data, MockInlineState};
481
482 #[test]
483 fn test_small_state_iterate_too_many_bits() {
484 || -> Result<()> {
485 let c = create_context()?;
486 let g = c.create_graph()?;
487 let initial_state = g.constant(
488 array_type(vec![10], BIT),
489 Value::from_flattened_array(&vec![0; 10], BIT)?,
490 )?;
491 let input_vals = vec![1; 5];
492 let mut inputs = vec![];
493 for i in input_vals {
494 let val = g.constant(scalar_type(BIT), Value::from_scalar(i, BIT)?)?;
495 inputs.push(val.clone());
496 }
497 let inputs_node = g.create_vector(scalar_type(BIT), inputs.clone())?;
498 let mut inliner = MockInlineState {
499 fake_graph: g.clone(),
500 inputs: vec![],
501 inline_graph_calls: vec![],
502 returned_nodes: vec![],
503 };
504 let g_inline = c.create_graph()?;
505 let empty = g_inline.create_tuple(vec![])?;
506 g_inline.set_output_node(g_inline.create_tuple(vec![empty.clone(), empty.clone()])?)?;
507 let res = inline_iterate_small_state(
508 false,
509 DepthOptimizationLevel::Extreme,
510 g_inline.clone(),
511 initial_state.clone(),
512 inputs_node.clone(),
513 &mut inliner,
514 );
515 assert!(res.is_err());
516 Ok(())
517 }()
518 .unwrap();
519 }
520
521 #[test]
522 fn test_small_state_iterate_nonempty_output() {
523 || -> Result<()> {
524 let c = create_context()?;
525 let (g, initial_state, inputs_node, _input_vals) = build_test_data(c.clone(), BIT)?;
526 let mut inliner = MockInlineState {
527 fake_graph: g.clone(),
528 inputs: vec![],
529 inline_graph_calls: vec![],
530 returned_nodes: vec![],
531 };
532 let g_inline = c.create_graph()?;
533 let one_bit = g_inline.input(scalar_type(BIT))?;
534 g_inline
535 .set_output_node(g_inline.create_tuple(vec![one_bit.clone(), one_bit.clone()])?)?;
536 inline_iterate_small_state(
537 true,
538 DepthOptimizationLevel::Extreme,
539 g_inline.clone(),
540 initial_state.clone(),
541 inputs_node.clone(),
542 &mut inliner,
543 )?;
544 assert_eq!(inliner.inputs.len(), 15);
545 Ok(())
546 }()
547 .unwrap();
548 }
549
550 #[test]
551 fn test_small_state_iterate_valid_case() {
552 || -> Result<()> {
553 let c = create_context()?;
554 let (g, initial_state, inputs_node, _input_vals) = build_test_data(c.clone(), BIT)?;
555 let mut inliner = MockInlineState {
556 fake_graph: g.clone(),
557 inputs: vec![],
558 inline_graph_calls: vec![],
559 returned_nodes: vec![],
560 };
561 let g_inline = c.create_graph()?;
562 let one_bit = g_inline.input(scalar_type(BIT))?;
563 let empty = g_inline.create_tuple(vec![])?;
564 g_inline
565 .set_output_node(g_inline.create_tuple(vec![one_bit.clone(), empty.clone()])?)?;
566 inline_iterate_small_state(
567 true,
568 DepthOptimizationLevel::Extreme,
569 g_inline.clone(),
570 initial_state.clone(),
571 inputs_node.clone(),
572 &mut inliner,
573 )?;
574 assert_eq!(inliner.inline_graph_calls.len(), 5 * 2);
575 Ok(())
576 }()
577 .unwrap();
578 }
579}