gemla 0.1.32

Using evolutionary computation to generate machine learning algorithms
Documentation
use async_trait::async_trait;
use gemla::{
    core::genetic_node::{GeneticNode, GeneticNodeContext},
    error::Error,
};
use rand::prelude::*;
use serde::{Deserialize, Serialize};
use uuid::Uuid;

const POPULATION_SIZE: u64 = 5;
const POPULATION_REDUCTION_SIZE: u64 = 3;

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct TestState {
    pub population: Vec<i64>,
    pub max_generations: u64,
}

#[async_trait]
impl GeneticNode for TestState {
    type Context = ();

    async fn initialize(_context: GeneticNodeContext<Self::Context>) -> Result<Box<Self>, Error> {
        let mut population: Vec<i64> = vec![];

        for _ in 0..POPULATION_SIZE {
            population.push(thread_rng().gen_range(0..100))
        }

        Ok(Box::new(TestState {
            population,
            max_generations: 10,
        }))
    }

    async fn simulate(
        &mut self,
        context: GeneticNodeContext<Self::Context>,
    ) -> Result<bool, Error> {
        let mut rng = thread_rng();

        self.population = self
            .population
            .iter()
            .map(|p| p.saturating_add(rng.gen_range(-1..2)))
            .collect();

        if context.generation >= self.max_generations {
            Ok(false)
        } else {
            Ok(true)
        }
    }

    async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
        let mut rng = thread_rng();

        let mut v = self.population.clone();

        v.sort_unstable();
        v.reverse();

        self.population = v[0..(POPULATION_REDUCTION_SIZE as usize)].to_vec();

        loop {
            if self.population.len() as u64 >= POPULATION_SIZE {
                break;
            }

            let new_individual_index = rng.gen_range(0..self.population.len());
            let mut cross_breed_index = rng.gen_range(0..self.population.len());

            loop {
                if new_individual_index != cross_breed_index {
                    break;
                }

                cross_breed_index = rng.gen_range(0..self.population.len());
            }

            let mut new_individual = self.population.clone()[new_individual_index];
            let cross_breed = self.population.clone()[cross_breed_index];

            new_individual = (new_individual.saturating_add(cross_breed) / 2)
                .saturating_add(rng.gen_range(-1..2));

            self.population.push(new_individual);
        }

        Ok(())
    }

    async fn merge(
        left: &TestState,
        right: &TestState,
        id: &Uuid,
        gemla_context: Self::Context,
    ) -> Result<Box<TestState>, Error> {
        let mut v = left.population.clone();
        v.append(&mut right.population.clone());

        v.sort_by(|a, b| a.partial_cmp(b).unwrap());
        v.reverse();

        v = v[..(POPULATION_REDUCTION_SIZE as usize)].to_vec();

        let mut result = TestState {
            population: v,
            max_generations: 10,
        };

        result
            .mutate(GeneticNodeContext {
                id: *id,
                generation: 0,
                gemla_context,
            })
            .await?;

        Ok(Box::new(result))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use gemla::core::genetic_node::GeneticNode;

    #[tokio::test]
    async fn test_initialize() {
        let state = TestState::initialize(GeneticNodeContext {
            id: Uuid::new_v4(),
            generation: 0,
            gemla_context: (),
        })
        .await
        .unwrap();

        assert_eq!(state.population.len(), POPULATION_SIZE as usize);
    }

    #[tokio::test]
    async fn test_simulate() {
        let mut state = TestState {
            population: vec![1, 1, 2, 3],
            max_generations: 1,
        };

        let original_population = state.population.clone();

        state
            .simulate(GeneticNodeContext {
                id: Uuid::new_v4(),
                generation: 0,
                gemla_context: (),
            })
            .await
            .unwrap();
        assert!(original_population
            .iter()
            .zip(state.population.iter())
            .all(|(&a, &b)| b >= a - 1 && b <= a + 2));

        state
            .simulate(GeneticNodeContext {
                id: Uuid::new_v4(),
                generation: 0,
                gemla_context: (),
            })
            .await
            .unwrap();
        state
            .simulate(GeneticNodeContext {
                id: Uuid::new_v4(),
                generation: 0,
                gemla_context: (),
            })
            .await
            .unwrap();
        assert!(original_population
            .iter()
            .zip(state.population.iter())
            .all(|(&a, &b)| b >= a - 3 && b <= a + 6))
    }

    #[tokio::test]
    async fn test_mutate() {
        let mut state = TestState {
            population: vec![4, 3, 3],
            max_generations: 1,
        };

        state
            .mutate(GeneticNodeContext {
                id: Uuid::new_v4(),
                generation: 0,
                gemla_context: (),
            })
            .await
            .unwrap();

        assert_eq!(state.population.len(), POPULATION_SIZE as usize);
    }

    #[tokio::test]
    async fn test_merge() {
        let state1 = TestState {
            population: vec![1, 2, 4, 5],
            max_generations: 1,
        };

        let state2 = TestState {
            population: vec![0, 1, 3, 7],
            max_generations: 1,
        };

        let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4(), ())
            .await
            .unwrap();

        assert_eq!(merged_state.population.len(), POPULATION_SIZE as usize);
        assert!(merged_state.population.iter().any(|&x| x == 7));
        assert!(merged_state.population.iter().any(|&x| x == 5));
        assert!(merged_state.population.iter().any(|&x| x == 4));
    }
}