cge/
gene.rs

1//! Different types of genes that can be used in a [`Network`][crate::Network] genome.
2
3use num_traits::Float;
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7/// A bias gene.
8///
9/// Adds a constant value to the [`Network`][crate::Network].
10#[derive(Clone, Debug, PartialEq)]
11#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
12pub struct Bias<T> {
13    value: T,
14}
15
16impl<T: Float> Bias<T> {
17    /// Returns a new `Bias` that adds a constant `value` to the [`Network`][crate::Network].
18    pub fn new(value: T) -> Self {
19        Self { value }
20    }
21
22    /// Returns the value of the `Bias`.
23    pub fn value(&self) -> T {
24        self.value
25    }
26
27    /// Returns a mutable reference to the value of the `Bias`.
28    pub fn mut_value(&mut self) -> &mut T {
29        &mut self.value
30    }
31}
32
33/// The ID of a [`Network`][crate::Network]'s [`Input`].
34#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
35#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
36pub struct InputId(usize);
37
38impl InputId {
39    /// Returns a new `InputId` with the given id.
40    pub fn new(id: usize) -> Self {
41        Self(id)
42    }
43
44    /// Returns this `InputId` as a `usize`.
45    pub fn as_usize(&self) -> usize {
46        self.0
47    }
48}
49
50/// An input gene.
51///
52/// Adds a connection to one of the [`Network`][crate::Network] inputs.
53#[derive(Clone, Debug, PartialEq)]
54#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
55pub struct Input<T> {
56    // The ID of the network input referred to
57    id: InputId,
58    weight: T,
59}
60
61impl<T: Float> Input<T> {
62    /// Returns a new `Input` that connects to the [`Network`][crate::Network] input with the id
63    /// and weights it by `weight`.
64    pub fn new(id: InputId, weight: T) -> Self {
65        Self { id, weight }
66    }
67
68    /// Returns the id of the [`Network`][crate::Network] input this `Input` refers to.
69    pub fn id(&self) -> InputId {
70        self.id
71    }
72
73    /// Returns the weight of this `Input`.
74    pub fn weight(&self) -> T {
75        self.weight
76    }
77
78    /// Returns a mutable reference to the weight of this `Input`.
79    pub fn mut_weight(&mut self) -> &mut T {
80        &mut self.weight
81    }
82}
83
84/// The ID of a [`Neuron`] in a [`Network`][crate::Network].
85#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
86#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
87pub struct NeuronId(usize);
88
89impl NeuronId {
90    /// Returns a new `NeuronId` with the given id.
91    pub fn new(id: usize) -> Self {
92        Self(id)
93    }
94
95    /// Returns this `NeuronId` as a `usize`.
96    pub fn as_usize(&self) -> usize {
97        self.0
98    }
99}
100
101/// A neuron gene.
102///
103/// Takes some number of incoming connections and applies the activation function to their sum.
104#[derive(Clone, Debug, PartialEq)]
105#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
106pub struct Neuron<T: Float> {
107    // The ID of this neuron
108    id: NeuronId,
109    // The number of incoming connections to this neuron
110    num_inputs: usize,
111    // The weight to apply to the result of the activation function
112    // Note that this weight is not used when the neuron is referred to by a jumper connection; the
113    // jumper's weight is used instead
114    weight: T,
115    // The unweighted value outputted by this neuron during the current network evaluation if it has
116    // been calculated already
117    #[cfg_attr(feature = "serde", serde(skip))]
118    #[cfg_attr(feature = "serde", serde(default = "Default::default"))]
119    current_value: Option<T>,
120    // The unweighted value outputted by this neuron during the previous network evaluation
121    #[cfg_attr(feature = "serde", serde(skip))]
122    #[cfg_attr(feature = "serde", serde(default = "T::zero"))]
123    previous_value: T,
124}
125
126impl<T: Float> Neuron<T> {
127    /// Returns a new `Neuron` that takes `num_inputs` inputs and weights its output by `weight`.
128    pub fn new(id: NeuronId, num_inputs: usize, weight: T) -> Self {
129        Self {
130            id,
131            num_inputs,
132            weight,
133            current_value: None,
134            previous_value: T::zero(),
135        }
136    }
137
138    /// Returns the id of this `Neuron`.
139    pub fn id(&self) -> NeuronId {
140        self.id
141    }
142
143    /// Returns the number of inputs required by this `Neuron`.
144    pub fn num_inputs(&self) -> usize {
145        self.num_inputs
146    }
147
148    /// Sets the number of inputs required by this `Neuron`.
149    pub(crate) fn set_num_inputs(&mut self, num_inputs: usize) {
150        self.num_inputs = num_inputs;
151    }
152
153    /// Returns the weight of this `Neuron`.
154    pub fn weight(&self) -> T {
155        self.weight
156    }
157
158    /// Returns a mutable reference to the weight of this `Neuron`.
159    pub(crate) fn mut_weight(&mut self) -> &mut T {
160        &mut self.weight
161    }
162
163    pub(crate) fn current_value(&self) -> Option<T> {
164        self.current_value
165    }
166
167    pub(crate) fn set_current_value(&mut self, value: Option<T>) {
168        self.current_value = value;
169    }
170
171    /// Returns the output value of this `Neuron` from the previous [`Network`][crate::Network]
172    /// evaluation, or zero if either no evaluation has occurred yet or
173    /// [`Network::clear_state`][crate::Network::clear_state] was called.
174    pub fn previous_value(&self) -> T {
175        self.previous_value
176    }
177
178    pub(crate) fn mut_previous_value(&mut self) -> &mut T {
179        &mut self.previous_value
180    }
181}
182
183/// A forward jumper gene.
184///
185/// Adds a connection to the output of a source neuron with a higher depth than the parent neuron
186/// of the jumper.
187#[derive(Clone, Debug, PartialEq)]
188#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
189pub struct ForwardJumper<T> {
190    // The ID of the source neuron
191    source_id: NeuronId,
192    // The weight of the forward jumper connection
193    // This replaces the weight of the source neuron
194    weight: T,
195}
196
197impl<T: Float> ForwardJumper<T> {
198    /// Returns a new `ForwardJumper` that connects to the output of the neuron with the id and
199    /// weights it by `weight`.
200    pub fn new(source_id: NeuronId, weight: T) -> Self {
201        Self { source_id, weight }
202    }
203
204    /// Returns the id of the source neuron of this `ForwardJumper`.
205    pub fn source_id(&self) -> NeuronId {
206        self.source_id
207    }
208
209    /// Returns the weight of this `ForwardJumper`.
210    pub fn weight(&self) -> T {
211        self.weight
212    }
213
214    /// Returns a mutable reference to the weight of this `ForwardJumper`.
215    pub fn mut_weight(&mut self) -> &mut T {
216        &mut self.weight
217    }
218}
219
220/// A recurrent jumper gene.
221///
222/// Adds a connection to the output from the previous [`Network`][crate::Network] evaluation of a
223/// source [`Neuron`] with any depth.
224#[derive(Clone, Debug, PartialEq)]
225#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
226pub struct RecurrentJumper<T> {
227    // The ID of the source neuron
228    source_id: NeuronId,
229    // The weight of the forward jumper connection
230    // This replaces the weight of the source neuron
231    weight: T,
232}
233
234impl<T: Float> RecurrentJumper<T> {
235    /// Returns a new `RecurrentJumper` that connects to the output of the neuron with the id and
236    /// weights it by `weight`.
237    pub fn new(source_id: NeuronId, weight: T) -> Self {
238        Self { source_id, weight }
239    }
240
241    /// Returns the id of the source neuron of this `ForwardJumper`.
242    pub fn source_id(&self) -> NeuronId {
243        self.source_id
244    }
245
246    /// Returns the weight of this `RecurrentJumper`.
247    pub fn weight(&self) -> T {
248        self.weight
249    }
250
251    /// Returns a mutable reference to the weight of this `RecurrentJumper`.
252    pub fn mut_weight(&mut self) -> &mut T {
253        &mut self.weight
254    }
255}
256
257/// A single gene in a genome, which can be either a [`Bias`], [`Input`], [`Neuron`],
258/// [`ForwardJumper`], or [`RecurrentJumper`].
259#[derive(Clone, Debug, PartialEq)]
260#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
261#[cfg_attr(feature = "serde", serde(tag = "kind"))]
262#[cfg_attr(feature = "serde", serde(rename_all = "lowercase"))]
263pub enum Gene<T: Float> {
264    /// See [`Bias`].
265    Bias(Bias<T>),
266    /// See [`Input`].
267    Input(Input<T>),
268    /// See [`Neuron`].
269    Neuron(Neuron<T>),
270    /// See [`ForwardJumper`].
271    ForwardJumper(ForwardJumper<T>),
272    /// See [`RecurrentJumper`].
273    RecurrentJumper(RecurrentJumper<T>),
274}
275
276impl<T: Float> Gene<T> {
277    /// Returns the weight of this `Gene` or its value if it is a [`Bias`].
278    pub fn weight(&self) -> T {
279        match self {
280            Self::Bias(bias) => bias.value(),
281            Self::Input(input) => input.weight(),
282            Self::Neuron(neuron) => neuron.weight(),
283            Self::ForwardJumper(forward) => forward.weight(),
284            Self::RecurrentJumper(recurrent) => recurrent.weight(),
285        }
286    }
287
288    /// Returns a mutable reference to the weight of this `Gene` or its value if it is a [`Bias`].
289    pub(crate) fn mut_weight(&mut self) -> &mut T {
290        match self {
291            Self::Bias(bias) => bias.mut_value(),
292            Self::Input(input) => input.mut_weight(),
293            Self::Neuron(neuron) => neuron.mut_weight(),
294            Self::ForwardJumper(forward) => forward.mut_weight(),
295            Self::RecurrentJumper(recurrent) => recurrent.mut_weight(),
296        }
297    }
298
299    /// Returns whether this is a [`Bias`] gene.
300    pub fn is_bias(&self) -> bool {
301        matches!(self, Self::Bias(_))
302    }
303
304    /// Returns whether this is a [`Input`] gene.
305    pub fn is_input(&self) -> bool {
306        matches!(self, Self::Input(_))
307    }
308
309    /// Returns whether this is a [`Neuron`] gene.
310    pub fn is_neuron(&self) -> bool {
311        matches!(self, Self::Neuron(_))
312    }
313
314    /// Returns whether this is a [`ForwardJumper`] gene.
315    pub fn is_forward_jumper(&self) -> bool {
316        matches!(self, Self::ForwardJumper(_))
317    }
318
319    /// Returns whether this is a [`RecurrentJumper`] gene.
320    pub fn is_recurrent_jumper(&self) -> bool {
321        matches!(self, Self::RecurrentJumper(_))
322    }
323
324    /// Returns a reference to the contained [`Bias`] if this is a bias gene.
325    pub fn as_bias(&self) -> Option<&Bias<T>> {
326        if let Self::Bias(bias) = self {
327            Some(bias)
328        } else {
329            None
330        }
331    }
332
333    /// Returns a reference to the contained [`Input`] if this is an input gene.
334    pub fn as_input(&self) -> Option<&Input<T>> {
335        if let Self::Input(input) = self {
336            Some(input)
337        } else {
338            None
339        }
340    }
341
342    /// Returns a reference to the contained [`Neuron`] if this is a neuron gene.
343    pub fn as_neuron(&self) -> Option<&Neuron<T>> {
344        if let Self::Neuron(neuron) = self {
345            Some(neuron)
346        } else {
347            None
348        }
349    }
350
351    /// Returns a reference to the contained [`ForwardJumper`] if this is a forward jumper gene.
352    pub fn as_forward_jumper(&self) -> Option<&ForwardJumper<T>> {
353        if let Self::ForwardJumper(forward) = self {
354            Some(forward)
355        } else {
356            None
357        }
358    }
359
360    /// Returns a reference to the contained [`RecurrentJumper`] if this is a recurrent jumper gene.
361    pub fn as_recurrent_jumper(&self) -> Option<&RecurrentJumper<T>> {
362        if let Self::RecurrentJumper(recurrent) = self {
363            Some(recurrent)
364        } else {
365            None
366        }
367    }
368
369    /// Returns a mutable reference to the contained [`Neuron`] if this is a neuron gene.
370    pub(crate) fn as_mut_neuron(&mut self) -> Option<&mut Neuron<T>> {
371        if let Self::Neuron(neuron) = self {
372            Some(neuron)
373        } else {
374            None
375        }
376    }
377}
378
379impl<T: Float> From<Bias<T>> for Gene<T> {
380    fn from(bias: Bias<T>) -> Self {
381        Self::Bias(bias)
382    }
383}
384
385impl<T: Float> From<Input<T>> for Gene<T> {
386    fn from(input: Input<T>) -> Self {
387        Self::Input(input)
388    }
389}
390
391impl<T: Float> From<Neuron<T>> for Gene<T> {
392    fn from(neuron: Neuron<T>) -> Self {
393        Self::Neuron(neuron)
394    }
395}
396
397impl<T: Float> From<ForwardJumper<T>> for Gene<T> {
398    fn from(forward: ForwardJumper<T>) -> Self {
399        Self::ForwardJumper(forward)
400    }
401}
402
403impl<T: Float> From<RecurrentJumper<T>> for Gene<T> {
404    fn from(recurrent: RecurrentJumper<T>) -> Self {
405        Self::RecurrentJumper(recurrent)
406    }
407}
408
409/// Like [`Gene`], but cannot be a [`Neuron`] gene.
410#[derive(Clone, Debug, PartialEq)]
411pub enum NonNeuronGene<T> {
412    /// See [`Bias`].
413    Bias(Bias<T>),
414    /// See [`Input`].
415    Input(Input<T>),
416    /// See [`ForwardJumper`].
417    ForwardJumper(ForwardJumper<T>),
418    /// See [`RecurrentJumper`].
419    RecurrentJumper(RecurrentJumper<T>),
420}
421
422impl<T> From<Bias<T>> for NonNeuronGene<T> {
423    fn from(bias: Bias<T>) -> Self {
424        Self::Bias(bias)
425    }
426}
427
428impl<T> From<Input<T>> for NonNeuronGene<T> {
429    fn from(input: Input<T>) -> Self {
430        Self::Input(input)
431    }
432}
433
434impl<T> From<ForwardJumper<T>> for NonNeuronGene<T> {
435    fn from(forward: ForwardJumper<T>) -> Self {
436        Self::ForwardJumper(forward)
437    }
438}
439
440impl<T> From<RecurrentJumper<T>> for NonNeuronGene<T> {
441    fn from(recurrent: RecurrentJumper<T>) -> Self {
442        Self::RecurrentJumper(recurrent)
443    }
444}
445
446impl<T: Float> From<NonNeuronGene<T>> for Gene<T> {
447    fn from(gene: NonNeuronGene<T>) -> Self {
448        match gene {
449            NonNeuronGene::Bias(bias) => Self::Bias(bias),
450            NonNeuronGene::Input(input) => Self::Input(input),
451            NonNeuronGene::ForwardJumper(forward) => Self::ForwardJumper(forward),
452            NonNeuronGene::RecurrentJumper(recurrent) => Self::RecurrentJumper(recurrent),
453        }
454    }
455}