import os
import re
from pathlib import Path
from utils import load_spec
def find_nullable_fields(spec):
nullable_fields = {}
schemas = spec.get('components', {}).get('schemas', {})
for schema_name, schema_def in schemas.items():
if not isinstance(schema_def, dict):
continue
properties = schema_def.get('properties', {})
for prop_name, prop_def in properties.items():
if isinstance(prop_def, dict) and prop_def.get('nullable') == True:
if schema_name not in nullable_fields:
nullable_fields[schema_name] = []
nullable_fields[schema_name].append(prop_name)
return nullable_fields
def fix_nullable_field(file_path, field_name):
with open(file_path, 'r') as f:
content = f.read()
original_content = content
field_pattern = rf'(pub {field_name}:\s+)(Box<[^>]+>)'
replacement = r'\1Option<\2>'
content = re.sub(field_pattern, replacement, content)
serde_pattern = rf'(#\[serde\(rename = "{field_name}".*?\]\s*\n\s*pub {field_name}:\s+)(Box<[^>]+>)'
content = re.sub(serde_pattern, r'\1Option<\2>', content)
if content != original_content:
serde_pattern = rf'(#\[serde\(rename = "{field_name}")'
replacement = rf'\1, skip_serializing_if = "Option::is_none"'
if 'skip_serializing_if' not in content:
content = re.sub(serde_pattern, replacement, content)
return content != original_content, content
def fix_struct_nullable_fields(models_dir, struct_name, nullable_fields):
file_name = re.sub(r'(?<!^)(?=[A-Z])', '_', struct_name).lower() + '.rs'
file_path = models_dir / file_name
if not file_path.exists():
print(f"Warning: File {file_path} not found")
return False
with open(file_path, 'r') as f:
content = f.read()
original_content = content
changes_made = False
for field_name in nullable_fields:
snake_field = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', field_name).lower()
changed, content = fix_nullable_field(file_path, snake_field)
if changed:
changes_made = True
print(f" Fixed nullable field: {snake_field}")
if changes_made:
if 'Default' not in content and '#[derive(' in content:
content = re.sub(
r'(#\[derive\([^)]*)(PartialEq)',
r'\1Default, \2',
content
)
print(f" Re-added Default derive")
for field_name in nullable_fields:
snake_field = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', field_name).lower()
pattern = rf'({snake_field}:\s+)(Box::new\([^)]+\))'
replacement = r'\1Some(\2)'
content = re.sub(pattern, replacement, content)
with open(file_path, 'w') as f:
f.write(content)
return True
return False
def main():
if len(os.sys.argv) < 2:
print("Usage: fix_nullable_fields.py <root_dir> [spec_path]")
os.sys.exit(1)
root_dir = Path(os.sys.argv[1])
spec_path = Path(os.sys.argv[2]) if len(os.sys.argv) > 2 else root_dir / 'stainless.yaml'
models_dir = root_dir / 'src' / 'models'
if not models_dir.exists():
print(f"Models directory not found: {models_dir}")
os.sys.exit(1)
if not spec_path.exists():
print(f"OpenAPI spec not found: {spec_path}")
os.sys.exit(1)
print(f"Loading OpenAPI spec from: {spec_path}")
spec = load_spec(str(spec_path))
print("Finding nullable fields in spec...")
nullable_fields = find_nullable_fields(spec)
print(f"Found {len(nullable_fields)} schemas with nullable fields")
fixed_count = 0
for struct_name, fields in nullable_fields.items():
print(f"Checking {struct_name} with {len(fields)} nullable fields...")
if fix_struct_nullable_fields(models_dir, struct_name, fields):
fixed_count += 1
print(f"Fixed {fixed_count} structs with nullable fields")
if __name__ == '__main__':
main()