1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
crate::ix!();

/**
  | Base implementation for segment reduction
  | op that leverages continuity of the
  | data
  | 
  | Assumes that segments are sorted and
  | there are no skip indices class InputAccessor
  | = BaseInputAccessor<T>>
  |
  */
#[USE_OPERATOR_CONTEXT_FUNCTIONS]
pub struct AbstractSortedSegmentRangeOp<T,SIndex,Context,RangeReducer,InputAccessor> {
    storage:             OperatorStorage,
    context:             Context,
    input_accessor:      InputAccessor,
    phantom:             PhantomData<T>,
    phantomSIndex:       PhantomData<SIndex>,
    phantomRangeReducer: PhantomData<RangeReducer>,
}

input_tags!{
    AbstractSortedSegmentRangeOp {
        Data,
        SegmentIds
    }
}

impl<T,SIndex,Context,RangeReducer,InputAccessor> 
AbstractSortedSegmentRangeOp<T,SIndex,Context,RangeReducer,InputAccessor> {

    const kNumInputs: i32 = 2;
    
    #[inline] pub fn run_on_device(&mut self) -> bool {
        
        todo!();
        /*
            auto& dataInput = Input(DATA);
        auto& segment_ids = Input(SEGMENT_IDS);

        CAFFE_ENFORCE_EQ(1, segment_ids.dim(), "SEGMENT_IDS must be a vector");
        auto N = segment_ids.size(0);
        CAFFE_ENFORCE_EQ(
            N,
            dataInput.size(0),
            "SEGMENT_IDS must have the same length as outer dimension of DATA");

        OPERATOR_NEEDS_FEATURE(
            inputAccessor_.observeInput(dataInput),
            "Unsupported input type: ",
            dataInput.dtype().name(),
            ".");

        const SIndex* s_ids = segment_ids.template data<SIndex>();

        const SIndex K = N > 0 ? s_ids[N - 1] + 1 : 0;
        auto shape = dataInput.sizes().vec();
        shape[0] = K;
        auto* output = Output(0, shape, at::dtype<T>());

        T* out = output->template mutable_data<T>();

        if (N == 0) {
          return true;
        }

        int64_t block_size = dataInput.numel() / N;

        // Assume the segments are sorted and there are no gaps
        CAFFE_ENFORCE_EQ(0, s_ids[0], "Indices must be sorted and not have gaps");
        for (int64_t i = 0; i < N;) {
          int64_t start = i;
          for (++i; i < N && s_ids[start] == s_ids[i]; ++i)
            ;

          RangeReducer()(
              block_size,
              i - start,
              inputAccessor_.getBlockPtr(block_size, start, i - start),
              out + block_size * s_ids[start],
              &context_);

          // check correctness of the next segment
          if (i < N) {
            CAFFE_ENFORCE_EQ(
                s_ids[start] + 1,
                s_ids[i],
                "Indices must be sorted and not have gaps");
          }
        }
        return true;
        */
    }
}