micro_hnsw_wasm/
lib.rs

1//! Micro HNSW v2.3 - Neuromorphic HNSW with Novel Discoveries
2//! Target: <12KB WASM with multi-core support
3//!
4//! Features:
5//! - Multiple distance metrics (L2, Cosine, Dot)
6//! - Multi-core sharding (256 cores × 32 vectors = 8K total)
7//! - Batch operations with 4-vector batching
8//! - Beam search for better recall
9//! - Result merging across cores
10//! - Node types for Cypher-style typed graphs (16 types)
11//! - Edge weights for GNN message passing
12//! - Vector updates for online learning/GNN propagation
13//! - **Spiking Neural Network integration with LIF neurons**
14//! - **STDP (Spike-Timing Dependent Plasticity) learning**
15//!
16//! ## Novel Neuromorphic Discoveries (v2.3)
17//! - **Spike-Timing Vector Encoding**: Convert vectors to temporal spike patterns
18//! - **Homeostatic Plasticity**: Self-stabilizing network activity
19//! - **Oscillatory Resonance**: Frequency-tuned search amplification
20//! - **Temporal Pattern Recognition**: Spike-based similarity matching
21//! - **Winner-Take-All Circuits**: Competitive neural selection
22//! - **Dendritic Computation**: Non-linear local processing
23
24#![no_std]
25
26// ============ Configuration ============
27const MAX_VECTORS: usize = 32;       // Per core (256 × 32 = 8K total)
28const MAX_DIMS: usize = 16;          // Vector dimensions
29const MAX_NEIGHBORS: usize = 6;      // Graph connectivity
30const BEAM_WIDTH: usize = 3;         // Search beam width
31
32// ============ Spiking Neural Network Configuration ============
33const TAU_MEMBRANE: f32 = 20.0;      // Membrane time constant (ms)
34const TAU_REFRAC: f32 = 2.0;         // Refractory period (ms)
35const V_RESET: f32 = 0.0;            // Reset potential
36const V_REST: f32 = 0.0;             // Resting potential
37const STDP_A_PLUS: f32 = 0.01;       // STDP potentiation magnitude
38const STDP_A_MINUS: f32 = 0.012;     // STDP depression magnitude
39const TAU_STDP: f32 = 20.0;          // STDP time constant
40const INV_TAU_STDP: f32 = 0.05;      // Pre-computed 1/TAU_STDP for optimization
41const INV_255: f32 = 0.00392157;     // Pre-computed 1/255 for weight normalization
42
43// ============ Novel Neuromorphic Configuration ============
44const HOMEOSTATIC_TARGET: f32 = 0.1; // Target spike rate (spikes/ms)
45const HOMEOSTATIC_TAU: f32 = 1000.0; // Homeostasis time constant (slow)
46const OSCILLATOR_FREQ: f32 = 40.0;   // Gamma oscillation frequency (Hz)
47const WTA_INHIBITION: f32 = 0.8;     // Winner-take-all lateral inhibition
48const DENDRITIC_NONLIN: f32 = 2.0;   // Dendritic nonlinearity exponent
49const SPIKE_ENCODING_RES: u8 = 8;    // Temporal encoding resolution (bits)
50
51// ============ Types ============
52#[repr(u8)]
53#[derive(Clone, Copy, PartialEq)]
54pub enum Metric { L2 = 0, Cosine = 1, Dot = 2 }
55
56#[derive(Clone, Copy)]
57#[repr(C)]
58pub struct Vector {
59    data: [f32; MAX_DIMS],
60    norm: f32,
61}
62
63#[derive(Clone, Copy)]
64#[repr(C)]
65pub struct Node {
66    neighbors: [u8; MAX_NEIGHBORS],
67    count: u8,
68}
69
70#[derive(Clone, Copy)]
71#[repr(C)]
72pub struct SearchResult {
73    pub idx: u8,
74    pub core_id: u8,
75    pub distance: f32,
76}
77
78#[repr(C)]
79pub struct MicroHnsw {
80    vectors: [Vector; MAX_VECTORS],
81    nodes: [Node; MAX_VECTORS],
82    count: u8,
83    dims: u8,
84    metric: Metric,
85    core_id: u8,
86}
87
88// ============ Static Storage ============
89static mut HNSW: MicroHnsw = MicroHnsw {
90    vectors: [Vector { data: [0.0; MAX_DIMS], norm: 0.0 }; MAX_VECTORS],
91    nodes: [Node { neighbors: [0; MAX_NEIGHBORS], count: 0 }; MAX_VECTORS],
92    count: 0,
93    dims: 16,
94    metric: Metric::L2,
95    core_id: 0,
96};
97
98static mut QUERY: [f32; MAX_DIMS] = [0.0; MAX_DIMS];
99static mut INSERT: [f32; MAX_DIMS] = [0.0; MAX_DIMS];
100static mut RESULTS: [SearchResult; 16] = [SearchResult { idx: 255, core_id: 0, distance: 0.0 }; 16];
101static mut GLOBAL: [SearchResult; 16] = [SearchResult { idx: 255, core_id: 0, distance: 0.0 }; 16];
102
103// ============ GNN/Cypher Extensions ============
104// Node types: 4 bits per node (16 types), packed 2 per byte = 16 bytes
105static mut NODE_TYPES: [u8; MAX_VECTORS / 2] = [0; 16];
106// Edge weights: 4 bits per edge (packed), uniform per node = 32 bytes
107static mut EDGE_WEIGHTS: [u8; MAX_VECTORS] = [255; 32];
108// Delta buffer for vector updates
109static mut DELTA: [f32; MAX_DIMS] = [0.0; MAX_DIMS];
110
111
112// ============ Spiking Neural Network State ============
113// Membrane potentials: LIF neuron states (one per vector)
114static mut MEMBRANE: [f32; MAX_VECTORS] = [0.0; MAX_VECTORS];
115// Adaptive thresholds: Dynamic firing thresholds
116static mut THRESHOLD: [f32; MAX_VECTORS] = [1.0; MAX_VECTORS];
117// Last spike time: For STDP calculations
118static mut LAST_SPIKE: [f32; MAX_VECTORS] = [-1000.0; MAX_VECTORS];
119// Refractory state: Time remaining in refractory period
120static mut REFRAC: [f32; MAX_VECTORS] = [0.0; MAX_VECTORS];
121// Current simulation time
122static mut SIM_TIME: f32 = 0.0;
123// Spike output buffer: Which neurons spiked this timestep
124static mut SPIKES: [bool; MAX_VECTORS] = [false; MAX_VECTORS];
125
126// ============ Novel Neuromorphic State ============
127// Homeostatic plasticity: running average spike rate
128static mut SPIKE_RATE: [f32; MAX_VECTORS] = [0.0; MAX_VECTORS];
129// Oscillator phase: gamma rhythm for synchronization
130static mut OSCILLATOR_PHASE: f32 = 0.0;
131// Dendritic compartments: local nonlinear integration
132static mut DENDRITE: [[f32; MAX_NEIGHBORS]; MAX_VECTORS] = [[0.0; MAX_NEIGHBORS]; MAX_VECTORS];
133// Temporal spike pattern buffer (recent spikes encoded as bits)
134static mut SPIKE_PATTERN: [u32; MAX_VECTORS] = [0; MAX_VECTORS];
135// Resonance amplitude for each neuron
136static mut RESONANCE: [f32; MAX_VECTORS] = [0.0; MAX_VECTORS];
137// Winner-take-all state: inhibition accumulator
138static mut WTA_INHIBIT: f32 = 0.0;
139
140// ============ Math ============
141#[inline(always)]
142fn sqrt_fast(x: f32) -> f32 {
143    if x <= 0.0 { return 0.0; }
144    let i = 0x5f3759df - (x.to_bits() >> 1);
145    let y = f32::from_bits(i);
146    x * y * (1.5 - 0.5 * x * y * y)
147}
148
149#[inline(always)]
150fn norm(v: &[f32], n: usize) -> f32 {
151    let mut s = 0.0f32;
152    let mut i = 0;
153    while i < n { s += v[i] * v[i]; i += 1; }
154    sqrt_fast(s)
155}
156
157#[inline(always)]
158fn dist_l2(a: &[f32], b: &[f32], n: usize) -> f32 {
159    let mut s = 0.0f32;
160    let mut i = 0;
161    while i < n { let d = a[i] - b[i]; s += d * d; i += 1; }
162    s
163}
164
165#[inline(always)]
166fn dist_dot(a: &[f32], b: &[f32], n: usize) -> f32 {
167    let mut s = 0.0f32;
168    let mut i = 0;
169    while i < n { s += a[i] * b[i]; i += 1; }
170    -s
171}
172
173#[inline(always)]
174fn dist_cos(a: &[f32], an: f32, b: &[f32], bn: f32, n: usize) -> f32 {
175    if an == 0.0 || bn == 0.0 { return 1.0; }
176    let mut d = 0.0f32;
177    let mut i = 0;
178    while i < n { d += a[i] * b[i]; i += 1; }
179    1.0 - d / (an * bn)
180}
181
182#[inline(always)]
183fn distance(q: &[f32], qn: f32, idx: u8) -> f32 {
184    unsafe {
185        let n = HNSW.dims as usize;
186        let v = &HNSW.vectors[idx as usize];
187        match HNSW.metric {
188            Metric::Cosine => dist_cos(q, qn, &v.data[..n], v.norm, n),
189            Metric::Dot => dist_dot(q, &v.data[..n], n),
190            Metric::L2 => dist_l2(q, &v.data[..n], n),
191        }
192    }
193}
194
195// ============ Core API ============
196
197/// Initialize: init(dims, metric, core_id)
198/// metric: 0=L2, 1=Cosine, 2=Dot
199#[no_mangle]
200pub extern "C" fn init(dims: u8, metric: u8, core_id: u8) {
201    unsafe {
202        HNSW.count = 0;
203        HNSW.dims = dims.min(MAX_DIMS as u8);
204        HNSW.metric = match metric { 1 => Metric::Cosine, 2 => Metric::Dot, _ => Metric::L2 };
205        HNSW.core_id = core_id;
206    }
207}
208
209#[no_mangle]
210pub extern "C" fn get_insert_ptr() -> *mut f32 { unsafe { INSERT.as_mut_ptr() } }
211
212#[no_mangle]
213pub extern "C" fn get_query_ptr() -> *mut f32 { unsafe { QUERY.as_mut_ptr() } }
214
215#[no_mangle]
216pub extern "C" fn get_result_ptr() -> *const SearchResult { unsafe { RESULTS.as_ptr() } }
217
218#[no_mangle]
219pub extern "C" fn get_global_ptr() -> *const SearchResult { unsafe { GLOBAL.as_ptr() } }
220
221/// Insert vector from INSERT buffer, returns index or 255 if full
222#[no_mangle]
223pub extern "C" fn insert() -> u8 {
224    unsafe {
225        if HNSW.count >= MAX_VECTORS as u8 { return 255; }
226
227        let idx = HNSW.count;
228        let n = HNSW.dims as usize;
229
230        // Copy vector and compute norm
231        let mut i = 0;
232        while i < n { HNSW.vectors[idx as usize].data[i] = INSERT[i]; i += 1; }
233        HNSW.vectors[idx as usize].norm = norm(&INSERT[..n], n);
234        HNSW.nodes[idx as usize].count = 0;
235
236        // Connect to nearest neighbors
237        if idx > 0 {
238            let qn = HNSW.vectors[idx as usize].norm;
239            let mut best = [0u8; MAX_NEIGHBORS];
240            let mut best_d = [f32::MAX; MAX_NEIGHBORS];
241            let mut found = 0usize;
242
243            // Find M nearest
244            let mut j = 0u8;
245            while j < idx {
246                let d = distance(&INSERT[..n], qn, j);
247                if found < MAX_NEIGHBORS || d < best_d[found.saturating_sub(1)] {
248                    let mut p = found.min(MAX_NEIGHBORS - 1);
249                    while p > 0 && best_d[p - 1] > d {
250                        if p < MAX_NEIGHBORS { best[p] = best[p - 1]; best_d[p] = best_d[p - 1]; }
251                        p -= 1;
252                    }
253                    best[p] = j; best_d[p] = d;
254                    if found < MAX_NEIGHBORS { found += 1; }
255                }
256                j += 1;
257            }
258
259            // Add bidirectional edges
260            let mut k = 0;
261            while k < found {
262                let nb = best[k];
263                let c = HNSW.nodes[idx as usize].count as usize;
264                if c < MAX_NEIGHBORS {
265                    HNSW.nodes[idx as usize].neighbors[c] = nb;
266                    HNSW.nodes[idx as usize].count += 1;
267                }
268                let nc = HNSW.nodes[nb as usize].count as usize;
269                if nc < MAX_NEIGHBORS {
270                    HNSW.nodes[nb as usize].neighbors[nc] = idx;
271                    HNSW.nodes[nb as usize].count += 1;
272                }
273                k += 1;
274            }
275        }
276
277        HNSW.count += 1;
278        idx
279    }
280}
281
282/// Search for k nearest neighbors using beam search
283#[no_mangle]
284pub extern "C" fn search(k: u8) -> u8 {
285    unsafe {
286        if HNSW.count == 0 { return 0; }
287
288        let n = HNSW.dims as usize;
289        let k = k.min(16).min(HNSW.count);
290        let qn = norm(&QUERY[..n], n);
291
292        // Reset
293        let mut i = 0;
294        while i < 16 { RESULTS[i] = SearchResult { idx: 255, core_id: HNSW.core_id, distance: f32::MAX }; i += 1; }
295
296        let mut visited = [false; MAX_VECTORS];
297        let mut beam = [255u8; BEAM_WIDTH];
298        let mut beam_d = [f32::MAX; BEAM_WIDTH];
299
300        // Start from entry point
301        beam[0] = 0;
302        beam_d[0] = distance(&QUERY[..n], qn, 0);
303        visited[0] = true;
304        RESULTS[0] = SearchResult { idx: 0, core_id: HNSW.core_id, distance: beam_d[0] };
305        let mut rc = 1u8;
306        let mut bs = 1usize;
307
308        // Beam search iterations
309        let mut iter = 0u8;
310        while iter < k.max(BEAM_WIDTH as u8) && bs > 0 {
311            let mut nb = [255u8; BEAM_WIDTH];
312            let mut nd = [f32::MAX; BEAM_WIDTH];
313            let mut ns = 0usize;
314
315            let mut b = 0;
316            while b < bs {
317                if beam[b] == 255 { b += 1; continue; }
318                let node = &HNSW.nodes[beam[b] as usize];
319
320                let mut j = 0u8;
321                while j < node.count {
322                    let nbr = node.neighbors[j as usize];
323                    j += 1;
324                    if visited[nbr as usize] { continue; }
325                    visited[nbr as usize] = true;
326
327                    let d = distance(&QUERY[..n], qn, nbr);
328
329                    // Update beam
330                    if ns < BEAM_WIDTH || d < nd[ns.saturating_sub(1)] {
331                        let mut p = ns.min(BEAM_WIDTH - 1);
332                        while p > 0 && nd[p - 1] > d {
333                            if p < BEAM_WIDTH { nb[p] = nb[p - 1]; nd[p] = nd[p - 1]; }
334                            p -= 1;
335                        }
336                        nb[p] = nbr; nd[p] = d;
337                        if ns < BEAM_WIDTH { ns += 1; }
338                    }
339
340                    // Update results
341                    if rc < 16 || d < RESULTS[(rc - 1) as usize].distance {
342                        let mut p = rc.min(15) as usize;
343                        while p > 0 && RESULTS[p - 1].distance > d {
344                            if p < 16 { RESULTS[p] = RESULTS[p - 1]; }
345                            p -= 1;
346                        }
347                        if p < 16 {
348                            RESULTS[p] = SearchResult { idx: nbr, core_id: HNSW.core_id, distance: d };
349                            if rc < 16 { rc += 1; }
350                        }
351                    }
352                }
353                b += 1;
354            }
355
356            beam = nb; beam_d = nd; bs = ns;
357            iter += 1;
358        }
359
360        rc.min(k)
361    }
362}
363
364// ============ Multi-Core ============
365
366/// Merge results from another core into global buffer
367#[no_mangle]
368pub extern "C" fn merge(ptr: *const SearchResult, cnt: u8) -> u8 {
369    unsafe {
370        let mut gc = 0u8;
371        while gc < 16 && GLOBAL[gc as usize].idx != 255 { gc += 1; }
372
373        let mut i = 0u8;
374        while i < cnt.min(16) {
375            let r = &*ptr.add(i as usize);
376            i += 1;
377            if r.idx == 255 { continue; }
378
379            if gc < 16 || r.distance < GLOBAL[(gc - 1) as usize].distance {
380                let mut p = gc.min(15) as usize;
381                while p > 0 && GLOBAL[p - 1].distance > r.distance {
382                    if p < 16 { GLOBAL[p] = GLOBAL[p - 1]; }
383                    p -= 1;
384                }
385                if p < 16 {
386                    GLOBAL[p] = *r;
387                    if gc < 16 { gc += 1; }
388                }
389            }
390        }
391        gc
392    }
393}
394
395/// Clear global results
396#[no_mangle]
397pub extern "C" fn clear_global() {
398    unsafe {
399        let mut i = 0;
400        while i < 16 { GLOBAL[i] = SearchResult { idx: 255, core_id: 0, distance: f32::MAX }; i += 1; }
401    }
402}
403
404// ============ Info ============
405#[no_mangle]
406pub extern "C" fn count() -> u8 { unsafe { HNSW.count } }
407
408#[no_mangle]
409pub extern "C" fn get_core_id() -> u8 { unsafe { HNSW.core_id } }
410
411#[no_mangle]
412pub extern "C" fn get_metric() -> u8 { unsafe { HNSW.metric as u8 } }
413
414#[no_mangle]
415pub extern "C" fn get_dims() -> u8 { unsafe { HNSW.dims } }
416
417#[no_mangle]
418pub extern "C" fn get_capacity() -> u8 { MAX_VECTORS as u8 }
419
420// ============ Cypher Node Types ============
421
422/// Set node type (0-15) for Cypher-style typed queries
423/// Types packed 2 per byte (4 bits each)
424#[no_mangle]
425pub extern "C" fn set_node_type(idx: u8, node_type: u8) {
426    if idx >= MAX_VECTORS as u8 { return; }
427    unsafe {
428        let byte_idx = (idx / 2) as usize;
429        let node_type = node_type & 0x0F; // Clamp to 4 bits
430        if idx & 1 == 0 {
431            NODE_TYPES[byte_idx] = (NODE_TYPES[byte_idx] & 0xF0) | node_type;
432        } else {
433            NODE_TYPES[byte_idx] = (NODE_TYPES[byte_idx] & 0x0F) | (node_type << 4);
434        }
435    }
436}
437
438/// Get node type (0-15)
439#[no_mangle]
440pub extern "C" fn get_node_type(idx: u8) -> u8 {
441    if idx >= MAX_VECTORS as u8 { return 0; }
442    unsafe {
443        let byte_idx = (idx / 2) as usize;
444        if idx & 1 == 0 {
445            NODE_TYPES[byte_idx] & 0x0F
446        } else {
447            NODE_TYPES[byte_idx] >> 4
448        }
449    }
450}
451
452/// Check if node type matches mask (for filtering in JS/host)
453#[no_mangle]
454pub extern "C" fn type_matches(idx: u8, type_mask: u16) -> u8 {
455    ((type_mask >> get_node_type(idx)) & 1) as u8
456}
457
458// ============ GNN Edge Weights ============
459
460/// Set node edge weight (uniform for all edges from this node, 0-255)
461#[no_mangle]
462pub extern "C" fn set_edge_weight(node: u8, weight: u8) {
463    if node < MAX_VECTORS as u8 { unsafe { EDGE_WEIGHTS[node as usize] = weight; } }
464}
465
466/// Get node edge weight
467#[no_mangle]
468pub extern "C" fn get_edge_weight(node: u8) -> u8 {
469    if node < MAX_VECTORS as u8 { unsafe { EDGE_WEIGHTS[node as usize] } } else { 0 }
470}
471
472/// Aggregate neighbors into DELTA buffer (GNN message passing)
473#[no_mangle]
474pub extern "C" fn aggregate_neighbors(idx: u8) {
475    unsafe {
476        if idx >= HNSW.count { return; }
477        let n = HNSW.dims as usize;
478        let nc = HNSW.nodes[idx as usize].count;
479        let mut d = 0;
480        while d < n { DELTA[d] = 0.0; d += 1; }
481        if nc == 0 { return; }
482        let mut i = 0u8;
483        while i < nc {
484            let nb = HNSW.nodes[idx as usize].neighbors[i as usize];
485            let w = EDGE_WEIGHTS[nb as usize] as f32;
486            d = 0;
487            while d < n { DELTA[d] += w * HNSW.vectors[nb as usize].data[d]; d += 1; }
488            i += 1;
489        }
490        let s = INV_255 / nc as f32;
491        d = 0; while d < n { DELTA[d] *= s; d += 1; }
492    }
493}
494
495// ============ Vector Updates ============
496
497/// Get delta buffer pointer for reading aggregated values
498#[no_mangle]
499pub extern "C" fn get_delta_ptr() -> *const f32 { unsafe { DELTA.as_ptr() } }
500
501/// Update vector: v = v + alpha * delta (in-place)
502#[no_mangle]
503pub extern "C" fn update_vector(idx: u8, alpha: f32) {
504    unsafe {
505        if idx >= HNSW.count { return; }
506        let n = HNSW.dims as usize;
507        let mut i = 0;
508        while i < n { HNSW.vectors[idx as usize].data[i] += alpha * DELTA[i]; i += 1; }
509        HNSW.vectors[idx as usize].norm = norm(&HNSW.vectors[idx as usize].data[..n], n);
510    }
511}
512
513/// Get mutable delta buffer pointer
514#[no_mangle]
515pub extern "C" fn set_delta_ptr() -> *mut f32 { unsafe { DELTA.as_mut_ptr() } }
516
517/// Combined HNSW-SNN cycle: search → convert to currents → inject
518/// Useful for linking vector similarity to neural activation
519#[no_mangle]
520pub extern "C" fn hnsw_to_snn(k: u8, gain: f32) -> u8 {
521    unsafe {
522        let found = search(k);
523        if found == 0 { return 0; }
524
525        // Convert search results to neural currents
526        let mut i = 0u8;
527        while i < found {
528            let r = &RESULTS[i as usize];
529            if r.idx != 255 {
530                // Inverse distance = stronger activation
531                let current = gain / (1.0 + r.distance);
532                MEMBRANE[r.idx as usize] += current;
533            }
534            i += 1;
535        }
536        found
537    }
538}
539
540// ============ Spiking Neural Network API ============
541
542/// Reset SNN state for all neurons
543#[no_mangle]
544pub extern "C" fn snn_reset() {
545    unsafe {
546        let mut i = 0;
547        while i < MAX_VECTORS {
548            MEMBRANE[i] = V_REST;
549            THRESHOLD[i] = 1.0;
550            LAST_SPIKE[i] = -1000.0;
551            REFRAC[i] = 0.0;
552            SPIKES[i] = false;
553            i += 1;
554        }
555        SIM_TIME = 0.0;
556    }
557}
558
559/// Set membrane potential for a neuron
560#[no_mangle]
561pub extern "C" fn snn_set_membrane(idx: u8, v: f32) {
562    if idx < MAX_VECTORS as u8 { unsafe { MEMBRANE[idx as usize] = v; } }
563}
564
565/// Get membrane potential
566#[no_mangle]
567pub extern "C" fn snn_get_membrane(idx: u8) -> f32 {
568    if idx < MAX_VECTORS as u8 { unsafe { MEMBRANE[idx as usize] } } else { 0.0 }
569}
570
571/// Set firing threshold for a neuron
572#[no_mangle]
573pub extern "C" fn snn_set_threshold(idx: u8, t: f32) {
574    if idx < MAX_VECTORS as u8 { unsafe { THRESHOLD[idx as usize] = t; } }
575}
576
577/// Inject current into a neuron (adds to membrane potential)
578#[no_mangle]
579pub extern "C" fn snn_inject(idx: u8, current: f32) {
580    if idx < MAX_VECTORS as u8 { unsafe { MEMBRANE[idx as usize] += current; } }
581}
582
583/// Get spike status (1 if spiked last step, 0 otherwise)
584#[no_mangle]
585pub extern "C" fn snn_spiked(idx: u8) -> u8 {
586    if idx < MAX_VECTORS as u8 { unsafe { SPIKES[idx as usize] as u8 } } else { 0 }
587}
588
589/// Get spike bitset (32 neurons packed into u32)
590#[no_mangle]
591pub extern "C" fn snn_get_spikes() -> u32 {
592    unsafe {
593        let mut bits = 0u32;
594        let mut i = 0;
595        while i < MAX_VECTORS { if SPIKES[i] { bits |= 1 << i; } i += 1; }
596        bits
597    }
598}
599
600/// LIF neuron step: simulate one timestep (dt in ms)
601/// Returns number of neurons that spiked
602#[no_mangle]
603pub extern "C" fn snn_step(dt: f32) -> u8 {
604    unsafe {
605        let decay = 1.0 - dt / TAU_MEMBRANE;
606        let mut spike_count = 0u8;
607
608        let mut i = 0u8;
609        while i < HNSW.count {
610            let idx = i as usize;
611            SPIKES[idx] = false;
612
613            // Skip if in refractory period
614            if REFRAC[idx] > 0.0 {
615                REFRAC[idx] -= dt;
616                i += 1;
617                continue;
618            }
619
620            // Leaky integration: V = V * decay
621            MEMBRANE[idx] *= decay;
622
623            // Check for spike
624            if MEMBRANE[idx] >= THRESHOLD[idx] {
625                SPIKES[idx] = true;
626                spike_count += 1;
627                LAST_SPIKE[idx] = SIM_TIME;
628                MEMBRANE[idx] = V_RESET;
629                REFRAC[idx] = TAU_REFRAC;
630            }
631            i += 1;
632        }
633
634        SIM_TIME += dt;
635        spike_count
636    }
637}
638
639/// Propagate spikes to neighbors (injects current based on edge weights)
640/// Call after snn_step to propagate activity
641#[no_mangle]
642pub extern "C" fn snn_propagate(gain: f32) {
643    unsafe {
644        let mut i = 0u8;
645        while i < HNSW.count {
646            if !SPIKES[i as usize] { i += 1; continue; }
647
648            // This neuron spiked, inject current to neighbors
649            let nc = HNSW.nodes[i as usize].count;
650            let mut j = 0u8;
651            while j < nc {
652                let nb = HNSW.nodes[i as usize].neighbors[j as usize];
653                let w = EDGE_WEIGHTS[i as usize] as f32 / 255.0;
654                MEMBRANE[nb as usize] += gain * w;
655                j += 1;
656            }
657            i += 1;
658        }
659    }
660}
661
662/// STDP learning: adjust edge weights based on spike timing
663/// Call after snn_step to apply plasticity
664#[no_mangle]
665pub extern "C" fn snn_stdp() {
666    unsafe {
667        let mut i = 0u8;
668        while i < HNSW.count {
669            if !SPIKES[i as usize] { i += 1; continue; }
670
671            // Post-synaptic neuron spiked
672            let nc = HNSW.nodes[i as usize].count;
673            let mut j = 0u8;
674            while j < nc {
675                let pre = HNSW.nodes[i as usize].neighbors[j as usize];
676                let dt = LAST_SPIKE[pre as usize] - SIM_TIME;
677
678                // LTP: pre before post, LTD: pre after post
679                // Simplified exponential approximation
680                let dw = if dt < 0.0 {
681                    STDP_A_PLUS * (1.0 + dt * INV_TAU_STDP)  // dt negative, so this decays
682                } else {
683                    -STDP_A_MINUS * (1.0 - dt * INV_TAU_STDP)
684                };
685
686                // Update weight (clamped to 0-255 using integer math)
687                let w = EDGE_WEIGHTS[pre as usize] as i16 + (dw * 255.0) as i16;
688                EDGE_WEIGHTS[pre as usize] = if w < 0 { 0 } else if w > 255 { 255 } else { w as u8 };
689                j += 1;
690            }
691            i += 1;
692        }
693    }
694}
695
696/// Combined: step + propagate + optionally STDP
697/// Returns spike count
698#[no_mangle]
699pub extern "C" fn snn_tick(dt: f32, gain: f32, learn: u8) -> u8 {
700    let spikes = snn_step(dt);
701    snn_propagate(gain);
702    if learn != 0 { snn_stdp(); }
703    spikes
704}
705
706/// Get current simulation time
707#[no_mangle]
708pub extern "C" fn snn_get_time() -> f32 { unsafe { SIM_TIME } }
709
710// ============================================================================
711// NOVEL NEUROMORPHIC DISCOVERIES
712// ============================================================================
713
714// ============ Spike-Timing Vector Encoding ============
715// Novel discovery: Encode vectors as temporal spike patterns
716// Each dimension becomes a spike time within a coding window
717
718/// Encode vector to temporal spike pattern (rate-to-time conversion)
719/// Higher values → earlier spikes (first-spike coding)
720/// Returns encoded pattern as 32-bit bitmask
721#[no_mangle]
722pub extern "C" fn encode_vector_to_spikes(idx: u8) -> u32 {
723    unsafe {
724        if idx >= HNSW.count { return 0; }
725        let n = HNSW.dims as usize;
726        let mut pattern = 0u32;
727
728        // Normalize vector values to spike times
729        let mut max_val = 0.0f32;
730        let mut i = 0;
731        while i < n {
732            let v = HNSW.vectors[idx as usize].data[i];
733            if v > max_val { max_val = v; }
734            if -v > max_val { max_val = -v; }
735            i += 1;
736        }
737        if max_val == 0.0 { return 0; }
738
739        // Encode: high values → low bit positions (early spikes)
740        i = 0;
741        while i < n.min(SPIKE_ENCODING_RES as usize * 4) {
742            let normalized = (HNSW.vectors[idx as usize].data[i] + max_val) / (2.0 * max_val);
743            let slot = ((1.0 - normalized) * SPIKE_ENCODING_RES as f32) as u8;
744            let bit_pos = i as u8 + slot * (n as u8 / SPIKE_ENCODING_RES);
745            if bit_pos < 32 { pattern |= 1u32 << bit_pos; }
746            i += 1;
747        }
748
749        SPIKE_PATTERN[idx as usize] = pattern;
750        pattern
751    }
752}
753
754/// Compute spike-timing similarity between two spike patterns
755/// Uses Victor-Purpura-inspired metric: count matching spike times
756#[no_mangle]
757pub extern "C" fn spike_timing_similarity(a: u32, b: u32) -> f32 {
758    // Count matching spike positions
759    let matches = (a & b).count_ones() as f32;
760    let total = (a | b).count_ones() as f32;
761    if total == 0.0 { return 1.0; }
762    matches / total  // Jaccard-like similarity
763}
764
765/// Search using spike-timing representation
766/// Novel: temporal code matching instead of distance
767#[no_mangle]
768pub extern "C" fn spike_search(query_pattern: u32, k: u8) -> u8 {
769    unsafe {
770        if HNSW.count == 0 { return 0; }
771        let k = k.min(16).min(HNSW.count);
772
773        // Reset results
774        let mut i = 0;
775        while i < 16 {
776            RESULTS[i] = SearchResult { idx: 255, core_id: HNSW.core_id, distance: 0.0 };
777            i += 1;
778        }
779
780        let mut found = 0u8;
781        i = 0;
782        while i < HNSW.count as usize {
783            let sim = spike_timing_similarity(query_pattern, SPIKE_PATTERN[i]);
784            // Store as negative similarity for compatibility (lower = better)
785            let dist = 1.0 - sim;
786
787            if found < k || dist < RESULTS[(found - 1) as usize].distance {
788                let mut p = found.min(k - 1) as usize;
789                while p > 0 && RESULTS[p - 1].distance > dist {
790                    if p < 16 { RESULTS[p] = RESULTS[p - 1]; }
791                    p -= 1;
792                }
793                if p < 16 {
794                    RESULTS[p] = SearchResult {
795                        idx: i as u8,
796                        core_id: HNSW.core_id,
797                        distance: dist
798                    };
799                    if found < k { found += 1; }
800                }
801            }
802            i += 1;
803        }
804        found
805    }
806}
807
808// ============ Homeostatic Plasticity ============
809// Novel: Self-stabilizing network maintains target activity level
810// Prevents runaway excitation or complete silence
811
812/// Apply homeostatic plasticity: adjust thresholds to maintain target rate
813#[no_mangle]
814pub extern "C" fn homeostatic_update(dt: f32) {
815    unsafe {
816        let alpha = dt / HOMEOSTATIC_TAU;
817
818        let mut i = 0u8;
819        while i < HNSW.count {
820            let idx = i as usize;
821
822            // Update running spike rate estimate
823            let instant_rate = if SPIKES[idx] { 1.0 / dt } else { 0.0 };
824            SPIKE_RATE[idx] = SPIKE_RATE[idx] * (1.0 - alpha) + instant_rate * alpha;
825
826            // Adjust threshold to approach target rate
827            let rate_error = SPIKE_RATE[idx] - HOMEOSTATIC_TARGET;
828            THRESHOLD[idx] += rate_error * alpha;
829
830            // Clamp threshold to reasonable range
831            if THRESHOLD[idx] < 0.1 { THRESHOLD[idx] = 0.1; }
832            if THRESHOLD[idx] > 10.0 { THRESHOLD[idx] = 10.0; }
833
834            i += 1;
835        }
836    }
837}
838
839/// Get current spike rate estimate
840#[no_mangle]
841pub extern "C" fn get_spike_rate(idx: u8) -> f32 {
842    if idx < MAX_VECTORS as u8 { unsafe { SPIKE_RATE[idx as usize] } } else { 0.0 }
843}
844
845// ============ Oscillatory Resonance ============
846// Novel: Gamma-rhythm synchronization for binding and search enhancement
847// Neurons tuned to oscillation phase get amplified
848
849/// Update oscillator phase
850#[no_mangle]
851pub extern "C" fn oscillator_step(dt: f32) {
852    unsafe {
853        // Phase advances with time: ω = 2πf
854        let omega = 6.28318 * OSCILLATOR_FREQ / 1000.0; // Convert Hz to rad/ms
855        OSCILLATOR_PHASE += omega * dt;
856        if OSCILLATOR_PHASE > 6.28318 { OSCILLATOR_PHASE -= 6.28318; }
857    }
858}
859
860/// Get current oscillator phase (0 to 2π)
861#[no_mangle]
862pub extern "C" fn oscillator_get_phase() -> f32 { unsafe { OSCILLATOR_PHASE } }
863
864/// Compute resonance boost for a neuron based on phase alignment
865/// Neurons in sync with gamma get amplified
866#[no_mangle]
867pub extern "C" fn compute_resonance(idx: u8) -> f32 {
868    unsafe {
869        if idx >= HNSW.count { return 0.0; }
870        let i = idx as usize;
871
872        // Each neuron has preferred phase based on its index
873        let preferred_phase = (idx as f32 / MAX_VECTORS as f32) * 6.28318;
874        let phase_diff = (OSCILLATOR_PHASE - preferred_phase).abs();
875        let min_diff = if phase_diff > 3.14159 { 6.28318 - phase_diff } else { phase_diff };
876
877        // Resonance is high when phase matches
878        RESONANCE[i] = 1.0 - min_diff / 3.14159;
879        RESONANCE[i]
880    }
881}
882
883/// Apply resonance-modulated search boost
884/// Query matches are enhanced when neuron is in favorable phase
885#[no_mangle]
886pub extern "C" fn resonance_search(k: u8, phase_weight: f32) -> u8 {
887    unsafe {
888        let found = search(k);
889
890        // Modulate results by resonance
891        let mut i = 0u8;
892        while i < found {
893            let idx = RESULTS[i as usize].idx;
894            if idx != 255 {
895                let res = compute_resonance(idx);
896                // Lower distance = better, so multiply by (2 - resonance)
897                RESULTS[i as usize].distance *= 2.0 - res * phase_weight;
898            }
899            i += 1;
900        }
901
902        // Re-sort results after resonance modulation
903        let mut i = 0usize;
904        while i < found as usize {
905            let mut j = i + 1;
906            while j < found as usize {
907                if RESULTS[j].distance < RESULTS[i].distance {
908                    let tmp = RESULTS[i];
909                    RESULTS[i] = RESULTS[j];
910                    RESULTS[j] = tmp;
911                }
912                j += 1;
913            }
914            i += 1;
915        }
916        found
917    }
918}
919
920// ============ Winner-Take-All Circuits ============
921// Novel: Competitive selection via lateral inhibition
922// Only the most active neuron wins, enabling hard decisions
923
924/// Reset WTA state
925#[no_mangle]
926pub extern "C" fn wta_reset() { unsafe { WTA_INHIBIT = 0.0; } }
927
928/// Run WTA competition: only highest membrane potential survives
929/// Returns winner index (or 255 if no winner)
930#[no_mangle]
931pub extern "C" fn wta_compete() -> u8 {
932    unsafe {
933        let mut max_v = 0.0f32;
934        let mut winner = 255u8;
935
936        let mut i = 0u8;
937        while i < HNSW.count {
938            let v = MEMBRANE[i as usize];
939            if v > max_v && REFRAC[i as usize] <= 0.0 {
940                max_v = v;
941                winner = i;
942            }
943            i += 1;
944        }
945
946        // Apply lateral inhibition to all losers
947        if winner != 255 {
948            WTA_INHIBIT = max_v * WTA_INHIBITION;
949            i = 0;
950            while i < HNSW.count {
951                if i != winner {
952                    MEMBRANE[i as usize] -= WTA_INHIBIT;
953                    if MEMBRANE[i as usize] < V_RESET {
954                        MEMBRANE[i as usize] = V_RESET;
955                    }
956                }
957                i += 1;
958            }
959        }
960        winner
961    }
962}
963
964/// Soft WTA: proportional inhibition based on rank
965#[no_mangle]
966pub extern "C" fn wta_soft() {
967    unsafe {
968        // Find max membrane potential
969        let mut max_v = 0.0f32;
970        let mut i = 0u8;
971        while i < HNSW.count {
972            if MEMBRANE[i as usize] > max_v { max_v = MEMBRANE[i as usize]; }
973            i += 1;
974        }
975        if max_v <= 0.0 { return; }
976
977        // Normalize and apply softmax-like competition
978        i = 0;
979        while i < HNSW.count {
980            let ratio = MEMBRANE[i as usize] / max_v;
981            // Exponential competition: low ratios get strongly suppressed
982            let survival = ratio * ratio; // Square for sharper competition
983            MEMBRANE[i as usize] *= survival;
984            i += 1;
985        }
986    }
987}
988
989// ============ Dendritic Computation ============
990// Novel: Nonlinear integration in dendritic compartments
991// Enables local coincidence detection before soma integration
992
993/// Reset dendritic compartments
994#[no_mangle]
995pub extern "C" fn dendrite_reset() {
996    unsafe {
997        let mut i = 0;
998        while i < MAX_VECTORS {
999            let mut j = 0;
1000            while j < MAX_NEIGHBORS { DENDRITE[i][j] = 0.0; j += 1; }
1001            i += 1;
1002        }
1003    }
1004}
1005
1006/// Inject input to specific dendritic compartment
1007#[no_mangle]
1008pub extern "C" fn dendrite_inject(neuron: u8, branch: u8, current: f32) {
1009    unsafe {
1010        if neuron < MAX_VECTORS as u8 && branch < MAX_NEIGHBORS as u8 {
1011            DENDRITE[neuron as usize][branch as usize] += current;
1012        }
1013    }
1014}
1015
1016/// Dendritic integration with nonlinearity
1017/// Multiple coincident inputs on same branch get amplified
1018#[no_mangle]
1019pub extern "C" fn dendrite_integrate(neuron: u8) -> f32 {
1020    unsafe {
1021        if neuron >= HNSW.count { return 0.0; }
1022        let idx = neuron as usize;
1023        let nc = HNSW.nodes[idx].count as usize;
1024
1025        let mut total = 0.0f32;
1026        let mut branch = 0;
1027        while branch < nc {
1028            let d = DENDRITE[idx][branch];
1029            // Nonlinear: small inputs are linear, large inputs saturate with boost
1030            if d > 0.0 {
1031                // Sigmoidal nonlinearity with supralinear boost
1032                let nonlin = if d < 1.0 {
1033                    d
1034                } else {
1035                    1.0 + (d - 1.0) / (1.0 + (d - 1.0) / DENDRITIC_NONLIN)
1036                };
1037                total += nonlin;
1038            }
1039            branch += 1;
1040        }
1041
1042        // Transfer to soma
1043        MEMBRANE[idx] += total;
1044        total
1045    }
1046}
1047
1048/// Propagate spikes through dendritic tree (not just soma)
1049#[no_mangle]
1050pub extern "C" fn dendrite_propagate(gain: f32) {
1051    unsafe {
1052        let mut i = 0u8;
1053        while i < HNSW.count {
1054            if !SPIKES[i as usize] { i += 1; continue; }
1055
1056            // This neuron spiked, inject to neighbor dendrites
1057            let nc = HNSW.nodes[i as usize].count;
1058            let mut j = 0u8;
1059            while j < nc {
1060                let nb = HNSW.nodes[i as usize].neighbors[j as usize];
1061                let w = EDGE_WEIGHTS[i as usize] as f32 / 255.0;
1062
1063                // Find which dendrite branch this connection is on
1064                let mut branch = 0u8;
1065                let nb_nc = HNSW.nodes[nb as usize].count;
1066                while branch < nb_nc {
1067                    if HNSW.nodes[nb as usize].neighbors[branch as usize] == i {
1068                        break;
1069                    }
1070                    branch += 1;
1071                }
1072
1073                if branch < MAX_NEIGHBORS as u8 {
1074                    DENDRITE[nb as usize][branch as usize] += gain * w;
1075                }
1076                j += 1;
1077            }
1078            i += 1;
1079        }
1080    }
1081}
1082
1083// ============ Temporal Pattern Recognition ============
1084// Novel: Store and match spike pattern sequences
1085// Enables recognition of dynamic temporal signatures
1086
1087/// Record current spike state into pattern buffer (shift register)
1088#[no_mangle]
1089pub extern "C" fn pattern_record() {
1090    unsafe {
1091        let mut i = 0;
1092        while i < MAX_VECTORS {
1093            // Shift pattern left and add new spike
1094            SPIKE_PATTERN[i] = (SPIKE_PATTERN[i] << 1) | (SPIKES[i] as u32);
1095            i += 1;
1096        }
1097    }
1098}
1099
1100/// Get temporal spike pattern for a neuron
1101#[no_mangle]
1102pub extern "C" fn get_pattern(idx: u8) -> u32 {
1103    if idx < MAX_VECTORS as u8 { unsafe { SPIKE_PATTERN[idx as usize] } } else { 0 }
1104}
1105
1106/// Match pattern against stored patterns (Hamming similarity)
1107/// Returns best matching neuron index
1108#[no_mangle]
1109pub extern "C" fn pattern_match(target: u32) -> u8 {
1110    unsafe {
1111        let mut best_idx = 255u8;
1112        let mut best_sim = 0u32;
1113
1114        let mut i = 0u8;
1115        while i < HNSW.count {
1116            // XOR gives difference, NOT gives similarity bits
1117            let diff = target ^ SPIKE_PATTERN[i as usize];
1118            let sim = (!diff).count_ones();
1119            if sim > best_sim {
1120                best_sim = sim;
1121                best_idx = i;
1122            }
1123            i += 1;
1124        }
1125        best_idx
1126    }
1127}
1128
1129/// Temporal correlation: find neurons with similar spike history
1130#[no_mangle]
1131pub extern "C" fn pattern_correlate(idx: u8, threshold: u8) -> u32 {
1132    unsafe {
1133        if idx >= HNSW.count { return 0; }
1134        let target = SPIKE_PATTERN[idx as usize];
1135        let mut correlated = 0u32;
1136
1137        let mut i = 0u8;
1138        while i < HNSW.count {
1139            if i != idx {
1140                let diff = target ^ SPIKE_PATTERN[i as usize];
1141                let dist = diff.count_ones() as u8;
1142                if dist <= threshold && i < 32 {
1143                    correlated |= 1u32 << i;
1144                }
1145            }
1146            i += 1;
1147        }
1148        correlated
1149    }
1150}
1151
1152// ============ Combined Neuromorphic Search ============
1153// Novel: Unified search combining all mechanisms
1154
1155/// Advanced neuromorphic search with all novel features
1156/// Combines: HNSW graph, spike timing, oscillation, WTA
1157#[no_mangle]
1158pub extern "C" fn neuromorphic_search(k: u8, dt: f32, iterations: u8) -> u8 {
1159    unsafe {
1160        if HNSW.count == 0 { return 0; }
1161
1162        // Reset neural state
1163        snn_reset();
1164        dendrite_reset();
1165        wta_reset();
1166
1167        // Convert query to spike pattern
1168        let n = HNSW.dims as usize;
1169        let qn = norm(&QUERY[..n], n);
1170
1171        // Initialize membrane potentials from vector distances
1172        let mut i = 0u8;
1173        while i < HNSW.count {
1174            let d = distance(&QUERY[..n], qn, i);
1175            // Inverse distance = initial activation
1176            MEMBRANE[i as usize] = 1.0 / (1.0 + d);
1177            i += 1;
1178        }
1179
1180        // Run neuromorphic dynamics
1181        let mut iter = 0u8;
1182        while iter < iterations {
1183            oscillator_step(dt);
1184
1185            // Dendritic integration
1186            i = 0;
1187            while i < HNSW.count {
1188                dendrite_integrate(i);
1189                i += 1;
1190            }
1191
1192            // Neural step with spike propagation
1193            snn_step(dt);
1194            dendrite_propagate(0.5);
1195
1196            // WTA competition for sharpening
1197            wta_soft();
1198
1199            // Record spike patterns
1200            pattern_record();
1201
1202            // Homeostatic regulation
1203            homeostatic_update(dt);
1204
1205            iter += 1;
1206        }
1207
1208        // Collect results based on final spike patterns and resonance
1209        let mut i = 0;
1210        while i < 16 {
1211            RESULTS[i] = SearchResult { idx: 255, core_id: HNSW.core_id, distance: f32::MAX };
1212            i += 1;
1213        }
1214
1215        let mut found = 0u8;
1216        i = 0;
1217        while i < HNSW.count as usize {
1218            // Score = spike count + resonance + membrane potential
1219            let spikes = SPIKE_PATTERN[i].count_ones() as f32;
1220            let res = RESONANCE[i];
1221            let vm = MEMBRANE[i];
1222            let score = -(spikes * 10.0 + res * 5.0 + vm);  // Negative for sorting
1223
1224            if found < k || score < RESULTS[(found - 1) as usize].distance {
1225                let mut p = found.min(k - 1) as usize;
1226                while p > 0 && RESULTS[p - 1].distance > score {
1227                    if p < 16 { RESULTS[p] = RESULTS[p - 1]; }
1228                    p -= 1;
1229                }
1230                if p < 16 {
1231                    RESULTS[p] = SearchResult {
1232                        idx: i as u8,
1233                        core_id: HNSW.core_id,
1234                        distance: score
1235                    };
1236                    if found < k { found += 1; }
1237                }
1238            }
1239            i += 1;
1240        }
1241        found
1242    }
1243}
1244
1245/// Get total network activity (sum of spike rates)
1246#[no_mangle]
1247pub extern "C" fn get_network_activity() -> f32 {
1248    unsafe {
1249        let mut total = 0.0f32;
1250        let mut i = 0;
1251        while i < MAX_VECTORS {
1252            total += SPIKE_RATE[i];
1253            i += 1;
1254        }
1255        total
1256    }
1257}
1258
1259#[cfg(not(test))]
1260#[panic_handler]
1261fn panic(_: &core::panic::PanicInfo) -> ! { loop {} }