use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::fmt;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum LateralValue {
Iri(String),
Literal {
value: String,
datatype: Option<String>,
lang: Option<String>,
},
BlankNode(String),
}
impl fmt::Display for LateralValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Iri(iri) => write!(f, "<{iri}>"),
Self::Literal {
value,
datatype,
lang,
} => {
write!(f, "\"{value}\"")?;
if let Some(dt) = datatype {
write!(f, "^^<{dt}>")?;
}
if let Some(l) = lang {
write!(f, "@{l}")?;
}
Ok(())
}
Self::BlankNode(id) => write!(f, "_:{id}"),
}
}
}
pub type SolutionMapping = HashMap<String, LateralValue>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LateralSubquery {
pub description: String,
pub correlated_vars: Vec<String>,
pub projected_vars: Vec<String>,
pub has_aggregates: bool,
pub limit: Option<usize>,
pub order_by: Vec<OrderSpec>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OrderSpec {
pub variable: String,
pub ascending: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LateralJoin {
pub left_description: String,
pub subquery: LateralSubquery,
pub strategy: LateralStrategy,
pub pushed_filters: Vec<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum LateralStrategy {
NestedLoop,
BatchedValues,
Decorrelate,
CachedCorrelation,
}
impl fmt::Display for LateralStrategy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NestedLoop => write!(f, "NestedLoop"),
Self::BatchedValues => write!(f, "BatchedValues"),
Self::Decorrelate => write!(f, "Decorrelate"),
Self::CachedCorrelation => write!(f, "CachedCorrelation"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LateralJoinConfig {
pub batch_size: usize,
pub cache_capacity: usize,
pub subquery_timeout: Duration,
pub auto_decorrelate: bool,
pub max_nesting_depth: usize,
}
impl Default for LateralJoinConfig {
fn default() -> Self {
Self {
batch_size: 128,
cache_capacity: 4096,
subquery_timeout: Duration::from_secs(30),
auto_decorrelate: true,
max_nesting_depth: 4,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct LateralJoinStats {
pub left_rows: u64,
pub result_rows: u64,
pub subquery_evaluations: u64,
pub cache_hits: u64,
pub cache_misses: u64,
pub batches_submitted: u64,
pub subquery_time_ms: u64,
pub decorrelated: bool,
pub rows_filtered: u64,
}
impl LateralJoinStats {
pub fn cache_hit_ratio(&self) -> f64 {
let total = self.cache_hits + self.cache_misses;
if total == 0 {
return 0.0;
}
(self.cache_hits as f64 / total as f64) * 100.0
}
pub fn avg_subquery_time_ms(&self) -> f64 {
if self.subquery_evaluations == 0 {
return 0.0;
}
self.subquery_time_ms as f64 / self.subquery_evaluations as f64
}
}
pub struct LateralJoinExecutor {
config: LateralJoinConfig,
stats: LateralJoinStats,
cache: HashMap<String, Vec<SolutionMapping>>,
}
impl LateralJoinExecutor {
pub fn new(config: LateralJoinConfig) -> Self {
Self {
config,
stats: LateralJoinStats::default(),
cache: HashMap::new(),
}
}
pub fn with_defaults() -> Self {
Self::new(LateralJoinConfig::default())
}
pub fn stats(&self) -> &LateralJoinStats {
&self.stats
}
pub fn reset(&mut self) {
self.stats = LateralJoinStats::default();
self.cache.clear();
}
pub fn execute<F>(
&mut self,
lateral: &LateralJoin,
left_rows: &[SolutionMapping],
subquery_evaluator: F,
) -> Result<Vec<SolutionMapping>, LateralJoinError>
where
F: Fn(&SolutionMapping) -> Result<Vec<SolutionMapping>, LateralJoinError>,
{
self.stats.left_rows = left_rows.len() as u64;
match lateral.strategy {
LateralStrategy::NestedLoop => {
self.execute_nested_loop(lateral, left_rows, subquery_evaluator)
}
LateralStrategy::BatchedValues => {
self.execute_batched(lateral, left_rows, subquery_evaluator)
}
LateralStrategy::CachedCorrelation => {
self.execute_cached(lateral, left_rows, subquery_evaluator)
}
LateralStrategy::Decorrelate => {
self.execute_cached(lateral, left_rows, subquery_evaluator)
}
}
}
fn execute_nested_loop<F>(
&mut self,
lateral: &LateralJoin,
left_rows: &[SolutionMapping],
evaluator: F,
) -> Result<Vec<SolutionMapping>, LateralJoinError>
where
F: Fn(&SolutionMapping) -> Result<Vec<SolutionMapping>, LateralJoinError>,
{
let mut results = Vec::new();
for left_row in left_rows {
let correlated =
Self::extract_correlated_bindings(left_row, &lateral.subquery.correlated_vars);
if !self.passes_pushed_filters(left_row, &lateral.pushed_filters) {
self.stats.rows_filtered += 1;
continue;
}
let start = Instant::now();
let sub_results = evaluator(&correlated)?;
self.stats.subquery_time_ms += start.elapsed().as_millis() as u64;
self.stats.subquery_evaluations += 1;
for sub_row in &sub_results {
let merged = Self::merge_mappings(left_row, sub_row)?;
results.push(merged);
}
}
self.stats.result_rows = results.len() as u64;
Ok(results)
}
fn execute_batched<F>(
&mut self,
lateral: &LateralJoin,
left_rows: &[SolutionMapping],
evaluator: F,
) -> Result<Vec<SolutionMapping>, LateralJoinError>
where
F: Fn(&SolutionMapping) -> Result<Vec<SolutionMapping>, LateralJoinError>,
{
let mut results = Vec::new();
let batch_size = self.config.batch_size.max(1);
for chunk in left_rows.chunks(batch_size) {
self.stats.batches_submitted += 1;
let batch_bindings =
Self::build_batch_bindings(chunk, &lateral.subquery.correlated_vars);
let start = Instant::now();
let batch_results = evaluator(&batch_bindings)?;
self.stats.subquery_time_ms += start.elapsed().as_millis() as u64;
self.stats.subquery_evaluations += 1;
for left_row in chunk {
if !self.passes_pushed_filters(left_row, &lateral.pushed_filters) {
self.stats.rows_filtered += 1;
continue;
}
for sub_row in &batch_results {
if Self::is_compatible(left_row, sub_row, &lateral.subquery.correlated_vars) {
let merged = Self::merge_mappings(left_row, sub_row)?;
results.push(merged);
}
}
}
}
self.stats.result_rows = results.len() as u64;
Ok(results)
}
fn execute_cached<F>(
&mut self,
lateral: &LateralJoin,
left_rows: &[SolutionMapping],
evaluator: F,
) -> Result<Vec<SolutionMapping>, LateralJoinError>
where
F: Fn(&SolutionMapping) -> Result<Vec<SolutionMapping>, LateralJoinError>,
{
let mut results = Vec::new();
for left_row in left_rows {
if !self.passes_pushed_filters(left_row, &lateral.pushed_filters) {
self.stats.rows_filtered += 1;
continue;
}
let correlated =
Self::extract_correlated_bindings(left_row, &lateral.subquery.correlated_vars);
let cache_key = Self::cache_key(&correlated, &lateral.subquery.correlated_vars);
let sub_results = if let Some(cached) = self.cache.get(&cache_key) {
self.stats.cache_hits += 1;
cached.clone()
} else {
self.stats.cache_misses += 1;
let start = Instant::now();
let fresh = evaluator(&correlated)?;
self.stats.subquery_time_ms += start.elapsed().as_millis() as u64;
self.stats.subquery_evaluations += 1;
if self.cache.len() >= self.config.cache_capacity {
if let Some(first_key) = self.cache.keys().next().cloned() {
self.cache.remove(&first_key);
}
}
self.cache.insert(cache_key, fresh.clone());
fresh
};
for sub_row in &sub_results {
let merged = Self::merge_mappings(left_row, sub_row)?;
results.push(merged);
}
}
self.stats.result_rows = results.len() as u64;
Ok(results)
}
fn extract_correlated_bindings(
row: &SolutionMapping,
correlated_vars: &[String],
) -> SolutionMapping {
let mut bindings = SolutionMapping::new();
for var in correlated_vars {
if let Some(val) = row.get(var) {
bindings.insert(var.clone(), val.clone());
}
}
bindings
}
fn build_batch_bindings(
rows: &[SolutionMapping],
correlated_vars: &[String],
) -> SolutionMapping {
let mut combined = SolutionMapping::new();
for var in correlated_vars {
let mut seen = HashSet::new();
for row in rows {
if let Some(val) = row.get(var) {
let key = format!("{val}");
if seen.insert(key) {
combined.entry(var.clone()).or_insert_with(|| val.clone());
}
}
}
}
combined
}
fn is_compatible(
left: &SolutionMapping,
right: &SolutionMapping,
correlated_vars: &[String],
) -> bool {
for var in correlated_vars {
match (left.get(var), right.get(var)) {
(Some(l), Some(r)) => {
if l != r {
return false;
}
}
(None, Some(_)) | (Some(_), None) => {
}
(None, None) => {}
}
}
true
}
fn merge_mappings(
left: &SolutionMapping,
right: &SolutionMapping,
) -> Result<SolutionMapping, LateralJoinError> {
let mut merged = left.clone();
for (var, val) in right {
merged.insert(var.clone(), val.clone());
}
Ok(merged)
}
fn cache_key(correlated: &SolutionMapping, vars: &[String]) -> String {
let mut parts = Vec::with_capacity(vars.len());
for var in vars {
match correlated.get(var) {
Some(val) => parts.push(format!("{var}={val}")),
None => parts.push(format!("{var}=UNDEF")),
}
}
parts.join("|")
}
fn passes_pushed_filters(&self, row: &SolutionMapping, filters: &[String]) -> bool {
for filter in filters {
if let Some((var, expected)) = Self::parse_equality_filter(filter) {
if let Some(actual) = row.get(&var) {
let actual_str = format!("{actual}");
if actual_str != expected {
return false;
}
}
}
}
true
}
fn parse_equality_filter(filter: &str) -> Option<(String, String)> {
let parts: Vec<&str> = filter.splitn(3, ' ').collect();
if parts.len() == 3 && parts[1] == "=" {
let var = parts[0].trim_start_matches('?').to_string();
let val = parts[2].to_string();
Some((var, val))
} else {
None
}
}
}
#[derive(Default)]
pub struct LateralOptimizer {
config: LateralOptimizerConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LateralOptimizerConfig {
pub cache_threshold: usize,
pub batch_threshold: usize,
pub decorrelate_min_improvement: f64,
}
impl Default for LateralOptimizerConfig {
fn default() -> Self {
Self {
cache_threshold: 1000,
batch_threshold: 500,
decorrelate_min_improvement: 0.3,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LateralCostEstimate {
pub strategy: LateralStrategy,
pub estimated_cost: f64,
pub estimated_evaluations: u64,
pub cacheable: bool,
pub decorrelatable: bool,
}
impl LateralOptimizer {
pub fn new() -> Self {
Self::default()
}
pub fn with_config(config: LateralOptimizerConfig) -> Self {
Self { config }
}
pub fn choose_strategy(
&self,
left_cardinality: u64,
distinct_keys: u64,
subquery: &LateralSubquery,
) -> LateralCostEstimate {
let mut candidates = Vec::new();
let nl_cost = left_cardinality as f64 * self.estimate_subquery_cost(subquery);
candidates.push(LateralCostEstimate {
strategy: LateralStrategy::NestedLoop,
estimated_cost: nl_cost,
estimated_evaluations: left_cardinality,
cacheable: false,
decorrelatable: false,
});
let cache_cost = distinct_keys as f64 * self.estimate_subquery_cost(subquery)
+ (left_cardinality.saturating_sub(distinct_keys)) as f64 * 0.01;
candidates.push(LateralCostEstimate {
strategy: LateralStrategy::CachedCorrelation,
estimated_cost: cache_cost,
estimated_evaluations: distinct_keys,
cacheable: distinct_keys < self.config.cache_threshold as u64,
decorrelatable: false,
});
let batch_size = self.config.batch_threshold.max(1) as f64;
let batch_evals = (left_cardinality as f64 / batch_size).ceil();
let per_row_correlation_cost = left_cardinality as f64 * 0.5;
let batch_cost =
batch_evals * self.estimate_subquery_cost(subquery) + per_row_correlation_cost;
candidates.push(LateralCostEstimate {
strategy: LateralStrategy::BatchedValues,
estimated_cost: batch_cost,
estimated_evaluations: batch_evals as u64,
cacheable: false,
decorrelatable: false,
});
if self.can_decorrelate(subquery) {
let decorrelate_cost = left_cardinality as f64 * 0.5; candidates.push(LateralCostEstimate {
strategy: LateralStrategy::Decorrelate,
estimated_cost: decorrelate_cost,
estimated_evaluations: 1,
cacheable: false,
decorrelatable: true,
});
}
candidates.sort_by(|a, b| {
a.estimated_cost
.partial_cmp(&b.estimated_cost)
.unwrap_or(std::cmp::Ordering::Equal)
});
candidates
.into_iter()
.next()
.expect("at least one candidate strategy")
}
fn estimate_subquery_cost(&self, subquery: &LateralSubquery) -> f64 {
let mut cost = 1.0;
if subquery.has_aggregates {
cost *= 2.0;
}
if let Some(limit) = subquery.limit {
cost *= (limit as f64).min(100.0) / 100.0;
}
if !subquery.order_by.is_empty() {
cost *= 1.5;
}
cost
}
fn can_decorrelate(&self, subquery: &LateralSubquery) -> bool {
subquery.correlated_vars.len() == 1 && subquery.has_aggregates
}
pub fn analyze(
&self,
left_cardinality: u64,
distinct_keys: u64,
subquery: &LateralSubquery,
) -> Vec<LateralCostEstimate> {
let mut estimates = vec![
LateralCostEstimate {
strategy: LateralStrategy::NestedLoop,
estimated_cost: left_cardinality as f64 * self.estimate_subquery_cost(subquery),
estimated_evaluations: left_cardinality,
cacheable: false,
decorrelatable: false,
},
LateralCostEstimate {
strategy: LateralStrategy::CachedCorrelation,
estimated_cost: distinct_keys as f64 * self.estimate_subquery_cost(subquery)
+ (left_cardinality.saturating_sub(distinct_keys)) as f64 * 0.01,
estimated_evaluations: distinct_keys,
cacheable: distinct_keys < self.config.cache_threshold as u64,
decorrelatable: false,
},
{
let batch_size = self.config.batch_threshold.max(1) as f64;
let batch_evals = (left_cardinality as f64 / batch_size).ceil();
let per_row_correlation_cost = left_cardinality as f64 * 0.5;
LateralCostEstimate {
strategy: LateralStrategy::BatchedValues,
estimated_cost: batch_evals * self.estimate_subquery_cost(subquery)
+ per_row_correlation_cost,
estimated_evaluations: batch_evals as u64,
cacheable: false,
decorrelatable: false,
}
},
];
if self.can_decorrelate(subquery) {
estimates.push(LateralCostEstimate {
strategy: LateralStrategy::Decorrelate,
estimated_cost: left_cardinality as f64 * 0.5,
estimated_evaluations: 1,
cacheable: false,
decorrelatable: true,
});
}
estimates.sort_by(|a, b| {
a.estimated_cost
.partial_cmp(&b.estimated_cost)
.unwrap_or(std::cmp::Ordering::Equal)
});
estimates
}
}
pub struct LateralValidator;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LateralValidationResult {
pub is_valid: bool,
pub errors: Vec<LateralValidationError>,
pub warnings: Vec<String>,
pub detected_correlated_vars: Vec<String>,
pub output_vars: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LateralValidationError {
pub message: String,
pub code: LateralErrorCode,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum LateralErrorCode {
NoCorrelation,
UnboundCorrelatedVar,
ExcessiveNesting,
VariableConflict,
DisallowedConstruct,
}
impl LateralValidator {
pub fn validate(
subquery: &LateralSubquery,
left_vars: &[String],
nesting_depth: usize,
max_depth: usize,
) -> LateralValidationResult {
let mut result = LateralValidationResult {
is_valid: true,
errors: Vec::new(),
warnings: Vec::new(),
detected_correlated_vars: Vec::new(),
output_vars: Vec::new(),
};
if nesting_depth > max_depth {
result.is_valid = false;
result.errors.push(LateralValidationError {
message: format!(
"LATERAL nesting depth {nesting_depth} exceeds maximum {max_depth}"
),
code: LateralErrorCode::ExcessiveNesting,
});
}
let left_set: HashSet<&str> = left_vars.iter().map(|s| s.as_str()).collect();
for var in &subquery.correlated_vars {
if left_set.contains(var.as_str()) {
result.detected_correlated_vars.push(var.clone());
} else {
result.is_valid = false;
result.errors.push(LateralValidationError {
message: format!("Correlated variable ?{var} is not bound by the left operand"),
code: LateralErrorCode::UnboundCorrelatedVar,
});
}
}
if subquery.correlated_vars.is_empty() {
result.warnings.push(
"LATERAL subquery has no correlated variables; consider using a regular join"
.to_string(),
);
}
for proj_var in &subquery.projected_vars {
if left_set.contains(proj_var.as_str()) && !subquery.correlated_vars.contains(proj_var)
{
result.errors.push(LateralValidationError {
message: format!(
"Projected variable ?{proj_var} conflicts with left operand binding"
),
code: LateralErrorCode::VariableConflict,
});
result.warnings.push(format!(
"Variable ?{proj_var} will be overridden by LATERAL subquery"
));
}
}
let mut output = HashSet::new();
for var in left_vars {
output.insert(var.clone());
}
for var in &subquery.projected_vars {
output.insert(var.clone());
}
result.output_vars = output.into_iter().collect();
result.output_vars.sort();
result
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LateralJoinError {
SubqueryError(String),
Timeout {
description: String,
elapsed_ms: u64,
},
IncompatibleBindings {
variable: String,
left_value: String,
right_value: String,
},
NestingDepthExceeded {
depth: usize,
max: usize,
},
}
impl fmt::Display for LateralJoinError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::SubqueryError(msg) => write!(f, "Lateral subquery error: {msg}"),
Self::Timeout {
description,
elapsed_ms,
} => {
write!(
f,
"Lateral subquery timed out after {elapsed_ms}ms: {description}"
)
}
Self::IncompatibleBindings {
variable,
left_value,
right_value,
} => {
write!(
f,
"Incompatible bindings for ?{variable}: left={left_value}, right={right_value}"
)
}
Self::NestingDepthExceeded { depth, max } => {
write!(
f,
"LATERAL nesting depth {depth} exceeds maximum allowed {max}"
)
}
}
}
}
impl std::error::Error for LateralJoinError {}
pub struct LateralParser;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParsedLateral {
pub outer_vars: Vec<String>,
pub correlated_vars: Vec<String>,
pub projected_vars: Vec<String>,
pub has_aggregates: bool,
pub has_order_by: bool,
pub has_limit: bool,
pub subquery_text: String,
}
impl LateralParser {
pub fn detect_lateral_clauses(query: &str) -> Vec<LateralClausePosition> {
let mut positions = Vec::new();
let upper = query.to_uppercase();
let mut search_from = 0;
while let Some(idx) = upper[search_from..].find("LATERAL") {
let abs_idx = search_from + idx;
let before_ok = abs_idx == 0 || !query.as_bytes()[abs_idx - 1].is_ascii_alphanumeric();
let after_idx = abs_idx + 7;
let after_ok =
after_idx >= query.len() || !query.as_bytes()[after_idx].is_ascii_alphanumeric();
if before_ok && after_ok {
if let Some(brace_start) = query[after_idx..].find('{') {
let open = after_idx + brace_start;
if let Some(close) = Self::find_matching_brace(query, open) {
let body = &query[open + 1..close];
positions.push(LateralClausePosition {
start: abs_idx,
end: close + 1,
body: body.trim().to_string(),
has_select: body.to_uppercase().contains("SELECT"),
});
}
}
}
search_from = abs_idx + 7;
}
positions
}
fn find_matching_brace(s: &str, pos: usize) -> Option<usize> {
let bytes = s.as_bytes();
if pos >= bytes.len() || bytes[pos] != b'{' {
return None;
}
let mut depth = 0i32;
for (i, &b) in bytes[pos..].iter().enumerate() {
match b {
b'{' => depth += 1,
b'}' => {
depth -= 1;
if depth == 0 {
return Some(pos + i);
}
}
_ => {}
}
}
None
}
pub fn extract_variables(fragment: &str) -> Vec<String> {
let mut vars = HashSet::new();
let bytes = fragment.as_bytes();
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'?' || bytes[i] == b'$' {
let start = i + 1;
i += 1;
while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
i += 1;
}
if i > start {
let var = String::from_utf8_lossy(&bytes[start..i]).to_string();
vars.insert(var);
}
} else {
i += 1;
}
}
let mut result: Vec<_> = vars.into_iter().collect();
result.sort();
result
}
pub fn detect_aggregates(fragment: &str) -> bool {
let upper = fragment.to_uppercase();
[
"COUNT(",
"SUM(",
"AVG(",
"MIN(",
"MAX(",
"GROUP_CONCAT(",
"SAMPLE(",
]
.iter()
.any(|agg| upper.contains(agg))
}
pub fn detect_order_by(fragment: &str) -> bool {
fragment.to_uppercase().contains("ORDER BY")
}
pub fn detect_limit(fragment: &str) -> bool {
fragment.to_uppercase().contains("LIMIT")
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LateralClausePosition {
pub start: usize,
pub end: usize,
pub body: String,
pub has_select: bool,
}
#[cfg(test)]
mod lateral_join_tests;