use std::ops::Deref;
use crate::ast::{ArgumentKind, ArgumentSlot, ArgumentValue, Ast, Node, NodeId};
use crate::knowledge::{KnowledgeBase, lookup_command_node_name, lookup_environment_node_name};
use crate::parse::ContentMode;
use crate::rewrite::rule::RuleKey;
use crate::rewrite::{RewriteReport, RuleError};
use texform_knowledge::specs::{
ActiveCharacterRecord, ActiveCommandRecord, ActiveEnvironmentRecord, BuiltinCommandRecord,
BuiltinEnvironmentRecord,
};
#[derive(Clone, Copy)]
pub struct CommandView<'a> {
pub name: &'a str,
pub args: &'a [ArgumentSlot],
}
impl CommandView<'_> {
pub fn subject(&self) -> String {
format!(r"\{}", self.name)
}
}
#[derive(Clone, Copy)]
pub struct InfixView<'a> {
pub name: &'a str,
pub args: &'a [ArgumentSlot],
pub left: NodeId,
pub right: NodeId,
}
impl InfixView<'_> {
pub fn subject(&self) -> String {
format!(r"\{}", self.name)
}
}
#[derive(Clone, Copy)]
pub struct DeclarativeView<'a> {
pub name: &'a str,
pub args: &'a [ArgumentSlot],
}
#[derive(Clone, Copy)]
pub struct EnvironmentView<'a> {
pub name: &'a str,
pub args: &'a [ArgumentSlot],
pub body: NodeId,
}
pub struct RuleContext<'a> {
pub ast: &'a mut Ast,
math_kb: &'a KnowledgeBase,
text_kb: &'a KnowledgeBase,
report: &'a mut RewriteReport,
}
pub struct RuleScopedContext<'cx, 'ctx> {
cx: &'cx RuleContext<'ctx>,
rule: RuleKey,
}
impl<'cx, 'ctx> Deref for RuleScopedContext<'cx, 'ctx> {
type Target = RuleContext<'ctx>;
fn deref(&self) -> &Self::Target {
self.cx
}
}
impl RuleScopedContext<'_, '_> {
pub fn invalid_shape(&self, message: impl Into<String>) -> RuleError {
self.cx.invalid_shape(self.rule, message)
}
pub fn missing_metadata(&self, name: impl Into<String>) -> RuleError {
self.cx.missing_metadata(self.rule, name)
}
pub fn ensure_shape(
&self,
condition: bool,
message: impl Into<String>,
) -> Result<(), RuleError> {
self.cx.ensure_shape(condition, self.rule, message)
}
pub fn expect_arg_len(
&self,
args: &[ArgumentSlot],
expected: usize,
subject: &str,
) -> Result<(), RuleError> {
self.cx.expect_arg_len(self.rule, args, expected, subject)
}
pub fn expect_no_args(&self, args: &[ArgumentSlot], subject: &str) -> Result<(), RuleError> {
self.cx.expect_no_args(self.rule, args, subject)
}
pub fn star_arg_value(&self, slot: &ArgumentSlot, subject: &str) -> Result<bool, RuleError> {
match slot {
Some(arg) if arg.kind == ArgumentKind::Star => match arg.value {
ArgumentValue::Boolean(value) => Ok(value),
_ => {
Err(self
.invalid_shape(format!("{subject} star slot should carry a boolean value")))
}
},
_ => Err(self.invalid_shape(format!("{subject} should carry a star slot"))),
}
}
pub fn optional_math_content(
&self,
slot: &ArgumentSlot,
subject: &str,
label: &str,
) -> Result<Option<NodeId>, RuleError> {
match slot {
None => Ok(None),
Some(arg) if arg.kind == ArgumentKind::Optional => match arg.value {
ArgumentValue::MathContent(node_id) => Ok(Some(node_id)),
_ => Err(self.invalid_shape(format!("{subject} {label} should be math content"))),
},
_ => Err(self.invalid_shape(format!(
"{subject} {label} should be an optional math argument"
))),
}
}
pub fn optional_group_math_content(
&self,
slot: &ArgumentSlot,
subject: &str,
label: &str,
) -> Result<Option<NodeId>, RuleError> {
match slot {
None => Ok(None),
Some(arg) if arg.kind == ArgumentKind::Group => match arg.value {
ArgumentValue::MathContent(node_id) => Ok(Some(node_id)),
_ => Err(self
.invalid_shape(format!("{subject} optional {label} should be math content"))),
},
_ => Err(self.invalid_shape(format!(
"{subject} optional {label} should be a braced group"
))),
}
}
pub fn mandatory_math_content(
&self,
slot: &ArgumentSlot,
subject: &str,
label: &str,
) -> Result<NodeId, RuleError> {
match slot {
Some(arg) if arg.kind == ArgumentKind::Mandatory => match arg.value {
ArgumentValue::MathContent(node_id) => Ok(node_id),
_ => Err(self.invalid_shape(format!("{subject} {label} should be math content"))),
},
_ => Err(self.invalid_shape(format!(
"{subject} {label} should be a mandatory math argument"
))),
}
}
pub fn mandatory_or_group_math_content(
&self,
slot: &ArgumentSlot,
subject: &str,
label: &str,
) -> Result<NodeId, RuleError> {
match slot {
Some(arg) if matches!(arg.kind, ArgumentKind::Mandatory | ArgumentKind::Group) => {
match arg.value {
ArgumentValue::MathContent(node_id) => Ok(node_id),
_ => {
Err(self.invalid_shape(format!("{subject} {label} should be math content")))
}
}
}
_ => Err(self.invalid_shape(format!("{subject} {label} should be math content"))),
}
}
}
impl<'a> RuleContext<'a> {
pub fn new(
ast: &'a mut Ast,
math_kb: &'a KnowledgeBase,
text_kb: &'a KnowledgeBase,
report: &'a mut RewriteReport,
) -> Self {
Self {
ast,
math_kb,
text_kb,
report,
}
}
fn kb_for(&self, mode: ContentMode) -> &'a KnowledgeBase {
match mode {
ContentMode::Math => self.math_kb,
ContentMode::Text => self.text_kb,
}
}
pub fn for_rule(&self, rule: RuleKey) -> RuleScopedContext<'_, 'a> {
RuleScopedContext { cx: self, rule }
}
pub fn knows_command_name(&self, name: &str) -> bool {
self.lookup_command(name, ContentMode::Math).is_some()
|| self.lookup_command(name, ContentMode::Text).is_some()
}
pub fn knows_env_name(&self, name: &str) -> bool {
self.lookup_env(name, ContentMode::Math).is_some()
|| self.lookup_env(name, ContentMode::Text).is_some()
}
pub fn command_has_tag(&self, name: &str, tag: &str) -> bool {
self.lookup_command(name, ContentMode::Math)
.is_some_and(|record| record.tags.contains(&tag))
|| self
.lookup_command(name, ContentMode::Text)
.is_some_and(|record| record.tags.contains(&tag))
}
pub fn env_has_tag(&self, name: &str, tag: &str) -> bool {
self.lookup_env(name, ContentMode::Math)
.is_some_and(|record| record.tags.contains(&tag))
|| self
.lookup_env(name, ContentMode::Text)
.is_some_and(|record| record.tags.contains(&tag))
}
pub fn active_command(&self, node_id: NodeId) -> Option<&ActiveCommandRecord> {
let name = lookup_command_node_name(self.ast.node(node_id))?;
self.lookup_command(name, ContentMode::Math)
.or_else(|| self.lookup_command(name, ContentMode::Text))
}
pub fn active_env(&self, node_id: NodeId) -> Option<&ActiveEnvironmentRecord> {
let name = lookup_environment_node_name(self.ast.node(node_id))?;
self.lookup_env(name, ContentMode::Math)
.or_else(|| self.lookup_env(name, ContentMode::Text))
}
pub fn lookup_command(&self, name: &str, mode: ContentMode) -> Option<&ActiveCommandRecord> {
self.kb_for(mode).lookup_command(name)
}
pub fn lookup_character(
&self,
name: &str,
mode: ContentMode,
) -> Option<&ActiveCharacterRecord> {
self.kb_for(mode).lookup_character(name)
}
pub fn lookup_env(&self, name: &str, mode: ContentMode) -> Option<&ActiveEnvironmentRecord> {
self.kb_for(mode).lookup_env(name)
}
pub fn mark_rule_applied(&mut self, key: RuleKey) {
self.report.mark_rule_applied(key);
}
pub fn mark_rule_skipped(&mut self, key: RuleKey) {
self.report.mark_rule_skipped(key);
}
pub fn record_iteration(&mut self, iterations: usize) {
self.report.record_iteration(iterations);
}
pub fn node(&self, node_id: NodeId) -> &Node {
self.ast.node(node_id)
}
pub fn invalid_shape(&self, _rule: RuleKey, message: impl Into<String>) -> RuleError {
RuleError::InvalidNodeShape {
message: message.into(),
}
}
pub fn missing_metadata(&self, _rule: RuleKey, name: impl Into<String>) -> RuleError {
RuleError::MissingMetadata { name: name.into() }
}
pub fn ensure_shape(
&self,
condition: bool,
rule: RuleKey,
message: impl Into<String>,
) -> Result<(), RuleError> {
if condition {
Ok(())
} else {
Err(self.invalid_shape(rule, message))
}
}
pub fn expect_arg_len(
&self,
rule: RuleKey,
args: &[ArgumentSlot],
expected: usize,
subject: &str,
) -> Result<(), RuleError> {
self.ensure_shape(
args.len() == expected,
rule,
format!(
"{subject} should carry exactly {expected} explicit argument slots, got {}",
args.len()
),
)
}
pub fn expect_no_args(
&self,
rule: RuleKey,
args: &[ArgumentSlot],
subject: &str,
) -> Result<(), RuleError> {
self.expect_arg_len(rule, args, 0, subject)
}
pub fn match_command(
&self,
node_id: NodeId,
record: &'static BuiltinCommandRecord,
) -> Option<CommandView<'_>> {
match self.ast.node(node_id) {
Node::Command { name, args, .. } if name == record.name => Some(CommandView {
name: name.as_str(),
args: args.as_slice(),
}),
_ => None,
}
}
pub fn match_infix(
&self,
node_id: NodeId,
record: &'static BuiltinCommandRecord,
) -> Option<InfixView<'_>> {
match self.ast.node(node_id) {
Node::Infix {
name,
args,
left,
right,
} if name == record.name => Some(InfixView {
name: name.as_str(),
args: args.as_slice(),
left: *left,
right: *right,
}),
_ => None,
}
}
pub fn match_declarative(
&self,
node_id: NodeId,
record: &'static BuiltinCommandRecord,
) -> Option<DeclarativeView<'_>> {
match self.ast.node(node_id) {
Node::Declarative { name, args } if name == record.name => Some(DeclarativeView {
name: name.as_str(),
args: args.as_slice(),
}),
_ => None,
}
}
pub fn match_environment(
&self,
node_id: NodeId,
record: &'static BuiltinEnvironmentRecord,
) -> Option<EnvironmentView<'_>> {
match self.ast.node(node_id) {
Node::Environment {
name, args, body, ..
} if name == record.name => Some(EnvironmentView {
name: name.as_str(),
args: args.as_slice(),
body: *body,
}),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::Argument;
use crate::parse::ParseContext;
use crate::rewrite::{PackageName, RewriteReport, RuleKey};
const TEST_RULE: RuleKey = RuleKey {
package: PackageName::Base,
name: "rule-context-test",
};
#[test]
fn extracts_common_prefix_argument_shapes() {
let parse_ctx = ParseContext::from_packages(&["base"]);
let mut report = RewriteReport::default();
let mut ast = Ast::new();
let required = ast.new_node(Node::Char('x'));
let optional = ast.new_node(Node::Char('2'));
let grouped = ast.new_node(Node::Char('t'));
let cx = RuleContext::new(
&mut ast,
parse_ctx.math_kb(),
parse_ctx.text_kb(),
&mut report,
);
let star = Some(Argument {
kind: ArgumentKind::Star,
value: ArgumentValue::Boolean(true),
});
let required = Some(Argument {
kind: ArgumentKind::Mandatory,
value: ArgumentValue::MathContent(required),
});
let optional = Some(Argument {
kind: ArgumentKind::Optional,
value: ArgumentValue::MathContent(optional),
});
let grouped = Some(Argument {
kind: ArgumentKind::Group,
value: ArgumentValue::MathContent(grouped),
});
assert!(
cx.for_rule(TEST_RULE)
.star_arg_value(&star, r"\example")
.unwrap()
);
assert_eq!(
cx.for_rule(TEST_RULE)
.mandatory_math_content(&required, r"\example", "argument")
.unwrap(),
required
.as_ref()
.and_then(|arg| match arg.value {
ArgumentValue::MathContent(id) => Some(id),
_ => None,
})
.unwrap()
);
assert_eq!(
cx.for_rule(TEST_RULE)
.optional_math_content(&optional, r"\example", "order")
.unwrap(),
optional.as_ref().and_then(|arg| match arg.value {
ArgumentValue::MathContent(id) => Some(id),
_ => None,
})
);
assert_eq!(
cx.for_rule(TEST_RULE)
.optional_group_math_content(&grouped, r"\example", "denominator")
.unwrap(),
grouped.as_ref().and_then(|arg| match arg.value {
ArgumentValue::MathContent(id) => Some(id),
_ => None,
})
);
assert_eq!(
cx.for_rule(TEST_RULE)
.mandatory_or_group_math_content(&grouped, r"\example", "argument")
.unwrap(),
grouped
.as_ref()
.and_then(|arg| match arg.value {
ArgumentValue::MathContent(id) => Some(id),
_ => None,
})
.unwrap()
);
assert_eq!(
cx.for_rule(TEST_RULE)
.optional_math_content(&None, r"\example", "order")
.unwrap(),
None
);
}
}