use crate::{
lexer::State,
utils::{recursion_tracker::RecursionLimiter, token_tracker::TokenLimiter},
};
use super::{
recursion_tracker::{RecursionLimitExceeded, RecursionTracker},
token_tracker::{TokenLimitExceeded, TokenTracker},
};
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
thiserror::Error,
derive_more::IsVariant,
derive_more::Unwrap,
derive_more::TryUnwrap,
)]
#[unwrap(ref)]
#[try_unwrap(ref)]
#[non_exhaustive]
pub enum LimitExceeded {
#[error(transparent)]
Token(#[from] TokenLimitExceeded),
#[error(transparent)]
Recursion(#[from] RecursionLimitExceeded),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Limiter {
token_tracker: TokenLimiter,
recursion_tracker: RecursionLimiter,
}
impl Default for Limiter {
#[cfg_attr(not(tarpaulin), inline(always))]
fn default() -> Self {
Self::new()
}
}
impl Limiter {
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn new() -> Self {
Self::with_trackers(TokenLimiter::new(), RecursionLimiter::new())
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn with_token_tracker(token_tracker: TokenLimiter) -> Self {
Self::with_trackers(token_tracker, RecursionLimiter::new())
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn with_recursion_tracker(recursion_tracker: RecursionLimiter) -> Self {
Self::with_trackers(TokenLimiter::new(), recursion_tracker)
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn with_trackers(
token_tracker: TokenLimiter,
recursion_tracker: RecursionLimiter,
) -> Self {
Self {
token_tracker,
recursion_tracker,
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn token(&self) -> &TokenLimiter {
&self.token_tracker
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn token_mut(&mut self) -> &mut TokenLimiter {
&mut self.token_tracker
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn recursion(&self) -> &RecursionLimiter {
&self.recursion_tracker
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn recursion_mut(&mut self) -> &mut RecursionLimiter {
&mut self.recursion_tracker
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn increase_token(&mut self) {
self.token_mut().increase();
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn increase_recursion(&mut self) {
self.recursion_mut().increase();
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn decrease_recursion(&mut self) {
self.recursion_mut().decrease();
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn check(&self) -> Result<(), LimitExceeded> {
self
.recursion_tracker
.check()
.map_err(LimitExceeded::from)?;
self.token_tracker.check().map_err(LimitExceeded::from)?;
Ok(())
}
}
impl State for Limiter {
type Error = LimitExceeded;
#[cfg_attr(not(tarpaulin), inline(always))]
fn check(&self) -> Result<(), Self::Error> {
<Self as Tracker>::check(self)
}
}
impl RecursionTracker for Limiter {
type Error = LimitExceeded;
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase(&mut self) {
self.recursion_tracker.increase();
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn decrease(&mut self) {
self.recursion_tracker.decrease();
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn check(&self) -> Result<(), Self::Error> {
self.recursion_tracker.check().map_err(Into::into)
}
}
impl TokenTracker for Limiter {
type Error = LimitExceeded;
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase(&mut self) {
self.token_tracker.increase();
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn check(&self) -> Result<(), Self::Error> {
self.token_tracker.check().map_err(Into::into)
}
}
pub trait Tracker {
type Error;
fn increase_token(&mut self);
fn increase_recursion(&mut self);
fn decrease_recursion(&mut self);
fn check(&self) -> Result<(), Self::Error>;
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase_token_and_decrease_recursion(&mut self) {
self.increase_token();
self.decrease_recursion();
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase_token_and_decrease_recursion_and_check(&mut self) -> Result<(), Self::Error> {
self.increase_token_and_decrease_recursion();
self.check()
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase_token_and_check(&mut self) -> Result<(), Self::Error> {
self.increase_token();
self.check()
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase_both(&mut self) {
self.increase_token();
self.increase_recursion();
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase_both_and_check(&mut self) -> Result<(), Self::Error> {
self.increase_both();
self.check()
}
}
impl Tracker for Limiter {
type Error = LimitExceeded;
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase_token(&mut self) {
self.increase_token();
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase_recursion(&mut self) {
self.increase_recursion();
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn decrease_recursion(&mut self) {
self.decrease_recursion();
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase_token_and_check(&mut self) -> Result<(), Self::Error> {
self.increase_token();
<Self as TokenTracker>::check(self)
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase_token_and_decrease_recursion_and_check(&mut self) -> Result<(), Self::Error> {
self.increase_token();
self.decrease_recursion();
<Self as TokenTracker>::check(self)
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn check(&self) -> Result<(), Self::Error> {
self.check()
}
}
#[cfg(feature = "logos")]
const _: () = {
use logos::{Lexer, Logos};
use crate::{Token, lexer::LogosLexer};
impl<'a, T> Tracker for Lexer<'a, T>
where
T: Logos<'a>,
T::Extras: Tracker,
{
type Error = <T::Extras as Tracker>::Error;
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase_token(&mut self) {
self.extras.increase_token();
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase_recursion(&mut self) {
self.extras.increase_recursion();
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn decrease_recursion(&mut self) {
self.extras.decrease_recursion();
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn check(&self) -> Result<(), Self::Error> {
self.extras.check()
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase_token_and_check(&mut self) -> Result<(), Self::Error> {
self.extras.increase_token_and_check()
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase_both(&mut self) {
self.extras.increase_both();
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase_both_and_check(&mut self) -> Result<(), Self::Error> {
self.extras.increase_both_and_check()
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase_token_and_decrease_recursion(&mut self) {
self.extras.increase_token_and_decrease_recursion();
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase_token_and_decrease_recursion_and_check(&mut self) -> Result<(), Self::Error> {
self
.extras
.increase_token_and_decrease_recursion_and_check()
}
}
impl<'a, T, L> Tracker for LogosLexer<'a, T, L>
where
T: From<L> + Token<'a>,
L: Logos<'a>,
L::Extras: Tracker,
{
type Error = <L::Extras as Tracker>::Error;
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase_token(&mut self) {
self.inner_mut().increase_token();
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase_recursion(&mut self) {
self.inner_mut().increase_recursion();
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn decrease_recursion(&mut self) {
self.inner_mut().decrease_recursion();
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn check(&self) -> Result<(), Self::Error> {
self.inner().check()
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase_token_and_check(&mut self) -> Result<(), Self::Error> {
self.inner_mut().increase_token_and_check()
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase_both(&mut self) {
self.inner_mut().increase_both();
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase_both_and_check(&mut self) -> Result<(), Self::Error> {
self.inner_mut().increase_both_and_check()
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase_token_and_decrease_recursion(&mut self) {
self.inner_mut().increase_token_and_decrease_recursion();
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase_token_and_decrease_recursion_and_check(&mut self) -> Result<(), Self::Error> {
self
.inner_mut()
.increase_token_and_decrease_recursion_and_check()
}
}
};