use std::panic::{self, AssertUnwindSafe};
use tensorlogic_ir::{EinsumGraph, TLExpr};
use crate::compile_to_einsum_with_context;
use crate::context::CompilerContext;
use super::collector::DiagnosticCollector;
use super::diagnostic::{Diagnostic, Severity};
use super::strategy::{RecoveryAction, RecoveryStrategy};
#[derive(Debug, Clone)]
pub struct PartialCompilationResult {
pub graphs: Vec<Option<EinsumGraph>>,
pub diagnostics: DiagnosticCollector,
pub strategy: RecoveryStrategy,
pub aborted: bool,
pub aborted_at: Option<usize>,
}
impl PartialCompilationResult {
pub fn success_count(&self) -> usize {
self.graphs.iter().filter(|g| g.is_some()).count()
}
pub fn failure_count(&self) -> usize {
self.graphs.iter().filter(|g| g.is_none()).count()
}
pub fn is_all_success(&self) -> bool {
self.graphs.iter().all(|g| g.is_some())
}
pub fn successes(&self) -> impl Iterator<Item = (usize, &EinsumGraph)> {
self.graphs
.iter()
.enumerate()
.filter_map(|(i, g)| g.as_ref().map(|gg| (i, gg)))
}
pub fn failures(&self) -> Vec<usize> {
self.graphs
.iter()
.enumerate()
.filter_map(|(i, g)| if g.is_none() { Some(i) } else { None })
.collect()
}
}
#[derive(Debug, Clone, Default)]
pub struct TolerantCompiler {
strategy: RecoveryStrategy,
}
impl TolerantCompiler {
pub fn new() -> Self {
Self::default()
}
pub fn with_strategy(strategy: RecoveryStrategy) -> Self {
Self { strategy }
}
pub fn strategy(&self) -> RecoveryStrategy {
self.strategy
}
pub fn set_strategy(&mut self, strategy: RecoveryStrategy) {
self.strategy = strategy;
}
pub fn compile_program(&self, program: &[TLExpr]) -> PartialCompilationResult {
self.compile_program_with(program, |_idx| CompilerContext::new())
}
pub fn compile_program_with<F>(
&self,
program: &[TLExpr],
mut make_ctx: F,
) -> PartialCompilationResult
where
F: FnMut(usize) -> CompilerContext,
{
let collector = DiagnosticCollector::new();
let mut graphs: Vec<Option<EinsumGraph>> = Vec::with_capacity(program.len());
let mut aborted = false;
let mut aborted_at: Option<usize> = None;
for (idx, expr) in program.iter().enumerate() {
if aborted {
graphs.push(None);
continue;
}
let mut ctx = make_ctx(idx);
match self.compile_one(idx, expr, &mut ctx, &collector) {
OneResult::Ok(graph) => graphs.push(Some(graph)),
OneResult::Skipped => graphs.push(None),
OneResult::Aborted => {
graphs.push(None);
aborted = true;
aborted_at = Some(idx);
}
}
}
PartialCompilationResult {
graphs,
diagnostics: collector,
strategy: self.strategy,
aborted,
aborted_at,
}
}
pub fn compile_program_with_contexts(
&self,
program: &[TLExpr],
contexts: &mut [CompilerContext],
) -> PartialCompilationResult {
let collector = DiagnosticCollector::new();
let mut graphs: Vec<Option<EinsumGraph>> = Vec::with_capacity(program.len());
let mut aborted = false;
let mut aborted_at: Option<usize> = None;
for (idx, expr) in program.iter().enumerate() {
if aborted {
graphs.push(None);
continue;
}
if idx >= contexts.len() {
collector.push(
Diagnostic::fatal(format!(
"tolerant compiler: missing CompilerContext for expression #{}",
idx
))
.with_expression_index(idx),
);
let action = self.strategy.decide(Severity::Fatal);
match action {
RecoveryAction::Continue => graphs.push(None),
RecoveryAction::SkipExpression => graphs.push(None),
RecoveryAction::AbortProgram => {
graphs.push(None);
aborted = true;
aborted_at = Some(idx);
}
}
continue;
}
match self.compile_one(idx, expr, &mut contexts[idx], &collector) {
OneResult::Ok(graph) => graphs.push(Some(graph)),
OneResult::Skipped => graphs.push(None),
OneResult::Aborted => {
graphs.push(None);
aborted = true;
aborted_at = Some(idx);
}
}
}
PartialCompilationResult {
graphs,
diagnostics: collector,
strategy: self.strategy,
aborted,
aborted_at,
}
}
fn compile_one(
&self,
idx: usize,
expr: &TLExpr,
ctx: &mut CompilerContext,
collector: &DiagnosticCollector,
) -> OneResult {
let unwind_result = panic::catch_unwind(AssertUnwindSafe(|| {
compile_to_einsum_with_context(expr, ctx)
}));
match unwind_result {
Ok(Ok(graph)) => OneResult::Ok(graph),
Ok(Err(err)) => {
let diag =
Diagnostic::error(format!("compilation error in expression #{}: {}", idx, err))
.with_expression_index(idx);
collector.push(diag);
self.react(idx, Severity::Error)
}
Err(payload) => {
let msg = panic_payload_to_string(&payload);
let diag = Diagnostic::fatal(format!(
"panic while compiling expression #{}: {}",
idx, msg
))
.with_expression_index(idx);
collector.push(diag);
self.react(idx, Severity::Fatal)
}
}
}
fn react(&self, _idx: usize, severity: Severity) -> OneResult {
match self.strategy.decide(severity) {
RecoveryAction::Continue => {
OneResult::Skipped
}
RecoveryAction::SkipExpression => OneResult::Skipped,
RecoveryAction::AbortProgram => OneResult::Aborted,
}
}
}
#[derive(Debug)]
enum OneResult {
Ok(EinsumGraph),
Skipped,
Aborted,
}
pub fn compile_tolerant(program: &[TLExpr]) -> PartialCompilationResult {
TolerantCompiler::new().compile_program(program)
}
pub fn compile_tolerant_with_strategy(
program: &[TLExpr],
strategy: RecoveryStrategy,
) -> PartialCompilationResult {
TolerantCompiler::with_strategy(strategy).compile_program(program)
}
fn panic_payload_to_string(payload: &Box<dyn std::any::Any + Send>) -> String {
if let Some(s) = payload.downcast_ref::<&'static str>() {
(*s).to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"<non-string panic payload>".to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::{TLExpr, Term};
fn good_expr() -> TLExpr {
TLExpr::pred("p", vec![Term::var("x")])
}
#[test]
fn compile_tolerant_all_good() {
let program = vec![good_expr(), good_expr(), good_expr()];
let res = compile_tolerant(&program);
assert_eq!(res.graphs.len(), 3);
assert!(res.is_all_success());
assert_eq!(res.success_count(), 3);
assert!(!res.aborted);
assert!(res.diagnostics.is_empty());
}
#[test]
fn partial_result_success_iter() {
let program = vec![good_expr(), good_expr()];
let res = compile_tolerant(&program);
let v: Vec<usize> = res.successes().map(|(i, _)| i).collect();
assert_eq!(v, vec![0, 1]);
}
}