use rustc_hash::FxHashSet;
use super::ops::Op;
use crate::common::CompactArc;
use crate::core::Value;
#[derive(Debug, Clone)]
pub enum Constant {
Value(Value),
String(CompactArc<str>),
}
#[derive(Clone)]
pub struct Program {
ops: Vec<Op>,
max_stack_depth: usize,
needs_outer_context: bool,
needs_second_row: bool,
has_subqueries: bool,
#[cfg(debug_assertions)]
source: Option<String>,
}
impl Program {
pub fn new(ops: Vec<Op>) -> Self {
let ops = Self::peephole_optimize(ops);
let max_stack_depth = Self::compute_stack_depth(&ops);
let needs_outer_context = ops.iter().any(|op| matches!(op, Op::LoadOuterColumn(_)));
let needs_second_row = ops.iter().any(|op| matches!(op, Op::LoadColumn2(_)));
let has_subqueries = ops.iter().any(|op| {
matches!(
op,
Op::ExecScalarSubquery(_)
| Op::ExecExists(_)
| Op::ExecInSubquery(_)
| Op::ExecAll(_, _)
| Op::ExecAny(_, _)
)
});
Self {
ops,
max_stack_depth,
needs_outer_context,
needs_second_row,
has_subqueries,
#[cfg(debug_assertions)]
source: None,
}
}
pub fn new_unoptimized(ops: Vec<Op>) -> Self {
let max_stack_depth = Self::compute_stack_depth(&ops);
let needs_outer_context = ops.iter().any(|op| matches!(op, Op::LoadOuterColumn(_)));
let needs_second_row = ops.iter().any(|op| matches!(op, Op::LoadColumn2(_)));
let has_subqueries = ops.iter().any(|op| {
matches!(
op,
Op::ExecScalarSubquery(_)
| Op::ExecExists(_)
| Op::ExecInSubquery(_)
| Op::ExecAll(_, _)
| Op::ExecAny(_, _)
)
});
Self {
ops,
max_stack_depth,
needs_outer_context,
needs_second_row,
has_subqueries,
#[cfg(debug_assertions)]
source: None,
}
}
pub fn null() -> Self {
Self::new(vec![Op::LoadNull(crate::core::DataType::Null), Op::Return])
}
pub fn constant(value: Value) -> Self {
Self::new(vec![Op::LoadConst(value), Op::Return])
}
pub fn always_true() -> Self {
Self::new(vec![Op::ReturnTrue])
}
pub fn always_false() -> Self {
Self::new(vec![Op::ReturnFalse])
}
#[inline]
pub fn ops(&self) -> &[Op] {
&self.ops
}
#[inline]
pub fn max_stack_depth(&self) -> usize {
self.max_stack_depth
}
#[inline]
pub fn needs_outer_context(&self) -> bool {
self.needs_outer_context
}
#[inline]
pub fn needs_second_row(&self) -> bool {
self.needs_second_row
}
#[inline]
pub fn has_subqueries(&self) -> bool {
self.has_subqueries
}
#[inline]
pub fn len(&self) -> usize {
self.ops.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.ops.is_empty()
}
#[cfg(debug_assertions)]
pub fn with_source(mut self, source: String) -> Self {
self.source = Some(source);
self
}
#[cfg(debug_assertions)]
pub fn source(&self) -> Option<&str> {
self.source.as_deref()
}
fn compute_stack_depth(ops: &[Op]) -> usize {
let mut depth: i32 = 0;
let mut max_depth: i32 = 0;
for op in ops {
let effect = match op {
Op::LoadColumn(_)
| Op::LoadColumn2(_)
| Op::LoadOuterColumn(_)
| Op::LoadConst(_)
| Op::LoadParam(_)
| Op::LoadNamedParam(_)
| Op::LoadNull(_)
| Op::LoadAggregateResult(_)
| Op::LoadTransactionId
| Op::Dup
| Op::EqColumnConst(_, _)
| Op::NeColumnConst(_, _)
| Op::LtColumnConst(_, _)
| Op::LeColumnConst(_, _)
| Op::GtColumnConst(_, _)
| Op::GeColumnConst(_, _)
| Op::IsNullColumn(_)
| Op::IsNotNullColumn(_)
| Op::LikeColumn(_, _, _)
| Op::InSetColumn(_, _, _)
| Op::BetweenColumnConst(_, _, _) => 1,
Op::Eq
| Op::Ne
| Op::Lt
| Op::Le
| Op::Gt
| Op::Ge
| Op::IsDistinctFrom
| Op::IsNotDistinctFrom
| Op::AndFinalize
| Op::OrFinalize
| Op::Add
| Op::Sub
| Op::Mul
| Op::Div
| Op::Mod
| Op::BitAnd
| Op::BitOr
| Op::BitXor
| Op::Shl
| Op::Shr
| Op::Concat
| Op::Xor
| Op::NullIf
| Op::CaseCompare => -1,
Op::Between | Op::NotBetween => -2,
Op::LikeDynamic(_)
| Op::LikeDynamicEscape(_, _)
| Op::GlobDynamic
| Op::RegexpDynamic
| Op::JsonAccess
| Op::JsonAccessText
| Op::TimestampAddInterval
| Op::TimestampSubInterval
| Op::TimestampDiff
| Op::TimestampAddDays
| Op::TimestampSubDays
| Op::VectorDistanceL2
| Op::VectorDistanceCosine
| Op::VectorDistanceIP => -1,
Op::IsNull
| Op::IsNotNull
| Op::IsTrue
| Op::IsNotTrue
| Op::IsFalse
| Op::IsNotFalse
| Op::Not
| Op::Neg
| Op::BitNot
| Op::Like(_, _)
| Op::LikeEscape(_, _, _)
| Op::Glob(_)
| Op::Regexp(_)
| Op::InSet(_, _)
| Op::NotInSet(_, _)
| Op::Cast(_)
| Op::TruncateToDate
| Op::ExecScalarSubquery(_)
| Op::ExecExists(_)
| Op::ExecInSubquery(_)
| Op::ExecAll(_, _)
| Op::ExecAny(_, _)
| Op::NativeFn1(_) => 0,
Op::InTupleSet { tuple_size, .. } => 1 - (*tuple_size as i32),
Op::Pop => -1,
Op::And(_) | Op::Or(_) => 0,
Op::CallScalar { arg_count, .. } => 1 - (*arg_count as i32),
Op::Coalesce(n) | Op::Greatest(n) | Op::Least(n) | Op::ConcatN(n) => {
1 - (*n as i32)
}
Op::Jump(_)
| Op::JumpIfTrue(_)
| Op::JumpIfFalse(_)
| Op::JumpIfNull(_)
| Op::JumpIfNotNull(_)
| Op::PopJumpIfTrue(_)
| Op::PopJumpIfFalse(_)
| Op::Swap
| Op::Nop
| Op::Return
| Op::ReturnTrue
| Op::ReturnFalse
| Op::ReturnNull(_)
| Op::CaseStart
| Op::CaseWhen(_)
| Op::CaseThen(_)
| Op::CaseElse
| Op::CaseEnd => 0,
};
depth += effect;
max_depth = max_depth.max(depth);
}
(max_depth as usize).max(1)
}
pub fn disassemble(&self) -> String {
let mut result = String::new();
for (i, op) in self.ops.iter().enumerate() {
result.push_str(&format!("{:04}: {:?}\n", i, op));
}
result
}
pub fn optimize(mut self) -> Self {
self.ops = Self::peephole_optimize(self.ops);
self.max_stack_depth = Self::compute_stack_depth(&self.ops);
self
}
fn peephole_optimize(mut ops: Vec<Op>) -> Vec<Op> {
if ops.len() < 2 {
return ops;
}
let mut jump_targets = FxHashSet::default();
for op in &ops {
match op {
Op::And(t)
| Op::Or(t)
| Op::Jump(t)
| Op::JumpIfTrue(t)
| Op::JumpIfFalse(t)
| Op::JumpIfNull(t)
| Op::JumpIfNotNull(t)
| Op::PopJumpIfTrue(t)
| Op::PopJumpIfFalse(t)
| Op::CaseWhen(t)
| Op::CaseThen(t) => {
jump_targets.insert(*t as usize);
}
_ => {}
}
}
let mut result = Vec::with_capacity(ops.len());
let mut position_map: Vec<usize> = Vec::with_capacity(ops.len());
let mut i = 0;
while i < ops.len() {
let new_pos = result.len();
if i + 3 < ops.len() {
let is_safe = !jump_targets.contains(&i)
&& !jump_targets.contains(&(i + 1))
&& !jump_targets.contains(&(i + 2))
&& !jump_targets.contains(&(i + 3));
if is_safe {
if let (
Op::LoadColumn(col_idx),
Op::LoadConst(low_val),
Op::LoadConst(high_val),
Op::Between,
) = (&ops[i], &ops[i + 1], &ops[i + 2], &ops[i + 3])
{
result.push(Op::BetweenColumnConst(
*col_idx,
low_val.clone(),
high_val.clone(),
));
position_map.push(new_pos);
position_map.push(new_pos);
position_map.push(new_pos);
position_map.push(new_pos);
i += 4;
continue;
}
}
}
if i + 2 < ops.len() {
let is_safe = !jump_targets.contains(&i)
&& !jump_targets.contains(&(i + 1))
&& !jump_targets.contains(&(i + 2));
if is_safe {
if let (Op::LoadColumn(col_idx), Op::LoadConst(const_val)) =
(&ops[i], &ops[i + 1])
{
let fused = match &ops[i + 2] {
Op::Eq => Some(Op::EqColumnConst(*col_idx, const_val.clone())),
Op::Ne => Some(Op::NeColumnConst(*col_idx, const_val.clone())),
Op::Lt => Some(Op::LtColumnConst(*col_idx, const_val.clone())),
Op::Le => Some(Op::LeColumnConst(*col_idx, const_val.clone())),
Op::Gt => Some(Op::GtColumnConst(*col_idx, const_val.clone())),
Op::Ge => Some(Op::GeColumnConst(*col_idx, const_val.clone())),
_ => None,
};
if let Some(fused_op) = fused {
result.push(fused_op);
position_map.push(new_pos);
position_map.push(new_pos);
position_map.push(new_pos);
i += 3;
continue;
}
}
}
}
if i + 1 < ops.len() {
let is_safe = !jump_targets.contains(&i) && !jump_targets.contains(&(i + 1));
if is_safe {
if let Op::LoadColumn(col_idx) = &ops[i] {
let fused = match &ops[i + 1] {
Op::IsNull => Some(Op::IsNullColumn(*col_idx)),
Op::IsNotNull => Some(Op::IsNotNullColumn(*col_idx)),
Op::Like(pattern, case_insensitive) => {
Some(Op::LikeColumn(*col_idx, pattern.clone(), *case_insensitive))
}
Op::InSet(set, has_null) => {
Some(Op::InSetColumn(*col_idx, set.clone(), *has_null))
}
_ => None,
};
if let Some(fused_op) = fused {
result.push(fused_op);
position_map.push(new_pos);
position_map.push(new_pos);
i += 2;
continue;
}
}
}
}
result.push(std::mem::replace(&mut ops[i], Op::Nop));
position_map.push(new_pos);
i += 1;
}
if result.len() != ops.len() {
Self::adjust_jump_targets(&mut result, &position_map);
}
result
}
fn adjust_jump_targets(ops: &mut [Op], position_map: &[usize]) {
for op in ops.iter_mut() {
match op {
Op::And(t)
| Op::Or(t)
| Op::Jump(t)
| Op::JumpIfTrue(t)
| Op::JumpIfFalse(t)
| Op::JumpIfNull(t)
| Op::JumpIfNotNull(t)
| Op::PopJumpIfTrue(t)
| Op::PopJumpIfFalse(t)
| Op::CaseWhen(t)
| Op::CaseThen(t) => {
let old_target = *t as usize;
if old_target < position_map.len() {
*t = position_map[old_target] as u16;
}
}
_ => {}
}
}
}
}
impl std::fmt::Debug for Program {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Program")
.field("ops_count", &self.ops.len())
.field("max_stack_depth", &self.max_stack_depth)
.field("needs_outer_context", &self.needs_outer_context)
.field("needs_second_row", &self.needs_second_row)
.field("has_subqueries", &self.has_subqueries)
.finish()
}
}
pub struct ProgramBuilder {
ops: Vec<Op>,
}
impl ProgramBuilder {
pub fn new() -> Self {
Self {
ops: Vec::with_capacity(32),
}
}
#[inline]
pub fn emit(&mut self, op: Op) {
self.ops.push(op);
}
#[inline]
pub fn position(&self) -> u16 {
self.ops.len() as u16
}
pub fn patch_jump(&mut self, pos: usize, target: u16) {
if pos < self.ops.len() {
match &mut self.ops[pos] {
Op::And(t)
| Op::Or(t)
| Op::Jump(t)
| Op::JumpIfTrue(t)
| Op::JumpIfFalse(t)
| Op::JumpIfNull(t)
| Op::JumpIfNotNull(t)
| Op::PopJumpIfTrue(t)
| Op::PopJumpIfFalse(t)
| Op::CaseWhen(t)
| Op::CaseThen(t) => *t = target,
_ => {}
}
}
}
pub fn build(self) -> Program {
Program::new(self.ops)
}
pub fn build_unoptimized(self) -> Program {
Program::new_unoptimized(self.ops)
}
}
impl Default for ProgramBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_program_null() {
let prog = Program::null();
assert!(!prog.is_empty());
assert!(!prog.needs_outer_context());
assert!(!prog.needs_second_row());
assert!(!prog.has_subqueries());
}
#[test]
fn test_program_constant() {
let prog = Program::constant(Value::Integer(42));
assert!(!prog.is_empty());
assert!(!prog.needs_outer_context());
assert!(!prog.needs_second_row());
}
#[test]
fn test_program_always_true() {
let prog = Program::always_true();
assert_eq!(prog.len(), 1);
assert!(matches!(prog.ops()[0], Op::ReturnTrue));
}
#[test]
fn test_program_always_false() {
let prog = Program::always_false();
assert_eq!(prog.len(), 1);
assert!(matches!(prog.ops()[0], Op::ReturnFalse));
}
#[test]
fn test_stack_depth_simple() {
let prog = Program::new_unoptimized(vec![Op::LoadConst(Value::Integer(1)), Op::Return]);
assert_eq!(prog.max_stack_depth(), 1);
}
#[test]
fn test_stack_depth_binary_op() {
let prog = Program::new_unoptimized(vec![
Op::LoadConst(Value::Integer(1)),
Op::LoadConst(Value::Integer(2)),
Op::Add,
Op::Return,
]);
assert_eq!(prog.max_stack_depth(), 2);
}
#[test]
fn test_stack_depth_nested() {
let prog = Program::new_unoptimized(vec![
Op::LoadConst(Value::Integer(1)),
Op::LoadConst(Value::Integer(2)),
Op::Add,
Op::LoadConst(Value::Integer(3)),
Op::LoadConst(Value::Integer(4)),
Op::Add,
Op::Mul,
Op::Return,
]);
assert!(prog.max_stack_depth() >= 2);
}
#[test]
fn test_stack_depth_fused_ops() {
let prog =
Program::new_unoptimized(vec![Op::EqColumnConst(0, Value::Integer(5)), Op::Return]);
assert_eq!(prog.max_stack_depth(), 1);
}
#[test]
fn test_stack_depth_between() {
let prog = Program::new_unoptimized(vec![
Op::LoadColumn(0),
Op::LoadConst(Value::Integer(1)),
Op::LoadConst(Value::Integer(10)),
Op::Between,
Op::Return,
]);
assert_eq!(prog.max_stack_depth(), 3);
}
#[test]
fn test_needs_outer_context() {
let prog = Program::new_unoptimized(vec![Op::LoadOuterColumn("col".into()), Op::Return]);
assert!(prog.needs_outer_context());
let prog2 = Program::new_unoptimized(vec![Op::LoadColumn(0), Op::Return]);
assert!(!prog2.needs_outer_context());
}
#[test]
fn test_needs_second_row() {
let prog = Program::new_unoptimized(vec![
Op::LoadColumn(0),
Op::LoadColumn2(1),
Op::Eq,
Op::Return,
]);
assert!(prog.needs_second_row());
let prog2 = Program::new_unoptimized(vec![Op::LoadColumn(0), Op::Return]);
assert!(!prog2.needs_second_row());
}
#[test]
fn test_has_subqueries() {
let prog1 = Program::new_unoptimized(vec![Op::ExecExists(0), Op::Return]);
assert!(prog1.has_subqueries());
let prog2 = Program::new_unoptimized(vec![Op::ExecScalarSubquery(0), Op::Return]);
assert!(prog2.has_subqueries());
let prog3 = Program::new_unoptimized(vec![Op::LoadColumn(0), Op::Return]);
assert!(!prog3.has_subqueries());
}
#[test]
fn test_peephole_eq_column_const() {
let ops = vec![
Op::LoadColumn(0),
Op::LoadConst(Value::Integer(5)),
Op::Eq,
Op::Return,
];
let prog = Program::new(ops);
assert_eq!(prog.len(), 2);
assert!(matches!(prog.ops()[0], Op::EqColumnConst(0, _)));
}
#[test]
fn test_peephole_lt_column_const() {
let ops = vec![
Op::LoadColumn(1),
Op::LoadConst(Value::Integer(10)),
Op::Lt,
Op::Return,
];
let prog = Program::new(ops);
assert_eq!(prog.len(), 2);
assert!(matches!(prog.ops()[0], Op::LtColumnConst(1, _)));
}
#[test]
fn test_peephole_is_null_column() {
let ops = vec![Op::LoadColumn(2), Op::IsNull, Op::Return];
let prog = Program::new(ops);
assert_eq!(prog.len(), 2);
assert!(matches!(prog.ops()[0], Op::IsNullColumn(2)));
}
#[test]
fn test_peephole_between_column_const() {
let ops = vec![
Op::LoadColumn(0),
Op::LoadConst(Value::Integer(1)),
Op::LoadConst(Value::Integer(100)),
Op::Between,
Op::Return,
];
let prog = Program::new(ops);
assert_eq!(prog.len(), 2);
assert!(matches!(prog.ops()[0], Op::BetweenColumnConst(0, _, _)));
}
#[test]
fn test_peephole_no_fusion_when_not_applicable() {
let ops = vec![
Op::LoadConst(Value::Integer(5)),
Op::LoadColumn(0),
Op::Eq,
Op::Return,
];
let prog = Program::new(ops);
assert!(prog.len() >= 3);
}
#[test]
fn test_peephole_preserves_jumps() {
let ops = vec![
Op::LoadColumn(0),
Op::JumpIfFalse(3), Op::LoadConst(Value::Integer(5)),
Op::Eq, Op::Return,
];
let prog = Program::new(ops);
assert!(prog.len() >= 3);
}
#[test]
fn test_builder_basic() {
let mut builder = ProgramBuilder::new();
builder.emit(Op::LoadConst(Value::Integer(42)));
builder.emit(Op::Return);
let prog = builder.build();
assert!(!prog.is_empty());
}
#[test]
fn test_builder_position() {
let mut builder = ProgramBuilder::new();
assert_eq!(builder.position(), 0);
builder.emit(Op::LoadColumn(0));
assert_eq!(builder.position(), 1);
builder.emit(Op::LoadColumn(1));
assert_eq!(builder.position(), 2);
}
#[test]
fn test_builder_patch_jump() {
let mut builder = ProgramBuilder::new();
builder.emit(Op::LoadColumn(0));
builder.emit(Op::JumpIfFalse(0)); let jump_pos = 1;
builder.emit(Op::LoadConst(Value::Integer(1)));
builder.emit(Op::Return);
let end_pos = builder.position();
builder.patch_jump(jump_pos, end_pos);
let prog = builder.build();
let has_jump = prog.ops().iter().any(|op| matches!(op, Op::JumpIfFalse(_)));
assert!(has_jump || prog.len() < 4); }
#[test]
fn test_builder_default() {
let builder: ProgramBuilder = Default::default();
let prog = builder.build();
assert_eq!(prog.max_stack_depth(), 1);
}
#[test]
fn test_disassemble() {
let prog = Program::new_unoptimized(vec![
Op::LoadColumn(0),
Op::LoadConst(Value::Integer(5)),
Op::Eq,
Op::Return,
]);
let disasm = prog.disassemble();
assert!(disasm.contains("LoadColumn"));
assert!(disasm.contains("LoadConst"));
assert!(disasm.contains("Eq"));
assert!(disasm.contains("Return"));
assert!(disasm.contains("0000:"));
assert!(disasm.contains("0001:"));
}
#[test]
fn test_program_debug() {
let prog = Program::constant(Value::Integer(42));
let debug_str = format!("{:?}", prog);
assert!(debug_str.contains("Program"));
assert!(debug_str.contains("ops_count"));
assert!(debug_str.contains("max_stack_depth"));
}
#[test]
fn test_constant_value() {
let c = Constant::Value(Value::Integer(42));
assert!(matches!(c, Constant::Value(Value::Integer(42))));
}
#[test]
fn test_constant_string() {
let c = Constant::String("test".into());
assert!(matches!(c, Constant::String(_)));
}
#[test]
fn test_constant_clone() {
let c1 = Constant::Value(Value::Text("hello".into()));
let c2 = c1.clone();
assert!(matches!(c2, Constant::Value(Value::Text(_))));
}
}