rustlearn 0.4.3

A machine learning package for Rust.
Documentation
import os

import subprocess

import numpy as np


def serialize_array(arr):

    template = 'Array::from(&{})'

    if len(arr.shape) == 1:
        return template.format(str([[x] for x in arr]).replace('[', 'vec!['))
    else:
        return template.format(str(arr).replace('[', 'vec!['))


class Module(object):

    TEMPLATE = """

#[cfg(test)]
#[allow(unused_imports)]
{flags}
mod generated_tests {{
    {imports}

    {tests}

}}

"""

    def __init__(self, imports=None, flags=None):

        self.imports = (imports or []) + ['prelude::*',
                                          'super::*']
        self.flags = flags or []

        self.tests = []

    def add_test(self, test):

        self.tests.append(test)

    def render_flags(self):

        return '\n'.join(self.flags)

    def render_imports(self):

        return '\n'.join(['use ' + x  + ';' for x in self.imports])

    def render(self):

        return self.TEMPLATE.format(flags=self.render_flags(),
                                    imports=self.render_imports(),
                                    tests='\n'.join([x.render() for x
                                                     in self.tests]))

    def write(self, fname):

        with open(fname, 'wb') as datafile:
            datafile.write(self.render())

        subprocess.check_call(['rustfmt',
                               fname,
                               '--write-mode=overwrite'])


class Test(object):

    TEMPLATE = """
               #[test]
               fn {name}() {{
                   // Body goes here
               }}

    """

    SERIALIZERS = {np.ndarray: serialize_array}

    def __init__(self, name, args):

        self.name = name
        self.args = args

    def _render_args(self):

        rendered = {}

        for key, value in self.args.items():
            for tpe, fnc in self.SERIALIZERS.items():
                if isinstance(value, tpe):
                    rendered[key] = fnc(value)
                else:
                    rendered[key] = str(value)

        return rendered

    def render(self):

        args = self._render_args()
        args['name'] = self.name

        return self.TEMPLATE.format(**args)