lextrail-test 0.1.0

A library for constraining language model outputs to follow CFG, REGEX and JSON (experimental).
Documentation
use std::collections::{HashMap, HashSet, VecDeque};
use std::mem;

use crate::build::{Symbol, SymbolVar, UUId};
use crate::guide::{CFGGraph, IdState, State, TrailLayer, build_cfg_graph};
use crate::helpers::TrailError;

#[derive(Debug, Clone)]
struct ASMNode {
    value: u8,
    id: UUId,
}

impl Default for ASMNode {
    fn default() -> Self {
        Self {
            value: 0,
            id: UUId::new(),
        }
    }
}

#[derive(Debug, Clone)]
pub struct ASMGraph {
    nodes: HashMap<UUId, ASMNode>,
    heads: HashSet<UUId>,
    edges: HashMap<UUId, HashSet<UUId>>,
    tails: HashSet<UUId>,
}

impl Default for ASMGraph {
    fn default() -> Self {
        Self {
            nodes: HashMap::new(),
            heads: HashSet::new(),
            edges: HashMap::new(),
            tails: HashSet::new(),
        }
    }
}

#[derive(Clone, PartialEq, Eq)]
struct ASMStep {
    accumulator: Vec<u8>,
    id: IdState,
}

impl Default for ASMStep {
    fn default() -> Self {
        Self {
            accumulator: Vec::new(),
            id: IdState::Unset,
        }
    }
}

#[derive(Clone)]
enum ASMToken {
    At(VecDeque<u8>),
    End,
}

impl ASMToken {
    pub fn new(s: &str) -> Self {
        Self::At(s.bytes().collect())
    }
}

#[derive(Clone)]
struct ASMFrame<'a> {
    layers: Vec<TrailLayer<'a>>,
    step: ASMStep,
    token: ASMToken,
}

#[derive(Clone)]
pub struct ASMProposal<'a> {
    frame: ASMFrame<'a>,
    value: String,
}

pub type ASMState<'a> = State<ASMProposal<'a>>;

pub struct ASMSchema {
    pub cfg: CFGGraph,
    pub asm: ASMGraph,
}

pub type ASM<'a> = (ASMSchema, ASMState<'a>);

type UTF8 = Vec<u8>;

pub fn build_asm_graph(alphabet: Vec<String>) -> ASMGraph {
    let (mut graph, mut id) = (ASMGraph::default(), UUId::new());
    let (nodes, tails) = (&mut graph.nodes, &mut graph.tails);

    let tokens: Vec<UTF8> = alphabet.into_iter().map(|s| s.into_bytes()).collect();

    for token in tokens {
        for (i, byte) in token.iter().enumerate() {
            let candidates = if i == 0 {
                &mut graph.heads
            } else {
                graph.edges.entry(id).or_insert(HashSet::new())
            };

            let opt_id = candidates
                .iter()
                .find_map(|id| (nodes[id].value == *byte).then_some(*id));

            match opt_id {
                Some(value) => id = value,
                None => {
                    let new_node = ASMNode {
                        value: *byte,
                        ..ASMNode::default()
                    };
                    let new_id = new_node.id;

                    candidates.insert(new_id);
                    nodes.insert(new_id, new_node);

                    id = new_id;
                }
            }
        }
        tails.insert(id);
    }

    return graph;
}

pub fn asm_cfg<'a>(cfg: &str, alphabet: Vec<String>) -> Result<ASM<'a>, TrailError> {
    return Ok((
        ASMSchema {
            cfg: build_cfg_graph(cfg)?,
            asm: build_asm_graph(alphabet),
        },
        ASMState::default(),
    ));
}

pub fn asm_rex<'a>(rex: &str, alphabet: Vec<String>) -> Result<ASM<'a>, TrailError> {
    return Ok((
        ASMSchema {
            cfg: build_cfg_graph(&format!("/{rex}/"))?,
            asm: build_asm_graph(alphabet),
        },
        ASMState::default(),
    ));
}

pub fn asm_exp<'a>(exp: &str, alphabet: Vec<String>) -> Result<ASM<'a>, TrailError> {
    return Ok((
        ASMSchema {
            cfg: build_cfg_graph(&format!("start: {exp}"))?,
            asm: build_asm_graph(alphabet),
        },
        ASMState::default(),
    ));
}

fn assemble<'a>(graph: &ASMGraph, frame: &mut ASMFrame<'a>) -> Vec<ASMProposal<'a>> {
    let mut proposals: Vec<ASMProposal> = Vec::new();

    let (nodes, tails) = (&graph.nodes, &graph.tails);
    let (token, step) = (&frame.token, &frame.step);

    let mut successors = if let IdState::Set(uuid) = step.id {
        graph.edges.get(&uuid).cloned().unwrap_or_default()
    } else {
        graph.heads.clone()
    };

    let mut bytes = if let ASMToken::At(value) = token {
        value.clone()
    } else {
        unreachable!()
    };

    while !bytes.is_empty() {
        let found = successors.iter().find(|uuid| nodes[uuid].value == bytes[0]);

        match found {
            Some(uuid) => {
                let byte = bytes.pop_front().unwrap();

                frame.step.id = IdState::Set(*uuid);
                frame.step.accumulator.push(byte);

                if tails.contains(uuid) {
                    let mut final_frame = frame.clone();
                    final_frame.token = if bytes.is_empty() {
                        ASMToken::End
                    } else {
                        ASMToken::At(bytes.clone())
                    };

                    let final_token = mem::take(&mut final_frame.step).accumulator;
                    let final_value = String::from_utf8(final_token.into_iter().collect()).unwrap();

                    proposals.push(ASMProposal {
                        frame: final_frame,
                        value: final_value,
                    });
                }

                match graph.edges.get(&uuid) {
                    Some(uuids) => successors = uuids.clone(),
                    None => break,
                }
            }
            None => break,
        }
    }

    frame.token = if bytes.is_empty() {
        ASMToken::End
    } else {
        ASMToken::At(bytes)
    };

    return proposals;
}

fn asm_run<'a>(schema: &'a ASMSchema, state: &mut ASMState<'a>) {
    let (cfg, asm) = (&schema.cfg, &schema.asm);
    let (proposals, backrefs) = (&mut state.proposals, &mut state.backrefs);

    for proposal in &*proposals {
        let checkpoint = proposal.frame.layers.last().unwrap();

        let id = if let IdState::Set(uuid) = checkpoint.id {
            uuid
        } else {
            unreachable!()
        };

        if let Symbol::TERMINAL { value, tags, .. } = &checkpoint.graph.nodes[&id] {
            for tag in tags {
                backrefs
                    .entry(tag.clone())
                    .or_insert_with(String::new)
                    .push_str(value);
            }
        }
    }

    let mut frames = if proposals.is_empty() {
        let start = vec![TrailLayer {
            graph: &cfg[&SymbolVar::new("start")],
            id: IdState::Unset,
        }];

        vec![ASMFrame {
            layers: start,
            step: ASMStep::default(),
            token: ASMToken::End,
        }]
    } else {
        proposals.drain(..).map(|proposal| proposal.frame).collect()
    };

    while !frames.is_empty() {
        let mut frame = frames.pop().unwrap();

        let checkpoint = frame
            .layers
            .last()
            .expect("Empty proposal states are neither processed, nor pushed to the state.");

        let (id, graph) = (checkpoint.id, checkpoint.graph);

        let nodes = &graph.nodes;

        let successors = match id {
            IdState::Set(uuid) => {
                if let ASMToken::End = frame.token {
                    graph.edges.get(&uuid).cloned().unwrap_or_default()
                } else {
                    HashSet::from([uuid])
                }
            }
            IdState::Unset => graph.heads.clone(),
        };

        if successors.is_empty() {
            let layers = &mut frame.layers;

            layers.pop();

            if !layers.is_empty() {
                frames.push(frame);
            }

            continue;
        }

        for successor in &successors {
            match &nodes[successor] {
                Symbol::TERMINAL { value, .. } => {
                    let mut next_frame = frame.clone();
                    next_frame.layers.last_mut().unwrap().id = IdState::Set(*successor);

                    if let ASMToken::End = frame.token {
                        next_frame.token = ASMToken::new(value)
                    }

                    let assembled = assemble(&asm, &mut next_frame);
                    proposals.extend(assembled);

                    if let ASMToken::End = next_frame.token {
                        frames.push(next_frame)
                    }
                }
                Symbol::VARIABLE { value, .. } => {
                    let mut next_frame = frame.clone();
                    next_frame.layers.last_mut().unwrap().id = IdState::Set(*successor);

                    let next_layer = TrailLayer {
                        graph: &cfg[&value],
                        id: IdState::Unset,
                    };
                    next_frame.layers.push(next_layer);

                    frames.push(next_frame);
                }
                Symbol::REFERENCE { value, .. } => {
                    let mut next_frame = frame.clone();
                    next_frame.layers.last_mut().unwrap().id = IdState::Set(*successor);

                    if let ASMToken::End = frame.token {
                        next_frame.token = ASMToken::new(&backrefs[value])
                    }

                    let assembled = assemble(&asm, &mut next_frame);
                    proposals.extend(assembled);

                    if let ASMToken::End = next_frame.token {
                        frames.push(next_frame)
                    }
                }
                Symbol::END { .. } if frame.step == ASMStep::default() => {
                    let mut next_frame = frame.clone();
                    next_frame.layers.last_mut().unwrap().id = IdState::Set(*successor);

                    proposals.push(ASMProposal {
                        frame: next_frame,
                        value: String::new(),
                    });
                }
                _ => (),
            }
        }
    }
}

pub fn get_next_tokens<'a, 'b>(
    schema: &'a ASMSchema,
    state: &mut ASMState<'a>,
    value: &String,
) -> Result<Vec<String>, TrailError> {
    let is_initial = state.proposals.is_empty();

    state.proposals.retain(|p| p.value == *value);

    if state.proposals.is_empty() && !is_initial {
        return Err(TrailError(format!(
            "Symbol `{:#?}` has no previous state.",
            value
        )));
    }

    asm_run(schema, state);

    Ok(state.proposals.iter().map(|p| p.value.clone()).collect())
}