import pytest
from splintr import Tokenizer
class TestMistralV1ExactTokens:
@pytest.fixture
def tokenizer(self):
return Tokenizer.from_pretrained("mistral_v1")
def test_hello_world_tokens(self, tokenizer):
tokens = tokenizer.encode("Hello world")
assert tokens == [16230, 1526], f"Expected [16230, 1526], got {tokens}"
def test_hello_world_punctuation(self, tokenizer):
tokens = tokenizer.encode("Hello, world!")
decoded = tokenizer.decode(tokens)
assert decoded == "Hello, world!", f"Roundtrip failed: got {decoded!r}"
def test_space_preservation(self, tokenizer):
tokens = tokenizer.encode(" world!")
decoded = tokenizer.decode(tokens)
assert decoded == " world!", f"Space not preserved: got {decoded!r}"
def test_chinese_tokens(self, tokenizer):
text = "你好世界"
tokens = tokenizer.encode(text)
decoded = tokenizer.decode(tokens)
assert decoded == text, f"Chinese roundtrip failed: {decoded!r}"
def test_emoji_tokens(self, tokenizer):
text = "Hello 🌍 World!"
tokens = tokenizer.encode(text)
decoded = tokenizer.decode(tokens)
assert decoded == text, f"Emoji roundtrip failed: {decoded!r}"
class TestMistralV1Roundtrip:
@pytest.fixture
def tokenizer(self):
return Tokenizer.from_pretrained("mistral_v1")
def test_encode_decode_roundtrip(self, tokenizer):
test_cases = [
"Hello, world!",
"The quick brown fox jumps over the lazy dog.",
"Rust is a systems programming language.",
"1234567890",
"Special characters: !@#$%^&*()",
"Unicode: こんにちは 世界 🦀",
"Mixed: Hello 你好 🌍 World!",
]
for text in test_cases:
tokens = tokenizer.encode(text)
decoded = tokenizer.decode(tokens)
assert decoded == text, f"Roundtrip failed for: {text!r}"
def test_multiline_roundtrip(self, tokenizer):
text = "Multi-line\ntext\nwith\nnewlines"
tokens = tokenizer.encode(text)
decoded = tokenizer.decode(tokens)
assert decoded == text, f"Roundtrip failed for: {text!r}"
def test_code_content(self, tokenizer):
code = '''def hello_world():
print("Hello, World!")
if __name__ == "__main__":
hello_world()
'''
tokens = tokenizer.encode(code)
decoded = tokenizer.decode(tokens)
assert decoded == code
class TestMistralV1SpecialTokens:
@pytest.fixture
def tokenizer(self):
return Tokenizer.from_pretrained("mistral_v1")
def test_bos_eos_tokens(self, tokenizer):
tokens = tokenizer.encode_with_special("<s>")
assert tokens == [1], f"<s> should be token 1, got {tokens}"
tokens = tokenizer.encode_with_special("</s>")
assert tokens == [2], f"</s> should be token 2, got {tokens}"
def test_v1_tokenizes_inst_as_text(self, tokenizer):
tokens = tokenizer.encode_with_special("[INST]")
assert len(tokens) > 1, "[INST] should be multiple text tokens in V1"
decoded = tokenizer.decode(tokens)
assert decoded == "[INST]"
def test_agent_tokens(self, tokenizer):
tokens = tokenizer.encode_with_special("<|think|>")
assert tokens == [32005], f"<|think|> should be [32005], got {tokens}"
tokens = tokenizer.encode_with_special("<|function|>")
assert tokens == [32015], f"<|function|> should be [32015], got {tokens}"
def test_decode_agent_tokens(self, tokenizer):
assert tokenizer.decode([32005]) == "<|think|>"
assert tokenizer.decode([32015]) == "<|function|>"
class TestMistralV1VocabSize:
def test_vocab_size(self):
tok = Tokenizer.from_pretrained("mistral_v1")
assert tok.vocab_size == 32054
def test_default_mistral_is_v1(self):
tok = Tokenizer.from_pretrained("mistral")
assert tok.vocab_size == 32054
def test_hyphenated_names_rejected(self):
with pytest.raises(ValueError):
Tokenizer.from_pretrained("mistral-v1")
with pytest.raises(ValueError):
Tokenizer.from_pretrained("mistral-7b")
class TestMistralV1Batch:
@pytest.fixture
def tokenizer(self):
return Tokenizer.from_pretrained("mistral_v1")
def test_batch_matches_individual(self, tokenizer):
texts = [
"Hello, world!",
"How are you?",
"I'm doing great!",
"Unicode: 你好 🌍",
]
batch_tokens = tokenizer.encode_batch(texts)
assert len(batch_tokens) == len(texts)
for i, text in enumerate(texts):
individual = tokenizer.encode(text)
assert batch_tokens[i] == individual, (
f"Batch mismatch for text {i}: {text!r}"
)
def test_empty_input(self, tokenizer):
assert tokenizer.encode("") == []
assert tokenizer.decode([]) == ""
class TestMistralV1Utf8Boundaries:
@pytest.fixture
def tokenizer(self):
return Tokenizer.from_pretrained("mistral_v1")
@pytest.fixture
def tokenizer_pcre2(self):
return Tokenizer.from_pretrained("mistral_v1").pcre2(True)
def test_em_dash(self, tokenizer):
text = "I'm sorry you're hurting—breakups suck, but you'll get through it."
tokens = tokenizer.encode(text)
decoded = tokenizer.decode(tokens)
assert decoded == text
def test_curly_quotes(self, tokenizer):
text = 'He said, \u2018Hello\u2019 and she replied, \u201cGoodbye\u201d.'
tokens = tokenizer.encode(text)
decoded = tokenizer.decode(tokens)
assert decoded == text
def test_em_dash_at_boundaries(self, tokenizer):
texts = [
"word—word",
"a—b",
"test—",
"—start",
"one—two—three",
"Check your brake pads—they might be worn out.",
]
for text in texts:
tokens = tokenizer.encode(text)
decoded = tokenizer.decode(tokens)
assert decoded == text, f"Failed for: {text!r}"
def test_backend_consistency_multibyte(self, tokenizer, tokenizer_pcre2):
texts = [
"word—word",
"I'm sorry you're hurting—breakups suck.",
'He said, \u2018Hello\u2019 and she replied, \u201cGoodbye\u201d.',
"Check credentials—API key—in headers.",
]
for text in texts:
tokens_regexr = tokenizer.encode(text)
tokens_pcre2 = tokenizer_pcre2.encode(text)
assert tokens_regexr == tokens_pcre2, f"Backend mismatch for: {text!r}"
class TestMistralV1LargeScaleBatch:
@pytest.fixture
def tokenizer(self):
return Tokenizer.from_pretrained("mistral_v1")
def test_large_batch_parallel(self, tokenizer):
base_texts = [
"I'm sorry you're hurting—breakups suck, but you'll get through it.",
"Check if you're using valid credentials—API key, token—in headers.",
"你好世界!这是一个测试。",
"Hello 🌍 World! 🦀 Rust is great!",
"Mixed: Hello 你好 🌍 —test— World!",
"Code: def foo(): return 42",
"A 403 Forbidden error means permission denied.",
]
texts = base_texts * 100
all_tokens = tokenizer.encode_batch(texts)
assert len(all_tokens) == len(texts)
for i in range(0, len(texts), 50):
decoded = tokenizer.decode(all_tokens[i])
assert decoded == texts[i], f"Failed roundtrip for text {i}"
class TestMistralV1BackendOptions:
def test_default_backend(self):
tokenizer = Tokenizer.from_pretrained("mistral_v1")
text = "Hello, world!"
tokens = tokenizer.encode(text)
assert tokenizer.decode(tokens) == text
def test_pcre2_backend(self):
tokenizer = Tokenizer.from_pretrained("mistral_v1").pcre2(True)
text = "Hello, world!"
tokens = tokenizer.encode(text)
assert tokenizer.decode(tokens) == text
def test_jit_disabled(self):
tokenizer = Tokenizer.from_pretrained("mistral_v1").jit(False)
text = "Hello, world!"
tokens = tokenizer.encode(text)
assert tokenizer.decode(tokens) == text
def test_backend_consistency(self):
text = "The quick brown fox 你好 🦀 jumps—over—the lazy dog."
tok_default = Tokenizer.from_pretrained("mistral_v1")
tok_pcre2 = Tokenizer.from_pretrained("mistral_v1").pcre2(True)
tok_no_jit = Tokenizer.from_pretrained("mistral_v1").jit(False)
tokens_default = tok_default.encode(text)
tokens_pcre2 = tok_pcre2.encode(text)
tokens_no_jit = tok_no_jit.encode(text)
assert tokens_default == tokens_pcre2, "PCRE2 should match default"
assert tokens_default == tokens_no_jit, "Non-JIT should match default"