use crate::error::ParseError;
use crate::parse::lexer;
use std::ops::Range;
use lalrpop_util::ErrorRecovery;
use once_cell::sync::Lazy;
use regex::Regex;
use partiql_ast::ast::{AstNode, SymbolPrimitive};
use partiql_ast::builder::AstNodeBuilder;
use partiql_common::node::{AutoNodeIdGenerator, NodeIdGenerator};
use partiql_common::syntax::location::{ByteOffset, BytePosition, Location};
use partiql_common::syntax::metadata::LocationMap;
type ParseErrorRecovery<'input> =
ErrorRecovery<ByteOffset, lexer::Token<'input>, ParseError<'input, BytePosition>>;
type ParseErrors<'input> = Vec<ParseErrorRecovery<'input>>;
const INIT_LOCATIONS: usize = 100;
pub(crate) struct ParserState<'input, Id: NodeIdGenerator> {
pub node_builder: AstNodeBuilder<Id>,
pub locations: LocationMap,
pub errors: ParseErrors<'input>,
aggregates_pat: &'static Regex,
}
impl Default for ParserState<'_, AutoNodeIdGenerator> {
fn default() -> Self {
ParserState::with_id_gen(AutoNodeIdGenerator::default())
}
}
const KNOWN_AGGREGATES: &str =
"(?i:^count$)|(?i:^avg$)|(?i:^min$)|(?i:^max$)|(?i:^sum$)|(?i:^any$)|(?i:^some$)|(?i:^every$)";
static KNOWN_AGGREGATE_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(KNOWN_AGGREGATES).unwrap());
impl<I> ParserState<'_, I>
where
I: NodeIdGenerator,
{
pub fn with_id_gen(id_gen: I) -> Self {
ParserState {
node_builder: AstNodeBuilder::new(id_gen),
locations: LocationMap::with_capacity(INIT_LOCATIONS),
errors: ParseErrors::default(),
aggregates_pat: &KNOWN_AGGREGATE_PATTERN,
}
}
}
impl<IdGen: NodeIdGenerator> ParserState<'_, IdGen> {
pub fn create_node<T, IntoLoc>(&mut self, node: T, location: IntoLoc) -> AstNode<T>
where
IntoLoc: Into<Location<BytePosition>>,
{
let node = self.node_builder.node(node);
self.locations.insert(node.id, location.into());
node
}
#[inline]
pub fn node<T>(&mut self, ast: T, Range { start, end }: Range<ByteOffset>) -> AstNode<T> {
self.create_node(ast, start.into()..end.into())
}
#[inline]
pub fn is_agg_fn(&self, name: &SymbolPrimitive) -> bool {
self.aggregates_pat.is_match(&name.value)
}
}