use std::collections::HashMap;
use pest::Parser;
use crate::schema::{
NumericValue, ObjectSchema, PrimitiveType, ScalarBase, ScalarDefinition,
ScalarExpression, ScalarLiteral, ScalarPredicate, SchemaDefinition, SchemaField,
SchemaRegistry, SchemaRegistryError, SchemaShape, TypeReference, TypeReferenceKind,
};
use super::{CollectionParser, Rule};
#[derive(Debug, Clone, Copy)]
pub(super) struct SourceSpan {
start: usize,
end: usize,
}
impl SourceSpan {
fn from_pest_span(span: pest::Span<'_>) -> Self {
Self {
start: span.start(),
end: span.end(),
}
}
fn into_pest_span<'a>(self, source: &'a str) -> pest::Span<'a> {
pest::Span::new(source, self.start, self.end).unwrap()
}
}
pub(super) type DeclarationSpanIndex = HashMap<String, SourceSpan>;
pub(super) fn register_declaration(
schema_registry: &mut SchemaRegistry,
pair: pest::iterators::Pair<Rule>,
) -> Result<(), pest::error::Error<Rule>> {
let inner = pair.into_inner().next().unwrap();
match inner.as_rule() {
Rule::scalar_declaration => register_scalar_declaration(schema_registry, inner),
Rule::schema_object_declaration | Rule::schema_array_declaration => {
register_schema_declaration(schema_registry, inner)
}
_ => unreachable!("unexpected declaration rule: {:?}", inner.as_rule()),
}
}
fn register_scalar_declaration(
schema_registry: &mut SchemaRegistry,
pair: pest::iterators::Pair<Rule>,
) -> Result<(), pest::error::Error<Rule>> {
let span = pair.as_span();
let definition = parse_scalar_definition(pair)?;
schema_registry
.register_scalar(definition)
.map_err(|err| schema_registry_error_to_span(err, span))
}
fn register_schema_declaration(
schema_registry: &mut SchemaRegistry,
pair: pest::iterators::Pair<Rule>,
) -> Result<(), pest::error::Error<Rule>> {
let span = pair.as_span();
let definition = parse_schema_definition(pair)?;
schema_registry
.register_schema(definition)
.map_err(|err| schema_registry_error_to_span(err, span))
}
fn parse_scalar_definition(
pair: pest::iterators::Pair<Rule>,
) -> Result<ScalarDefinition, pest::error::Error<Rule>> {
let span = pair.as_span();
let mut inner = pair.into_inner();
let name = inner.next().unwrap().as_str().to_string();
let expression = parse_scalar_expression(inner.next().unwrap())?;
if expression.base.is_none() && expression.predicates.is_empty() {
return Err(pest::error::Error::new_from_span(
pest::error::ErrorVariant::CustomError {
message: "scalar declaration must define a base type or predicate".to_string(),
},
span,
));
}
Ok(ScalarDefinition::new(name, expression))
}
fn parse_scalar_expression(
pair: pest::iterators::Pair<Rule>,
) -> Result<ScalarExpression, pest::error::Error<Rule>> {
let span = pair.as_span();
let mut base = None;
let mut predicates = Vec::new();
for term in pair.into_inner() {
match term.as_rule() {
Rule::primitive_type => {
if base
.replace(ScalarBase::Primitive(parse_primitive_type(term.as_str())))
.is_some()
{
return Err(pest::error::Error::new_from_span(
pest::error::ErrorVariant::CustomError {
message: "scalar expressions can only declare one base type".to_string(),
},
span,
));
}
}
Rule::identifier => {
if base
.replace(ScalarBase::Named(term.as_str().to_string()))
.is_some()
{
return Err(pest::error::Error::new_from_span(
pest::error::ErrorVariant::CustomError {
message: "scalar expressions can only declare one base type".to_string(),
},
span,
));
}
}
Rule::enum_predicate => predicates.push(parse_enum_predicate(term)),
Rule::format_predicate => predicates.push(parse_format_predicate(term)),
Rule::len_predicate => predicates.push(parse_len_predicate(term)?),
Rule::pattern_predicate => predicates.push(parse_pattern_predicate(term)),
Rule::range_predicate => predicates.push(parse_range_predicate(term)?),
_ => unreachable!("unexpected scalar term: {:?}", term.as_rule()),
}
}
Ok(ScalarExpression::new(base, predicates))
}
fn parse_enum_predicate(pair: pest::iterators::Pair<Rule>) -> ScalarPredicate {
let values = pair.into_inner().map(parse_scalar_literal).collect();
ScalarPredicate::Enum(values)
}
fn parse_format_predicate(pair: pest::iterators::Pair<Rule>) -> ScalarPredicate {
let name = pair.into_inner().next().unwrap().as_str().to_string();
ScalarPredicate::Format(name)
}
fn parse_len_predicate(
pair: pest::iterators::Pair<Rule>,
) -> Result<ScalarPredicate, pest::error::Error<Rule>> {
let span = pair.as_span();
let (min, max) = parse_bounds_body(pair.as_str(), span.clone(), "len")?;
let min = min
.map(|value| {
value.parse::<usize>().map_err(|_| {
pest::error::Error::new_from_span(
pest::error::ErrorVariant::CustomError {
message: "len() bounds must be integers".to_string(),
},
span.clone(),
)
})
})
.transpose()?;
let max = max
.map(|value| {
value.parse::<usize>().map_err(|_| {
pest::error::Error::new_from_span(
pest::error::ErrorVariant::CustomError {
message: "len() bounds must be integers".to_string(),
},
span.clone(),
)
})
})
.transpose()?;
Ok(ScalarPredicate::Length { min, max })
}
fn parse_pattern_predicate(pair: pest::iterators::Pair<Rule>) -> ScalarPredicate {
let regex_literal = pair.into_inner().next().unwrap().as_str();
ScalarPredicate::Pattern(regex_literal[1..regex_literal.len() - 1].to_string())
}
fn parse_range_predicate(
pair: pest::iterators::Pair<Rule>,
) -> Result<ScalarPredicate, pest::error::Error<Rule>> {
let span = pair.as_span();
let (min, max) = parse_bounds_body(pair.as_str(), span.clone(), "range")?;
let min = min
.map(|value| parse_numeric_value(value, span.clone()))
.transpose()?;
let max = max
.map(|value| parse_numeric_value(value, span.clone()))
.transpose()?;
Ok(ScalarPredicate::Range { min, max })
}
fn parse_scalar_literal(pair: pest::iterators::Pair<Rule>) -> ScalarLiteral {
match pair.as_rule() {
Rule::quoted_string => ScalarLiteral::String(normalize_quoted_string(pair.as_str())),
Rule::boolean_literal => ScalarLiteral::Boolean(pair.as_str() == "true"),
Rule::null_literal => ScalarLiteral::Null,
Rule::numeric_literal => {
if pair.as_str().contains('.') {
ScalarLiteral::Number(pair.as_str().parse().unwrap())
} else {
ScalarLiteral::Integer(pair.as_str().parse().unwrap())
}
}
_ => unreachable!("unexpected scalar literal: {:?}", pair.as_rule()),
}
}
fn parse_schema_definition(
pair: pest::iterators::Pair<Rule>,
) -> Result<SchemaDefinition, pest::error::Error<Rule>> {
let span = pair.as_span();
let rule = pair.as_rule();
let mut inner = pair.into_inner();
let name = inner.next().unwrap().as_str().to_string();
let shape = match rule {
Rule::schema_object_declaration => {
let object = parse_object_schema(inner.next().unwrap())?;
SchemaShape::Object(object)
}
Rule::schema_array_declaration => {
let value_type = parse_type_reference(inner.next().unwrap())?;
if !matches!(value_type.kind, TypeReferenceKind::Array(_)) {
return Err(pest::error::Error::new_from_span(
pest::error::ErrorVariant::CustomError {
message: "schema Name = ... must use an array type like Type[]".to_string(),
},
span,
));
}
SchemaShape::Array(value_type)
}
_ => unreachable!("unexpected schema declaration: {:?}", rule),
};
Ok(SchemaDefinition::new(name, shape))
}
fn parse_object_schema(
pair: pest::iterators::Pair<Rule>,
) -> Result<ObjectSchema, pest::error::Error<Rule>> {
let mut fields = Vec::new();
for field in pair.into_inner() {
fields.push(parse_schema_field(field)?);
}
Ok(ObjectSchema::open(fields))
}
fn parse_schema_field(
pair: pest::iterators::Pair<Rule>,
) -> Result<SchemaField, pest::error::Error<Rule>> {
let mut inner = pair.into_inner();
let name = inner.next().unwrap().as_str().to_string();
let next = inner.next().unwrap();
if next.as_rule() == Rule::field_optional {
let value_type = parse_type_reference(inner.next().unwrap())?;
Ok(SchemaField::optional(name, value_type))
} else {
let value_type = parse_type_reference(next)?;
Ok(SchemaField::required(name, value_type))
}
}
fn parse_type_reference(
pair: pest::iterators::Pair<Rule>,
) -> Result<TypeReference, pest::error::Error<Rule>> {
let mut inner = pair.into_inner();
let base = inner.next().unwrap();
let mut value_type = match base.as_rule() {
Rule::primitive_type => TypeReference::primitive(parse_primitive_type(base.as_str())),
Rule::identifier => TypeReference::named(base.as_str()),
_ => unreachable!("unexpected type reference base: {:?}", base.as_rule()),
};
for part in inner {
match part.as_rule() {
Rule::array_suffix => {
value_type = TypeReference::array(value_type);
}
Rule::nullable_suffix => {
value_type = value_type.nullable();
}
_ => unreachable!("unexpected type reference component: {:?}", part.as_rule()),
}
}
Ok(value_type)
}
fn parse_primitive_type(value: &str) -> PrimitiveType {
match value {
"string" => PrimitiveType::String,
"integer" => PrimitiveType::Integer,
"number" => PrimitiveType::Number,
"boolean" => PrimitiveType::Boolean,
"null" => PrimitiveType::Null,
_ => unreachable!("unexpected primitive type: {}", value),
}
}
fn parse_bounds_body<'a>(
raw: &'a str,
span: pest::Span<'_>,
prefix: &str,
) -> Result<(Option<&'a str>, Option<&'a str>), pest::error::Error<Rule>> {
let body = raw
.trim()
.strip_prefix(prefix)
.and_then(|value| value.trim_start().strip_prefix('('))
.and_then(|value| value.strip_suffix(')'))
.ok_or_else(|| {
pest::error::Error::new_from_span(
pest::error::ErrorVariant::CustomError {
message: format!("invalid {}() predicate", prefix),
},
span.clone(),
)
})?;
let Some((min, max)) = body.split_once("..") else {
return Err(pest::error::Error::new_from_span(
pest::error::ErrorVariant::CustomError {
message: format!("{}() requires min..max syntax", prefix),
},
span,
));
};
let min = if min.trim().is_empty() {
None
} else {
Some(min.trim())
};
let max = if max.trim().is_empty() {
None
} else {
Some(max.trim())
};
if min.is_none() && max.is_none() {
return Err(pest::error::Error::new_from_span(
pest::error::ErrorVariant::CustomError {
message: format!("{}() requires at least one bound", prefix),
},
span,
));
}
Ok((min, max))
}
fn parse_numeric_value(
raw: &str,
span: pest::Span<'_>,
) -> Result<NumericValue, pest::error::Error<Rule>> {
if raw.contains('.') {
raw.parse::<f64>()
.map(NumericValue::Number)
.map_err(|_| {
pest::error::Error::new_from_span(
pest::error::ErrorVariant::CustomError {
message: "range() bounds must be numbers".to_string(),
},
span,
)
})
} else {
raw.parse::<i64>()
.map(NumericValue::Integer)
.map_err(|_| {
pest::error::Error::new_from_span(
pest::error::ErrorVariant::CustomError {
message: "range() bounds must be numbers".to_string(),
},
span,
)
})
}
}
fn normalize_quoted_string(value: &str) -> String {
value[1..value.len() - 1].to_string()
}
fn schema_registry_error_to_span(
error: SchemaRegistryError,
span: pest::Span<'_>,
) -> pest::error::Error<Rule> {
pest::error::Error::new_from_span(
pest::error::ErrorVariant::CustomError {
message: error.to_string(),
},
span,
)
}
pub(super) fn validate_schema_registry(
schema_registry: &SchemaRegistry,
source: &str,
declaration_spans: &DeclarationSpanIndex,
) -> Result<(), pest::error::Error<Rule>> {
schema_registry.validate_references().map_err(|error| {
let span = span_for_schema_registry_error(&error, source, declaration_spans);
schema_registry_error_to_span(error, span)
})
}
pub(super) fn parse_request_collection(
source: &str,
) -> Result<pest::iterators::Pair<'_, Rule>, pest::error::Error<Rule>> {
let mut pairs = CollectionParser::parse(Rule::request_collection, source)
.map_err(|error| rewrite_request_collection_error(source, error))?;
Ok(pairs.next().unwrap())
}
fn rewrite_request_collection_error(
source: &str,
error: pest::error::Error<Rule>,
) -> pest::error::Error<Rule> {
if let Some(span) = misplaced_declaration_span(source, &error) {
return pest::error::Error::new_from_span(
pest::error::ErrorVariant::CustomError {
message: "schema and scalar declarations must appear before the first ---"
.to_string(),
},
span.into_pest_span(source),
);
}
error
}
fn misplaced_declaration_span(
source: &str,
error: &pest::error::Error<Rule>,
) -> Option<SourceSpan> {
let line_number = match error.line_col {
pest::error::LineColLocation::Pos((line, _)) => line,
pest::error::LineColLocation::Span((line, _), _) => line,
};
let line_text = source.lines().nth(line_number.checked_sub(1)?)?;
if !looks_like_declaration(line_text) {
return None;
}
line_span(source, line_number)
}
fn line_span(source: &str, line_number: usize) -> Option<SourceSpan> {
if line_number == 0 {
return None;
}
let mut current_line = 1;
let mut start = 0;
for (index, ch) in source.char_indices() {
if current_line == line_number {
break;
}
if ch == '\n' {
current_line += 1;
start = index + 1;
}
}
if current_line != line_number {
return None;
}
let end = source[start..]
.find('\n')
.map(|offset| start + offset)
.unwrap_or(source.len());
Some(SourceSpan { start, end })
}
pub(super) fn remember_declaration_span(
declaration_spans: &mut DeclarationSpanIndex,
pair: &pest::iterators::Pair<Rule>,
) {
let inner = pair.clone().into_inner().next().unwrap();
let span = SourceSpan::from_pest_span(inner.as_span());
let name = inner.into_inner().next().unwrap().as_str().to_string();
declaration_spans.insert(name, span);
}
fn span_for_schema_registry_error<'a>(
error: &SchemaRegistryError,
source: &'a str,
declaration_spans: &DeclarationSpanIndex,
) -> pest::Span<'a> {
let span = match error {
SchemaRegistryError::ReservedName(name) | SchemaRegistryError::DuplicateName(name) => {
declaration_spans.get(name).copied()
}
SchemaRegistryError::UnknownReference { owner, .. } => {
declaration_spans.get(owner).copied()
}
SchemaRegistryError::InvalidScalarBaseReference { scalar, .. } => {
declaration_spans.get(scalar).copied()
}
SchemaRegistryError::CircularReference(path) => path
.iter()
.find_map(|name| declaration_spans.get(name).copied()),
};
span.map(|span| span.into_pest_span(source))
.unwrap_or_else(|| pest::Span::new(source, 0, source.len()).unwrap())
}
pub(super) fn reject_misplaced_declaration(
raw: &str,
span: pest::Span<'_>,
) -> Result<(), pest::error::Error<Rule>> {
if looks_like_declaration(raw) {
return Err(pest::error::Error::new_from_span(
pest::error::ErrorVariant::CustomError {
message: "schema and scalar declarations must appear before the first ---"
.to_string(),
},
span,
));
}
Ok(())
}
fn looks_like_declaration(raw: &str) -> bool {
let trimmed = raw.trim();
looks_like_scalar_declaration(trimmed) || looks_like_schema_declaration(trimmed)
}
fn looks_like_scalar_declaration(raw: &str) -> bool {
let Some(remainder) = raw.strip_prefix("scalar ") else {
return false;
};
let remainder = remainder.trim_start();
let identifier_len = remainder
.chars()
.take_while(|ch| ch.is_ascii_alphanumeric() || *ch == '_')
.count();
identifier_len > 0 && remainder[identifier_len..].trim_start().starts_with('=')
}
fn looks_like_schema_declaration(raw: &str) -> bool {
let Some(remainder) = raw.strip_prefix("schema ") else {
return false;
};
let remainder = remainder.trim_start();
let identifier_len = remainder
.chars()
.take_while(|ch| ch.is_ascii_alphanumeric() || *ch == '_')
.count();
if identifier_len == 0 {
return false;
}
matches!(
remainder[identifier_len..].trim_start().chars().next(),
Some('{') | Some('=')
)
}