use std::sync::Arc;
use super::ops::Op;
use crate::core::Value;
#[derive(Debug, Clone)]
pub enum Constant {
Value(Value),
String(Arc<str>),
}
#[derive(Clone)]
pub struct Program {
ops: Vec<Op>,
max_stack_depth: usize,
needs_outer_context: bool,
needs_second_row: bool,
has_subqueries: bool,
#[cfg(debug_assertions)]
source: Option<String>,
}
impl Program {
pub fn new(ops: Vec<Op>) -> Self {
let ops = Self::peephole_optimize(ops);
let max_stack_depth = Self::compute_stack_depth(&ops);
let needs_outer_context = ops.iter().any(|op| matches!(op, Op::LoadOuterColumn(_)));
let needs_second_row = ops.iter().any(|op| matches!(op, Op::LoadColumn2(_)));
let has_subqueries = ops.iter().any(|op| {
matches!(
op,
Op::ExecScalarSubquery(_)
| Op::ExecExists(_)
| Op::ExecInSubquery(_)
| Op::ExecAll(_, _)
| Op::ExecAny(_, _)
)
});
Self {
ops,
max_stack_depth,
needs_outer_context,
needs_second_row,
has_subqueries,
#[cfg(debug_assertions)]
source: None,
}
}
pub fn new_unoptimized(ops: Vec<Op>) -> Self {
let max_stack_depth = Self::compute_stack_depth(&ops);
let needs_outer_context = ops.iter().any(|op| matches!(op, Op::LoadOuterColumn(_)));
let needs_second_row = ops.iter().any(|op| matches!(op, Op::LoadColumn2(_)));
let has_subqueries = ops.iter().any(|op| {
matches!(
op,
Op::ExecScalarSubquery(_)
| Op::ExecExists(_)
| Op::ExecInSubquery(_)
| Op::ExecAll(_, _)
| Op::ExecAny(_, _)
)
});
Self {
ops,
max_stack_depth,
needs_outer_context,
needs_second_row,
has_subqueries,
#[cfg(debug_assertions)]
source: None,
}
}
pub fn null() -> Self {
Self::new(vec![Op::LoadNull(crate::core::DataType::Null), Op::Return])
}
pub fn constant(value: Value) -> Self {
Self::new(vec![Op::LoadConst(value), Op::Return])
}
pub fn always_true() -> Self {
Self::new(vec![Op::ReturnTrue])
}
pub fn always_false() -> Self {
Self::new(vec![Op::ReturnFalse])
}
#[inline]
pub fn ops(&self) -> &[Op] {
&self.ops
}
#[inline]
pub fn max_stack_depth(&self) -> usize {
self.max_stack_depth
}
#[inline]
pub fn needs_outer_context(&self) -> bool {
self.needs_outer_context
}
#[inline]
pub fn needs_second_row(&self) -> bool {
self.needs_second_row
}
#[inline]
pub fn has_subqueries(&self) -> bool {
self.has_subqueries
}
#[inline]
pub fn len(&self) -> usize {
self.ops.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.ops.is_empty()
}
#[cfg(debug_assertions)]
pub fn with_source(mut self, source: String) -> Self {
self.source = Some(source);
self
}
#[cfg(debug_assertions)]
pub fn source(&self) -> Option<&str> {
self.source.as_deref()
}
fn compute_stack_depth(ops: &[Op]) -> usize {
let mut depth: i32 = 0;
let mut max_depth: i32 = 0;
for op in ops {
let effect = match op {
Op::LoadColumn(_)
| Op::LoadColumn2(_)
| Op::LoadOuterColumn(_)
| Op::LoadConst(_)
| Op::LoadParam(_)
| Op::LoadNamedParam(_)
| Op::LoadNull(_)
| Op::LoadAggregateResult(_)
| Op::LoadTransactionId
| Op::Dup
| Op::EqColumnConst(_, _)
| Op::NeColumnConst(_, _)
| Op::LtColumnConst(_, _)
| Op::LeColumnConst(_, _)
| Op::GtColumnConst(_, _)
| Op::GeColumnConst(_, _)
| Op::IsNullColumn(_)
| Op::IsNotNullColumn(_)
| Op::LikeColumn(_, _, _)
| Op::InSetColumn(_, _, _)
| Op::BetweenColumnConst(_, _, _) => 1,
Op::Eq
| Op::Ne
| Op::Lt
| Op::Le
| Op::Gt
| Op::Ge
| Op::IsDistinctFrom
| Op::IsNotDistinctFrom
| Op::AndFinalize
| Op::OrFinalize
| Op::Add
| Op::Sub
| Op::Mul
| Op::Div
| Op::Mod
| Op::BitAnd
| Op::BitOr
| Op::BitXor
| Op::Shl
| Op::Shr
| Op::Concat
| Op::Xor
| Op::NullIf
| Op::CaseCompare => -1,
Op::Between | Op::NotBetween => -2,
Op::JsonAccess
| Op::JsonAccessText
| Op::TimestampAddInterval
| Op::TimestampSubInterval
| Op::TimestampDiff
| Op::TimestampAddDays
| Op::TimestampSubDays => -1,
Op::IsNull
| Op::IsNotNull
| Op::IsTrue
| Op::IsNotTrue
| Op::IsFalse
| Op::IsNotFalse
| Op::Not
| Op::Neg
| Op::BitNot
| Op::Like(_, _)
| Op::LikeEscape(_, _, _)
| Op::Glob(_)
| Op::Regexp(_)
| Op::InSet(_, _)
| Op::NotInSet(_, _)
| Op::Cast(_)
| Op::TruncateToDate
| Op::ExecScalarSubquery(_)
| Op::ExecExists(_)
| Op::ExecInSubquery(_)
| Op::ExecAll(_, _)
| Op::ExecAny(_, _) => 0,
Op::InTupleSet { tuple_size, .. } => 1 - (*tuple_size as i32),
Op::Pop => -1,
Op::And(_) | Op::Or(_) => 0,
Op::CallScalar { arg_count, .. } => 1 - (*arg_count as i32),
Op::Coalesce(n) | Op::Greatest(n) | Op::Least(n) => 1 - (*n as i32),
Op::Jump(_)
| Op::JumpIfTrue(_)
| Op::JumpIfFalse(_)
| Op::JumpIfNull(_)
| Op::PopJumpIfTrue(_)
| Op::PopJumpIfFalse(_)
| Op::Swap
| Op::Nop
| Op::Return
| Op::ReturnTrue
| Op::ReturnFalse
| Op::ReturnNull(_)
| Op::CaseStart
| Op::CaseWhen(_)
| Op::CaseThen(_)
| Op::CaseElse
| Op::CaseEnd => 0,
Op::NextVal | Op::CurrVal => 0,
Op::SetVal => -2,
};
depth += effect;
max_depth = max_depth.max(depth);
}
(max_depth as usize).max(1)
}
pub fn disassemble(&self) -> String {
let mut result = String::new();
for (i, op) in self.ops.iter().enumerate() {
result.push_str(&format!("{:04}: {:?}\n", i, op));
}
result
}
pub fn optimize(mut self) -> Self {
self.ops = Self::peephole_optimize(self.ops);
self.max_stack_depth = Self::compute_stack_depth(&self.ops);
self
}
fn peephole_optimize(mut ops: Vec<Op>) -> Vec<Op> {
if ops.len() < 2 {
return ops;
}
let mut jump_targets = std::collections::HashSet::new();
for op in &ops {
match op {
Op::And(t)
| Op::Or(t)
| Op::Jump(t)
| Op::JumpIfTrue(t)
| Op::JumpIfFalse(t)
| Op::JumpIfNull(t)
| Op::PopJumpIfTrue(t)
| Op::PopJumpIfFalse(t)
| Op::CaseWhen(t)
| Op::CaseThen(t) => {
jump_targets.insert(*t as usize);
}
_ => {}
}
}
let mut result = Vec::with_capacity(ops.len());
let mut position_map: Vec<usize> = Vec::with_capacity(ops.len());
let mut i = 0;
while i < ops.len() {
let new_pos = result.len();
if i + 3 < ops.len() {
let is_safe = !jump_targets.contains(&i)
&& !jump_targets.contains(&(i + 1))
&& !jump_targets.contains(&(i + 2))
&& !jump_targets.contains(&(i + 3));
if is_safe {
if let (
Op::LoadColumn(col_idx),
Op::LoadConst(low_val),
Op::LoadConst(high_val),
Op::Between,
) = (&ops[i], &ops[i + 1], &ops[i + 2], &ops[i + 3])
{
result.push(Op::BetweenColumnConst(
*col_idx,
low_val.clone(),
high_val.clone(),
));
position_map.push(new_pos);
position_map.push(new_pos);
position_map.push(new_pos);
position_map.push(new_pos);
i += 4;
continue;
}
}
}
if i + 2 < ops.len() {
let is_safe = !jump_targets.contains(&i)
&& !jump_targets.contains(&(i + 1))
&& !jump_targets.contains(&(i + 2));
if is_safe {
if let (Op::LoadColumn(col_idx), Op::LoadConst(const_val)) =
(&ops[i], &ops[i + 1])
{
let fused = match &ops[i + 2] {
Op::Eq => Some(Op::EqColumnConst(*col_idx, const_val.clone())),
Op::Ne => Some(Op::NeColumnConst(*col_idx, const_val.clone())),
Op::Lt => Some(Op::LtColumnConst(*col_idx, const_val.clone())),
Op::Le => Some(Op::LeColumnConst(*col_idx, const_val.clone())),
Op::Gt => Some(Op::GtColumnConst(*col_idx, const_val.clone())),
Op::Ge => Some(Op::GeColumnConst(*col_idx, const_val.clone())),
_ => None,
};
if let Some(fused_op) = fused {
result.push(fused_op);
position_map.push(new_pos);
position_map.push(new_pos);
position_map.push(new_pos);
i += 3;
continue;
}
}
}
}
if i + 1 < ops.len() {
let is_safe = !jump_targets.contains(&i) && !jump_targets.contains(&(i + 1));
if is_safe {
if let Op::LoadColumn(col_idx) = &ops[i] {
let fused = match &ops[i + 1] {
Op::IsNull => Some(Op::IsNullColumn(*col_idx)),
Op::IsNotNull => Some(Op::IsNotNullColumn(*col_idx)),
Op::Like(pattern, case_insensitive) => {
Some(Op::LikeColumn(*col_idx, pattern.clone(), *case_insensitive))
}
Op::InSet(set, has_null) => {
Some(Op::InSetColumn(*col_idx, set.clone(), *has_null))
}
_ => None,
};
if let Some(fused_op) = fused {
result.push(fused_op);
position_map.push(new_pos);
position_map.push(new_pos);
i += 2;
continue;
}
}
}
}
result.push(std::mem::replace(&mut ops[i], Op::Nop));
position_map.push(new_pos);
i += 1;
}
if result.len() != ops.len() {
Self::adjust_jump_targets(&mut result, &position_map);
}
result
}
fn adjust_jump_targets(ops: &mut [Op], position_map: &[usize]) {
for op in ops.iter_mut() {
match op {
Op::And(t)
| Op::Or(t)
| Op::Jump(t)
| Op::JumpIfTrue(t)
| Op::JumpIfFalse(t)
| Op::JumpIfNull(t)
| Op::PopJumpIfTrue(t)
| Op::PopJumpIfFalse(t)
| Op::CaseWhen(t)
| Op::CaseThen(t) => {
let old_target = *t as usize;
if old_target < position_map.len() {
*t = position_map[old_target] as u16;
}
}
_ => {}
}
}
}
}
impl std::fmt::Debug for Program {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Program")
.field("ops_count", &self.ops.len())
.field("max_stack_depth", &self.max_stack_depth)
.field("needs_outer_context", &self.needs_outer_context)
.field("needs_second_row", &self.needs_second_row)
.field("has_subqueries", &self.has_subqueries)
.finish()
}
}
pub struct ProgramBuilder {
ops: Vec<Op>,
}
impl ProgramBuilder {
pub fn new() -> Self {
Self {
ops: Vec::with_capacity(32),
}
}
#[inline]
pub fn emit(&mut self, op: Op) {
self.ops.push(op);
}
#[inline]
pub fn position(&self) -> u16 {
self.ops.len() as u16
}
pub fn patch_jump(&mut self, pos: usize, target: u16) {
if pos < self.ops.len() {
match &mut self.ops[pos] {
Op::And(t)
| Op::Or(t)
| Op::Jump(t)
| Op::JumpIfTrue(t)
| Op::JumpIfFalse(t)
| Op::JumpIfNull(t)
| Op::PopJumpIfTrue(t)
| Op::PopJumpIfFalse(t)
| Op::CaseWhen(t)
| Op::CaseThen(t) => *t = target,
_ => {}
}
}
}
pub fn build(self) -> Program {
Program::new(self.ops)
}
}
impl Default for ProgramBuilder {
fn default() -> Self {
Self::new()
}
}