use crate::ast::SourceLocation;
use crate::compiler::_value::PartitionAxisOrigin;
use crate::compiler::shared_types::EntryAttrs;
use crate::error::{JITError, SpannedJITError};
use syn::{BinOp, Expr, ExprBinary, ExprCall};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) enum MetadataExpr {
Dim { tensor: String, axis: usize },
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct MetadataFact {
pub(crate) lhs: MetadataExpr,
pub(crate) rhs: MetadataExpr,
}
#[derive(Debug, Clone, Default)]
pub(crate) struct ProofResults {
pub(crate) metadata_facts: Vec<MetadataFact>,
}
impl ProofResults {
pub(crate) fn analyze_entry_attrs(entry_attrs: &EntryAttrs) -> Result<Self, JITError> {
let Some(expr) = entry_attrs.get_entry_arg_expr("preconditions") else {
return Ok(Self::default());
};
let mut metadata_facts = Vec::new();
for expr in precondition_entries(expr)? {
metadata_facts.push(parse_metadata_fact(expr)?);
}
Ok(Self { metadata_facts })
}
pub(crate) fn has_dim_equality(
&self,
lhs: &str,
lhs_axis: usize,
rhs: &str,
rhs_axis: usize,
) -> bool {
let lhs = MetadataExpr::Dim {
tensor: lhs.to_string(),
axis: lhs_axis,
};
let rhs = MetadataExpr::Dim {
tensor: rhs.to_string(),
axis: rhs_axis,
};
self.metadata_facts.iter().any(|fact| {
(fact.lhs == lhs && fact.rhs == rhs) || (fact.lhs == rhs && fact.rhs == lhs)
})
}
pub(crate) fn proves_partition_axis_access(
&self,
index_origin: &PartitionAxisOrigin,
target_tensor: &str,
target_axis: usize,
target_tile_dim: i32,
) -> bool {
index_origin.tile_dim == target_tile_dim
&& self.has_dim_equality(
&index_origin.tensor,
index_origin.axis,
target_tensor,
target_axis,
)
}
}
fn precondition_entries(expr: &Expr) -> Result<Vec<&Expr>, JITError> {
match expr {
Expr::Tuple(tuple) => Ok(tuple.elems.iter().collect()),
Expr::Paren(paren) => Ok(vec![paren.expr.as_ref()]),
Expr::Binary(_) => Ok(vec![expr]),
Expr::Call(call) => {
if call_name(call).as_deref() == Some("same_partition_axis") {
return SourceLocation::unknown().jit_error_result(
"`same_partition_axis` preconditions have been replaced by `dim(lhs, axis) == dim(rhs, axis)`",
);
}
SourceLocation::unknown()
.jit_error_result("`preconditions` entries must be metadata equalities")
}
_ => SourceLocation::unknown()
.jit_error_result("`preconditions` must be an equality or tuple of equalities"),
}
}
fn parse_metadata_fact(expr: &Expr) -> Result<MetadataFact, JITError> {
let Expr::Binary(binary) = expr else {
return SourceLocation::unknown()
.jit_error_result("each `preconditions` entry must be a metadata equality");
};
parse_metadata_equality(binary)
}
fn parse_metadata_equality(binary: &ExprBinary) -> Result<MetadataFact, JITError> {
if !matches!(binary.op, BinOp::Eq(_)) {
return SourceLocation::unknown().jit_error_result("precondition predicates must use `==`");
}
Ok(MetadataFact {
lhs: parse_metadata_expr(&binary.left)?,
rhs: parse_metadata_expr(&binary.right)?,
})
}
fn parse_metadata_expr(expr: &Expr) -> Result<MetadataExpr, JITError> {
let Expr::Call(call) = expr else {
return SourceLocation::unknown().jit_error_result(
"precondition metadata expressions must be calls like `dim(tensor, axis)`",
);
};
let Some(name) = call_name(call) else {
return SourceLocation::unknown()
.jit_error_result("precondition metadata expressions must use a function path");
};
match name.as_str() {
"dim" => parse_dim_expr(call),
other => SourceLocation::unknown().jit_error_result(&format!(
"unsupported precondition metadata expression `{other}`; expected `dim`"
)),
}
}
fn parse_dim_expr(call: &ExprCall) -> Result<MetadataExpr, JITError> {
if call.args.len() != 2 {
return SourceLocation::unknown().jit_error_result(&format!(
"`dim` expects 2 arguments, got {}",
call.args.len()
));
}
Ok(MetadataExpr::Dim {
tensor: parse_tensor_arg(&call.args[0])?,
axis: parse_axis_arg(&call.args[1])?,
})
}
fn call_name(call: &ExprCall) -> Option<String> {
let Expr::Path(func_path) = call.func.as_ref() else {
return None;
};
func_path
.path
.segments
.last()
.map(|segment| segment.ident.to_string())
}
fn parse_tensor_arg(expr: &Expr) -> Result<String, JITError> {
let Expr::Path(path) = expr else {
return SourceLocation::unknown()
.jit_error_result("precondition tensor arguments must be parameter names");
};
if path.qself.is_some() || path.path.segments.len() != 1 {
return SourceLocation::unknown()
.jit_error_result("precondition tensor arguments must be simple parameter names");
}
Ok(path.path.segments[0].ident.to_string())
}
fn parse_axis_arg(expr: &Expr) -> Result<usize, JITError> {
let axis = crate::types::parse_signed_literal_as_i32(expr);
if axis < 0 {
return SourceLocation::unknown().jit_error_result(&format!(
"precondition axis must be non-negative, got {axis}"
));
}
Ok(axis as usize)
}
impl std::fmt::Display for MetadataExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MetadataExpr::Dim { tensor, axis } => write!(f, "dim({tensor}, {axis})"),
}
}
}