import pytest
from rustling.wordseg import RandomSegmenter
def test_valid_prob_zero():
segmenter = RandomSegmenter(prob=0.0)
result = segmenter.predict(["hello"])
assert result == [["hello"]]
def test_valid_prob_half():
segmenter = RandomSegmenter(prob=0.5)
result = segmenter.predict(["hello"])
assert len(result) == 1
assert "".join(result[0]) == "hello"
def test_invalid_prob_negative():
with pytest.raises(ValueError, match="prob must be from"):
RandomSegmenter(prob=-0.1)
def test_invalid_prob_one():
with pytest.raises(ValueError, match="prob must be from"):
RandomSegmenter(prob=1.0)
def test_invalid_prob_greater_than_one():
with pytest.raises(ValueError, match="prob must be from"):
RandomSegmenter(prob=1.5)
def test_empty_input():
segmenter = RandomSegmenter(prob=0.5)
result = segmenter.predict([""])
assert result == [[]]
def test_single_char():
segmenter = RandomSegmenter(prob=0.5)
result = segmenter.predict(["a"])
assert result == [["a"]]
def test_unicode():
segmenter = RandomSegmenter(prob=0.0)
result = segmenter.predict(["你好"])
assert result == [["你好"]]
def test_multiple_sentences():
segmenter = RandomSegmenter(prob=0.0)
result = segmenter.predict(["hello", "world"])
assert result == [["hello"], ["world"]]
def test_segments_preserve_content():
segmenter = RandomSegmenter(prob=0.5)
inputs = ["hello", "world", "test123"]
results = segmenter.predict(inputs)
for inp, segments in zip(inputs, results):
assert "".join(segments) == inp
def test_predict_offsets():
segmenter = RandomSegmenter(prob=0.0)
result = segmenter.predict(["hello"], offsets=True)
assert result == [[("hello", (0, 5))]]
def test_predict_offsets_preserves_content():
segmenter = RandomSegmenter(prob=0.5)
input_str = "你好世界"
result = segmenter.predict([input_str], offsets=True)
assert len(result) == 1
for word, (start, end) in result[0]:
assert input_str[start:end] == word
def test_predict_offsets_false_unchanged():
segmenter = RandomSegmenter(prob=0.0)
result = segmenter.predict(["hello"], offsets=False)
assert result == [["hello"]]
def test_predict_offsets_empty():
segmenter = RandomSegmenter(prob=0.5)
assert segmenter.predict([], offsets=True) == []
assert segmenter.predict([""], offsets=True) == [[]]