use std::sync::Arc;
use sim_kernel::{Cx, Diagnostic, Expr, Result, ShapeRef, Value};
use crate::{
MatchScore, Shape, ShapeDoc, ShapeMatch,
hooks::types::{
MatchHook, MatchHookContext, MatchHookDecision, MatchHookKind, MatchHookPhase,
MatchHookTargetKind,
},
};
pub struct HookedShape {
inner: Arc<dyn Shape>,
hooks: Vec<Arc<dyn MatchHook>>,
}
impl HookedShape {
pub fn new(inner: Arc<dyn Shape>, hooks: Vec<Arc<dyn MatchHook>>) -> Self {
Self { inner, hooks }
}
pub fn inner(&self) -> &Arc<dyn Shape> {
&self.inner
}
pub fn hooks(&self) -> &[Arc<dyn MatchHook>] {
&self.hooks
}
}
impl Shape for HookedShape {
fn parents(&self, cx: &mut Cx) -> Result<Vec<ShapeRef>> {
self.inner.parents(cx)
}
fn is_effectful(&self) -> bool {
self.inner.is_effectful()
}
fn is_total(&self) -> bool {
self.inner.is_total()
}
fn is_subshape_of(&self, cx: &mut Cx, parent: &dyn Shape) -> Result<Option<bool>> {
self.inner.is_subshape_of(cx, parent)
}
fn check_value(&self, cx: &mut Cx, value: Value) -> Result<ShapeMatch> {
let label = self.inner.describe(cx)?.name;
let before = self.run_marks(
cx,
MatchHookTargetKind::Value,
MatchHookPhase::BeforeInner,
&label,
None,
)?;
let matched = self.inner.check_value(cx, value)?;
self.finish_match(cx, MatchHookTargetKind::Value, label, matched, before)
}
fn check_expr(&self, cx: &mut Cx, expr: &Expr) -> Result<ShapeMatch> {
let label = self.inner.describe(cx)?.name;
let before = self.run_marks(
cx,
MatchHookTargetKind::Expr,
MatchHookPhase::BeforeInner,
&label,
None,
)?;
let matched = self.inner.check_expr(cx, expr)?;
self.finish_match(cx, MatchHookTargetKind::Expr, label, matched, before)
}
fn describe(&self, cx: &mut Cx) -> Result<ShapeDoc> {
let mut doc = ShapeDoc::new("hooked shape").with_detail(self.inner.describe(cx)?.name);
for hook in &self.hooks {
doc = doc.with_detail(hook.symbol().to_string());
}
Ok(doc)
}
}
impl HookedShape {
fn finish_match(
&self,
cx: &mut Cx,
target_kind: MatchHookTargetKind,
label: String,
mut matched: ShapeMatch,
before_marks: Vec<Diagnostic>,
) -> Result<ShapeMatch> {
matched.diagnostics.extend(before_marks);
let after_marks = self.run_marks(
cx,
target_kind,
MatchHookPhase::AfterInner,
&label,
Some(&matched),
)?;
matched.diagnostics.extend(after_marks);
if !matched.accepted {
matched = self.run_accept_hooks(cx, target_kind, &label, matched)?;
}
if matched.accepted {
matched = self.run_discard_hooks(cx, target_kind, &label, matched)?;
}
self.run_annotate_hooks(cx, target_kind, &label, matched)
}
fn run_marks(
&self,
cx: &mut Cx,
target_kind: MatchHookTargetKind,
phase: MatchHookPhase,
shape_label: &str,
current: Option<&ShapeMatch>,
) -> Result<Vec<Diagnostic>> {
let mut diagnostics = Vec::new();
for (hook_index, hook) in self.hooks.iter().enumerate() {
if hook.kind() != MatchHookKind::Mark {
continue;
}
let ctx = MatchHookContext {
hook_index,
phase,
target_kind,
shape_label: shape_label.to_owned(),
};
if let MatchHookDecision::Mark { message } = hook.apply(cx, &ctx, current)? {
diagnostics.push(Diagnostic::info(format!("shape-hook:mark {message}")));
}
}
Ok(diagnostics)
}
fn run_accept_hooks(
&self,
cx: &mut Cx,
target_kind: MatchHookTargetKind,
shape_label: &str,
mut matched: ShapeMatch,
) -> Result<ShapeMatch> {
for (hook_index, hook) in self.hooks.iter().enumerate() {
if hook.kind() != MatchHookKind::Accept {
continue;
}
let ctx = MatchHookContext {
hook_index,
phase: MatchHookPhase::AfterInner,
target_kind,
shape_label: shape_label.to_owned(),
};
if let MatchHookDecision::Accept { reason, score } =
hook.apply(cx, &ctx, Some(&matched))?
{
matched.accepted = true;
if matched.score == MatchScore::reject() {
matched.score = score;
}
matched
.diagnostics
.push(Diagnostic::info(format!("shape-hook:accept {reason}")));
}
}
Ok(matched)
}
fn run_discard_hooks(
&self,
cx: &mut Cx,
target_kind: MatchHookTargetKind,
shape_label: &str,
mut matched: ShapeMatch,
) -> Result<ShapeMatch> {
for (hook_index, hook) in self.hooks.iter().enumerate() {
if hook.kind() != MatchHookKind::Discard {
continue;
}
let ctx = MatchHookContext {
hook_index,
phase: MatchHookPhase::AfterInner,
target_kind,
shape_label: shape_label.to_owned(),
};
if let MatchHookDecision::Discard { reason } = hook.apply(cx, &ctx, Some(&matched))? {
matched.accepted = false;
matched.score = MatchScore::reject();
matched
.diagnostics
.push(Diagnostic::error(format!("shape-hook:discard {reason}")));
break;
}
}
Ok(matched)
}
fn run_annotate_hooks(
&self,
cx: &mut Cx,
target_kind: MatchHookTargetKind,
shape_label: &str,
mut matched: ShapeMatch,
) -> Result<ShapeMatch> {
for (hook_index, hook) in self.hooks.iter().enumerate() {
if hook.kind() != MatchHookKind::Annotate {
continue;
}
let ctx = MatchHookContext {
hook_index,
phase: MatchHookPhase::AfterInner,
target_kind,
shape_label: shape_label.to_owned(),
};
if let MatchHookDecision::Annotate {
message,
score_delta,
} = hook.apply(cx, &ctx, Some(&matched))?
{
matched.score += MatchScore::exact(score_delta);
matched
.diagnostics
.push(Diagnostic::info(format!("shape-hook:annotate {message}")));
}
}
Ok(matched)
}
}