lextrail 0.1.0

A library for constraining language model outputs to follow CFG, REGEX and JSON (experimental).
Documentation
use std::collections::VecDeque;
use std::{env, fmt};

use crate::build::{SymbolGraph, UUId};

pub struct TrailError(pub String);

impl fmt::Debug for TrailError {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "\n{}", self.0)
    }
}

impl fmt::Display for TrailError {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "{}", self.0)
    }
}

impl std::error::Error for TrailError {}

pub fn consume_lexeme(lexemes: &mut Vec<String>, accumulate: &mut Vec<char>) {
    if accumulate.len() > 0 {
        lexemes.push(accumulate.iter().collect());
        accumulate.clear();
    }
}

pub fn is_escaped(unicodes: &Vec<char>, index: usize) -> bool {
    let mut count: i32 = 0;
    let mut index: i32 = index as i32 - 1;

    while index >= 0 && unicodes[index as usize] == '\\' {
        count += 1;
        index -= 1;
    }

    return count % 2 == 1;
}

pub fn peek(unicodes: &Vec<char>, index: usize, offset: i32) -> char {
    let result = index as i32 + offset;
    if 0 <= result && (result as usize) < unicodes.len() {
        return unicodes[result as usize];
    } else {
        return '\0';
    }
}

pub fn get_env(key: &str, default: bool) -> bool {
    return env::var(key)
        .ok()
        .and_then(|s| s.parse::<i32>().ok())
        .map(|n| n != 0)
        .unwrap_or(default);
}

pub fn contains_special_characters(head: &str) -> bool {
    return head.chars().any(|c| "[@!#$%^&*()<>?/\\|}~:".contains(c));
}

pub fn bfs<'a>(graph: &'a SymbolGraph, start: Vec<UUId>) -> Vec<UUId> {
    let mut visited: Vec<UUId> = Vec::new();
    let mut queue: VecDeque<UUId> = start.into_iter().collect();

    while !queue.is_empty() {
        let vertex = queue.pop_front().unwrap();

        if !visited.contains(&vertex) {
            visited.push(vertex);

            let successors = graph.edges.get(&vertex).cloned().unwrap_or_default();
            queue.extend(successors);
        }
    }

    return visited;
}

pub fn format_error(header: &str, context: &str, source: &str) -> String {
    const RED: &str = "\x1b[31m";
    const BLUE: &str = "\x1b[34m";
    const YELLOW: &str = "\x1b[33m";
    const BOLD: &str = "\x1b[1m";
    const RESET: &str = "\x1b[0m";

    let markers = (".".repeat(context.len()), "^".repeat(source.len()));

    format!(
        "{BOLD}{RED}error{RESET}: {header}\n\
         {BOLD}{BLUE}  |{RESET}\n\
         {BOLD}{BLUE}  |{RESET} {context}{source}\n\
         {BOLD}{BLUE}  |{RESET} {BOLD}{YELLOW}{}{BOLD}{RED}{}{RESET}",
        markers.0, markers.1,
    )
}