cell-map 0.5.3

Many-layered 2D cellular generic map
Documentation
'''
Generic CellMap specific helper class/utility
'''

import numpy as np
import json
import typing
import matplotlib.pyplot as plt

class CellMap:
    data: typing.Dict[str, np.ndarray]
    layers: typing.List[str]
    cell_size: np.ndarray
    cell_bounds: typing.Tuple[np.ndarray, np.ndarray]
    num_cells: np.ndarray
    extents: np.ndarray
    cell_boundary_precision: float
    from_parent: np.ndarray
    to_parent: np.ndarray


    @staticmethod
    def load(path):
        '''
        Loads a CellMap from the given path, expects a JSON file.
        '''

        # Read the data
        with open(path, 'r') as f:
            return CellMap.from_raw_dict(json.load(f))
    
    @staticmethod
    def from_raw_dict(raw, path=''):
        '''
        Loads a CellMap from a raw dictionary, i.e. from deserialised JSON.
        '''

        cm = CellMap()

        # Load metadata
        cm.path = path
        cm.layers = raw['layers']
        cm.cell_size = np.array(raw['cell_size'])
        cm.cell_bounds = np.array([raw['cell_bounds']['x'], raw['cell_bounds']['y']])
        cm.num_cells = np.array([
            cm.cell_bounds[1][1] - cm.cell_bounds[1][0],
            cm.cell_bounds[0][1] - cm.cell_bounds[0][0] 
        ])
        cm.cell_boundary_precision = np.array(raw['cell_boundary_precision'])
        cm.from_parent = np.array(raw['from_parent_matrix']).reshape((3, 3))
        cm.to_parent = np.linalg.inv(cm.from_parent)

        # Calculate extents of map
        extents = np.array([
            [cm.cell_bounds[0][0], cm.cell_bounds[1][0]], 
            [cm.cell_bounds[0][1], cm.cell_bounds[1][0]], 
            [cm.cell_bounds[0][0], cm.cell_bounds[1][1]], 
            [cm.cell_bounds[0][1], cm.cell_bounds[1][1]], 
        ])
        cm.extents = cm.transform_to_parent(extents)

        # Load each layer in turn, reshaping as needed
        cm.data = dict()
        for layer, data in zip(cm.layers, raw['data']):
            if data['dim'][0] != cm.num_cells[0] or data['dim'][1] != cm.num_cells[1]:
                raise RuntimeError(f'Data in cell map file is of wrong shape. Expected {cm.num_cells} but got {data["dim"]}')
            cm.data[layer] = np.array(data['data']).reshape(cm.num_cells)

        return cm

    def transform_to_parent(self, points: np.ndarray):
        '''
        Converts the given point(s) from the map frame to the parent frame.

        Points should be an (N, 2) dimension array.
        '''
        n = np.shape(points)[0]
        dehomog = lambda x: x[:-1]/x[-1]
        homog = np.ones((n, 3))
        homog[:,:-1] = points
        homog = homog @ self.to_parent
        return np.array([dehomog(x) for x in homog])

    def plot(self, name = None, ax = None, parent_relative = True, show_grid=False):
        '''
        Plots the given CellMap

        Arguments:
            map: The CellMap to plot
            name: The name to place in the title
            ax: The axis to plot onto, or None if we should create a new figure
            parent_relative: True if the map should be plotted relative to the parent
        '''
        
        if ax is None:
            fig, ax = plt.subplots()
        else:
            fig = None

        # Map-relative origin and axes directions
        origin = np.array([0.0, 0.0])
        x_dir = np.array([1.0, 0.0])
        y_dir = np.array([0.0, 1.0])

        # Map-relative limits
        x_lims = [-0.5, self.num_cells[0] + 0.5]
        y_lims = [-0.5, self.num_cells[1] + 0.5]

        # Setup the axis grid
        if parent_relative:
            # Get the grid
            grid = ax.add_artist(self._get_parent_rel_grid(show_grid=show_grid))

            # Update self origin
            origin = self.transform_to_parent(origin.reshape((1, 2))).reshape((2,))
            x_dir = self.transform_to_parent(x_dir.reshape((1, 2))).reshape((2,)) - origin
            y_dir = self.transform_to_parent(y_dir.reshape((1, 2))).reshape((2,)) - origin

            plot_bounds = np.max([self.cell_size[0] * 0.5, self.cell_size[1] * 0.5])

            # Update limits
            ext_plus_origin_x = np.append(self.extents[:,0], origin[0])
            ext_plus_origin_y = np.append(self.extents[:,1], origin[1])
            x_lims = [np.min(ext_plus_origin_x) - plot_bounds, np.max(ext_plus_origin_x) + plot_bounds]
            y_lims = [np.min(ext_plus_origin_y) - plot_bounds, np.max(ext_plus_origin_y) + plot_bounds]

        else:
            # Include the end line in the ticks
            x_ticks = range(self.num_cells[0] + 1)
            y_ticks = range(self.num_cells[1] + 1)
            ax.set_xticks(x_ticks)
            ax.set_yticks(y_ticks)
            ax.grid(True)

        # Plot origin and directions
        ax.plot(origin[0], origin[1], '.k')
        ax.quiver(*origin, x_dir[0], x_dir[1], color='r', angles='xy', scale_units='xy', scale=1)
        ax.quiver(*origin, y_dir[0], y_dir[1], color='g', angles='xy', scale_units='xy', scale=1)

        # Set limits
        ax.set_xlim(x_lims)
        ax.set_ylim(y_lims)
        ax.set_aspect('equal', 'box')

        if name is not None:
            ax.set_title(name)
        else:
            ax.set_title(self.path)

        if fig is not None:
            plt.show()

        return grid

    def _get_parent_rel_grid(self, show_grid=False):
        '''
        Gets the parent-relative grid as a matplotlib.collections.LineCollection
        '''

        # Create mesh grid points by transforming each meshgrid point into the
        # parent frame
        mesh_x, mesh_y = np.meshgrid(
            np.array(range(self.cell_bounds[0][0], self.cell_bounds[0][1] + 1)),
            np.array(range(self.cell_bounds[1][0], self.cell_bounds[1][1] + 1))
        )
        mesh_shape = mesh_x.shape
        mesh_points = np.vstack([mesh_x.ravel(), mesh_y.ravel()]).T
        mesh_points = self.transform_to_parent(mesh_points)
        mesh_x, mesh_y = [mesh_points[:,0].reshape(mesh_shape), mesh_points[:,1].reshape(mesh_shape)]
        
        mesh = plt.pcolormesh(
            mesh_x, mesh_y, self.data[self.layers[0]], 
            shading='flat', 
            edgecolors='grey' if show_grid else None, 
            linewidth=0.1, 
            zorder=-1.0
        )

        return mesh