import pytest
from splintr import Tokenizer, MISTRAL_V3_AGENT_TOKENS
class TestMistralV3Loading:
def test_load_mistral_v3(self):
tok = Tokenizer.from_pretrained("mistral_v3")
assert tok is not None
assert tok.vocab_size > 130000
class TestMistralV3VocabSize:
def test_vocab_size(self):
tok = Tokenizer.from_pretrained("mistral_v3")
assert tok.vocab_size == 131126
def test_v3_much_larger_than_v2(self):
v2 = Tokenizer.from_pretrained("mistral_v2")
v3 = Tokenizer.from_pretrained("mistral_v3")
assert v3.vocab_size > v2.vocab_size * 3
assert v2.vocab_size == 32822
assert v3.vocab_size == 131126
class TestMistralV3NativeSpecialTokens:
@pytest.fixture
def tokenizer(self):
return Tokenizer.from_pretrained("mistral_v3")
def test_bos_token(self, tokenizer):
tokens = tokenizer.encode_with_special("<s>")
assert tokens == [1], f"<s> should be token 1, got {tokens}"
def test_eos_token(self, tokenizer):
tokens = tokenizer.encode_with_special("</s>")
assert tokens == [2], f"</s> should be token 2, got {tokens}"
def test_unk_token(self, tokenizer):
tokens = tokenizer.encode_with_special("<unk>")
assert tokens == [0], f"<unk> should be token 0, got {tokens}"
def test_decode_bos_eos_unk(self, tokenizer):
assert tokenizer.decode([0]) == "<unk>"
assert tokenizer.decode([1]) == "<s>"
assert tokenizer.decode([2]) == "</s>"
class TestMistralV3AgentTokens:
@pytest.fixture
def tokenizer(self):
return Tokenizer.from_pretrained("mistral_v3")
def test_conversation_tokens(self, tokenizer):
tokens = tokenizer.encode_with_special("<|system|>")
assert tokens == [131072]
tokens = tokenizer.encode_with_special("<|user|>")
assert tokens == [131073]
tokens = tokenizer.encode_with_special("<|assistant|>")
assert tokens == [131074]
def test_thinking_tokens(self, tokenizer):
tokens = tokenizer.encode_with_special("<|think|>")
assert tokens == [131077]
tokens = tokenizer.encode_with_special("<|/think|>")
assert tokens == [131078]
def test_function_calling_tokens(self, tokenizer):
tokens = tokenizer.encode_with_special("<|function|>")
assert tokens == [131087]
tokens = tokenizer.encode_with_special("<|/function|>")
assert tokens == [131088]
class TestMistralV3AgentTokensClass:
def test_conversation_tokens(self):
assert MISTRAL_V3_AGENT_TOKENS.SYSTEM == 131072
assert MISTRAL_V3_AGENT_TOKENS.USER == 131073
assert MISTRAL_V3_AGENT_TOKENS.ASSISTANT == 131074
assert MISTRAL_V3_AGENT_TOKENS.IM_START == 131075
assert MISTRAL_V3_AGENT_TOKENS.IM_END == 131076
def test_thinking_tokens(self):
assert MISTRAL_V3_AGENT_TOKENS.THINK == 131077
assert MISTRAL_V3_AGENT_TOKENS.THINK_END == 131078
def test_react_tokens(self):
assert MISTRAL_V3_AGENT_TOKENS.PLAN == 131079
assert MISTRAL_V3_AGENT_TOKENS.PLAN_END == 131080
assert MISTRAL_V3_AGENT_TOKENS.STEP == 131081
assert MISTRAL_V3_AGENT_TOKENS.STEP_END == 131082
assert MISTRAL_V3_AGENT_TOKENS.ACT == 131083
assert MISTRAL_V3_AGENT_TOKENS.ACT_END == 131084
assert MISTRAL_V3_AGENT_TOKENS.OBSERVE == 131085
assert MISTRAL_V3_AGENT_TOKENS.OBSERVE_END == 131086
def test_function_tokens(self):
assert MISTRAL_V3_AGENT_TOKENS.FUNCTION == 131087
assert MISTRAL_V3_AGENT_TOKENS.FUNCTION_END == 131088
assert MISTRAL_V3_AGENT_TOKENS.RESULT == 131089
assert MISTRAL_V3_AGENT_TOKENS.RESULT_END == 131090
assert MISTRAL_V3_AGENT_TOKENS.ERROR == 131091
assert MISTRAL_V3_AGENT_TOKENS.ERROR_END == 131092
def test_code_tokens(self):
assert MISTRAL_V3_AGENT_TOKENS.CODE == 131093
assert MISTRAL_V3_AGENT_TOKENS.CODE_END == 131094
assert MISTRAL_V3_AGENT_TOKENS.OUTPUT == 131095
assert MISTRAL_V3_AGENT_TOKENS.OUTPUT_END == 131096
assert MISTRAL_V3_AGENT_TOKENS.LANG == 131097
assert MISTRAL_V3_AGENT_TOKENS.LANG_END == 131098
def test_rag_tokens(self):
assert MISTRAL_V3_AGENT_TOKENS.CONTEXT == 131099
assert MISTRAL_V3_AGENT_TOKENS.CONTEXT_END == 131100
assert MISTRAL_V3_AGENT_TOKENS.QUOTE == 131101
assert MISTRAL_V3_AGENT_TOKENS.QUOTE_END == 131102
class TestMistralV3DecodeAgentTokens:
@pytest.fixture
def tokenizer(self):
return Tokenizer.from_pretrained("mistral_v3")
def test_decode_system(self, tokenizer):
decoded = tokenizer.decode([131072])
assert decoded == "<|system|>"
def test_decode_user(self, tokenizer):
decoded = tokenizer.decode([131073])
assert decoded == "<|user|>"
def test_decode_assistant(self, tokenizer):
decoded = tokenizer.decode([131074])
assert decoded == "<|assistant|>"
def test_decode_think(self, tokenizer):
decoded = tokenizer.decode([131077])
assert decoded == "<|think|>"
decoded = tokenizer.decode([131078])
assert decoded == "<|/think|>"
class TestMistralV3SpecialTokensMixed:
@pytest.fixture
def tokenizer(self):
return Tokenizer.from_pretrained("mistral_v3")
def test_special_tokens_in_mixed_text(self, tokenizer):
tokens = tokenizer.encode_with_special("<|system|>Hi<|user|>Hello<|assistant|>World")
assert 131072 in tokens assert 131073 in tokens assert 131074 in tokens
decoded = tokenizer.decode(tokens)
assert "<|system|>" in decoded
assert "<|user|>" in decoded
assert "<|assistant|>" in decoded
def test_thinking_tokens_mixed(self, tokenizer):
tokens = tokenizer.encode_with_special("<|think|>reasoning<|/think|>")
assert 131077 in tokens assert 131078 in tokens
decoded = tokenizer.decode(tokens)
assert "<|think|>" in decoded
assert "<|/think|>" in decoded
class TestMistralV3VsOthers:
def test_different_from_v1(self):
v1 = Tokenizer.from_pretrained("mistral_v1")
v3 = Tokenizer.from_pretrained("mistral_v3")
text = "Hello"
v1_tokens = v1.encode(text)
v3_tokens = v3.encode(text)
assert v1_tokens != v3_tokens
def test_different_from_v2(self):
v2 = Tokenizer.from_pretrained("mistral_v2")
v3 = Tokenizer.from_pretrained("mistral_v3")
text = "Test"
v2_tokens = v2.encode(text)
v3_tokens = v3.encode(text)
assert v2_tokens != v3_tokens
class TestMistralV3BasicEncoding:
@pytest.fixture
def tokenizer(self):
return Tokenizer.from_pretrained("mistral_v3")
def test_encodes_text(self, tokenizer):
tokens = tokenizer.encode("Hello")
assert len(tokens) > 0
def test_empty_input(self, tokenizer):
tokens = tokenizer.encode("")
assert tokens == []
decoded = tokenizer.decode([])
assert decoded == ""
def test_batch_encoding(self, tokenizer):
texts = ["Hello", "World", "Test"]
batch_tokens = tokenizer.encode_batch(texts)
assert len(batch_tokens) == 3
for i, text in enumerate(texts):
individual = tokenizer.encode(text)
assert batch_tokens[i] == individual
class TestMistralV3Roundtrip:
@pytest.fixture
def tokenizer(self):
return Tokenizer.from_pretrained("mistral_v3")
def test_roundtrip_hello_world(self, tokenizer):
text = "Hello world"
tokens = tokenizer.encode(text)
decoded = tokenizer.decode(tokens)
assert decoded == text, "Spaces should be preserved"
def test_roundtrip_with_punctuation(self, tokenizer):
text = "Hello, world!"
tokens = tokenizer.encode(text)
decoded = tokenizer.decode(tokens)
assert decoded == text
def test_roundtrip_leading_space(self, tokenizer):
text = " hello world "
tokens = tokenizer.encode(text)
decoded = tokenizer.decode(tokens)
assert decoded == text, "Leading/trailing spaces should be preserved"
def test_roundtrip_multiple_spaces(self, tokenizer):
text = "hello world"
tokens = tokenizer.encode(text)
decoded = tokenizer.decode(tokens)
assert decoded == text, "Multiple spaces should be preserved"
def test_roundtrip_chinese(self, tokenizer):
text = "δ½ ε₯½δΈη"
tokens = tokenizer.encode(text)
decoded = tokenizer.decode(tokens)
assert decoded == text
def test_roundtrip_emoji(self, tokenizer):
text = "Hello π World!"
tokens = tokenizer.encode(text)
decoded = tokenizer.decode(tokens)
assert decoded == text
def test_roundtrip_multiline(self, tokenizer):
text = "Multi-line\ntext\nwith\nnewlines"
tokens = tokenizer.encode(text)
decoded = tokenizer.decode(tokens)
assert decoded == text
def test_roundtrip_code(self, tokenizer):
text = "def hello():\n print('Hello')"
tokens = tokenizer.encode(text)
decoded = tokenizer.decode(tokens)
assert decoded == text