use crate::{
error::Error,
functions::{Arguments, Function},
parser::{Expr, QName},
span::{ResultExt, Span, SpanExt, S},
value::Value,
};
use chrono::{NaiveDateTime, Utc};
use rand::{distributions::Bernoulli, seq::SliceRandom, Rng, RngCore};
use rand_distr::{LogNormal, Uniform};
use rand_regex::EncodedString;
use std::{cmp::Ordering, fmt, fs, ops::Range, path::PathBuf, sync::Arc};
use tzfile::{ArcTz, Tz};
use zipf::ZipfDistribution;
#[derive(Clone, Debug)]
pub struct CompileContext {
pub zoneinfo: PathBuf,
pub time_zone: ArcTz,
pub current_timestamp: NaiveDateTime,
pub variables: Box<[Value]>,
}
impl CompileContext {
pub fn new(variables_count: usize) -> Self {
Self {
zoneinfo: PathBuf::from("/usr/share/zoneinfo"),
time_zone: ArcTz::new(Utc.into()),
current_timestamp: NaiveDateTime::from_timestamp(0, 0),
variables: vec![Value::Null; variables_count].into_boxed_slice(),
}
}
pub fn parse_time_zone(&self, tz: &str) -> Result<ArcTz, Error> {
Ok(ArcTz::new(if tz == "UTC" {
Utc.into()
} else {
let path = self.zoneinfo.join(tz);
let content = fs::read(&path).map_err(|source| Error::Io {
action: "read time zone file",
path,
source,
})?;
Tz::parse(tz, &content).map_err(|source| Error::InvalidTimeZone {
time_zone: tz.to_owned(),
source,
})?
}))
}
}
pub struct State {
pub(crate) row_num: u64,
pub sub_row_num: u64,
rng: Box<dyn RngCore>,
compile_context: CompileContext,
}
impl fmt::Debug for State {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("State")
.field("row_num", &self.row_num)
.field("sub_row_num", &self.sub_row_num)
.field("rng", &())
.field("variables", &self.compile_context.variables)
.finish()
}
}
impl State {
pub fn new(row_num: u64, rng: Box<dyn RngCore>, compile_context: CompileContext) -> Self {
Self {
row_num,
sub_row_num: 1,
rng,
compile_context,
}
}
pub fn into_compile_context(self) -> CompileContext {
self.compile_context
}
pub fn increase_row_num(&mut self) {
self.row_num += 1;
}
}
#[derive(Debug)]
pub struct Table {
pub name: QName,
pub content: String,
pub column_name_ranges: Vec<Range<usize>>,
pub row: Row,
pub derived: Vec<(usize, Compiled)>,
}
#[derive(Debug, Copy, Clone)]
pub struct Schema<'a> {
pub name: &'a str,
pub content: &'a str,
column_name_ranges: &'a [Range<usize>],
}
impl<'a> Schema<'a> {
pub fn column_names(&self) -> impl Iterator<Item = &str> + '_ {
self.column_name_ranges.iter().map(move |r| &self.content[r.clone()])
}
}
impl Table {
pub fn schema(&self, qualified: bool) -> Schema<'_> {
Schema {
name: self.name.table_name(qualified),
content: &self.content,
column_name_ranges: &self.column_name_ranges,
}
}
}
impl CompileContext {
pub fn compile_table(&self, table: crate::parser::Table) -> Result<Table, S<Error>> {
Ok(Table {
name: table.name,
content: table.content,
column_name_ranges: table.column_name_ranges,
row: self.compile_row(table.exprs)?,
derived: table
.derived
.into_iter()
.map(|(i, e)| self.compile(e).map(|c| (i, c)))
.collect::<Result<_, _>>()?,
})
}
}
#[derive(Debug)]
pub struct Row(Vec<Compiled>);
impl CompileContext {
pub fn compile_row(&self, exprs: Vec<S<Expr>>) -> Result<Row, S<Error>> {
Ok(Row(exprs
.into_iter()
.map(|e| self.compile(e))
.collect::<Result<_, _>>()?))
}
}
impl Row {
pub fn eval(&self, state: &mut State) -> Result<Vec<Value>, S<Error>> {
let mut result = Vec::with_capacity(self.0.len());
for compiled in &self.0 {
result.push(compiled.eval(state)?);
}
Ok(result)
}
}
#[derive(Clone, Debug)]
pub enum C {
RowNum,
SubRowNum,
Constant(Value),
RawFunction {
function: &'static dyn Function,
args: Box<[Compiled]>,
},
GetVariable(usize),
SetVariable(usize, Box<Compiled>),
CaseValueWhen {
value: Option<Box<Compiled>>,
conditions: Box<[(Compiled, Compiled)]>,
otherwise: Box<Compiled>,
},
RandRegex(rand_regex::Regex),
RandUniformU64(Uniform<u64>),
RandUniformI64(Uniform<i64>),
RandUniformF64(Uniform<f64>),
RandZipf(ZipfDistribution),
RandLogNormal(LogNormal<f64>),
RandBool(Bernoulli),
RandFiniteF32(Uniform<u32>),
RandFiniteF64(Uniform<u64>),
RandU31Timestamp(Uniform<i64>),
RandShuffle(Arc<[Value]>),
RandUuid,
}
impl C {
fn span(self, span: Span) -> Compiled {
Compiled(S { span, inner: self })
}
}
#[derive(Clone, Debug)]
pub struct Compiled(pub(crate) S<C>);
impl CompileContext {
pub fn compile(&self, expr: S<Expr>) -> Result<Compiled, S<Error>> {
Ok(match expr.inner {
Expr::RowNum => C::RowNum,
Expr::SubRowNum => C::SubRowNum,
Expr::CurrentTimestamp => C::Constant(Value::Timestamp(self.current_timestamp, self.time_zone.clone())),
Expr::Value(v) => C::Constant(v),
Expr::GetVariable(index) => C::GetVariable(index),
Expr::SetVariable(index, e) => C::SetVariable(index, Box::new(self.compile(*e)?)),
Expr::Function { function, args } => {
let args = args
.into_iter()
.map(|e| self.compile(e))
.collect::<Result<Vec<_>, _>>()?;
if args.iter().all(Compiled::is_constant) {
let args = args
.into_iter()
.map(|c| match c.0.inner {
C::Constant(v) => v.span(c.0.span),
_ => unreachable!(),
})
.collect();
function.compile(self, expr.span, args)?
} else {
C::RawFunction {
function,
args: args.into_boxed_slice(),
}
}
}
Expr::CaseValueWhen {
value,
conditions,
otherwise,
} => {
let value = value.map(|v| Ok::<_, _>(Box::new(self.compile(*v)?))).transpose()?;
let conditions = conditions
.into_iter()
.map(|(p, r)| Ok((self.compile(p)?, self.compile(r)?)))
.collect::<Result<Vec<_>, _>>()?
.into_boxed_slice();
let otherwise = Box::new(if let Some(o) = otherwise {
self.compile(*o)?
} else {
C::Constant(Value::Null).span(expr.span)
});
C::CaseValueWhen {
value,
conditions,
otherwise,
}
}
}
.span(expr.span))
}
}
impl Compiled {
pub fn is_constant(&self) -> bool {
matches!(self.0.inner, C::Constant(_))
}
pub fn eval(&self, state: &mut State) -> Result<Value, S<Error>> {
let span = self.0.span;
Ok(match &self.0.inner {
C::RowNum => state.row_num.into(),
C::SubRowNum => state.sub_row_num.into(),
C::Constant(v) => v.clone(),
C::RawFunction { function, args } => {
let mut eval_args = Arguments::with_capacity(args.len());
for c in &**args {
eval_args.push(c.eval(state)?.span(c.0.span));
}
(*function)
.compile(&state.compile_context, span, eval_args)?
.span(span)
.eval(state)?
}
C::GetVariable(index) => state.compile_context.variables[*index].clone(),
C::SetVariable(index, c) => {
let value = c.eval(state)?;
state.compile_context.variables[*index] = value.clone();
value
}
C::CaseValueWhen {
value: Some(value),
conditions,
otherwise,
} => {
let value = value.eval(state)?;
for (p, r) in &**conditions {
let p_span = p.0.span;
let p = p.eval(state)?;
if value.sql_cmp(&p).span_err(p_span)? == Some(Ordering::Equal) {
return r.eval(state);
}
}
otherwise.eval(state)?
}
C::CaseValueWhen {
value: None,
conditions,
otherwise,
} => {
for (p, r) in &**conditions {
if p.eval(state)?.is_sql_true().span_err(p.0.span)? {
return r.eval(state);
}
}
otherwise.eval(state)?
}
C::RandRegex(generator) => state.rng.sample::<EncodedString, _>(generator).into(),
C::RandUniformU64(uniform) => state.rng.sample(uniform).into(),
C::RandUniformI64(uniform) => state.rng.sample(uniform).into(),
C::RandUniformF64(uniform) => Value::from_finite_f64(state.rng.sample(uniform)),
C::RandZipf(zipf) => (state.rng.sample(zipf) as u64).into(),
C::RandLogNormal(log_normal) => Value::from_finite_f64(state.rng.sample(log_normal)),
C::RandBool(bern) => u64::from(state.rng.sample(bern)).into(),
C::RandFiniteF32(uniform) => {
Value::from_finite_f64(f32::from_bits(state.rng.sample(uniform).rotate_right(1)).into())
}
C::RandFiniteF64(uniform) => {
Value::from_finite_f64(f64::from_bits(state.rng.sample(uniform).rotate_right(1)))
}
C::RandU31Timestamp(uniform) => {
let seconds = state.rng.sample(uniform);
let timestamp = NaiveDateTime::from_timestamp(seconds, 0);
Value::new_timestamp(timestamp, state.compile_context.time_zone.clone())
}
C::RandShuffle(array) => {
let mut shuffled_array = Arc::<[Value]>::from(&**array);
Arc::get_mut(&mut shuffled_array).unwrap().shuffle(&mut state.rng);
Value::Array(shuffled_array)
}
C::RandUuid => {
let g = state.rng.gen::<[u16; 8]>();
format!(
"{:04x}{:04x}-{:04x}-4{:03x}-{:04x}-{:04x}{:04x}{:04x}",
g[0],
g[1],
g[2],
g[3] & 0xfff,
(g[4] & 0x3fff) | 0x8000,
g[5],
g[6],
g[7],
)
.into()
}
})
}
}