import os
import re
CI_EXAMPLES = [
'comprehensive_demo.rs',
'output_access_demo.rs',
'parallel_execution_demo.rs',
'per_node_output_access.rs',
'tuple_api_demo.rs',
'variant_demo_full.rs'
]
def add_imports(content):
if 'use std::sync::Arc;' not in content:
lines = content.split('\n')
insert_pos = 0
for i, line in enumerate(lines):
if line.strip().startswith('use '):
insert_pos = i + 1
lines.insert(insert_pos, 'use std::sync::Arc;')
content = '\n'.join(lines)
if 'variants(' in content and 'NodeFunction' not in content:
content = content.replace('use dagex::{', 'use dagex::{NodeFunction, ')
return content
def fix_simple_add_calls(content):
content = re.sub(
r'(\.(add)\(\s*)([a-z_][a-z0-9_]*)(,)',
r'\1Arc::new(\3)\4',
content
)
return content
def fix_closure_wrapping(content):
lines = content.split('\n')
result = []
i = 0
while i < len(lines):
line = lines[i]
if re.search(r'\.(add)\(\s*$', line):
result.append(line)
i += 1
if i < len(lines) and re.match(r'^\s*\|', lines[i]):
closure_line = lines[i]
indent_match = re.match(r'^(\s*)', closure_line)
indent = indent_match.group(1) if indent_match else ''
wrapped_line = indent + 'Arc::new(' + closure_line.strip()
result.append(wrapped_line)
i += 1
brace_count = 1 if '{' in wrapped_line else 0
while i < len(lines) and brace_count > 0:
current_line = lines[i]
brace_count += current_line.count('{') - current_line.count('}')
if brace_count == 0:
if current_line.strip().endswith('},'):
result.append(current_line.replace('},', '}),'))
elif current_line.strip().endswith('}'):
result.append(current_line + ')')
else:
result.append(current_line + ')')
i += 1
break
else:
result.append(current_line)
i += 1
else:
continue
else:
result.append(line)
i += 1
return '\n'.join(result)
def convert_variants_patterns(content):
if '}).collect()' in content and 'variants(' in content:
content = content.replace('}).collect()', '}).map(|f| Arc::new(f) as dagex::NodeFunction).collect()')
return content
def fix_file(filepath):
print(f"Fixing {filepath}")
with open(filepath, 'r') as f:
content = f.read()
content = add_imports(content)
content = fix_simple_add_calls(content)
content = fix_closure_wrapping(content)
content = convert_variants_patterns(content)
with open(filepath, 'w') as f:
f.write(content)
def main():
base_path = '/workspaces/graph-sp/examples/rs'
for example_name in CI_EXAMPLES:
filepath = os.path.join(base_path, example_name)
if os.path.exists(filepath):
fix_file(filepath)
else:
print(f"Warning: {filepath} not found")
print("Done fixing CI examples!")
if __name__ == '__main__':
main()