1use crate::error::{M1ndError, M1ndResult};
4use crate::graph::Graph;
5use crate::types::*;
6
7pub const DEFAULT_LEARNING_RATE: f32 = 0.08;
12pub const DEFAULT_DECAY_RATE: f32 = 0.005;
13pub const LTP_THRESHOLD: u16 = 5;
14pub const LTD_THRESHOLD: u16 = 5;
15pub const LTP_BONUS: f32 = 0.15;
16pub const LTD_PENALTY: f32 = 0.15;
17pub const HOMEOSTATIC_CEILING: f32 = 5.0;
18pub const WEIGHT_FLOOR: f32 = 0.05;
19pub const WEIGHT_CAP: f32 = 3.0;
20pub const DEFAULT_MEMORY_CAPACITY: usize = 1000;
22pub const CAS_RETRY_LIMIT: u32 = 64;
24
25#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
33pub struct SynapticState {
34 pub source_label: String,
35 pub target_label: String,
36 pub relation: String,
37 pub original_weight: f32,
38 pub current_weight: f32,
39 pub strengthen_count: u16,
40 pub weaken_count: u16,
41 pub ltp_applied: bool,
42 pub ltd_applied: bool,
43}
44
45#[derive(Clone, Debug)]
53pub struct QueryRecord {
54 pub query_text: String,
55 pub seeds: Vec<NodeId>,
56 pub activated_nodes: Vec<NodeId>,
57 pub timestamp: f64,
58}
59
60pub struct QueryMemory {
70 records: Vec<Option<QueryRecord>>,
71 capacity: usize,
72 write_head: usize,
73 node_frequency: Vec<u32>,
75 seed_bigrams: std::collections::HashMap<(NodeId, NodeId), u32>,
77}
78
79impl QueryMemory {
80 pub fn new(capacity: usize, num_nodes: u32) -> Self {
81 Self {
82 records: vec![None; capacity],
83 capacity,
84 write_head: 0,
85 node_frequency: vec![0; num_nodes as usize],
86 seed_bigrams: std::collections::HashMap::new(),
87 }
88 }
89
90 pub fn record(&mut self, record: QueryRecord) {
93 if let Some(old) = &self.records[self.write_head] {
95 for &node in &old.activated_nodes {
96 let idx = node.as_usize();
97 if idx < self.node_frequency.len() {
98 self.node_frequency[idx] = self.node_frequency[idx].saturating_sub(1);
99 }
100 }
101 for i in 0..old.seeds.len() {
103 for j in (i + 1)..old.seeds.len() {
104 let key = if old.seeds[i] < old.seeds[j] {
105 (old.seeds[i], old.seeds[j])
106 } else {
107 (old.seeds[j], old.seeds[i])
108 };
109 if let Some(count) = self.seed_bigrams.get_mut(&key) {
110 *count = count.saturating_sub(1);
111 }
112 }
113 }
114 }
115
116 for &node in &record.activated_nodes {
118 let idx = node.as_usize();
119 if idx < self.node_frequency.len() {
120 self.node_frequency[idx] += 1;
121 }
122 }
123
124 for i in 0..record.seeds.len() {
126 for j in (i + 1)..record.seeds.len() {
127 let key = if record.seeds[i] < record.seeds[j] {
128 (record.seeds[i], record.seeds[j])
129 } else {
130 (record.seeds[j], record.seeds[i])
131 };
132 *self.seed_bigrams.entry(key).or_insert(0) += 1;
133 }
134 }
135
136 self.records[self.write_head] = Some(record);
137 self.write_head = (self.write_head + 1) % self.capacity;
138 }
139
140 pub fn get_priming_signal(
143 &self,
144 seeds: &[NodeId],
145 boost_strength: FiniteF32,
146 ) -> Vec<(NodeId, FiniteF32)> {
147 if seeds.is_empty() {
148 return Vec::new();
149 }
150
151 let mut node_scores: std::collections::HashMap<u32, f32> = std::collections::HashMap::new();
153
154 for record in self.records.iter().flatten() {
155 let shared = seeds.iter().any(|s| record.seeds.contains(s));
157 if !shared {
158 continue;
159 }
160
161 for &node in &record.activated_nodes {
162 if !seeds.contains(&node) {
163 *node_scores.entry(node.0).or_insert(0.0) += 1.0;
164 }
165 }
166 }
167
168 let max_score = node_scores.values().cloned().fold(0.0f32, f32::max);
170 if max_score <= 0.0 {
171 return Vec::new();
172 }
173
174 let mut results: Vec<(NodeId, FiniteF32)> = node_scores
175 .into_iter()
176 .map(|(id, score)| {
177 let normalized = (score / max_score) * boost_strength.get();
178 (NodeId::new(id), FiniteF32::new(normalized.min(1.0)))
179 })
180 .filter(|(_, s)| s.get() > 0.01)
181 .collect();
182
183 results.sort_by(|a, b| b.1.cmp(&a.1));
184 results.truncate(50); results
186 }
187
188 pub fn len(&self) -> usize {
190 self.records.iter().filter(|r| r.is_some()).count()
191 }
192
193 pub fn is_empty(&self) -> bool {
194 self.len() == 0
195 }
196}
197
198pub struct PlasticityConfig {
205 pub learning_rate: LearningRate,
206 pub decay_rate: PosF32,
207 pub ltp_threshold: u16,
208 pub ltd_threshold: u16,
209 pub ltp_bonus: FiniteF32,
210 pub ltd_penalty: FiniteF32,
211 pub homeostatic_ceiling: FiniteF32,
212 pub weight_floor: FiniteF32,
213 pub weight_cap: FiniteF32,
214 pub memory_capacity: usize,
215 pub cas_retry_limit: u32,
216}
217
218impl Default for PlasticityConfig {
219 fn default() -> Self {
220 Self {
221 learning_rate: LearningRate::DEFAULT,
222 decay_rate: PosF32::new(DEFAULT_DECAY_RATE).unwrap(),
223 ltp_threshold: LTP_THRESHOLD,
224 ltd_threshold: LTD_THRESHOLD,
225 ltp_bonus: FiniteF32::new(LTP_BONUS),
226 ltd_penalty: FiniteF32::new(LTD_PENALTY),
227 homeostatic_ceiling: FiniteF32::new(HOMEOSTATIC_CEILING),
228 weight_floor: FiniteF32::new(WEIGHT_FLOOR),
229 weight_cap: FiniteF32::new(WEIGHT_CAP),
230 memory_capacity: DEFAULT_MEMORY_CAPACITY,
231 cas_retry_limit: CAS_RETRY_LIMIT,
232 }
233 }
234}
235
236#[derive(Clone, Debug)]
242pub struct PlasticityResult {
243 pub edges_strengthened: u32,
244 pub edges_decayed: u32,
245 pub ltp_events: u32,
246 pub ltd_events: u32,
247 pub homeostatic_rescales: u32,
248 pub priming_nodes: u32,
249}
250
251pub struct PlasticityEngine {
261 config: PlasticityConfig,
262 memory: QueryMemory,
263 expected_generation: Generation,
265 query_count: u32,
267}
268
269impl PlasticityEngine {
270 pub fn new(graph: &Graph, config: PlasticityConfig) -> Self {
273 Self {
274 memory: QueryMemory::new(config.memory_capacity, graph.num_nodes()),
275 expected_generation: graph.generation,
276 query_count: 0,
277 config,
278 }
279 }
280
281 fn check_generation(&self, graph: &Graph) -> M1ndResult<()> {
283 if self.expected_generation != graph.generation {
284 return Err(M1ndError::GraphGenerationMismatch {
285 expected: self.expected_generation,
286 actual: graph.generation,
287 });
288 }
289 Ok(())
290 }
291
292 pub fn update(
297 &mut self,
298 graph: &mut Graph,
299 activated_nodes: &[(NodeId, FiniteF32)],
300 seeds: &[(NodeId, FiniteF32)],
301 query_text: &str,
302 ) -> M1ndResult<PlasticityResult> {
303 self.query_count += 1;
307
308 let n = graph.num_nodes() as usize;
310 let mut activated_set = vec![false; n];
311 let mut act_map = std::collections::HashMap::new();
312 for &(node, score) in activated_nodes {
313 let idx = node.as_usize();
314 if idx < n {
315 activated_set[idx] = true;
316 act_map.insert(node.0, score.get());
317 }
318 }
319
320 let edges_strengthened = self.hebbian_strengthen(graph, activated_nodes)?;
322
323 let edges_decayed = self.synaptic_decay(graph, &activated_set)?;
325
326 let (ltp_events, ltd_events) = self.apply_ltp_ltd(graph)?;
328
329 let homeostatic_rescales = self.homeostatic_normalize(graph)?;
331
332 let record = QueryRecord {
334 query_text: query_text.to_string(),
335 seeds: seeds.iter().map(|s| s.0).collect(),
336 activated_nodes: activated_nodes.iter().map(|a| a.0).collect(),
337 timestamp: std::time::SystemTime::now()
338 .duration_since(std::time::UNIX_EPOCH)
339 .map(|d| d.as_secs_f64())
340 .unwrap_or(0.0),
341 };
342 self.memory.record(record);
343
344 let priming_nodes = self
345 .memory
346 .get_priming_signal(
347 &seeds.iter().map(|s| s.0).collect::<Vec<_>>(),
348 FiniteF32::new(0.1),
349 )
350 .len() as u32;
351
352 Ok(PlasticityResult {
353 edges_strengthened,
354 edges_decayed,
355 ltp_events,
356 ltd_events,
357 homeostatic_rescales,
358 priming_nodes,
359 })
360 }
361
362 fn hebbian_strengthen(
365 &self,
366 graph: &mut Graph,
367 activated: &[(NodeId, FiniteF32)],
368 ) -> M1ndResult<u32> {
369 let n = graph.num_nodes() as usize;
370 let lr = self.config.learning_rate.get();
371 let cap = self.config.weight_cap.get();
372 let mut count = 0u32;
373
374 let mut act_val = vec![0.0f32; n];
376 for &(node, score) in activated {
377 let idx = node.as_usize();
378 if idx < n {
379 act_val[idx] = score.get();
380 }
381 }
382
383 for &(src, src_act) in activated {
385 let range = graph.csr.out_range(src);
386 for j in range {
387 let tgt = graph.csr.targets[j];
388 let tgt_idx = tgt.as_usize();
389 if tgt_idx >= n {
390 continue;
391 }
392 let tgt_act = act_val[tgt_idx];
393 if tgt_act <= 0.0 {
394 continue;
395 }
396
397 let delta = lr * src_act.get() * tgt_act;
399 let edge_idx = EdgeIdx::new(j as u32);
400 let current = graph.csr.read_weight(edge_idx).get();
401 let new_weight = (current + delta).min(cap);
402
403 let _ = graph.csr.atomic_write_weight(
404 edge_idx,
405 FiniteF32::new(new_weight),
406 self.config.cas_retry_limit,
407 );
408
409 if j < graph.edge_plasticity.strengthen_count.len() {
411 graph.edge_plasticity.strengthen_count[j] =
412 graph.edge_plasticity.strengthen_count[j].saturating_add(1);
413 graph.edge_plasticity.current_weight[j] = FiniteF32::new(new_weight);
414 graph.edge_plasticity.last_used_query[j] = self.query_count;
415 }
416
417 count += 1;
418 }
419 }
420
421 Ok(count)
422 }
423
424 fn synaptic_decay(&self, graph: &mut Graph, activated_set: &[bool]) -> M1ndResult<u32> {
427 let n = graph.num_nodes() as usize;
428 let decay_factor = 1.0 - self.config.decay_rate.get();
429 let floor = self.config.weight_floor.get();
430 let mut count = 0u32;
431
432 for i in 0..n {
433 if activated_set[i] {
434 continue; }
436
437 let range = graph.csr.out_range(NodeId::new(i as u32));
438 for j in range {
439 let edge_idx = EdgeIdx::new(j as u32);
440 let current = graph.csr.read_weight(edge_idx).get();
441 let new_weight = (current * decay_factor).max(floor);
442
443 if (new_weight - current).abs() > 1e-6 {
444 let _ = graph.csr.atomic_write_weight(
445 edge_idx,
446 FiniteF32::new(new_weight),
447 self.config.cas_retry_limit,
448 );
449
450 if j < graph.edge_plasticity.weaken_count.len() {
451 graph.edge_plasticity.weaken_count[j] =
452 graph.edge_plasticity.weaken_count[j].saturating_add(1);
453 graph.edge_plasticity.current_weight[j] = FiniteF32::new(new_weight);
454 }
455
456 count += 1;
457 }
458 }
459 }
460
461 Ok(count)
462 }
463
464 fn apply_ltp_ltd(&self, graph: &mut Graph) -> M1ndResult<(u32, u32)> {
467 let cap = self.config.weight_cap.get();
468 let floor = self.config.weight_floor.get();
469 let mut ltp_count = 0u32;
470 let mut ltd_count = 0u32;
471
472 let num_edges = graph.edge_plasticity.strengthen_count.len();
473 for j in 0..num_edges {
474 if !graph.edge_plasticity.ltp_applied[j]
476 && graph.edge_plasticity.strengthen_count[j] >= self.config.ltp_threshold
477 {
478 let edge_idx = EdgeIdx::new(j as u32);
479 let current = graph.csr.read_weight(edge_idx).get();
480 let new_weight = (current + self.config.ltp_bonus.get()).min(cap);
481 let _ = graph.csr.atomic_write_weight(
482 edge_idx,
483 FiniteF32::new(new_weight),
484 self.config.cas_retry_limit,
485 );
486 graph.edge_plasticity.ltp_applied[j] = true;
487 graph.edge_plasticity.current_weight[j] = FiniteF32::new(new_weight);
488 ltp_count += 1;
489 }
490
491 if !graph.edge_plasticity.ltd_applied[j]
493 && graph.edge_plasticity.weaken_count[j] >= self.config.ltd_threshold
494 {
495 let edge_idx = EdgeIdx::new(j as u32);
496 let current = graph.csr.read_weight(edge_idx).get();
497 let new_weight = (current - self.config.ltd_penalty.get()).max(floor);
498 let _ = graph.csr.atomic_write_weight(
499 edge_idx,
500 FiniteF32::new(new_weight),
501 self.config.cas_retry_limit,
502 );
503 graph.edge_plasticity.ltd_applied[j] = true;
504 graph.edge_plasticity.current_weight[j] = FiniteF32::new(new_weight);
505 ltd_count += 1;
506 }
507 }
508
509 Ok((ltp_count, ltd_count))
510 }
511
512 fn homeostatic_normalize(&self, graph: &mut Graph) -> M1ndResult<u32> {
516 let n = graph.num_nodes() as usize;
517 let ceiling = self.config.homeostatic_ceiling.get();
518 let mut rescale_count = 0u32;
519
520 for i in 0..n {
521 let range = graph.csr.in_range(NodeId::new(i as u32));
523 let mut total_incoming = 0.0f32;
524 for j in range.clone() {
525 let fwd_idx = graph.csr.rev_edge_idx[j];
526 total_incoming += graph.csr.read_weight(fwd_idx).get();
527 }
528
529 if total_incoming > ceiling {
530 let scale = ceiling / total_incoming;
532 for j in range {
533 let fwd_idx = graph.csr.rev_edge_idx[j];
534 let current = graph.csr.read_weight(fwd_idx).get();
535 let new_weight = current * scale;
536 let _ = graph.csr.atomic_write_weight(
537 fwd_idx,
538 FiniteF32::new(new_weight),
539 self.config.cas_retry_limit,
540 );
541 if fwd_idx.as_usize() < graph.edge_plasticity.current_weight.len() {
542 graph.edge_plasticity.current_weight[fwd_idx.as_usize()] =
543 FiniteF32::new(new_weight);
544 }
545 }
546 rescale_count += 1;
547 }
548 }
549
550 Ok(rescale_count)
551 }
552
553 pub fn export_state(&self, graph: &Graph) -> M1ndResult<Vec<SynapticState>> {
558 let n = graph.num_nodes() as usize;
559 let num_plasticity = graph.edge_plasticity.original_weight.len();
560 let num_csr = graph.csr.num_edges();
561
562 let mut node_ext_id = vec![String::new(); n];
564 for (&interned, &node_id) in &graph.id_to_node {
565 if let Some(s) = graph.strings.try_resolve(interned) {
566 if node_id.as_usize() < n {
567 node_ext_id[node_id.as_usize()] = s.to_string();
568 }
569 }
570 }
571
572 let mut edge_source = vec![0u32; num_csr];
574 for i in 0..n {
575 let lo = graph.csr.offsets[i] as usize;
576 let hi = graph.csr.offsets[i + 1] as usize;
577 for j in lo..hi {
578 edge_source[j] = i as u32;
579 }
580 }
581
582 let cap = num_plasticity.min(num_csr);
583 let mut states = Vec::with_capacity(cap);
584
585 for j in 0..cap {
586 let original = graph.edge_plasticity.original_weight[j].get();
587 let mut current = graph.edge_plasticity.current_weight[j].get();
588
589 if !current.is_finite() {
591 current = original;
592 }
593
594 let src_idx = edge_source[j] as usize;
596 let tgt_idx = graph.csr.targets[j].as_usize();
597 let source_label = if src_idx < n {
598 node_ext_id[src_idx].clone()
599 } else {
600 format!("node_{}", src_idx)
601 };
602 let target_label = if tgt_idx < n {
603 node_ext_id[tgt_idx].clone()
604 } else {
605 format!("node_{}", tgt_idx)
606 };
607 let relation = graph
608 .strings
609 .try_resolve(graph.csr.relations[j])
610 .unwrap_or("edge")
611 .to_string();
612
613 states.push(SynapticState {
614 source_label,
615 target_label,
616 relation,
617 original_weight: original,
618 current_weight: current,
619 strengthen_count: graph.edge_plasticity.strengthen_count[j],
620 weaken_count: graph.edge_plasticity.weaken_count[j],
621 ltp_applied: graph.edge_plasticity.ltp_applied[j],
622 ltd_applied: graph.edge_plasticity.ltd_applied[j],
623 });
624 }
625
626 Ok(states)
627 }
628
629 pub fn import_state(&mut self, graph: &mut Graph, states: &[SynapticState]) -> M1ndResult<u32> {
634 let n = graph.num_nodes() as usize;
635 let num_csr = graph.csr.num_edges();
636 let num_plasticity = graph.edge_plasticity.original_weight.len();
637
638 let mut node_ext_id = vec![String::new(); n];
640 for (&interned, &node_id) in &graph.id_to_node {
641 if let Some(s) = graph.strings.try_resolve(interned) {
642 if node_id.as_usize() < n {
643 node_ext_id[node_id.as_usize()] = s.to_string();
644 }
645 }
646 }
647
648 let mut edge_source = vec![0u32; num_csr];
650 for i in 0..n {
651 let lo = graph.csr.offsets[i] as usize;
652 let hi = graph.csr.offsets[i + 1] as usize;
653 for j in lo..hi {
654 edge_source[j] = i as u32;
655 }
656 }
657
658 use std::collections::HashMap;
660 let cap = num_plasticity.min(num_csr);
661 let mut triple_to_edge: HashMap<(&str, &str, &str), usize> = HashMap::with_capacity(cap);
662 for j in 0..cap {
663 let src_idx = edge_source[j] as usize;
664 let tgt_idx = graph.csr.targets[j].as_usize();
665 if src_idx < n && tgt_idx < n {
666 let rel = graph
667 .strings
668 .try_resolve(graph.csr.relations[j])
669 .unwrap_or("");
670 triple_to_edge.insert((&node_ext_id[src_idx], &node_ext_id[tgt_idx], rel), j);
671 }
672 }
673
674 let mut applied = 0u32;
675
676 for state in states {
677 let rel_str = state.relation.as_str();
679 let j = match triple_to_edge.get(&(
680 state.source_label.as_str(),
681 state.target_label.as_str(),
682 rel_str,
683 )) {
684 Some(&idx) => idx,
685 None => continue, };
687
688 let weight = if state.current_weight.is_finite() {
690 state.current_weight
691 } else {
692 state.original_weight
693 };
694
695 let clamped = weight
697 .max(self.config.weight_floor.get())
698 .min(self.config.weight_cap.get());
699
700 graph.edge_plasticity.current_weight[j] = FiniteF32::new(clamped);
701 graph.edge_plasticity.strengthen_count[j] = state.strengthen_count;
702 graph.edge_plasticity.weaken_count[j] = state.weaken_count;
703 graph.edge_plasticity.ltp_applied[j] = state.ltp_applied;
704 graph.edge_plasticity.ltd_applied[j] = state.ltd_applied;
705
706 let edge_idx = EdgeIdx::new(j as u32);
708 if j < graph.csr.weights.len() {
709 let _ = graph.csr.atomic_write_weight(
710 edge_idx,
711 FiniteF32::new(clamped),
712 self.config.cas_retry_limit,
713 );
714 }
715
716 applied += 1;
717 }
718
719 Ok(applied)
720 }
721
722 pub fn get_priming(
724 &self,
725 seeds: &[NodeId],
726 boost_strength: FiniteF32,
727 ) -> Vec<(NodeId, FiniteF32)> {
728 self.memory.get_priming_signal(seeds, boost_strength)
729 }
730}
731
732static_assertions::assert_impl_all!(PlasticityEngine: Send, Sync);