Skip to main content

feagi_brain_development/connectivity/core_morphologies/
bitmask.rs

1// Copyright 2025 Neuraville Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5Bitmask encoder/decoder morphology implementation.
6
7Supports function-type morphologies:
8- bitmask_encoder_x / bitmask_encoder_y / bitmask_encoder_z
9- bitmask_decoder_x / bitmask_decoder_y / bitmask_decoder_z
10*/
11
12use crate::types::BduResult;
13use feagi_npu_neural::types::{NeuronId, SynapticPsp, SynapticWeight};
14use feagi_npu_neural::SynapseType;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum BitmaskAxis {
18    X,
19    Y,
20    Z,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum BitmaskMode {
25    Encoder,
26    Decoder,
27}
28
29#[allow(clippy::too_many_arguments)]
30pub fn apply_bitmask_morphology_with_dimensions(
31    npu: &mut feagi_npu_burst_engine::DynamicNPU,
32    src_area_id: u32,
33    dst_area_id: u32,
34    src_dimensions: (usize, usize, usize),
35    dst_dimensions: (usize, usize, usize),
36    axis: BitmaskAxis,
37    mode: BitmaskMode,
38    weight: f32,
39    psp: f32,
40    synapse_attractivity: u8,
41    synapse_type: SynapseType,
42    delay_bursts: u8,
43) -> BduResult<u32> {
44    use crate::rng::get_rng;
45    use rand::Rng;
46
47    let mut rng = get_rng();
48
49    let src_neurons = npu.get_neurons_in_cortical_area(src_area_id);
50    if src_neurons.is_empty() {
51        return Ok(0);
52    }
53
54    if dst_dimensions.0 == 0 || dst_dimensions.1 == 0 || dst_dimensions.2 == 0 {
55        return Ok(0);
56    }
57
58    let mut dst_pos_map = std::collections::HashMap::new();
59    for dst_nid in npu.get_neurons_in_cortical_area(dst_area_id) {
60        if let Some(coords) = npu.get_neuron_coordinates(dst_nid) {
61            dst_pos_map.insert(coords, dst_nid);
62        }
63    }
64
65    let mut synapse_count = 0u32;
66
67    for src_nid in src_neurons {
68        let Some(src_pos) = npu.get_neuron_coordinates(src_nid) else {
69            continue;
70        };
71
72        let dst_positions =
73            map_positions_for_bitmask(src_pos, src_dimensions, dst_dimensions, axis, mode);
74
75        for dst_pos in dst_positions {
76            // Note: Keep nested conditionals to maintain Rust 2021 compatibility.
77            #[allow(clippy::collapsible_if)]
78            if let Some(&dst_nid) = dst_pos_map.get(&dst_pos) {
79                if rng.gen_range(0..100) < synapse_attractivity
80                    && npu
81                        .add_synapse(
82                            NeuronId(src_nid),
83                            NeuronId(dst_nid),
84                            SynapticWeight(weight),
85                            SynapticPsp(psp),
86                            synapse_type,
87                            0,
88                            delay_bursts,
89                        )
90                        .is_ok()
91                {
92                    synapse_count += 1;
93                }
94            }
95        }
96    }
97
98    Ok(synapse_count)
99}
100
101fn map_positions_for_bitmask(
102    src_pos: (u32, u32, u32),
103    src_dimensions: (usize, usize, usize),
104    dst_dimensions: (usize, usize, usize),
105    axis: BitmaskAxis,
106    mode: BitmaskMode,
107) -> Vec<(u32, u32, u32)> {
108    let src_axis_len = axis_len(src_dimensions, axis);
109    let dst_axis_len = axis_len(dst_dimensions, axis);
110    if src_axis_len == 0 || dst_axis_len == 0 {
111        return Vec::new();
112    }
113
114    let clamped = (
115        clamp_to_dim(src_pos.0, dst_dimensions.0),
116        clamp_to_dim(src_pos.1, dst_dimensions.1),
117        clamp_to_dim(src_pos.2, dst_dimensions.2),
118    );
119
120    match mode {
121        BitmaskMode::Encoder => {
122            let src_bit_index = axis_value(src_pos, axis) as usize;
123            if src_bit_index >= src_axis_len {
124                return Vec::new();
125            }
126
127            let mut out = Vec::new();
128            for dst_axis_index in 0..dst_axis_len {
129                if bit_is_set_msb(dst_axis_index as u32, src_bit_index, src_axis_len) {
130                    out.push(compose_pos(axis, dst_axis_index as u32, clamped));
131                }
132            }
133            out
134        }
135        BitmaskMode::Decoder => {
136            let encoded_axis_value = axis_value(src_pos, axis);
137            let mut out = Vec::new();
138            for dst_bit_index in 0..dst_axis_len {
139                if bit_is_set_msb(encoded_axis_value, dst_bit_index, dst_axis_len) {
140                    out.push(compose_pos(axis, dst_bit_index as u32, clamped));
141                }
142            }
143            out
144        }
145    }
146}
147
148#[inline]
149fn axis_len(dim: (usize, usize, usize), axis: BitmaskAxis) -> usize {
150    match axis {
151        BitmaskAxis::X => dim.0,
152        BitmaskAxis::Y => dim.1,
153        BitmaskAxis::Z => dim.2,
154    }
155}
156
157#[inline]
158fn axis_value(pos: (u32, u32, u32), axis: BitmaskAxis) -> u32 {
159    match axis {
160        BitmaskAxis::X => pos.0,
161        BitmaskAxis::Y => pos.1,
162        BitmaskAxis::Z => pos.2,
163    }
164}
165
166#[inline]
167fn compose_pos(
168    axis: BitmaskAxis,
169    axis_value: u32,
170    clamped_src_pos: (u32, u32, u32),
171) -> (u32, u32, u32) {
172    match axis {
173        BitmaskAxis::X => (axis_value, clamped_src_pos.1, clamped_src_pos.2),
174        BitmaskAxis::Y => (clamped_src_pos.0, axis_value, clamped_src_pos.2),
175        BitmaskAxis::Z => (clamped_src_pos.0, clamped_src_pos.1, axis_value),
176    }
177}
178
179#[inline]
180fn clamp_to_dim(value: u32, dim_len: usize) -> u32 {
181    if dim_len == 0 {
182        return 0;
183    }
184    value.min((dim_len - 1) as u32)
185}
186
187#[inline]
188fn bit_is_set_msb(value: u32, bit_index_from_msb: usize, bit_width: usize) -> bool {
189    if bit_index_from_msb >= bit_width {
190        return false;
191    }
192    let lsb_index = bit_width - 1 - bit_index_from_msb;
193    if lsb_index >= u32::BITS as usize {
194        return false;
195    }
196    (value & (1u32 << lsb_index)) != 0
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn test_encoder_x_uses_msb_convention() {
205        // Source X index 1 with width 3 checks middle bit in 3-bit destination values.
206        let out = map_positions_for_bitmask(
207            (1, 2, 3),
208            (3, 10, 10),
209            (8, 10, 10),
210            BitmaskAxis::X,
211            BitmaskMode::Encoder,
212        );
213
214        assert_eq!(
215            out,
216            vec![(2, 2, 3), (3, 2, 3), (6, 2, 3), (7, 2, 3)],
217            "Expected X positions where middle bit is set for 3-bit values"
218        );
219    }
220
221    #[test]
222    fn test_decoder_x_uses_msb_convention() {
223        // Encoded value 5 is 0101 in width-4 bitspace => active dst bits at indices 1 and 3.
224        let out = map_positions_for_bitmask(
225            (5, 1, 1),
226            (16, 10, 10),
227            (4, 10, 10),
228            BitmaskAxis::X,
229            BitmaskMode::Decoder,
230        );
231
232        assert_eq!(out, vec![(1, 1, 1), (3, 1, 1)]);
233    }
234}