import os
import re
from pathlib import Path
def fix_recursive_grammar_format(models_dir):
for filename in ["grammar_format.rs", "grammar_format_1.rs"]:
grammar_file = models_dir / filename
if not grammar_file.exists():
continue
with open(grammar_file, "r") as f:
content = f.read()
original_content = content
content = re.sub(
r"pub grammar: models::GrammarFormat,",
r"pub grammar: Box<models::GrammarFormat>,",
content,
)
content = re.sub(
r"pub grammar: models::GrammarFormat1,",
r"pub grammar: Box<models::GrammarFormat>,",
content,
)
content = re.sub(
r"pub fn new\(r#type: Type, grammar: models::GrammarFormat1\)",
r"pub fn new(r#type: Type, grammar: models::GrammarFormat)",
content,
)
content = re.sub(
r"pub fn new\(r#type: Type, grammar: models::GrammarFormat\) -> (GrammarFormat1?)\s*\{",
r"pub fn new(r#type: Type, grammar: models::GrammarFormat) -> \1 {",
content,
)
if "pub grammar: Box<models::GrammarFormat>" in content:
content = re.sub(
r"(\s+)(r#type,\s+grammar),",
r"\1r#type,\n\1grammar: Box::new(grammar),",
content,
)
content = re.sub(
r"(\s+grammar:) grammar,", r"\1 Box::new(grammar),", content
)
if content != original_content:
with open(grammar_file, "w") as f:
f.write(content)
print(f"Fixed recursive type in {filename}")
def fix_invalid_enum_variants(models_dir):
for file_path in models_dir.glob("*.rs"):
with open(file_path, "r") as f:
content = f.read()
original_content = content
content = re.sub(r"\bGpt4\.1\b", "Gpt4_1", content)
content = re.sub(r"\bGpt4\.5\b", "Gpt4_5", content)
content = re.sub(r"\bGpt3\.5\b", "Gpt3_5", content)
content = re.sub(r"(models::)([A-Z]\w+)-(\d+)", r"\1\2\3", content)
content = re.sub(r"(Box<models::)([A-Z]\w+)-(\d+)", r"\1\2\3", content)
numeric_variant_pattern = re.compile(r"(?m)^(\s*)(\d[\w]*)\s*,\s*$")
matches = list(numeric_variant_pattern.finditer(content))
if matches:
for match in reversed(matches):
indent, raw_name = match.groups()
sanitized = f"Variant{re.sub(r'[^A-Za-z0-9_]', '_', raw_name)}"
existing_pattern = rf'^{indent}(?:#\[serde\(rename = "{re.escape(raw_name)}"\)\]\n{indent})?{re.escape(sanitized)}\s*,'
if re.search(existing_pattern, content, re.MULTILINE):
continue
replacement = (
f'{indent}#[serde(rename = "{raw_name}")]\n{indent}{sanitized},'
)
content = (
content[: match.start()] + replacement + content[match.end() :]
)
content = re.sub(
rf"(::|Self::){re.escape(raw_name)}\b",
rf"\1{sanitized}",
content,
)
if content != original_content:
with open(file_path, "w") as f:
f.write(content)
print(f"Fixed enum variants in {file_path.name}")
def add_display_impl_for_structs(models_dir):
struct_pattern = re.compile(r"pub struct (\w+)")
for file_path in models_dir.glob("*.rs"):
content = file_path.read_text()
if "pub struct " not in content:
continue
structs = struct_pattern.findall(content)
if not structs:
continue
added = []
for struct_name in structs:
derive_regex = re.compile(
rf"#\[derive\(([^)]*?)\)\]\s*pub struct {struct_name}\b", re.DOTALL
)
derive_match = derive_regex.search(content)
if derive_match:
derive_clause = derive_match.group(1)
if "Serialize" not in derive_clause:
continue
else:
continue
if f"impl std::fmt::Display for {struct_name}" in content:
continue
impl_body = f"""
impl std::fmt::Display for {struct_name} {{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{
match serde_json::to_string(self) {{
Ok(s) => write!(f, "{{}}", s),
Err(_) => Err(std::fmt::Error),
}}
}}
}}
"""
added.append(impl_body)
if added:
file_path.write_text(content + "".join(added))
print(f"Added JSON Display impls in {file_path.name}")
def fix_manual_option_box_map(models_dir):
pattern = re.compile(
r"if let Some\(x\) = (?P<var>[A-Za-z_][A-Za-z0-9_]*) \{\s*"
r"Some\(Box::new\(x\)\)\s*\}\s*else \{\s*None\s*\}",
re.DOTALL,
)
for file_path in models_dir.glob("*.rs"):
content = file_path.read_text()
new_content, count = pattern.subn(
lambda m: f"{m.group('var')}.map(Box::new)",
content,
)
if count > 0:
file_path.write_text(new_content)
print(
f"Simplified Option::map pattern in {file_path.name} ({count} occurrences)"
)
def remove_default_from_empty_enums(models_dir):
for file_path in models_dir.glob("*.rs"):
with open(file_path, "r") as f:
content = f.read()
modified = False
if re.search(r"pub enum \w+ \{\s*\}", content):
new_content = re.sub(
r"impl Default for (\w+) \{[^}]*\}\s*\}\s*",
"",
content,
flags=re.DOTALL,
)
new_content = re.sub(
r"(pub enum \w+ \{\s*\})\s*\n\s*\}", r"\1", new_content
)
if new_content != content:
with open(file_path, "w") as f:
f.write(new_content)
modified = True
print(f"Fixed empty enum in {file_path.name}")
def remove_default_from_problematic_structs(models_dir):
non_default_types = set()
for file_path in models_dir.glob("*.rs"):
with open(file_path, "r") as f:
content = f.read()
for enum_match in re.finditer(r"pub enum (\w+)", content):
name = enum_match.group(1)
derive_pattern = rf"#\[derive\([^)]*Default[^)]*\)\]\s*(?:#\[[^\]]*\]\s*)*pub enum {name}\b"
has_derive_default = bool(re.search(derive_pattern, content))
has_impl_default = bool(re.search(rf"impl Default for {name}\b", content))
if not has_derive_default and not has_impl_default:
non_default_types.add(name)
for file_path in models_dir.glob("*.rs"):
with open(file_path, "r") as f:
content = f.read()
match = re.search(r"pub enum (\w+) \{\s*\}", content)
if match:
non_default_types.add(match.group(1))
changes = True
iterations = 0
while changes and iterations < 10: changes = False
iterations += 1
for file_path in models_dir.glob("*.rs"):
with open(file_path, "r") as f:
content = f.read()
struct_match = re.search(r"pub struct (\w+)", content)
if not struct_match:
continue
struct_name = struct_match.group(1)
if not re.search(r"#\[derive\([^)]*Default[^)]*\)\]", content):
if struct_name not in non_default_types:
non_default_types.add(struct_name)
continue
should_remove_default = False
for type_name in non_default_types:
box_pattern = rf"pub \w+: (?:Option<)?Box<(?:models::)?{type_name}>"
direct_pattern = rf"pub \w+: (?:Option<)?(?:models::)?{type_name}\b"
if re.search(box_pattern, content) or re.search(
direct_pattern, content
):
should_remove_default = True
break
if should_remove_default:
new_content = re.sub(
r"#\[derive\(([^)]*?)Default,\s*([^)]*)\)\]",
r"#[derive(\1\2)]",
content,
)
new_content = re.sub(
r"#\[derive\(([^)]*?),\s*Default([^)]*)\)\]",
r"#[derive(\1\2)]",
new_content,
)
new_content = re.sub(r",\s*,", ",", new_content)
new_content = re.sub(r"\(\s*,", "(", new_content)
new_content = re.sub(r",\s*\)", ")", new_content)
if new_content != content:
with open(file_path, "w") as f:
f.write(new_content)
print(f"Removed Default derive from {file_path.name}")
non_default_types.add(struct_name)
changes = True
for file_path in models_dir.glob("*.rs"):
with open(file_path, "r") as f:
content = f.read()
impl_match = re.search(r"impl Default for (\w+)\s*\{", content)
if not impl_match:
continue
enum_name = impl_match.group(1)
if enum_name in non_default_types:
continue
should_remove = False
for type_name in non_default_types:
variant_pattern = rf"\w+\((?:Box<)?(?:models::)?{type_name}>?\)"
if re.search(variant_pattern, content):
should_remove = True
break
if should_remove:
new_content = re.sub(
r"impl Default for "
+ re.escape(enum_name)
+ r"\s*\{[^}]*\{[^}]*\}[^}]*\}",
"",
content,
flags=re.DOTALL,
)
if new_content != content:
with open(file_path, "w") as f:
f.write(new_content)
print(
f"Removed impl Default from enum {enum_name} in {file_path.name}"
)
non_default_types.add(enum_name)
changes = True
def main():
project_root = Path(__file__).parent.parent
models_dir = project_root / "src" / "models"
if not models_dir.exists():
print("No models directory found, skipping fixes")
return
print("Fixing generated Rust code...")
fix_invalid_enum_variants(models_dir)
fix_recursive_grammar_format(models_dir)
add_display_impl_for_structs(models_dir)
fix_manual_option_box_map(models_dir)
remove_default_from_empty_enums(models_dir)
remove_default_from_problematic_structs(models_dir)
print("\nCode fixes applied successfully!")
if __name__ == "__main__":
main()