use crate::State;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, thiserror::Error)]
#[error("token limit exceeded: tokens {}, maximum {}", .0.tokens(), .0.limitation())]
pub struct TokenLimitExceeded(TokenLimiter);
impl TokenLimitExceeded {
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn tokens(&self) -> usize {
self.0.tokens()
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn limitation(&self) -> usize {
self.0.limitation()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TokenLimiter {
max: usize,
current: usize,
}
impl Default for TokenLimiter {
#[cfg_attr(not(tarpaulin), inline(always))]
fn default() -> Self {
Self::new()
}
}
impl TokenLimiter {
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn new() -> Self {
Self {
max: usize::MAX,
current: 0,
}
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn with_limitation(max: usize) -> Self {
Self { max, current: 0 }
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn tokens(&self) -> usize {
self.current
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn increase(&mut self) {
self.current += 1;
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn limitation(&self) -> usize {
self.max
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn increase_token(&mut self) {
self.increase();
}
#[cfg_attr(not(tarpaulin), inline(always))]
pub fn check(&self) -> Result<(), TokenLimitExceeded> {
if self.tokens() > self.limitation() {
Err(TokenLimitExceeded(*self))
} else {
Ok(())
}
}
}
impl State for TokenLimiter {
type Error = TokenLimitExceeded;
#[cfg_attr(not(tarpaulin), inline(always))]
fn check(&self) -> Result<(), Self::Error> {
<Self as TokenTracker>::check(self)
}
}
pub trait TokenTracker {
type Error;
fn increase(&mut self);
fn check(&self) -> Result<(), Self::Error>
where
Self: Sized;
}
impl TokenTracker for TokenLimiter {
type Error = TokenLimitExceeded;
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase(&mut self) {
self.increase();
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn check(&self) -> Result<(), Self::Error> {
self.check()
}
}
#[cfg(feature = "logos")]
const _: () = {
use logos::{Lexer, Logos};
use crate::{Token, lexer::LogosLexer};
impl<'a, T> TokenTracker for Lexer<'a, T>
where
T: Logos<'a>,
T::Extras: TokenTracker,
{
type Error = <T::Extras as TokenTracker>::Error;
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase(&mut self) {
self.extras.increase();
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn check(&self) -> Result<(), Self::Error> {
self.extras.check()
}
}
impl<'a, T, L> TokenTracker for LogosLexer<'a, T, L>
where
T: From<L> + Token<'a>,
L: Logos<'a>,
L::Extras: TokenTracker,
{
type Error = <L::Extras as TokenTracker>::Error;
#[cfg_attr(not(tarpaulin), inline(always))]
fn increase(&mut self) {
self.inner_mut().extras.increase();
}
#[cfg_attr(not(tarpaulin), inline(always))]
fn check(&self) -> Result<(), Self::Error> {
self.inner().extras.check()
}
}
};