Skip to main content

dbgen/
eval.rs

1//! Evaluating compiled expressions into values.
2
3use crate::{
4    error::Error,
5    functions::{Arguments, Function},
6    parser::{Expr, QName},
7    span::{ResultExt, Span, SpanExt, S},
8    value::Value,
9};
10use chrono::{NaiveDateTime, Utc};
11use rand::{distributions::Bernoulli, seq::SliceRandom, Rng, RngCore};
12use rand_distr::{LogNormal, Uniform};
13use rand_regex::EncodedString;
14use std::{cmp::Ordering, fmt, fs, ops::Range, path::PathBuf, sync::Arc};
15use tzfile::{ArcTz, Tz};
16use zipf::ZipfDistribution;
17
18/// Environment information shared by all compilations
19#[derive(Clone, Debug)]
20pub struct CompileContext {
21    /// The zoneinfo directory where timezones can be read.
22    pub zoneinfo: PathBuf,
23    /// The time zone used to interpret strings into timestamps.
24    pub time_zone: ArcTz,
25    /// The current timestamp in UTC.
26    pub current_timestamp: NaiveDateTime,
27    /// The global variables.
28    pub variables: Box<[Value]>,
29}
30
31impl CompileContext {
32    /// Creates a default compile context storing the given number of variables.
33    pub fn new(variables_count: usize) -> Self {
34        Self {
35            zoneinfo: PathBuf::from("/usr/share/zoneinfo"),
36            time_zone: ArcTz::new(Utc.into()),
37            current_timestamp: NaiveDateTime::from_timestamp(0, 0),
38            variables: vec![Value::Null; variables_count].into_boxed_slice(),
39        }
40    }
41
42    /// Parses the time zone name into a time zone object.
43    pub fn parse_time_zone(&self, tz: &str) -> Result<ArcTz, Error> {
44        Ok(ArcTz::new(if tz == "UTC" {
45            Utc.into()
46        } else {
47            let path = self.zoneinfo.join(tz);
48            let content = fs::read(&path).map_err(|source| Error::Io {
49                action: "read time zone file",
50                path,
51                source,
52            })?;
53            Tz::parse(tz, &content).map_err(|source| Error::InvalidTimeZone {
54                time_zone: tz.to_owned(),
55                source,
56            })?
57        }))
58    }
59}
60
61/// The external mutable state used during evaluation.
62pub struct State {
63    pub(crate) row_num: u64,
64    /// Defines the value of `subrownum`.
65    pub sub_row_num: u64,
66    rng: Box<dyn RngCore>,
67    compile_context: CompileContext,
68}
69
70impl fmt::Debug for State {
71    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72        f.debug_struct("State")
73            .field("row_num", &self.row_num)
74            .field("sub_row_num", &self.sub_row_num)
75            .field("rng", &())
76            .field("variables", &self.compile_context.variables)
77            .finish()
78    }
79}
80
81impl State {
82    /// Creates a new state.
83    ///
84    /// # Parameters
85    ///
86    /// - `row_num`: The starting row number in this state. The first file should have this set
87    ///     to 1, and the second to `rows_count * inserts_count + 1`, etc.
88    /// - `rng`: The seeded random number generator.
89    pub fn new(row_num: u64, rng: Box<dyn RngCore>, compile_context: CompileContext) -> Self {
90        Self {
91            row_num,
92            sub_row_num: 1,
93            rng,
94            compile_context,
95        }
96    }
97
98    /// Extracts the compile context from the state.
99    pub fn into_compile_context(self) -> CompileContext {
100        self.compile_context
101    }
102
103    /// Increases the rownum by 1.
104    pub fn increase_row_num(&mut self) {
105        self.row_num += 1;
106    }
107}
108
109/// A compiled table
110#[derive(Debug)]
111pub struct Table {
112    /// Table name.
113    pub name: QName,
114    /// Content of table schema.
115    pub content: String,
116    /// The ranges in `content` which column names appear.
117    pub column_name_ranges: Vec<Range<usize>>,
118    /// Compiled row.
119    pub row: Row,
120    /// Information of dervied tables (index, and number of rows to generate)
121    pub derived: Vec<(usize, Compiled)>,
122}
123
124/// The schema information extracted from the compiled table.
125#[derive(Debug, Copy, Clone)]
126pub struct Schema<'a> {
127    /// Table name (qualified or unqualified).
128    pub name: &'a str,
129    /// Content of table schema.
130    pub content: &'a str,
131    /// The ranges in `content` which column names appear.
132    column_name_ranges: &'a [Range<usize>],
133}
134
135impl<'a> Schema<'a> {
136    /// Returns an iterator of column names associated with the table.
137    pub fn column_names(&self) -> impl Iterator<Item = &str> + '_ {
138        self.column_name_ranges.iter().map(move |r| &self.content[r.clone()])
139    }
140}
141
142impl Table {
143    /// Gets the schema associated with the table.
144    pub fn schema(&self, qualified: bool) -> Schema<'_> {
145        Schema {
146            name: self.name.table_name(qualified),
147            content: &self.content,
148            column_name_ranges: &self.column_name_ranges,
149        }
150    }
151}
152
153impl CompileContext {
154    /// Compiles a table.
155    pub fn compile_table(&self, table: crate::parser::Table) -> Result<Table, S<Error>> {
156        Ok(Table {
157            name: table.name,
158            content: table.content,
159            column_name_ranges: table.column_name_ranges,
160            row: self.compile_row(table.exprs)?,
161            derived: table
162                .derived
163                .into_iter()
164                .map(|(i, e)| self.compile(e).map(|c| (i, c)))
165                .collect::<Result<_, _>>()?,
166        })
167    }
168}
169
170/// Represents a row of compiled values.
171#[derive(Debug)]
172pub struct Row(Vec<Compiled>);
173
174impl CompileContext {
175    /// Compiles a vector of parsed expressions into a row.
176    pub fn compile_row(&self, exprs: Vec<S<Expr>>) -> Result<Row, S<Error>> {
177        Ok(Row(exprs
178            .into_iter()
179            .map(|e| self.compile(e))
180            .collect::<Result<_, _>>()?))
181    }
182}
183
184impl Row {
185    /// Evaluates the row into a vector of values.
186    pub fn eval(&self, state: &mut State) -> Result<Vec<Value>, S<Error>> {
187        let mut result = Vec::with_capacity(self.0.len());
188        for compiled in &self.0 {
189            result.push(compiled.eval(state)?);
190        }
191        Ok(result)
192    }
193}
194
195/// Interior of a compiled expression.
196#[derive(Clone, Debug)]
197pub enum C {
198    /// The row number.
199    RowNum,
200    /// The derived row number.
201    SubRowNum,
202    /// An evaluated constant.
203    Constant(Value),
204    /// An unevaluated function.
205    RawFunction {
206        /// The function.
207        function: &'static dyn Function,
208        /// Function arguments.
209        args: Box<[Compiled]>,
210    },
211    /// Obtains a local variable.
212    GetVariable(usize),
213    /// Assigns a value to a local variable.
214    SetVariable(usize, Box<Compiled>),
215    /// The `CASE … WHEN` expression.
216    CaseValueWhen {
217        /// The value to match against.
218        value: Option<Box<Compiled>>,
219        /// The conditions and their corresponding results.
220        conditions: Box<[(Compiled, Compiled)]>,
221        /// The result when all conditions failed.
222        otherwise: Box<Compiled>,
223    },
224
225    /// Regex-based random string.
226    RandRegex(rand_regex::Regex),
227    /// Uniform distribution for `u64`.
228    RandUniformU64(Uniform<u64>),
229    /// Uniform distribution for `i64`.
230    RandUniformI64(Uniform<i64>),
231    /// Uniform distribution for `f64`.
232    RandUniformF64(Uniform<f64>),
233    /// Zipfian distribution.
234    RandZipf(ZipfDistribution),
235    /// Log-normal distribution.
236    RandLogNormal(LogNormal<f64>),
237    /// Bernoulli distribution for `bool` (i.e. a weighted random boolean).
238    RandBool(Bernoulli),
239    /// Random f32 with uniform bit pattern
240    RandFiniteF32(Uniform<u32>),
241    /// Random f64 with uniform bit pattern
242    RandFiniteF64(Uniform<u64>),
243    /// Random u31 timestamp
244    RandU31Timestamp(Uniform<i64>),
245    /// Random shuffled array
246    RandShuffle(Arc<[Value]>),
247    /// Random (version 4) UUID
248    RandUuid,
249}
250
251impl C {
252    fn span(self, span: Span) -> Compiled {
253        Compiled(S { span, inner: self })
254    }
255}
256
257/// A compiled expression
258#[derive(Clone, Debug)]
259pub struct Compiled(pub(crate) S<C>);
260
261impl CompileContext {
262    /// Compiles an expression.
263    pub fn compile(&self, expr: S<Expr>) -> Result<Compiled, S<Error>> {
264        Ok(match expr.inner {
265            Expr::RowNum => C::RowNum,
266            Expr::SubRowNum => C::SubRowNum,
267            Expr::CurrentTimestamp => C::Constant(Value::Timestamp(self.current_timestamp, self.time_zone.clone())),
268            Expr::Value(v) => C::Constant(v),
269            Expr::GetVariable(index) => C::GetVariable(index),
270            Expr::SetVariable(index, e) => C::SetVariable(index, Box::new(self.compile(*e)?)),
271            Expr::Function { function, args } => {
272                let args = args
273                    .into_iter()
274                    .map(|e| self.compile(e))
275                    .collect::<Result<Vec<_>, _>>()?;
276                if args.iter().all(Compiled::is_constant) {
277                    let args = args
278                        .into_iter()
279                        .map(|c| match c.0.inner {
280                            C::Constant(v) => v.span(c.0.span),
281                            _ => unreachable!(),
282                        })
283                        .collect();
284                    function.compile(self, expr.span, args)?
285                } else {
286                    C::RawFunction {
287                        function,
288                        args: args.into_boxed_slice(),
289                    }
290                }
291            }
292            Expr::CaseValueWhen {
293                value,
294                conditions,
295                otherwise,
296            } => {
297                let value = value.map(|v| Ok::<_, _>(Box::new(self.compile(*v)?))).transpose()?;
298                let conditions = conditions
299                    .into_iter()
300                    .map(|(p, r)| Ok((self.compile(p)?, self.compile(r)?)))
301                    .collect::<Result<Vec<_>, _>>()?
302                    .into_boxed_slice();
303                let otherwise = Box::new(if let Some(o) = otherwise {
304                    self.compile(*o)?
305                } else {
306                    C::Constant(Value::Null).span(expr.span)
307                });
308                C::CaseValueWhen {
309                    value,
310                    conditions,
311                    otherwise,
312                }
313            }
314        }
315        .span(expr.span))
316    }
317}
318
319impl Compiled {
320    /// Returns whether this compiled value is a constant.
321    pub fn is_constant(&self) -> bool {
322        matches!(self.0.inner, C::Constant(_))
323    }
324
325    /// Evaluates a compiled expression and updates the state. Returns the evaluated value.
326    pub fn eval(&self, state: &mut State) -> Result<Value, S<Error>> {
327        let span = self.0.span;
328        Ok(match &self.0.inner {
329            C::RowNum => state.row_num.into(),
330            C::SubRowNum => state.sub_row_num.into(),
331            C::Constant(v) => v.clone(),
332            C::RawFunction { function, args } => {
333                let mut eval_args = Arguments::with_capacity(args.len());
334                for c in &**args {
335                    eval_args.push(c.eval(state)?.span(c.0.span));
336                }
337                (*function)
338                    .compile(&state.compile_context, span, eval_args)?
339                    .span(span)
340                    .eval(state)?
341            }
342            C::GetVariable(index) => state.compile_context.variables[*index].clone(),
343            C::SetVariable(index, c) => {
344                let value = c.eval(state)?;
345                state.compile_context.variables[*index] = value.clone();
346                value
347            }
348
349            C::CaseValueWhen {
350                value: Some(value),
351                conditions,
352                otherwise,
353            } => {
354                let value = value.eval(state)?;
355                for (p, r) in &**conditions {
356                    let p_span = p.0.span;
357                    let p = p.eval(state)?;
358                    if value.sql_cmp(&p).span_err(p_span)? == Some(Ordering::Equal) {
359                        return r.eval(state);
360                    }
361                }
362                otherwise.eval(state)?
363            }
364
365            C::CaseValueWhen {
366                value: None,
367                conditions,
368                otherwise,
369            } => {
370                for (p, r) in &**conditions {
371                    if p.eval(state)?.is_sql_true().span_err(p.0.span)? {
372                        return r.eval(state);
373                    }
374                }
375                otherwise.eval(state)?
376            }
377
378            C::RandRegex(generator) => state.rng.sample::<EncodedString, _>(generator).into(),
379            C::RandUniformU64(uniform) => state.rng.sample(uniform).into(),
380            C::RandUniformI64(uniform) => state.rng.sample(uniform).into(),
381            C::RandUniformF64(uniform) => Value::from_finite_f64(state.rng.sample(uniform)),
382            C::RandZipf(zipf) => (state.rng.sample(zipf) as u64).into(),
383            C::RandLogNormal(log_normal) => Value::from_finite_f64(state.rng.sample(log_normal)),
384            C::RandBool(bern) => u64::from(state.rng.sample(bern)).into(),
385            C::RandFiniteF32(uniform) => {
386                Value::from_finite_f64(f32::from_bits(state.rng.sample(uniform).rotate_right(1)).into())
387            }
388            C::RandFiniteF64(uniform) => {
389                Value::from_finite_f64(f64::from_bits(state.rng.sample(uniform).rotate_right(1)))
390            }
391
392            C::RandU31Timestamp(uniform) => {
393                let seconds = state.rng.sample(uniform);
394                let timestamp = NaiveDateTime::from_timestamp(seconds, 0);
395                Value::new_timestamp(timestamp, state.compile_context.time_zone.clone())
396            }
397
398            C::RandShuffle(array) => {
399                let mut shuffled_array = Arc::<[Value]>::from(&**array);
400                Arc::get_mut(&mut shuffled_array).unwrap().shuffle(&mut state.rng);
401                Value::Array(shuffled_array)
402            }
403
404            C::RandUuid => {
405                // we will loss 6 bits but that's still uniform.
406                let g = state.rng.gen::<[u16; 8]>();
407                format!(
408                    "{:04x}{:04x}-{:04x}-4{:03x}-{:04x}-{:04x}{:04x}{:04x}",
409                    g[0],
410                    g[1],
411                    g[2],
412                    g[3] & 0xfff,
413                    (g[4] & 0x3fff) | 0x8000,
414                    g[5],
415                    g[6],
416                    g[7],
417                )
418                .into()
419            }
420        })
421    }
422}