import os
import subprocess
import numpy as np
def serialize_array(arr):
template = 'Array::from(&{})'
if len(arr.shape) == 1:
return template.format(str([[x] for x in arr]).replace('[', 'vec!['))
else:
return template.format(str(arr).replace('[', 'vec!['))
class Module(object):
TEMPLATE = """
#[cfg(test)]
#[allow(unused_imports)]
{flags}
mod generated_tests {{
{imports}
{tests}
}}
"""
def __init__(self, imports=None, flags=None):
self.imports = (imports or []) + ['prelude::*',
'super::*']
self.flags = flags or []
self.tests = []
def add_test(self, test):
self.tests.append(test)
def render_flags(self):
return '\n'.join(self.flags)
def render_imports(self):
return '\n'.join(['use ' + x + ';' for x in self.imports])
def render(self):
return self.TEMPLATE.format(flags=self.render_flags(),
imports=self.render_imports(),
tests='\n'.join([x.render() for x
in self.tests]))
def write(self, fname):
with open(fname, 'wb') as datafile:
datafile.write(self.render())
subprocess.check_call(['rustfmt',
fname,
'--write-mode=overwrite'])
class Test(object):
TEMPLATE = """
#[test]
fn {name}() {{
// Body goes here
}}
"""
SERIALIZERS = {np.ndarray: serialize_array}
def __init__(self, name, args):
self.name = name
self.args = args
def _render_args(self):
rendered = {}
for key, value in self.args.items():
for tpe, fnc in self.SERIALIZERS.items():
if isinstance(value, tpe):
rendered[key] = fnc(value)
else:
rendered[key] = str(value)
return rendered
def render(self):
args = self._render_args()
args['name'] = self.name
return self.TEMPLATE.format(**args)