1use crate::{
4 error::Error,
5 functions::{Arguments, Function},
6 parser::{Expr, QName},
7 span::{ResultExt, Span, SpanExt, S},
8 value::Value,
9};
10use chrono::{NaiveDateTime, Utc};
11use rand::{distributions::Bernoulli, seq::SliceRandom, Rng, RngCore};
12use rand_distr::{LogNormal, Uniform};
13use rand_regex::EncodedString;
14use std::{cmp::Ordering, fmt, fs, ops::Range, path::PathBuf, sync::Arc};
15use tzfile::{ArcTz, Tz};
16use zipf::ZipfDistribution;
17
18#[derive(Clone, Debug)]
20pub struct CompileContext {
21 pub zoneinfo: PathBuf,
23 pub time_zone: ArcTz,
25 pub current_timestamp: NaiveDateTime,
27 pub variables: Box<[Value]>,
29}
30
31impl CompileContext {
32 pub fn new(variables_count: usize) -> Self {
34 Self {
35 zoneinfo: PathBuf::from("/usr/share/zoneinfo"),
36 time_zone: ArcTz::new(Utc.into()),
37 current_timestamp: NaiveDateTime::from_timestamp(0, 0),
38 variables: vec![Value::Null; variables_count].into_boxed_slice(),
39 }
40 }
41
42 pub fn parse_time_zone(&self, tz: &str) -> Result<ArcTz, Error> {
44 Ok(ArcTz::new(if tz == "UTC" {
45 Utc.into()
46 } else {
47 let path = self.zoneinfo.join(tz);
48 let content = fs::read(&path).map_err(|source| Error::Io {
49 action: "read time zone file",
50 path,
51 source,
52 })?;
53 Tz::parse(tz, &content).map_err(|source| Error::InvalidTimeZone {
54 time_zone: tz.to_owned(),
55 source,
56 })?
57 }))
58 }
59}
60
61pub struct State {
63 pub(crate) row_num: u64,
64 pub sub_row_num: u64,
66 rng: Box<dyn RngCore>,
67 compile_context: CompileContext,
68}
69
70impl fmt::Debug for State {
71 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72 f.debug_struct("State")
73 .field("row_num", &self.row_num)
74 .field("sub_row_num", &self.sub_row_num)
75 .field("rng", &())
76 .field("variables", &self.compile_context.variables)
77 .finish()
78 }
79}
80
81impl State {
82 pub fn new(row_num: u64, rng: Box<dyn RngCore>, compile_context: CompileContext) -> Self {
90 Self {
91 row_num,
92 sub_row_num: 1,
93 rng,
94 compile_context,
95 }
96 }
97
98 pub fn into_compile_context(self) -> CompileContext {
100 self.compile_context
101 }
102
103 pub fn increase_row_num(&mut self) {
105 self.row_num += 1;
106 }
107}
108
109#[derive(Debug)]
111pub struct Table {
112 pub name: QName,
114 pub content: String,
116 pub column_name_ranges: Vec<Range<usize>>,
118 pub row: Row,
120 pub derived: Vec<(usize, Compiled)>,
122}
123
124#[derive(Debug, Copy, Clone)]
126pub struct Schema<'a> {
127 pub name: &'a str,
129 pub content: &'a str,
131 column_name_ranges: &'a [Range<usize>],
133}
134
135impl<'a> Schema<'a> {
136 pub fn column_names(&self) -> impl Iterator<Item = &str> + '_ {
138 self.column_name_ranges.iter().map(move |r| &self.content[r.clone()])
139 }
140}
141
142impl Table {
143 pub fn schema(&self, qualified: bool) -> Schema<'_> {
145 Schema {
146 name: self.name.table_name(qualified),
147 content: &self.content,
148 column_name_ranges: &self.column_name_ranges,
149 }
150 }
151}
152
153impl CompileContext {
154 pub fn compile_table(&self, table: crate::parser::Table) -> Result<Table, S<Error>> {
156 Ok(Table {
157 name: table.name,
158 content: table.content,
159 column_name_ranges: table.column_name_ranges,
160 row: self.compile_row(table.exprs)?,
161 derived: table
162 .derived
163 .into_iter()
164 .map(|(i, e)| self.compile(e).map(|c| (i, c)))
165 .collect::<Result<_, _>>()?,
166 })
167 }
168}
169
170#[derive(Debug)]
172pub struct Row(Vec<Compiled>);
173
174impl CompileContext {
175 pub fn compile_row(&self, exprs: Vec<S<Expr>>) -> Result<Row, S<Error>> {
177 Ok(Row(exprs
178 .into_iter()
179 .map(|e| self.compile(e))
180 .collect::<Result<_, _>>()?))
181 }
182}
183
184impl Row {
185 pub fn eval(&self, state: &mut State) -> Result<Vec<Value>, S<Error>> {
187 let mut result = Vec::with_capacity(self.0.len());
188 for compiled in &self.0 {
189 result.push(compiled.eval(state)?);
190 }
191 Ok(result)
192 }
193}
194
195#[derive(Clone, Debug)]
197pub enum C {
198 RowNum,
200 SubRowNum,
202 Constant(Value),
204 RawFunction {
206 function: &'static dyn Function,
208 args: Box<[Compiled]>,
210 },
211 GetVariable(usize),
213 SetVariable(usize, Box<Compiled>),
215 CaseValueWhen {
217 value: Option<Box<Compiled>>,
219 conditions: Box<[(Compiled, Compiled)]>,
221 otherwise: Box<Compiled>,
223 },
224
225 RandRegex(rand_regex::Regex),
227 RandUniformU64(Uniform<u64>),
229 RandUniformI64(Uniform<i64>),
231 RandUniformF64(Uniform<f64>),
233 RandZipf(ZipfDistribution),
235 RandLogNormal(LogNormal<f64>),
237 RandBool(Bernoulli),
239 RandFiniteF32(Uniform<u32>),
241 RandFiniteF64(Uniform<u64>),
243 RandU31Timestamp(Uniform<i64>),
245 RandShuffle(Arc<[Value]>),
247 RandUuid,
249}
250
251impl C {
252 fn span(self, span: Span) -> Compiled {
253 Compiled(S { span, inner: self })
254 }
255}
256
257#[derive(Clone, Debug)]
259pub struct Compiled(pub(crate) S<C>);
260
261impl CompileContext {
262 pub fn compile(&self, expr: S<Expr>) -> Result<Compiled, S<Error>> {
264 Ok(match expr.inner {
265 Expr::RowNum => C::RowNum,
266 Expr::SubRowNum => C::SubRowNum,
267 Expr::CurrentTimestamp => C::Constant(Value::Timestamp(self.current_timestamp, self.time_zone.clone())),
268 Expr::Value(v) => C::Constant(v),
269 Expr::GetVariable(index) => C::GetVariable(index),
270 Expr::SetVariable(index, e) => C::SetVariable(index, Box::new(self.compile(*e)?)),
271 Expr::Function { function, args } => {
272 let args = args
273 .into_iter()
274 .map(|e| self.compile(e))
275 .collect::<Result<Vec<_>, _>>()?;
276 if args.iter().all(Compiled::is_constant) {
277 let args = args
278 .into_iter()
279 .map(|c| match c.0.inner {
280 C::Constant(v) => v.span(c.0.span),
281 _ => unreachable!(),
282 })
283 .collect();
284 function.compile(self, expr.span, args)?
285 } else {
286 C::RawFunction {
287 function,
288 args: args.into_boxed_slice(),
289 }
290 }
291 }
292 Expr::CaseValueWhen {
293 value,
294 conditions,
295 otherwise,
296 } => {
297 let value = value.map(|v| Ok::<_, _>(Box::new(self.compile(*v)?))).transpose()?;
298 let conditions = conditions
299 .into_iter()
300 .map(|(p, r)| Ok((self.compile(p)?, self.compile(r)?)))
301 .collect::<Result<Vec<_>, _>>()?
302 .into_boxed_slice();
303 let otherwise = Box::new(if let Some(o) = otherwise {
304 self.compile(*o)?
305 } else {
306 C::Constant(Value::Null).span(expr.span)
307 });
308 C::CaseValueWhen {
309 value,
310 conditions,
311 otherwise,
312 }
313 }
314 }
315 .span(expr.span))
316 }
317}
318
319impl Compiled {
320 pub fn is_constant(&self) -> bool {
322 matches!(self.0.inner, C::Constant(_))
323 }
324
325 pub fn eval(&self, state: &mut State) -> Result<Value, S<Error>> {
327 let span = self.0.span;
328 Ok(match &self.0.inner {
329 C::RowNum => state.row_num.into(),
330 C::SubRowNum => state.sub_row_num.into(),
331 C::Constant(v) => v.clone(),
332 C::RawFunction { function, args } => {
333 let mut eval_args = Arguments::with_capacity(args.len());
334 for c in &**args {
335 eval_args.push(c.eval(state)?.span(c.0.span));
336 }
337 (*function)
338 .compile(&state.compile_context, span, eval_args)?
339 .span(span)
340 .eval(state)?
341 }
342 C::GetVariable(index) => state.compile_context.variables[*index].clone(),
343 C::SetVariable(index, c) => {
344 let value = c.eval(state)?;
345 state.compile_context.variables[*index] = value.clone();
346 value
347 }
348
349 C::CaseValueWhen {
350 value: Some(value),
351 conditions,
352 otherwise,
353 } => {
354 let value = value.eval(state)?;
355 for (p, r) in &**conditions {
356 let p_span = p.0.span;
357 let p = p.eval(state)?;
358 if value.sql_cmp(&p).span_err(p_span)? == Some(Ordering::Equal) {
359 return r.eval(state);
360 }
361 }
362 otherwise.eval(state)?
363 }
364
365 C::CaseValueWhen {
366 value: None,
367 conditions,
368 otherwise,
369 } => {
370 for (p, r) in &**conditions {
371 if p.eval(state)?.is_sql_true().span_err(p.0.span)? {
372 return r.eval(state);
373 }
374 }
375 otherwise.eval(state)?
376 }
377
378 C::RandRegex(generator) => state.rng.sample::<EncodedString, _>(generator).into(),
379 C::RandUniformU64(uniform) => state.rng.sample(uniform).into(),
380 C::RandUniformI64(uniform) => state.rng.sample(uniform).into(),
381 C::RandUniformF64(uniform) => Value::from_finite_f64(state.rng.sample(uniform)),
382 C::RandZipf(zipf) => (state.rng.sample(zipf) as u64).into(),
383 C::RandLogNormal(log_normal) => Value::from_finite_f64(state.rng.sample(log_normal)),
384 C::RandBool(bern) => u64::from(state.rng.sample(bern)).into(),
385 C::RandFiniteF32(uniform) => {
386 Value::from_finite_f64(f32::from_bits(state.rng.sample(uniform).rotate_right(1)).into())
387 }
388 C::RandFiniteF64(uniform) => {
389 Value::from_finite_f64(f64::from_bits(state.rng.sample(uniform).rotate_right(1)))
390 }
391
392 C::RandU31Timestamp(uniform) => {
393 let seconds = state.rng.sample(uniform);
394 let timestamp = NaiveDateTime::from_timestamp(seconds, 0);
395 Value::new_timestamp(timestamp, state.compile_context.time_zone.clone())
396 }
397
398 C::RandShuffle(array) => {
399 let mut shuffled_array = Arc::<[Value]>::from(&**array);
400 Arc::get_mut(&mut shuffled_array).unwrap().shuffle(&mut state.rng);
401 Value::Array(shuffled_array)
402 }
403
404 C::RandUuid => {
405 let g = state.rng.gen::<[u16; 8]>();
407 format!(
408 "{:04x}{:04x}-{:04x}-4{:03x}-{:04x}-{:04x}{:04x}{:04x}",
409 g[0],
410 g[1],
411 g[2],
412 g[3] & 0xfff,
413 (g[4] & 0x3fff) | 0x8000,
414 g[5],
415 g[6],
416 g[7],
417 )
418 .into()
419 }
420 })
421 }
422}