1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
//! Interval expression

use std::{
    fmt::{Display, Write},
    str::FromStr,
};

use owo_colors::OwoColorize;
use rand::{
    distributions::{Open01, OpenClosed01},
    Rng,
};

use crate::regex;
use crate::Pcg;

/// Int type used in the interval
pub type Int = i32;
/// Float type used in the interval
pub type Float = f32;

/// Description of an interval
#[derive(Debug, Clone, PartialEq)]
pub struct Interval {
    low_inc: bool,
    high_inc: bool,
    kind: IntervalKind,
}

#[derive(Debug, Clone, PartialEq)]
enum IntervalKind {
    Int(std::ops::Range<Int>),
    Float(std::ops::Range<Float>),
}

/// Error from [`Interval::from_str`]
#[derive(Debug, thiserror::Error)]
pub enum IntervalParseError {
    #[error("the input is not an interval")]
    NoMatch,
    #[error("invalid interval: {0}")]
    Invalid(String),
}

impl FromStr for Interval {
    type Err = IntervalParseError;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        match parse_range(s) {
            Err(IntervalParseError::NoMatch) => {}
            other => return other,
        }
        parse_interval(s)
    }
}

const START: &str = "start";
const END: &str = "end";
const TOO_BIG: &str = "value is too big";
const EMPTY_INTERVAL: &str = "the interval is empty";

fn parse_int(num: &str, part: &str) -> Result<Int, IntervalParseError> {
    num.parse::<Int>()
        .map_err(|e| IntervalParseError::Invalid(format!("{part}: {e}")))
}

fn parse_float(num: &str, part: &str) -> Result<Float, IntervalParseError> {
    num.parse::<Float>()
        .map_err(|e| IntervalParseError::Invalid(format!("{part}: {e}")))
}

fn build_int_range(
    mut start: Int,
    mut end: Int,
    low_inc: bool,
    high_inc: bool,
) -> Result<std::ops::Range<Int>, IntervalParseError> {
    if !low_inc {
        start = start
            .checked_add(1)
            .ok_or_else(|| IntervalParseError::Invalid(format!("{START} {TOO_BIG}")))?;
    }
    if high_inc {
        end = end
            .checked_add(1)
            .ok_or_else(|| IntervalParseError::Invalid(format!("{END} {TOO_BIG}")))?;
    }
    let range = start..end;
    if range.is_empty() {
        return Err(IntervalParseError::Invalid(EMPTY_INTERVAL.to_string()));
    }
    Ok(range)
}

fn parse_interval(s: &str) -> Result<Interval, IntervalParseError> {
    let re = regex!(
        r"\A([\[\(])\s*((?:\+|-)?(?:\d*\.)?\d+)\s*(,|\.{2})\s*((?:\+|-)?(?:\d*\.)?\d+)\s*([\]\)])\z"
    );

    let caps = re.captures(s).ok_or(IntervalParseError::NoMatch)?;

    let low_inc = &caps[1] == "[";
    let high_inc = &caps[5] == "]";
    let start = &caps[2];
    let end = &caps[4];
    let is_float = &caps[3] == "," || start.contains('.') || end.contains('.');

    let kind = if is_float {
        let start = parse_float(start, START)?;
        let end = parse_float(end, END)?;
        let range = start..end;
        if range.is_empty() {
            return Err(IntervalParseError::Invalid(EMPTY_INTERVAL.to_string()));
        }
        IntervalKind::Float(start..end)
    } else {
        let start = parse_int(start, START)?;
        let end = parse_int(end, END)?;
        let range = build_int_range(start, end, low_inc, high_inc)?;
        IntervalKind::Int(range)
    };
    Ok(Interval {
        low_inc,
        high_inc,
        kind,
    })
}

fn parse_range(s: &str) -> Result<Interval, IntervalParseError> {
    let re = regex!(r"\A((?:\+|-)?\d+)..(=)?((?:\+|-)?\d+)\z");

    let caps = re.captures(s).ok_or(IntervalParseError::NoMatch)?;

    let start = parse_int(&caps[1], START)?;
    let end = parse_int(&caps[3], END)?;
    let inclusive = caps.get(2).is_some();

    let range = build_int_range(start, end, true, inclusive)?;

    Ok(Interval {
        low_inc: true,
        high_inc: inclusive,
        kind: IntervalKind::Int(range),
    })
}

impl Display for Interval {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self.low_inc {
            true => f.write_char('[')?,
            false => f.write_char('(')?,
        }

        match &self.kind {
            IntervalKind::Int(r) => {
                let mut start = r.start;
                if !self.low_inc {
                    start = start.checked_sub(1).unwrap(); // checked in creation
                }
                let mut end = r.end;
                if self.high_inc {
                    end = end.checked_sub(1).unwrap(); // checked in creation
                }
                write!(f, "{start}..{end}")?;
            }
            IntervalKind::Float(r) => {
                let start = r.start;
                let end = r.end;
                write!(f, "{start}, {end}")?;
            }
        }

        match self.low_inc {
            true => f.write_char(']'),
            false => f.write_char(')'),
        }
    }
}

/// Sample from an interval
///
/// The [`Display`] [alternate modifier](std::fmt#sign0) will only print
/// the sampled value.
#[derive(Debug, Clone, PartialEq)]
pub struct IntervalSample {
    /// Original interval
    interval: Interval,
    /// Value obtained
    value: Num,
}

/// Either an [`Int`] or a [`Float`].
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Num {
    Int(Int),
    Float(Float),
}

impl IntervalSample {
    /// Sampled value
    pub fn value(&self) -> Num {
        self.value
    }

    /// Start endpoint from the source interval and boolean true if included
    pub fn start(&self) -> (Num, bool) {
        self.interval.start()
    }
    /// End endpoint from the source interval and boolean true if included
    pub fn end(&self) -> (Num, bool) {
        self.interval.end()
    }
}

impl Interval {
    pub(crate) fn eval(&self, rng: &mut Pcg) -> IntervalSample {
        let Interval {
            low_inc,
            high_inc,
            kind,
        } = self;
        let value = match kind {
            IntervalKind::Int(r) => Num::Int(rng.gen_range(r.clone())),
            IntervalKind::Float(r) => {
                let f = match (low_inc, high_inc) {
                    (true, true) => rng.gen_range(r.start..=r.end),
                    (true, false) => rng.gen_range(r.start..r.end),
                    (false, true) => {
                        let val: Float = rng.sample(OpenClosed01);
                        let scale = r.end - r.start;
                        val * scale + r.start
                    }
                    (false, false) => {
                        let val: Float = rng.sample(Open01);
                        let scale = r.end - r.start;
                        val * scale + r.start
                    }
                };
                Num::Float(f)
            }
        };
        IntervalSample {
            interval: self.clone(),
            value,
        }
    }

    fn start(&self) -> (Num, bool) {
        let inc = self.low_inc;
        let n = match &self.kind {
            IntervalKind::Int(r) => {
                let mut start = r.start;
                if !inc {
                    start -= 1;
                }
                Num::Int(start)
            }
            IntervalKind::Float(r) => Num::Float(r.start),
        };
        (n, inc)
    }

    fn end(&self) -> (Num, bool) {
        let inc = self.high_inc;
        let n = match &self.kind {
            IntervalKind::Int(r) => {
                let mut end = r.end;
                if inc {
                    end -= 1;
                }
                Num::Int(end)
            }
            IntervalKind::Float(r) => Num::Float(r.end),
        };
        (n, inc)
    }
}

impl Display for IntervalSample {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        if f.alternate() {
            self.value.fmt(f)
        } else {
            write!(f, "{}: {}", self.interval.bold().yellow(), self.value)
        }
    }
}

impl Display for Num {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Num::Int(n) => n.fmt(f),
            Num::Float(n) => n.fmt(f),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use test_case::test_case;

    #[test_case("[1..10]" => 1..11 ; "inclusive")]
    #[test_case("[1..10)" => 1..10 ; "end exclusive")]
    #[test_case("(1..10]" => 2..11 ; "start exclusive")]
    #[test_case("(1..10)" => 2..10 ; "exclusive")]
    #[test_case("1..10" => 1..10 ; "alt exclusive")]
    #[test_case("1..=10" => 1..11 ; "alt inclusive")]
    #[test_case("[-5..-3)" => -5..-3 ; "neg")]
    #[test_case("[-5..-3]" => -5..-2 ; "neg inclusive")]
    #[test_case("-5..-3" => -5..-3 ; "alt neg")]
    #[test_case("-5..=-3" => -5..-2 ; "alt neg inclusive")]
    fn parse_int(s: &str) -> std::ops::Range<Int> {
        let interval = s.parse::<Interval>().expect("failed to parse");
        match interval.kind {
            IntervalKind::Int(r) => r,
            IntervalKind::Float(_) => panic!("not int"),
        }
    }

    #[test_case("[1,10]" => (1.0..10.0, true, true) ; "inclusive")]
    #[test_case("[1,10)" => (1.0..10.0, true, false) ; "end exclusive")]
    #[test_case("(1,10]" => (1.0..10.0, false, true) ; "start exclusive")]
    #[test_case("(1,10)" => (1.0..10.0, false, false) ; "exclusive")]
    #[test_case("(1.0,10.0)" => (1.0..10.0, false, false) ; "full decimal")]
    #[test_case("(1.0,10)" => (1.0..10.0, false, false) ; "only first decimal")]
    #[test_case("(1,10.0)" => (1.0..10.0, false, false) ; "only second decimal")]
    #[test_case("(1.,10)" => panics "failed to parse" ; "bad partial decimal start")] // no float with trailing .
    #[test_case("(0,.9)" => (0.0..0.9, false, false) ; "bad partial decimal end")]
    #[test_case("(1.0,10.0)" => (1.0..10.0, false, false) ; "partial decimal")]
    #[test_case("(1.0..10.0)" => (1.0..10.0, false, false) ; "decimal on int")]
    #[test_case("(.5..1)" => (0.5..1.0, false, false) ; "one decimal on int")]
    #[test_case("(1..10)" => panics "not float" ; "int")]
    #[test_case("(-1, 1)" => (-1.0..1.0, false, false) ; "neg start")]
    #[test_case("(2, -1)" => panics "failed to parse" ; "neg end")] // start > end
    #[test_case("(-2, -1)" => (-2.0..-1.0, false, false) ; "neg")]
    fn parse_float(s: &str) -> (std::ops::Range<Float>, bool, bool) {
        let interval = s.parse::<Interval>().expect("failed to parse");
        match interval.kind {
            IntervalKind::Int(_) => panic!("not float"),
            IntervalKind::Float(r) => (r, interval.low_inc, interval.high_inc),
        }
    }

    #[test_case("1..5" => (Num::Int(1), true) ; "range")]
    #[test_case("[1..5]" => (Num::Int(1), true) ; "inclusive")]
    #[test_case("(1..5]" => (Num::Int(1), false) ; "exclusive")]
    fn interval_start(s: &str) -> (Num, bool) {
        let interval = s.parse::<Interval>().expect("failed to parse");
        interval.start()
    }

    #[test_case("1..5" => (Num::Int(5), false) ; "range exclusive")]
    #[test_case("1..=5" => (Num::Int(5), true) ; "range inclusive")]
    #[test_case("[1..5]" => (Num::Int(5), true) ; "inclusive")]
    #[test_case("[1..5)" => (Num::Int(5), false) ; "exclusive")]
    fn interval_end(s: &str) -> (Num, bool) {
        let interval = s.parse::<Interval>().expect("failed to parse");
        interval.end()
    }
}