import pytest
import io
import tempfile
import os
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
from mrrc import MARCReader, MARCWriter
class TestWriteGILRelease:
def test_write_single_record(self, fixture_1k):
data = io.BytesIO(fixture_1k)
reader = MARCReader(data)
records = []
for record in reader:
records.append(record)
if len(records) >= 1:
break
output = io.BytesIO()
writer = MARCWriter(output)
for record in records:
writer.write_record(record)
writer.close()
output.seek(0)
data = output.read()
assert len(data) > 0
def test_write_multiple_records(self, fixture_1k):
data = io.BytesIO(fixture_1k)
reader = MARCReader(data)
records = list(reader)
assert len(records) > 0
output = io.BytesIO()
writer = MARCWriter(output)
for record in records:
writer.write_record(record)
writer.close()
output.seek(0)
data = output.read()
assert len(data) > 0
def test_sequential_write_2x_1k(self, fixture_1k):
def read_records():
reader = MARCReader(io.BytesIO(fixture_1k))
return list(reader)
records_a = read_records()
records_b = read_records()
all_records = records_a + records_b
output = io.BytesIO()
writer = MARCWriter(output)
for record in all_records:
writer.write_record(record)
writer.close()
output.seek(0)
sequential_data = output.read()
assert len(sequential_data) > 0
assert len(all_records) == len(records_a) + len(records_b)
def test_concurrent_write_2x_1k_speedup(self, fixture_1k):
reader = MARCReader(io.BytesIO(fixture_1k))
records = list(reader)
assert len(records) > 0
def write_records(records_copy):
output = io.BytesIO()
writer = MARCWriter(output)
for record in records_copy:
writer.write_record(record)
writer.close()
output.seek(0)
return output.read()
sequential_data = write_records(records)
with ThreadPoolExecutor(max_workers=2) as executor:
futures = [
executor.submit(write_records, records),
executor.submit(write_records, records),
]
concurrent_data = [f.result() for f in futures]
assert len(concurrent_data) == 2
assert all(len(d) > 0 for d in concurrent_data)
assert concurrent_data[0] == sequential_data
assert concurrent_data[1] == sequential_data
def test_concurrent_write_4x_1k(self, fixture_1k):
reader = MARCReader(io.BytesIO(fixture_1k))
records = list(reader)
def write_records(records_copy):
output = io.BytesIO()
writer = MARCWriter(output)
for record in records_copy:
writer.write_record(record)
writer.close()
output.seek(0)
return output.read()
baseline = write_records(records)
with ThreadPoolExecutor(max_workers=4) as executor:
results = list(
executor.map(write_records, [records] * 4)
)
assert len(results) == 4
assert all(r == baseline for r in results)
class TestRoundTrip:
def test_round_trip_basic(self, fixture_1k):
reader = MARCReader(io.BytesIO(fixture_1k))
records_original = list(reader)
assert len(records_original) > 0
output = io.BytesIO()
writer = MARCWriter(output)
for record in records_original:
writer.write_record(record)
writer.close()
output.seek(0)
reader2 = MARCReader(output)
records_roundtrip = list(reader2)
assert len(records_roundtrip) == len(records_original)
for orig, roundtrip in zip(records_original, records_roundtrip):
assert orig == roundtrip
def test_round_trip_preserves_fields(self, fixture_1k):
reader = MARCReader(io.BytesIO(fixture_1k))
records_original = list(reader)
output = io.BytesIO()
writer = MARCWriter(output)
for record in records_original:
writer.write_record(record)
writer.close()
output.seek(0)
reader2 = MARCReader(output)
records_roundtrip = list(reader2)
for orig, rt in zip(records_original, records_roundtrip):
assert orig.leader().record_type == rt.leader().record_type
assert orig.leader().bibliographic_level == rt.leader().bibliographic_level
orig_title = orig.title
rt_title = rt.title
if orig_title:
assert rt_title == orig_title
orig_author = orig.author
rt_author = rt.author
if orig_author:
assert rt_author == orig_author
def test_round_trip_with_modification(self, fixture_1k):
reader = MARCReader(io.BytesIO(fixture_1k))
records_original = list(reader)
assert len(records_original) > 0
for i, record in enumerate(records_original[:3]):
leader = record.leader()
leader.record_status = 'c'
leader.encoding_level = 'I'
leader.cataloging_form = 'a'
output = io.BytesIO()
writer = MARCWriter(output)
for record in records_original:
writer.write_record(record)
writer.close()
output.seek(0)
reader2 = MARCReader(output)
records_roundtrip = list(reader2)
assert len(records_roundtrip) == len(records_original)
for i, (orig, roundtrip) in enumerate(zip(records_original[:3], records_roundtrip[:3])):
orig_leader = orig.leader()
rt_leader = roundtrip.leader()
assert rt_leader.record_status == 'c'
assert rt_leader.encoding_level == 'I'
assert rt_leader.cataloging_form == 'a'
assert rt_leader.record_status == orig_leader.record_status
assert rt_leader.encoding_level == orig_leader.encoding_level
assert rt_leader.cataloging_form == orig_leader.cataloging_form
for i, (orig, roundtrip) in enumerate(zip(records_original[3:], records_roundtrip[3:]), start=3):
assert orig.leader() == roundtrip.leader()
def test_round_trip_large_file(self, fixture_10k):
reader = MARCReader(io.BytesIO(fixture_10k))
records_original = list(reader)
count_original = len(records_original)
output = io.BytesIO()
writer = MARCWriter(output)
for record in records_original:
writer.write_record(record)
writer.close()
output.seek(0)
reader2 = MARCReader(output)
records_roundtrip = list(reader2)
assert len(records_roundtrip) == count_original
assert records_original[0] == records_roundtrip[0]
assert records_original[-1] == records_roundtrip[-1]
class TestWriteEdgeCases:
def test_write_empty_file(self):
output = io.BytesIO()
writer = MARCWriter(output)
writer.close()
output.seek(0)
data = output.read()
assert len(data) >= 0
def test_write_context_manager(self, fixture_1k):
reader = MARCReader(io.BytesIO(fixture_1k))
records = list(reader)
output = io.BytesIO()
with MARCWriter(output) as writer:
for record in records:
writer.write_record(record)
output.seek(0)
data = output.read()
assert len(data) > 0
def test_write_after_close_raises_error(self, fixture_1k):
reader = MARCReader(io.BytesIO(fixture_1k))
record = next(reader)
output = io.BytesIO()
writer = MARCWriter(output)
writer.close()
with pytest.raises(RuntimeError):
writer.write_record(record)
def test_write_close_idempotent(self, fixture_1k):
output = io.BytesIO()
writer = MARCWriter(output)
writer.close()
writer.close() writer.close()
class TestRustFileBackend:
def test_write_roundtrip_rust_file(self, fixture_1k):
reader = MARCReader(io.BytesIO(fixture_1k))
records_original = list(reader)
assert len(records_original) > 0
with tempfile.NamedTemporaryFile(delete=False, suffix='.mrc') as tmp:
temp_path = tmp.name
try:
writer = MARCWriter(temp_path)
for record in records_original:
writer.write_record(record)
writer.close()
with open(temp_path, 'rb') as f:
reader2 = MARCReader(f)
records_roundtrip = list(reader2)
assert len(records_roundtrip) == len(records_original)
for orig, roundtrip in zip(records_original, records_roundtrip):
assert orig == roundtrip
finally:
if os.path.exists(temp_path):
os.unlink(temp_path)
def test_write_roundtrip_pathlib_path(self, fixture_1k):
reader = MARCReader(io.BytesIO(fixture_1k))
records_original = list(reader)
assert len(records_original) > 0
with tempfile.NamedTemporaryFile(delete=False, suffix='.mrc') as tmp:
temp_path = Path(tmp.name)
try:
writer = MARCWriter(temp_path)
for record in records_original:
writer.write_record(record)
writer.close()
with open(temp_path, 'rb') as f:
reader2 = MARCReader(f)
records_roundtrip = list(reader2)
assert len(records_roundtrip) == len(records_original)
for orig, roundtrip in zip(records_original, records_roundtrip):
assert orig == roundtrip
finally:
if temp_path.exists():
temp_path.unlink()
def test_write_multiple_records_rust_file(self, fixture_1k):
reader = MARCReader(io.BytesIO(fixture_1k))
records = list(reader)
assert len(records) > 0
with tempfile.NamedTemporaryFile(delete=False, suffix='.mrc') as tmp:
temp_path = tmp.name
try:
writer = MARCWriter(temp_path)
for record in records:
writer.write_record(record)
writer.close()
with open(temp_path, 'rb') as f:
reader2 = MARCReader(f)
roundtrip_records = list(reader2)
assert len(roundtrip_records) == len(records)
finally:
if os.path.exists(temp_path):
os.unlink(temp_path)
def test_concurrent_writes_different_files(self, fixture_1k):
reader = MARCReader(io.BytesIO(fixture_1k))
records = list(reader)
assert len(records) > 0
def write_to_file(file_index):
with tempfile.NamedTemporaryFile(delete=False, suffix=f'_{file_index}.mrc') as tmp:
temp_path = tmp.name
try:
writer = MARCWriter(temp_path)
for record in records:
writer.write_record(record)
writer.close()
with open(temp_path, 'rb') as f:
reader2 = MARCReader(f)
roundtrip = list(reader2)
return len(roundtrip) == len(records), temp_path
except Exception as e:
if os.path.exists(temp_path):
os.unlink(temp_path)
raise e
with ThreadPoolExecutor(max_workers=2) as executor:
results = list(executor.map(write_to_file, range(2)))
for success, temp_path in results:
assert success, f"Write to {temp_path} failed"
if os.path.exists(temp_path):
os.unlink(temp_path)