use float_cmp::assert_approx_eq;
use itertools::{assert_equal, Itertools};
use morphine::{
instance::{self, create_instance, Instance, TickInput, TickResult},
params::{
InitialSynWeight, InstanceParams, LayerConnectionParams, LayerParams,
PlasticityModulationParams, ShortTermStdpParams, StdpParams, StpParams,
},
};
use rand::{
distributions::Uniform, prelude::Distribution, rngs::StdRng, seq::SliceRandom, SeedableRng,
};
fn tick(instance: &mut Instance, in_channel_ids: &[usize]) -> TickResult {
instance
.tick(&TickInput::from_spiking_in_channel_ids(in_channel_ids))
.unwrap()
}
fn tick_extract_snapshot(instance: &mut Instance, in_channel_ids: &[usize]) -> TickResult {
let mut tick_input = TickInput::from_spiking_in_channel_ids(in_channel_ids);
tick_input.extract_state_snapshot = true;
instance.tick(&tick_input).unwrap()
}
fn make_simple_1_in_1_out_instance(weight: f32) -> Instance {
let mut params = InstanceParams::default();
let mut layer = LayerParams::default();
layer.neuron_params.tau_membrane = 10.0;
layer.neuron_params.refractory_period = 10;
layer.num_neurons = 1;
params.layers.push(layer.clone());
params.layers.push(layer);
let mut connection_params = LayerConnectionParams::defaults_for_layer_ids(0, 1);
connection_params.initial_syn_weight = InitialSynWeight::Constant(weight);
params.layer_connections.push(connection_params);
create_instance(params).unwrap()
}
const STDP_PARAMS: StdpParams = StdpParams {
factor_pre_before_post: 0.1,
tau_pre_before_post: 20.0,
factor_pre_after_post: -0.11,
tau_pre_after_post: 25.0,
};
#[test]
fn empty_instance() {
let params = InstanceParams::default();
let mut instance = create_instance(params).unwrap();
let tick_0_result = instance.tick_no_input();
assert!(tick_0_result.spiking_out_channel_ids.is_empty());
assert!(tick_0_result.spiking_nids.is_empty());
}
#[test]
fn empty_output_layer() {
let mut params = InstanceParams::default();
let mut layer_params = LayerParams::default();
layer_params.num_neurons = 1;
params.layers.push(layer_params.clone());
layer_params.num_neurons = 0;
params.layers.push(layer_params);
let mut instance = create_instance(params).unwrap();
let tick_0_result = tick(&mut instance, &[0]);
assert!(tick_0_result.spiking_out_channel_ids.is_empty());
assert_equal(tick_0_result.spiking_nids, [0]);
}
#[test]
fn single_neuron() {
let mut params = InstanceParams::default();
let mut layer = LayerParams::default();
layer.num_neurons = 1;
params.layers.push(layer);
let mut instance = create_instance(params).unwrap();
let tick_0_result = tick(&mut instance, &[]);
assert!(tick_0_result.spiking_out_channel_ids.is_empty());
let tick_1_result = tick(&mut instance, &[0]);
assert_equal(tick_1_result.spiking_out_channel_ids, [0]);
}
#[test]
fn single_direct_mapped_output() {
let mut instance = make_simple_1_in_1_out_instance(0.5);
let tick_0_result = tick(&mut instance, &[0, 0]);
assert!(tick_0_result.spiking_out_channel_ids.is_empty());
let tick_1_result = instance.tick_no_input();
assert_equal(tick_1_result.spiking_out_channel_ids, [0]);
}
#[test]
fn missed_spike_after_leakage() {
let mut instance = make_simple_1_in_1_out_instance(0.5);
let tick_0_result = tick(&mut instance, &[0]);
assert!(tick_0_result.spiking_out_channel_ids.is_empty());
let tick_1_result = tick(&mut instance, &[0]);
assert!(tick_1_result.spiking_out_channel_ids.is_empty());
let tick_2_result = tick_extract_snapshot(&mut instance, &[]);
assert!(tick_2_result.spiking_out_channel_ids.is_empty());
assert_approx_eq!(
f32,
tick_2_result.state_snapshot.unwrap().neuron_states[1].voltage,
0.5 * (-1.0 / 10.0f32).exp() + 0.5
);
}
#[test]
fn voltage_trajectory() {
let mut instance = make_simple_1_in_1_out_instance(0.5);
tick(&mut instance, &[0]);
let tick_1_result = tick_extract_snapshot(&mut instance, &[]);
assert_approx_eq!(
f32,
tick_1_result.state_snapshot.unwrap().neuron_states[1].voltage,
0.5
);
for _ in 0..4 {
instance.tick_no_input();
}
let tick_6_result = tick_extract_snapshot(&mut instance, &[]);
assert_approx_eq!(
f32,
tick_6_result.state_snapshot.unwrap().neuron_states[1].voltage,
0.5 * (-5.0 / 10.0f32).exp()
);
}
#[test]
fn no_psp_during_refractory_period() {
let mut instance = make_simple_1_in_1_out_instance(0.5);
tick(&mut instance, &[0, 0]);
let tick_1_result = tick_extract_snapshot(&mut instance, &[]);
assert_approx_eq!(
f32,
tick_1_result.state_snapshot.unwrap().neuron_states[1].voltage,
0.0
);
tick(&mut instance, &[0]);
let tick_3_result = tick_extract_snapshot(&mut instance, &[]);
assert_approx_eq!(
f32,
tick_3_result.state_snapshot.unwrap().neuron_states[1].voltage,
0.0
);
while instance.get_tick_period() < 9 {
instance.tick_no_input();
}
tick(&mut instance, &[0]);
let tick_10_result = tick_extract_snapshot(&mut instance, &[0]);
assert_approx_eq!(
f32,
tick_10_result.state_snapshot.unwrap().neuron_states[1].voltage,
0.0
);
let tick_11_result = tick_extract_snapshot(&mut instance, &[]);
assert_approx_eq!(
f32,
tick_11_result.state_snapshot.unwrap().neuron_states[1].voltage,
0.5
);
}
#[test]
fn two_epsps_and_one_ipsp() {
let mut params = InstanceParams::default();
let mut layer = LayerParams::default();
layer.neuron_params.tau_membrane = 10.0;
layer.num_neurons = 1;
params.layers.push(layer.clone());
params.layers.push(layer.clone());
params.layers.push(layer);
let mut connection_params = LayerConnectionParams::defaults_for_layer_ids(0, 1);
connection_params.initial_syn_weight = InitialSynWeight::Constant(0.5);
params.layer_connections.push(connection_params);
let mut connection_params = LayerConnectionParams::defaults_for_layer_ids(0, 2);
connection_params.initial_syn_weight = InitialSynWeight::Constant(0.5);
connection_params.conduction_delay_add_on = 1;
params.layer_connections.push(connection_params);
let mut connection_params = LayerConnectionParams::defaults_for_layer_ids(1, 2);
connection_params.initial_syn_weight = InitialSynWeight::Constant(0.1);
connection_params
.projection_params
.synapse_params
.weight_scale_factor = -1.0;
params.layer_connections.push(connection_params);
let mut instance = create_instance(params).unwrap();
instance.tick_no_input();
tick(&mut instance, &[0, 0]);
let tick_2_result = tick(&mut instance, &[]);
assert_equal(tick_2_result.spiking_nids, [1]);
assert_equal(tick_2_result.spiking_out_channel_ids, []);
let tick_3_result = tick_extract_snapshot(&mut instance, &[]);
assert!(tick_3_result.spiking_nids.is_empty());
assert!(tick_3_result.spiking_out_channel_ids.is_empty());
assert_approx_eq!(
f32,
tick_3_result.state_snapshot.unwrap().neuron_states[2].voltage,
0.9
);
}
#[test]
fn voltage_floor() {
let mut params = InstanceParams::default();
let mut layer = LayerParams::default();
layer.neuron_params.tau_membrane = 10.0;
layer.neuron_params.refractory_period = 10;
layer.neuron_params.voltage_floor = -0.6;
layer.num_neurons = 1;
params.layers.push(layer.clone());
params.layers.push(layer);
let mut connection_params = LayerConnectionParams::defaults_for_layer_ids(0, 1);
connection_params.initial_syn_weight = InitialSynWeight::Constant(0.5);
connection_params
.projection_params
.synapse_params
.weight_scale_factor = -1.0;
params.layer_connections.push(connection_params);
let mut instance = create_instance(params).unwrap();
tick(&mut instance, &[0, 0]);
let tick_1_result = tick_extract_snapshot(&mut instance, &[]);
assert_approx_eq!(
f32,
tick_1_result.state_snapshot.unwrap().neuron_states[1].voltage,
-0.6
);
}
#[test]
fn threshold_adaptation() {
let mut params = InstanceParams::default();
let mut layer = LayerParams::default();
layer.neuron_params.refractory_period = 1;
layer.num_neurons = 1;
params.layers.push(layer.clone());
layer.neuron_params.adaptation_threshold = 0.4;
layer.neuron_params.tau_threshold = 10.0;
params.layers.push(layer);
let mut connection_params = LayerConnectionParams::defaults_for_layer_ids(0, 1);
connection_params.initial_syn_weight = InitialSynWeight::Constant(0.5);
params.layer_connections.push(connection_params);
let mut instance = create_instance(params).unwrap();
tick(&mut instance, &[0, 0]);
instance.tick_no_input();
tick(&mut instance, &[0]);
let tick_3_result = instance.tick_no_input();
assert_equal(tick_3_result.spiking_out_channel_ids, [0]);
instance.tick_no_input();
instance.tick_no_input();
tick(&mut instance, &[0]);
let tick_7_result = instance.tick_no_input();
assert!(tick_7_result.spiking_out_channel_ids.is_empty());
}
#[test]
fn simple_potentiation_long_term_stdp() {
let mut params = InstanceParams::default();
let mut layer = LayerParams::default();
layer.neuron_params.refractory_period = 1;
layer.num_neurons = 1;
params.layers.push(layer.clone());
layer.neuron_params.adaptation_threshold = 2.0;
layer.neuron_params.tau_threshold = 10.0;
params.layers.push(layer);
let mut connection_params = LayerConnectionParams::defaults_for_layer_ids(0, 1);
connection_params.initial_syn_weight = InitialSynWeight::Constant(1.0);
connection_params
.projection_params
.synapse_params
.max_weight = 1.5;
connection_params.projection_params.long_term_stdp_params = Some(STDP_PARAMS);
params.layer_connections.push(connection_params);
let mut instance = create_instance(params).unwrap();
tick(&mut instance, &[0]);
instance.tick_no_input(); let tick_2_result = tick_extract_snapshot(&mut instance, &[0]);
assert_approx_eq!(
f32,
tick_2_result.state_snapshot.unwrap().synapse_states[0].weight,
1.1
);
let tick_3_result = tick_extract_snapshot(&mut instance, &[]);
assert!(tick_3_result.spiking_out_channel_ids.is_empty());
assert_approx_eq!(
f32,
tick_3_result.state_snapshot.as_ref().unwrap().neuron_states[1].voltage,
1.1 );
}
#[test]
fn synaptic_transmission_count() {
let mut params = InstanceParams::default();
let mut layer = LayerParams::default();
layer.neuron_params.tau_membrane = 10.0;
layer.neuron_params.refractory_period = 10;
layer.neuron_params.t_cutoff_coincidence = 10;
layer.num_neurons = 10;
params.layers.push(layer.clone());
params.layers.push(layer);
let mut connection_params = LayerConnectionParams::defaults_for_layer_ids(0, 0);
connection_params.projection_params.long_term_stdp_params = Some(STDP_PARAMS);
connection_params.initial_syn_weight = InitialSynWeight::Constant(0.5);
connection_params.conduction_delay_position_distance_scale_factor = 0.0;
connection_params.connect_width = 2.0;
connection_params.connect_density = 1.0;
connection_params
.projection_params
.synapse_params
.max_weight = 1.0;
params.layer_connections.push(connection_params.clone());
connection_params.to_layer_id = 1;
connection_params.connect_density = 0.5;
params.layer_connections.push(connection_params);
let mut instance = create_instance(params).unwrap();
tick(&mut instance, &[0]);
let tick_1_result = instance.tick_no_input();
assert_eq!(tick_1_result.synaptic_transmission_count, 15);
}
#[test]
fn simple_potentiation_short_term_stdp() {
let mut params = InstanceParams::default();
let mut layer = LayerParams::default();
layer.neuron_params.refractory_period = 1;
layer.num_neurons = 1;
params.layers.push(layer.clone());
layer.neuron_params.adaptation_threshold = 2.0;
layer.neuron_params.tau_threshold = 10.0;
params.layers.push(layer);
let mut connection_params = LayerConnectionParams::defaults_for_layer_ids(0, 1);
connection_params.initial_syn_weight = InitialSynWeight::Constant(1.0);
connection_params
.projection_params
.synapse_params
.max_weight = 1.5;
connection_params.projection_params.short_term_stdp_params = Some(ShortTermStdpParams {
stdp_params: STDP_PARAMS,
tau: 1.0,
});
params.layer_connections.push(connection_params);
let mut instance = create_instance(params).unwrap();
tick(&mut instance, &[0]);
instance.tick_no_input(); instance.tick_no_input();
tick(&mut instance, &[0]);
let tick_3_result = tick_extract_snapshot(&mut instance, &[]);
assert!(tick_3_result.spiking_out_channel_ids.is_empty());
assert_approx_eq!(
f32,
tick_3_result.state_snapshot.as_ref().unwrap().neuron_states[1].voltage,
1.0 + 0.1 * (-3.0f32).exp()
);
assert_approx_eq!(
f32,
tick_3_result.state_snapshot.unwrap().synapse_states[0].weight,
1.0 );
}
#[test]
fn pre_syn_spike_then_two_post_syn_spikes() {
let mut params = InstanceParams::default();
let mut layer = LayerParams::default();
layer.neuron_params.refractory_period = 10;
layer.num_neurons = 2;
layer.neuron_params.tau_threshold = 10.0;
layer.neuron_params.t_cutoff_coincidence = 10;
params.layers.push(layer);
let mut connection_params = LayerConnectionParams::defaults_for_layer_ids(0, 0);
connection_params.initial_syn_weight = InitialSynWeight::Constant(0.4);
connection_params
.projection_params
.synapse_params
.max_weight = 1.0;
connection_params.connect_width = 2.0;
connection_params.projection_params.long_term_stdp_params = Some(STDP_PARAMS);
params.layer_connections.push(connection_params);
let mut instance = create_instance(params).unwrap();
tick(&mut instance, &[0]);
tick(&mut instance, &[1]);
instance.tick_no_input_until(3);
tick(&mut instance, &[1]);
instance.tick_no_input_until(20);
let tick_20_result = tick_extract_snapshot(&mut instance, &[0]);
assert_eq!(
tick_20_result
.state_snapshot
.as_ref()
.unwrap()
.synapse_states[1]
.pre_syn_nid,
0
);
assert_eq!(
tick_20_result
.state_snapshot
.as_ref()
.unwrap()
.synapse_states[1]
.post_syn_nid,
1
);
assert_approx_eq!(
f32,
tick_20_result.state_snapshot.unwrap().synapse_states[1].weight,
0.5 );
let tick_21_result = tick_extract_snapshot(&mut instance, &[]);
assert_approx_eq!(
f32,
tick_21_result.state_snapshot.unwrap().neuron_states[1].voltage,
0.5 );
}
#[test]
fn post_syn_spike_then_two_pre_syn_spikes() {
let mut params = InstanceParams::default();
let mut layer = LayerParams::default();
layer.neuron_params.refractory_period = 10;
layer.num_neurons = 2;
layer.neuron_params.tau_threshold = 10.0;
layer.neuron_params.t_cutoff_coincidence = 10;
params.layers.push(layer);
let mut connection_params = LayerConnectionParams::defaults_for_layer_ids(0, 0);
connection_params.initial_syn_weight = InitialSynWeight::Constant(0.4);
connection_params
.projection_params
.synapse_params
.max_weight = 1.5;
connection_params.connect_width = 2.0;
connection_params.projection_params.long_term_stdp_params = Some(STDP_PARAMS);
connection_params.projection_params.short_term_stdp_params = Some(ShortTermStdpParams {
stdp_params: StdpParams {
factor_pre_before_post: 0.05,
tau_pre_before_post: 15.0,
factor_pre_after_post: -0.06,
tau_pre_after_post: 20.0,
},
tau: 10.0,
});
params.layer_connections.push(connection_params);
let mut instance = create_instance(params).unwrap();
tick(&mut instance, &[0]);
for _tick_period in 1..5 {
instance.tick_no_input();
}
tick(&mut instance, &[1]);
instance.tick_no_input(); tick(&mut instance, &[1]);
instance.tick_no_input(); tick(&mut instance, &[1]);
let tick_10_result = tick_extract_snapshot(&mut instance, &[]);
let expected_lt_stdp_value_tick_6 = -0.11 * (-6.0 / 25.0f32).exp();
let expected_st_stdp_value_tick_6 = -0.06 * (-6.0 / 20.0f32).exp();
let expected_weight = 0.4 + expected_lt_stdp_value_tick_6;
let expected_st_stdp_offset_tick_10 = expected_st_stdp_value_tick_6 * (-4.0 / 10.0f32).exp();
let expected_psp = expected_weight + expected_st_stdp_offset_tick_10;
assert_approx_eq!(
f32,
tick_10_result.state_snapshot.unwrap().neuron_states[0].voltage,
expected_psp
);
}
#[test]
fn stdp_alternating_pre_post_syn_spikes() {
let mut params = InstanceParams::default();
let mut layer = LayerParams::default();
layer.neuron_params.refractory_period = 15;
layer.num_neurons = 2;
layer.neuron_params.t_cutoff_coincidence = 20;
params.layers.push(layer);
let mut connection_params = LayerConnectionParams::defaults_for_layer_ids(0, 0);
connection_params.initial_syn_weight = InitialSynWeight::Constant(0.4);
connection_params
.projection_params
.synapse_params
.max_weight = 1.5;
connection_params.conduction_delay_add_on = 2;
connection_params.connect_width = 2.0;
connection_params.projection_params.long_term_stdp_params = Some(STDP_PARAMS);
params.layer_connections.push(connection_params);
let mut instance = create_instance(params).unwrap();
tick(&mut instance, &[1]);
instance.tick_no_input_until(8);
tick(&mut instance, &[0]); instance.tick_no_input_until(12);
tick(&mut instance, &[1]);
instance.tick_no_input_until(21);
tick(&mut instance, &[0]); instance.tick_no_input_until(30);
tick(&mut instance, &[0]); instance.tick_no_input_until(33);
let tick_33_result = tick_extract_snapshot(&mut instance, &[]);
let tick_11_stdp = -0.11 * (-11.0 / 25.0f32).exp();
let tick_12_stdp = 0.1 * (-1.0 / 20.0f32).exp();
let tick_24_stdp = -0.11 * (-12.0 / 25.0f32).exp();
let expected_weight = 0.4 + tick_11_stdp + tick_12_stdp + tick_24_stdp;
assert_approx_eq!(
f32,
tick_33_result.state_snapshot.unwrap().neuron_states[1].voltage,
expected_weight
);
}
#[test]
fn long_term_stdp_complex_scenario() {
let mut params = InstanceParams::default();
let mut layer = LayerParams::default();
layer.neuron_params.tau_membrane = 10.0;
layer.neuron_params.refractory_period = 10;
layer.neuron_params.t_cutoff_coincidence = 10;
layer.num_neurons = 10;
params.layers.push(layer.clone());
params.layers.push(layer);
let mut connection_params = LayerConnectionParams::defaults_for_layer_ids(0, 1);
connection_params.projection_params.long_term_stdp_params = Some(STDP_PARAMS);
connection_params.initial_syn_weight = InitialSynWeight::Constant(0.5);
connection_params.conduction_delay_position_distance_scale_factor = 10.0;
connection_params
.projection_params
.synapse_params
.max_weight = 1.0;
params.layer_connections.push(connection_params);
let mut instance = create_instance(params).unwrap();
tick(&mut instance, &[0]);
instance.tick_no_input_until(8);
tick(&mut instance, &[5]);
instance.tick_no_input_until(10);
tick(&mut instance, &[9]);
let tick_11_result = instance.tick_no_input();
assert_equal(tick_11_result.spiking_out_channel_ids, [9]);
instance.tick_no_input_until(25);
tick(&mut instance, &[9]);
let tick_26_result = tick_extract_snapshot(&mut instance, &[9]);
assert_approx_eq!(
f32,
tick_26_result.state_snapshot.unwrap().neuron_states[19].voltage,
0.6 );
let tick_27_result = tick(&mut instance, &[9]);
assert_equal(tick_27_result.spiking_out_channel_ids, [9]);
tick(&mut instance, &[0]);
tick(&mut instance, &[0]);
instance.tick_no_input_until(39);
let tick_39_result = tick_extract_snapshot(&mut instance, &[]);
assert_approx_eq!(
f32,
tick_39_result.state_snapshot.unwrap().neuron_states[19].voltage,
0.6
);
let tick_40_result = instance.tick_no_input();
assert_equal(tick_40_result.spiking_out_channel_ids, [9]);
instance.tick_no_input_until(50);
tick(&mut instance, &[5]);
instance.tick_no_input_until(55);
let tick_55_result = tick_extract_snapshot(&mut instance, &[]);
assert_approx_eq!(
f32,
tick_55_result.state_snapshot.unwrap().neuron_states[19].voltage,
0.5 - 0.11 * (-2.0 / 25.0f32).exp() );
tick(&mut instance, &[9, 9]);
let tick_57_result = instance.tick_no_input();
assert_equal(tick_57_result.spiking_out_channel_ids, [9]);
instance.tick_no_input_until(70);
tick(&mut instance, &[6]);
instance.tick_no_input_until(74);
let tick_74_result = tick_extract_snapshot(&mut instance, &[]);
assert_approx_eq!(
f32,
tick_74_result.state_snapshot.unwrap().neuron_states[19].voltage,
0.5 );
}
#[test]
fn no_dopamine() {
let mut params = InstanceParams::default();
let mut layer = LayerParams::default();
layer.neuron_params.refractory_period = 1;
layer.num_neurons = 1;
params.layers.push(layer.clone());
layer.plasticity_modulation_params = Some(PlasticityModulationParams {
tau_eligibility_trace: 1000.0,
eligibility_trace_delay: 0,
dopamine_modulation_factor: 1.0,
t_cutoff_eligibility_trace: 2000,
dopamine_flush_period: 100,
dopamine_conflation_period: 10,
});
params.layers.push(layer);
let mut connection_params = LayerConnectionParams::defaults_for_layer_ids(0, 1);
connection_params.initial_syn_weight = InitialSynWeight::Constant(0.5);
connection_params
.projection_params
.synapse_params
.max_weight = 1.5;
connection_params.projection_params.long_term_stdp_params = Some(STDP_PARAMS);
params.layer_connections.push(connection_params);
let mut instance = create_instance(params).unwrap();
tick(&mut instance, &[0, 0]);
instance.tick_no_input(); instance.tick_no_input_until(1500);
tick(&mut instance, &[0]);
let tick_1501_result = tick_extract_snapshot(&mut instance, &[]);
assert_approx_eq!(
f32,
tick_1501_result.state_snapshot.unwrap().neuron_states[1].voltage,
0.5 );
}
#[test]
fn negative_reward() {
let mut params = InstanceParams::default();
let mut layer = LayerParams::default();
layer.num_neurons = 1;
params.layers.push(layer.clone());
let mut plasticity_modulation_params = PlasticityModulationParams::default();
plasticity_modulation_params.dopamine_conflation_period = 1;
plasticity_modulation_params.dopamine_flush_period = 1;
plasticity_modulation_params.tau_eligibility_trace = 15.0;
layer.plasticity_modulation_params = Some(plasticity_modulation_params);
params.layers.push(layer);
let mut connection_params = LayerConnectionParams::defaults_for_layer_ids(0, 1);
connection_params.projection_params.long_term_stdp_params = Some(STDP_PARAMS);
connection_params
.projection_params
.synapse_params
.max_weight = 1.5;
connection_params.initial_syn_weight = InitialSynWeight::Constant(1.0);
params.layer_connections.push(connection_params);
let mut instance = create_instance(params).unwrap();
tick(&mut instance, &[0]);
tick(&mut instance, &[]);
let mut tick_input = TickInput::from_reward(-1.0);
tick_input.extract_state_snapshot = true;
let tick_result = instance.tick(&tick_input).unwrap();
let expected_depression = 0.1 * (-1.0 / 15.0f32).exp();
let synapse_states = tick_result.state_snapshot.unwrap().synapse_states;
assert_approx_eq!(f32, synapse_states[0].weight, 1.0 - expected_depression);
}
#[test]
fn simple_dopamine_scenario() {
let mut params = InstanceParams::default();
let mut layer = LayerParams::default();
layer.neuron_params.refractory_period = 1;
layer.num_neurons = 1;
params.layers.push(layer.clone());
layer.plasticity_modulation_params = Some(PlasticityModulationParams {
tau_eligibility_trace: 1000.0,
eligibility_trace_delay: 0,
dopamine_modulation_factor: 0.3,
t_cutoff_eligibility_trace: 2000,
dopamine_flush_period: 100,
dopamine_conflation_period: 10,
});
params.layers.push(layer);
let mut connection_params = LayerConnectionParams::defaults_for_layer_ids(0, 1);
connection_params.initial_syn_weight = InitialSynWeight::Constant(0.5);
connection_params
.projection_params
.synapse_params
.max_weight = 1.5;
connection_params.projection_params.long_term_stdp_params = Some(STDP_PARAMS);
params.layer_connections.push(connection_params);
let mut instance = create_instance(params).unwrap();
tick(&mut instance, &[0, 0]);
instance.tick_no_input_until(53);
instance.tick(&TickInput::from_reward(1.5)).unwrap();
instance.tick_no_input_until(1500);
let tick_1500_result = tick_extract_snapshot(&mut instance, &[0]);
let stdp_value = 0.2; let elig_trace_value = 1.5 * (-59.0 / 1000f32).exp();
let expected_weight = 0.5 + 0.3 * stdp_value * elig_trace_value;
assert_approx_eq!(
f32,
tick_1500_result.state_snapshot.unwrap().synapse_states[0].weight,
expected_weight
);
let tick_1501_result = tick_extract_snapshot(&mut instance, &[]);
assert_approx_eq!(
f32,
tick_1501_result.state_snapshot.unwrap().neuron_states[1].voltage,
expected_weight
);
}
#[test]
fn short_term_plasticity() {
let mut params = InstanceParams::default();
let mut layer = LayerParams::default();
layer.neuron_params.refractory_period = 1;
layer.num_neurons = 2;
params.layers.push(layer.clone());
params.layers.push(layer);
let mut connection_params = LayerConnectionParams::defaults_for_layer_ids(0, 1);
connection_params.initial_syn_weight = InitialSynWeight::Constant(0.5);
connection_params.conduction_delay_position_distance_scale_factor = 5.0;
connection_params
.projection_params
.synapse_params
.max_weight = 1.5;
connection_params.projection_params.long_term_stdp_params = Some(STDP_PARAMS);
connection_params.projection_params.stp_params = StpParams::Depression {
tau: 800.0,
p0: 0.8,
factor: 0.6,
};
params.layer_connections.push(connection_params);
let mut instance = create_instance(params).unwrap();
tick(&mut instance, &[0]);
let tick_1_result = tick_extract_snapshot(&mut instance, &[]);
assert_approx_eq!(
f32,
tick_1_result.state_snapshot.unwrap().neuron_states[2].voltage,
0.5 * 0.8
);
instance.tick_no_input_until(6);
let tick_6_result = tick_extract_snapshot(&mut instance, &[]);
assert_approx_eq!(
f32,
tick_6_result.state_snapshot.unwrap().neuron_states[3].voltage,
0.5 * 0.8
);
instance.tick_no_input_until(500);
tick(&mut instance, &[1]);
let tick_8_result = tick_extract_snapshot(&mut instance, &[]);
assert_approx_eq!(
f32,
tick_8_result.state_snapshot.unwrap().neuron_states[3].voltage,
0.5 * 0.8
);
instance.tick_no_input_until(1000);
tick(&mut instance, &[0]);
let tick_1001_result = tick_extract_snapshot(&mut instance, &[]);
let expected_stp_factor = 0.8 * (1.0 - 0.6 * (-1000.0 / 800.0f32).exp());
assert_approx_eq!(
f32,
tick_1001_result.state_snapshot.unwrap().neuron_states[2].voltage,
0.5 * expected_stp_factor
);
instance.tick_no_input_until(1006);
let tick_1006_result = tick_extract_snapshot(&mut instance, &[]);
assert_approx_eq!(
f32,
tick_1006_result.state_snapshot.unwrap().neuron_states[3].voltage,
0.5 * expected_stp_factor
);
}
fn get_scenario_template_params() -> InstanceParams {
let mut params = InstanceParams::default();
let mut layer = LayerParams::default();
layer.neuron_params.refractory_period = 10;
layer.num_neurons = 800;
layer.plasticity_modulation_params = Some(PlasticityModulationParams {
tau_eligibility_trace: 1000.0,
eligibility_trace_delay: 20,
dopamine_modulation_factor: 1.5,
t_cutoff_eligibility_trace: 1000,
dopamine_flush_period: 100,
dopamine_conflation_period: 50,
});
params.layers.push(layer.clone());
layer.plasticity_modulation_params = None;
layer.num_neurons = 200;
layer.neuron_params.tau_membrane = 4.0;
layer.neuron_params.refractory_period = 5;
params.layers.push(layer);
let mut connection_params = LayerConnectionParams::defaults_for_layer_ids(0, 0);
connection_params.initial_syn_weight = InitialSynWeight::Randomized(0.5);
connection_params.conduction_delay_position_distance_scale_factor = 0.0;
connection_params.connect_width = 2.0;
connection_params.connect_density = 0.1;
connection_params.conduction_delay_max_random_part = 20;
connection_params
.projection_params
.synapse_params
.max_weight = 0.5;
connection_params.projection_params.long_term_stdp_params = Some(StdpParams::default());
connection_params.projection_params.short_term_stdp_params = Some(ShortTermStdpParams {
stdp_params: StdpParams {
factor_pre_before_post: 0.01,
tau_pre_before_post: 20.0,
factor_pre_after_post: 0.012,
tau_pre_after_post: 20.0,
},
tau: 500.0,
});
connection_params.projection_params.stp_params = StpParams::Depression {
tau: 800.0,
p0: 0.9,
factor: 0.2,
};
params.layer_connections.push(connection_params.clone());
connection_params.connect_density = 0.25;
connection_params.to_layer_id = 1;
connection_params
.projection_params
.synapse_params
.weight_scale_factor = 2.0;
params.layer_connections.push(connection_params.clone());
connection_params.from_layer_id = 1;
connection_params.to_layer_id = 0;
connection_params.initial_syn_weight = InitialSynWeight::Constant(0.85);
connection_params.projection_params.long_term_stdp_params = None;
connection_params.projection_params.short_term_stdp_params = None;
connection_params.projection_params.stp_params = StpParams::NoStp;
connection_params.conduction_delay_max_random_part = 0;
connection_params
.projection_params
.synapse_params
.weight_scale_factor = -1.0;
params.layer_connections.push(connection_params.clone());
connection_params.to_layer_id = 1;
params.layer_connections.push(connection_params);
params
}
fn assert_equivalence(instances: &mut [Instance], t_stop: usize) {
let all_in_channels: Vec<usize> = (0..800).collect();
let mut rng = StdRng::seed_from_u64(0);
let reward_dist = Uniform::new(0.0, 0.005);
let mut tick_input = TickInput::new();
for _ in 0..t_stop {
tick_input.reset();
tick_input.spiking_in_channel_ids = all_in_channels
.choose_multiple(&mut rng, 5)
.copied()
.collect();
tick_input.reward = reward_dist.sample(&mut rng);
let mut tick_results = instances
.iter_mut()
.map(|instance| instance.tick(&tick_input).unwrap())
.collect_vec();
let cmp_result = tick_results.pop().unwrap();
for tick_result in tick_results {
assert_eq!(
tick_result.spiking_out_channel_ids,
cmp_result.spiking_out_channel_ids
);
assert_eq!(tick_result.spiking_nids, cmp_result.spiking_nids);
assert_eq!(
tick_result.synaptic_transmission_count,
cmp_result.synaptic_transmission_count
);
}
}
}
#[test]
fn invariance_partitioning_buffering() {
let thread_counts = vec![1, 6, 7];
let mut instances = Vec::new();
for thread_count in thread_counts {
let mut params = get_scenario_template_params();
params.technical_params.num_threads = Some(thread_count);
instances.push(instance::create_instance(params).unwrap());
}
assert_equivalence(&mut instances, 102);
}
#[test]
fn zero_vs_absent_long_term_stdp() {
let mut params = get_scenario_template_params();
params.layers[0].plasticity_modulation_params = None;
params.layer_connections[0].projection_params.stp_params = StpParams::NoStp;
params.layer_connections[0]
.projection_params
.short_term_stdp_params = None;
params.layer_connections[0]
.projection_params
.long_term_stdp_params = Some(StdpParams {
factor_pre_before_post: 0.0,
tau_pre_before_post: 20.0,
factor_pre_after_post: 0.0,
tau_pre_after_post: 20.0,
});
let mut instances = Vec::new();
instances.push(instance::create_instance(params.clone()).unwrap());
params.layer_connections[0]
.projection_params
.long_term_stdp_params = None;
instances.push(instance::create_instance(params).unwrap());
assert_equivalence(&mut instances, 100);
}
#[test]
fn zero_vs_absent_short_term_stdp() {
let mut params = get_scenario_template_params();
params.layers[0].plasticity_modulation_params = None;
params.layer_connections[0].projection_params.stp_params = StpParams::NoStp;
params.layer_connections[0]
.projection_params
.long_term_stdp_params = None;
params.layer_connections[0]
.projection_params
.short_term_stdp_params = Some(ShortTermStdpParams {
stdp_params: StdpParams {
factor_pre_before_post: 0.0,
tau_pre_before_post: 20.0,
factor_pre_after_post: 0.0,
tau_pre_after_post: 20.0,
},
tau: 500.0,
});
let mut instances = Vec::new();
instances.push(instance::create_instance(params.clone()).unwrap());
params.layer_connections[0]
.projection_params
.short_term_stdp_params = None;
instances.push(instance::create_instance(params).unwrap());
assert_equivalence(&mut instances, 100);
}
#[test]
fn zero_vs_absent_plasticity_modulation() {
let mut params = get_scenario_template_params();
let mut instances = Vec::new();
instances.push(instance::create_instance(params.clone()).unwrap());
params.layers[0].plasticity_modulation_params = None;
params.layer_connections[0]
.projection_params
.long_term_stdp_params = None;
instances.push(instance::create_instance(params).unwrap());
assert_equivalence(&mut instances, 110);
}
#[test]
fn state_snapshot() {
let mut params = InstanceParams::default();
let mut layer = LayerParams::default();
layer.num_neurons = 5;
params.layers.push(layer.clone());
layer.num_neurons = 2;
params.layers.push(layer);
let mut conn_params = LayerConnectionParams::defaults_for_layer_ids(0, 0);
conn_params.projection_params.long_term_stdp_params = Some(STDP_PARAMS);
conn_params.projection_params.synapse_params.max_weight = 1.5;
conn_params.initial_syn_weight = InitialSynWeight::Constant(0.5);
params.layer_connections.push(conn_params.clone());
conn_params.to_layer_id = 1;
conn_params.initial_syn_weight = InitialSynWeight::Constant(1.0);
params.layer_connections.push(conn_params);
let mut instance = create_instance(params).unwrap();
tick(&mut instance, &[0]);
let tick_1_result = tick_extract_snapshot(&mut instance, &[]);
assert_equal(tick_1_result.spiking_nids, [5, 6]);
assert_equal(tick_1_result.spiking_out_channel_ids, [0, 1]);
let state_snapshot = tick_1_result.state_snapshot.unwrap();
for i in 1..5 {
assert_eq!(state_snapshot.synapse_states[i].pre_syn_nid, 0);
assert_eq!(state_snapshot.synapse_states[i].post_syn_nid, i);
assert_approx_eq!(f32, state_snapshot.synapse_states[i].weight, 0.5);
}
for i in 25..27 {
assert_eq!(state_snapshot.synapse_states[i].pre_syn_nid, 0);
assert_eq!(state_snapshot.synapse_states[i].post_syn_nid, i - 20);
assert_approx_eq!(f32, state_snapshot.synapse_states[i].weight, 1.1);
}
for pre_syn_nid in 1..5 {
for post_syn_nid in 5..7 {
let idx = 25 + 2 * pre_syn_nid + post_syn_nid - 5;
assert_eq!(state_snapshot.synapse_states[idx].pre_syn_nid, pre_syn_nid);
assert_eq!(
state_snapshot.synapse_states[idx].post_syn_nid,
post_syn_nid
);
assert_approx_eq!(f32, state_snapshot.synapse_states[idx].weight, 1.0);
}
}
}
#[test]
fn no_self_innervation() {
let mut params = InstanceParams::default();
let mut layer = LayerParams::default();
layer.num_neurons = 2;
params.layers.push(layer.clone());
layer.num_neurons = 5;
params.layers.push(layer);
let mut conn_params = LayerConnectionParams::defaults_for_layer_ids(1, 0);
params.layer_connections.push(conn_params.clone());
conn_params.to_layer_id = 1;
conn_params.allow_self_innervation = false;
params.layer_connections.push(conn_params);
let mut instance = create_instance(params).unwrap();
let tick_result = tick_extract_snapshot(&mut instance, &[]);
let state_snapshot = tick_result.state_snapshot.unwrap();
for synapse_state in state_snapshot.synapse_states {
assert_ne!(synapse_state.pre_syn_nid, synapse_state.post_syn_nid);
}
}