outlines-core 0.2.14

Structured Generation
Documentation
import copy
import pickle
from typing import Dict, List, Union

import pytest

from outlines_core import Guide, Index, Vocabulary


@pytest.fixture(scope="session")
def index() -> Index:
    eos_token_id = 3
    # types here only to please mypy checks
    tokens: Dict[Union[str, bytes], List[int]] = {"1": [1], "2": [2]}
    regex = r"[1-9]"

    vocabulary = Vocabulary(eos_token_id, tokens)
    return Index(regex, vocabulary)


def test_interface():
    eos_token_id = 3
    tokens = {"1": [1], "a": [2]}
    regex = r"[1-9]"

    vocabulary = Vocabulary(eos_token_id, tokens)
    index = Index(regex, vocabulary)
    guide = Guide(index)

    assert guide.get_state() == index.get_initial_state() == 12
    assert guide.get_tokens() == [1]

    assert guide.advance(1) == [vocabulary.get_eos_token_id()]
    assert guide.is_finished()
    assert guide.get_state() == 20
    assert guide.get_tokens() == [eos_token_id]

    with pytest.raises(
        ValueError,
        match="No next state found for the current state",
    ):
        # No advancement is possible for state with allowed tokens == eos
        assert guide.advance(eos_token_id)
        # As well as with any other random token id
        assert guide.advance(4)


def test_regex_final_state_walk():
    # Make sure that the Guide can walk to the final state correctly.
    eos_token_id = 104
    tokens = {b"\n": [103], b".": [102], b"`": [101]}
    regex = r"`\n(\.\n)?`\n"

    vocabulary = Vocabulary(eos_token_id, tokens)
    index = Index(regex, vocabulary)
    guide = Guide(index)

    assert guide.get_tokens() == [101]
    assert guide.advance(101) == [103]
    assert sorted(guide.advance(103)) == [101, 102]
    assert guide.advance(101) == [103]
    assert guide.advance(103) == [vocabulary.get_eos_token_id()]
    assert guide.is_finished()


def test_token_trans_keys_identical():
    tokens = {"a": [1], "b": [2], "z": [3]}
    eos_token_id = 4
    regex = r"z[ab]z"

    vocabulary = Vocabulary(eos_token_id, tokens)
    index = Index(regex, vocabulary)

    guide1 = Guide(index)
    guide2 = Guide(index)

    assert sorted(guide1.advance(3)) == sorted(guide2.advance(3))
    # `a` and `b` have similar transitions to `z`
    assert sorted(guide1.advance(1)) == sorted(guide2.advance(2))
    assert guide1.advance(3) == guide2.advance(3) == [eos_token_id]
    assert guide1.is_finished()
    assert guide2.is_finished()


def test_str_and_bytes_produce_the_same():
    tokens1 = {"a": [1], "b": [2], "z": [3]}
    tokens2 = {b"a": [1], b"b": [2], b"z": [3]}
    eos_token_id = 4
    regex = r"z[ab]z"

    vocabulary1 = Vocabulary(eos_token_id, tokens1)
    vocabulary2 = Vocabulary(eos_token_id, tokens2)
    index1 = Index(regex, vocabulary1)
    index2 = Index(regex, vocabulary2)
    guide1 = Guide(index1)
    guide2 = Guide(index2)

    assert sorted(guide1.advance(3)) == sorted(guide2.advance(3))
    # `a` and `b` have similar transitions to `z`
    assert sorted(guide1.advance(1)) == sorted(guide2.advance(2))
    assert guide1.advance(3) == guide2.advance(3) == [eos_token_id]
    assert guide1.is_finished()
    assert guide2.is_finished()


def test_pickling(index):
    guide = Guide(index)
    serialized = pickle.dumps(guide)
    deserialized = pickle.loads(serialized)
    assert sorted(deserialized.get_tokens()) == sorted(guide.get_tokens())


@pytest.mark.parametrize(
    "model, revision",
    [
        ("openai-community/gpt2", "607a30d783dfa663caf39e06633721c8d4cfcd7e"),
        ("microsoft/phi-2", "ef382358ec9e382308935a992d908de099b64c23"),
        ("Qwen/Qwen1.5-0.5B-Chat", "4d14e384a4b037942bb3f3016665157c8bcb70ea"),
        (
            "NousResearch/Hermes-2-Pro-Llama-3-8B",
            "783fd50eb82d7f57758de033861f54d62dde234f",
        ),
    ],
)
def test_pickling_from_pretrained_with_revision(model, revision):
    regex = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~"

    vocabulary = Vocabulary.from_pretrained(model, revision=revision)
    index = Index(regex, vocabulary)
    assert len(index.get_transitions()) == 810

    guide = Guide(index)
    serialized = pickle.dumps(guide)
    deserialized = pickle.loads(serialized)
    assert sorted(deserialized.get_tokens()) == sorted(guide.get_tokens())


def test_equality(index):
    guide1 = Guide(index)
    guide2 = Guide(index)
    assert guide1 == guide2

    # confirm that equality is about inner index, not reference difference
    index2 = copy.deepcopy(index)
    guide3 = Guide(index2)
    assert guide3 == guide2 == guide1

    # progress one of the guides, confirm different state == different guide
    guide1.advance(guide1.get_tokens()[-1])
    assert guide1 != guide2
    assert guide3 == guide2


def test_write_mask_into(index):
    import torch

    guide = Guide(index)
    mask = torch.tensor(torch.tensor([-1], dtype=torch.uint32))
    guide.write_mask_into(mask.data_ptr(), mask.numel(), mask.element_size())

    expected_mask = 0
    for token in guide.get_tokens():
        expected_mask |= 1 << (token % 32)

    assert (
        mask[0] == expected_mask
    ), f"mask and expected mask do not match. Mask: {mask[0]}, Expected: {expected_mask}"


def test_write_mask_into_interface(index):
    import torch

    guide = Guide(index)
    mask = torch.tensor(torch.tensor([-1], dtype=torch.uint32))

    with pytest.raises(ValueError, match="Invalid buffer size"):
        guide.write_mask_into(mask.data_ptr(), 0, mask.element_size())
    with pytest.raises(ValueError, match="Invalid element size"):
        guide.write_mask_into(mask.data_ptr(), mask.numel(), 5)
    with pytest.raises(ValueError, match="Invalid data pointer"):
        guide.write_mask_into(0, mask.numel(), mask.element_size())
    with pytest.raises(ValueError, match="Invalid data pointer alignment"):
        guide.write_mask_into(5, mask.numel(), mask.element_size())


def test_rollback(index):
    guide = Guide(index, max_rollback=3)

    first_state = guide.get_state()
    guide.advance(1)

    # Roll back one token to initial state
    guide.rollback_state(1)
    assert not guide.is_finished()
    # we should now be in the initial state
    assert guide.get_state() == first_state


def test_rollback_interface(index):
    guide = Guide(index, max_rollback=3)

    # Rolling back more than recorded history must raise
    with pytest.raises(ValueError, match="Cannot roll back"):
        guide.rollback_state(5)


@pytest.mark.parametrize(
    "seq, expected",
    [
        ([1], True),  # single allowed token: accept
        ([2], True),  # different allowed token: accept
        ([1, 1], False),  # too long for r"[1-9]": reject
        ([2, 3], False),  # extra token: reject
    ],
)
def test_accepts_tokens_correctness(index, seq, expected):
    guide = Guide(index)
    assert guide.accepts_tokens(seq) is expected