use crate::template::{
LocalizedTemplateSpec, TemplateComponent, TemplateVariant, TemplateVariants,
};
use crate::version::{MAX_TEMPLATE_COMPONENTS, MAX_TEMPLATE_NESTING_DEPTH};
use crate::{BibliographySpec, CitationSpec, ResolutionError};
use super::Style;
#[cfg(test)]
use crate::template::TemplateGroup;
#[derive(Debug, Clone, PartialEq)]
pub enum SchemaWarning {
UnknownTypeName {
name: String,
location: String,
},
}
impl std::fmt::Display for SchemaWarning {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SchemaWarning::UnknownTypeName { name, location } => {
write!(
f,
"unknown reference type \"{name}\" in {location} \
(will silently match nothing; check for typos)"
)
}
}
}
}
impl Style {
pub fn validate_resource_limits(&self) -> Result<(), String> {
let mut budget = TemplateResourceBudget::default();
if let Some(templates) = &self.templates {
for (name, template) in templates {
budget.check_template(template, &format!("templates.{name}"), 0)?;
}
}
if let Some(citation) = &self.citation {
budget.check_citation_spec(citation, "citation", 0)?;
}
if let Some(bibliography) = &self.bibliography {
budget.check_bibliography_spec(bibliography, "bibliography", 0)?;
}
Ok(())
}
pub fn validate(&self) -> Vec<SchemaWarning> {
let mut warnings = Vec::new();
self.collect_type_selector_warnings(&mut warnings);
warnings
}
fn collect_type_selector_warnings(&self, warnings: &mut Vec<SchemaWarning>) {
if let Some(bib) = &self.bibliography
&& let Some(type_variants) = &bib.type_variants
{
for selector in type_variants.keys() {
for name in selector.unknown_type_names() {
warnings.push(SchemaWarning::UnknownTypeName {
name: name.to_string(),
location: "bibliography.type-variants".to_string(),
});
}
}
}
if let Some(cit) = &self.citation {
collect_citation_spec_warnings(cit, "citation", warnings);
}
}
pub(crate) fn validate_profile_shape(&self) -> Result<(), ResolutionError> {
if self.templates.is_some() || yaml_path_present(self.raw_yaml.as_ref(), &["templates"]) {
return Err(ResolutionError::InvalidProfileOverride {
location: "templates".to_string(),
});
}
if let Some(location) = forbidden_profile_template_path(self.raw_yaml.as_ref()) {
return Err(ResolutionError::InvalidProfileOverride { location });
}
Ok(())
}
}
fn forbidden_profile_template_path(raw_yaml: Option<&serde_yaml::Value>) -> Option<String> {
let raw_yaml = raw_yaml?;
for (section, recursive) in [("citation", true), ("bibliography", false)] {
if let Some(section_value) = mapping_child(raw_yaml, section) {
if recursive {
if let Some(location) = forbidden_citation_template_path(section_value, section) {
return Some(location);
}
} else if let Some(location) = forbidden_section_template_path(section_value, section) {
return Some(location);
}
}
}
None
}
fn forbidden_section_template_path(section: &serde_yaml::Value, location: &str) -> Option<String> {
for key in ["template", "template-ref", "type-variants", "locales"] {
if mapping_child(section, key).is_some() {
return Some(format!("{location}.{key}"));
}
}
None
}
fn forbidden_citation_template_path(section: &serde_yaml::Value, location: &str) -> Option<String> {
if let Some(location) = forbidden_section_template_path(section, location) {
return Some(location);
}
for sub_section in ["integral", "non-integral", "subsequent", "ibid"] {
if let Some(child) = mapping_child(section, sub_section)
&& let Some(location) =
forbidden_citation_template_path(child, &format!("{location}.{sub_section}"))
{
return Some(location);
}
}
None
}
fn mapping_child<'a>(value: &'a serde_yaml::Value, segment: &str) -> Option<&'a serde_yaml::Value> {
let serde_yaml::Value::Mapping(map) = value else {
return None;
};
let key = serde_yaml::Value::String(segment.to_string());
map.get(&key)
}
fn yaml_path_present(value: Option<&serde_yaml::Value>, path: &[&str]) -> bool {
let Some(mut current) = value else {
return false;
};
for segment in path {
let Some(next) = mapping_child(current, segment) else {
return false;
};
current = next;
}
true
}
fn collect_citation_spec_warnings(
spec: &CitationSpec,
location: &str,
warnings: &mut Vec<SchemaWarning>,
) {
if let Some(type_variants) = &spec.type_variants {
for selector in type_variants.keys() {
for name in selector.unknown_type_names() {
warnings.push(SchemaWarning::UnknownTypeName {
name: name.to_string(),
location: format!("{location}.type-variants"),
});
}
}
}
for (sub_name, sub_spec) in [
("integral", spec.integral.as_deref()),
("non-integral", spec.non_integral.as_deref()),
("subsequent", spec.subsequent.as_deref()),
("ibid", spec.ibid.as_deref()),
]
.into_iter()
.filter_map(|(n, s)| s.map(|s| (n, s)))
{
collect_citation_spec_warnings(sub_spec, &format!("{location}.{sub_name}"), warnings);
}
}
#[derive(Default)]
struct TemplateResourceBudget {
component_count: usize,
}
impl TemplateResourceBudget {
fn check_template(
&mut self,
template: &[TemplateComponent],
location: &str,
depth: usize,
) -> Result<(), String> {
if depth > MAX_TEMPLATE_NESTING_DEPTH {
return Err(format!(
"{location} exceeds maximum template nesting depth of {MAX_TEMPLATE_NESTING_DEPTH}"
));
}
for component in template {
self.check_component(component, location, depth)?;
}
Ok(())
}
fn check_component(
&mut self,
component: &TemplateComponent,
location: &str,
depth: usize,
) -> Result<(), String> {
self.component_count = self.component_count.saturating_add(1);
if self.component_count > MAX_TEMPLATE_COMPONENTS {
return Err(format!(
"style exceeds maximum template component count of {MAX_TEMPLATE_COMPONENTS}"
));
}
match component {
TemplateComponent::Date(date) => {
if let Some(fallback) = &date.fallback {
self.check_template(fallback, &format!("{location}.date.fallback"), depth + 1)?;
}
}
TemplateComponent::Group(group) => {
self.check_template(&group.group, &format!("{location}.group"), depth + 1)?;
}
TemplateComponent::Contributor(_)
| TemplateComponent::Title(_)
| TemplateComponent::Number(_)
| TemplateComponent::Variable(_)
| TemplateComponent::Term(_) => {}
}
Ok(())
}
fn check_variant(
&mut self,
variant: &TemplateVariant,
location: &str,
depth: usize,
) -> Result<(), String> {
match variant {
TemplateVariant::Full(template) => self.check_template(template, location, depth),
TemplateVariant::Diff(diff) => {
for (index, add) in diff.add.iter().enumerate() {
self.check_component(
&add.component,
&format!("{location}.add[{index}].component"),
depth,
)?;
}
Ok(())
}
}
}
fn check_variants(
&mut self,
variants: &TemplateVariants,
location: &str,
depth: usize,
) -> Result<(), String> {
for (selector, variant) in variants {
self.check_variant(variant, &format!("{location}.{selector:?}"), depth)?;
}
Ok(())
}
fn check_locales(
&mut self,
locales: &[LocalizedTemplateSpec],
location: &str,
depth: usize,
) -> Result<(), String> {
for (index, locale) in locales.iter().enumerate() {
self.check_template(
&locale.template,
&format!("{location}[{index}].template"),
depth,
)?;
}
Ok(())
}
fn check_citation_spec(
&mut self,
spec: &CitationSpec,
location: &str,
depth: usize,
) -> Result<(), String> {
if let Some(template) = &spec.template {
self.check_template(template, &format!("{location}.template"), depth)?;
}
if let Some(locales) = &spec.locales {
self.check_locales(locales, &format!("{location}.locales"), depth)?;
}
if let Some(variants) = &spec.type_variants {
self.check_variants(variants, &format!("{location}.type-variants"), depth)?;
}
for (sub_name, sub_spec) in [
("integral", spec.integral.as_deref()),
("non-integral", spec.non_integral.as_deref()),
("subsequent", spec.subsequent.as_deref()),
("ibid", spec.ibid.as_deref()),
]
.into_iter()
.filter_map(|(n, s)| s.map(|s| (n, s)))
{
self.check_citation_spec(sub_spec, &format!("{location}.{sub_name}"), depth + 1)?;
}
Ok(())
}
fn check_bibliography_spec(
&mut self,
spec: &BibliographySpec,
location: &str,
depth: usize,
) -> Result<(), String> {
if let Some(template) = &spec.template {
self.check_template(template, &format!("{location}.template"), depth)?;
}
if let Some(locales) = &spec.locales {
self.check_locales(locales, &format!("{location}.locales"), depth)?;
}
if let Some(variants) = &spec.type_variants {
self.check_variants(variants, &format!("{location}.type-variants"), depth)?;
}
Ok(())
}
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::indexing_slicing,
clippy::todo,
clippy::unimplemented,
clippy::unreachable,
clippy::get_unwrap,
reason = "Panicking is acceptable and often desired in tests."
)]
mod security_resource_tests {
use super::*;
fn nested_group(depth: usize) -> TemplateComponent {
if depth == 0 {
TemplateComponent::default()
} else {
TemplateComponent::Group(TemplateGroup {
group: vec![nested_group(depth - 1)],
..TemplateGroup::default()
})
}
}
#[test]
fn validate_resource_limits_rejects_deeply_nested_templates() {
let style = Style {
bibliography: Some(BibliographySpec {
template: Some(vec![nested_group(MAX_TEMPLATE_NESTING_DEPTH + 1)]),
..BibliographySpec::default()
}),
..Style::default()
};
let err = style
.validate_resource_limits()
.expect_err("deep template must be rejected");
assert!(err.contains("maximum template nesting depth"));
}
#[test]
fn validate_resource_limits_rejects_too_many_components() {
let style = Style {
bibliography: Some(BibliographySpec {
template: Some(vec![
TemplateComponent::default();
MAX_TEMPLATE_COMPONENTS + 1
]),
..BibliographySpec::default()
}),
..Style::default()
};
let err = style
.validate_resource_limits()
.expect_err("oversized template must be rejected");
assert!(err.contains("maximum template component count"));
}
}