1use crate::{EmbeddingError, ModelConfig, Vector};
14use anyhow::Result;
15use scirs2_core::ndarray_ext::{s, Array1, Array2, Array3, Axis};
16use serde::{Deserialize, Serialize};
17use serde_json;
18use std::collections::HashMap;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct MambaConfig {
23 pub d_state: usize,
25 pub d_model: usize,
27 pub d_inner: usize,
29 pub d_conv: usize,
31 pub expand: usize,
33 pub dt_rank: usize,
35 pub dt_min: f64,
37 pub dt_max: f64,
39 pub dt_init: String,
41 pub dt_scale: f64,
43 pub dt_init_floor: f64,
45 pub bias: bool,
47 pub conv_bias: bool,
49 pub activation: ActivationType,
51 pub use_complex: bool,
53 pub num_heads: usize,
55}
56
57impl Default for MambaConfig {
58 fn default() -> Self {
59 Self {
60 d_state: 16,
61 d_model: 512,
62 d_inner: 1024,
63 d_conv: 4,
64 expand: 2,
65 dt_rank: 32,
66 dt_min: 0.001,
67 dt_max: 0.1,
68 dt_init: "random".to_string(),
69 dt_scale: 1.0,
70 dt_init_floor: 1e-4,
71 bias: false,
72 conv_bias: true,
73 activation: ActivationType::SiLU,
74 use_complex: false,
75 num_heads: 8,
76 }
77 }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub enum ActivationType {
83 SiLU,
84 GELU,
85 ReLU,
86 Swish,
87 Mish,
88}
89
90#[derive(Debug, Clone)]
92pub struct MambaBlock {
93 config: MambaConfig,
94 in_proj: Array2<f32>,
96 conv1d: Array2<f32>,
98 a_log: Array2<f32>,
100 d: Array1<f32>,
102 dt_proj: Array2<f32>,
104 out_proj: Array2<f32>,
106 norm: LayerNorm,
108 cached_states: Option<Array3<f32>>,
110}
111
112impl MambaBlock {
113 pub fn new(config: MambaConfig) -> Self {
115 let d_model = config.d_model;
116 let d_inner = config.d_inner;
117 let d_state = config.d_state;
118 let dt_rank = config.dt_rank;
119
120 let in_proj = Array2::zeros((d_model, d_inner * 2));
122 let conv1d = Array2::zeros((d_inner, config.d_conv));
123 let a_log = Array2::zeros((d_inner, d_state));
124 let d = Array1::ones(d_inner);
125 let dt_proj = Array2::zeros((dt_rank, d_inner));
126 let out_proj = Array2::zeros((d_inner, d_model));
127 let norm = LayerNorm::new(d_model);
128
129 Self {
130 config,
131 in_proj,
132 conv1d,
133 a_log,
134 d,
135 dt_proj,
136 out_proj,
137 norm,
138 cached_states: None,
139 }
140 }
141
142 pub fn forward(&mut self, x: &Array2<f32>) -> Result<Array2<f32>> {
144 let (_batch_size, _seq_len) = x.dim();
145
146 let x_norm = self.norm.forward(x)?;
148 let x_and_res = self.apply_projection(&x_norm)?;
149
150 let (x_main, x_res) = self.split_projection(&x_and_res)?;
152
153 let x_conv = self.apply_convolution(&x_main)?;
155
156 let y = self.selective_ssm(&x_conv, &x_res)?;
158
159 let output = self.apply_output_projection(&y)?;
161
162 Ok(output)
163 }
164
165 fn apply_projection(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
167 let result = x.dot(&self.in_proj);
169 Ok(result)
170 }
171
172 fn split_projection(&self, x: &Array2<f32>) -> Result<(Array2<f32>, Array2<f32>)> {
174 let (_, total_dim) = x.dim();
175 let split_point = total_dim / 2;
176
177 let x_main = x.slice(s![.., ..split_point]).to_owned();
178 let x_res = x.slice(s![.., split_point..]).to_owned();
179
180 Ok((x_main, x_res))
181 }
182
183 fn apply_convolution(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
185 let (batch_size, seq_len) = x.dim();
188 let mut result = Array2::zeros((batch_size, seq_len));
189
190 for i in 0..batch_size {
191 for j in 0..seq_len {
192 let start = j.saturating_sub(self.config.d_conv / 2);
193 let end = std::cmp::min(j + self.config.d_conv / 2 + 1, seq_len);
194
195 let mut conv_sum = 0.0;
196 let mut weight_idx = 0;
197
198 for k in start..end {
199 if weight_idx < self.conv1d.ncols() {
200 conv_sum += x[[i, k]] * self.conv1d[[0, weight_idx]];
201 weight_idx += 1;
202 }
203 }
204
205 result[[i, j]] = conv_sum;
206 }
207 }
208
209 Ok(result)
210 }
211
212 fn selective_ssm(&mut self, x: &Array2<f32>, z: &Array2<f32>) -> Result<Array2<f32>> {
214 let (batch_size, seq_len) = x.dim();
215 let d_state = self.config.d_state;
216 let _d_inner = self.config.d_inner;
217
218 let delta = self.compute_delta(x)?;
220
221 let a = self.compute_a_matrix(&delta)?;
223 let b = self.compute_b_matrix(x)?;
224
225 let mut h = Array2::zeros((batch_size, d_state));
227 let mut outputs = Array2::zeros((batch_size, seq_len));
228
229 for t in 0..seq_len {
231 let x_t = x.slice(s![.., t]).to_owned();
232 let a_t = a.slice(s![.., t, ..]).to_owned();
233 let b_t = b.slice(s![.., t]).to_owned();
234
235 h = &a_t.dot(&h.t()).t() + &(&b_t * &x_t);
237
238 let c = Array1::ones(d_state); let y_t = c.dot(&h.t()) + &self.d * &x_t;
241 outputs.slice_mut(s![.., t]).assign(&y_t);
242 }
243
244 let gated_output = &outputs * &self.apply_activation(z)?;
246
247 Ok(gated_output)
248 }
249
250 fn compute_delta(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
252 let (_batch_size, _seq_len) = x.dim();
253
254 let delta_proj = x.dot(&self.dt_proj.t());
256
257 let delta = delta_proj.mapv(|x| {
259 let exp_x = x.exp();
260 (1.0 + exp_x)
261 .ln()
262 .max(self.config.dt_min as f32)
263 .min(self.config.dt_max as f32)
264 });
265
266 Ok(delta)
267 }
268
269 fn compute_a_matrix(&self, delta: &Array2<f32>) -> Result<Array3<f32>> {
271 let (batch_size, seq_len) = delta.dim();
272 let d_state = self.config.d_state;
273
274 let mut a = Array3::zeros((batch_size, seq_len, d_state));
275
276 for i in 0..batch_size {
277 for j in 0..seq_len {
278 for k in 0..d_state {
279 a[[i, j, k]] = (delta[[i, j]] * self.a_log[[0, k]]).exp();
281 }
282 }
283 }
284
285 Ok(a)
286 }
287
288 fn compute_b_matrix(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
290 Ok(x.clone())
293 }
294
295 fn apply_activation(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
297 match self.config.activation {
298 ActivationType::SiLU => Ok(x.mapv(|x| x / (1.0 + (-x).exp()))),
299 ActivationType::GELU => Ok(x.mapv(|x| {
300 0.5 * x
301 * (1.0 + (std::f32::consts::FRAC_2_SQRT_PI * (x + 0.044715 * x.powi(3))).tanh())
302 })),
303 ActivationType::ReLU => Ok(x.mapv(|x| x.max(0.0))),
304 ActivationType::Swish => Ok(x.mapv(|x| x / (1.0 + (-x).exp()))),
305 ActivationType::Mish => Ok(x.mapv(|x| x * (1.0 + x.exp()).ln().tanh())),
306 }
307 }
308
309 fn apply_output_projection(&self, y: &Array2<f32>) -> Result<Array2<f32>> {
311 Ok(y.dot(&self.out_proj))
312 }
313}
314
315#[derive(Debug, Clone)]
317pub struct LayerNorm {
318 weight: Array1<f32>,
319 bias: Array1<f32>,
320 eps: f32,
321}
322
323impl LayerNorm {
324 pub fn new(d_model: usize) -> Self {
325 Self {
326 weight: Array1::ones(d_model),
327 bias: Array1::zeros(d_model),
328 eps: 1e-5,
329 }
330 }
331
332 pub fn forward(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
333 let mean = x
334 .mean_axis(Axis(1))
335 .expect("mean should succeed on valid axis");
336 let centered = x - &mean.insert_axis(Axis(1));
337 let variance = centered
338 .mapv(|x| x.powi(2))
339 .mean_axis(Axis(1))
340 .expect("mean should succeed on valid axis");
341 let std = variance.mapv(|x| (x + self.eps).sqrt());
342
343 let normalized = ¢ered / &std.insert_axis(Axis(1));
344 let result = &normalized * &self.weight + &self.bias;
345
346 Ok(result)
347 }
348}
349
350#[derive(Debug, Clone)]
352pub struct MambaEmbedding {
353 id: uuid::Uuid,
354 config: ModelConfig,
355 mamba_config: MambaConfig,
356 mamba_blocks: Vec<MambaBlock>,
357 entities: HashMap<String, usize>,
358 relations: HashMap<String, usize>,
359 entity_embeddings: Array2<f32>,
360 relation_embeddings: Array2<f32>,
361 is_trained: bool,
362 stats: crate::ModelStats,
363}
364
365impl MambaEmbedding {
366 pub fn new(config: ModelConfig, mamba_config: MambaConfig) -> Self {
368 let num_layers = 6; let mut mamba_blocks = Vec::new();
370
371 for _ in 0..num_layers {
372 mamba_blocks.push(MambaBlock::new(mamba_config.clone()));
373 }
374
375 Self {
376 id: uuid::Uuid::new_v4(),
377 config: config.clone(),
378 mamba_config,
379 mamba_blocks,
380 entities: HashMap::new(),
381 relations: HashMap::new(),
382 entity_embeddings: Array2::zeros((1, config.dimensions)),
383 relation_embeddings: Array2::zeros((1, config.dimensions)),
384 is_trained: false,
385 stats: crate::ModelStats {
386 model_type: "Mamba".to_string(),
387 dimensions: config.dimensions,
388 creation_time: chrono::Utc::now(),
389 ..Default::default()
390 },
391 }
392 }
393
394 pub fn process_sequence(&mut self, input: &Array2<f32>) -> Result<Array2<f32>> {
396 let mut x = input.clone();
397
398 for block in &mut self.mamba_blocks {
399 x = block.forward(&x)?;
400 }
401
402 Ok(x)
403 }
404
405 pub fn encode_kg_structure(&mut self, triples: &[crate::Triple]) -> Result<Array2<f32>> {
407 let sequence = self.triples_to_sequence(triples)?;
409
410 let encoded = self.process_sequence(&sequence)?;
412
413 Ok(encoded)
414 }
415
416 fn triples_to_sequence(&self, triples: &[crate::Triple]) -> Result<Array2<f32>> {
418 let seq_len = triples.len();
419 let _d_model = self.mamba_config.d_model;
420
421 let mut sequence = Array2::zeros((1, seq_len));
422
423 for (i, triple) in triples.iter().enumerate() {
425 let subj_idx = self.entities.get(&triple.subject.iri).unwrap_or(&0);
426 let pred_idx = self.relations.get(&triple.predicate.iri).unwrap_or(&0);
427 let obj_idx = self.entities.get(&triple.object.iri).unwrap_or(&0);
428
429 sequence[[0, i]] = (*subj_idx as f32 + *pred_idx as f32 + *obj_idx as f32) / 3.0;
431 }
432
433 Ok(sequence)
434 }
435
436 pub fn generate_selective_embedding(
438 &mut self,
439 entity: &str,
440 context: &[String],
441 ) -> Result<Vector> {
442 let context_sequence = self.create_context_sequence(entity, context)?;
444
445 let processed = self.process_sequence(&context_sequence)?;
447
448 let embedding = processed.slice(s![-1, ..]).to_owned();
450
451 Ok(Vector::new(embedding.to_vec()))
452 }
453
454 fn create_context_sequence(&self, entity: &str, context: &[String]) -> Result<Array2<f32>> {
456 let seq_len = context.len() + 1; let _d_model = self.mamba_config.d_model;
458
459 let mut sequence = Array2::zeros((1, seq_len));
460
461 if let Some(&entity_idx) = self.entities.get(entity) {
463 sequence[[0, 0]] = entity_idx as f32;
464 }
465
466 for (i, ctx) in context.iter().enumerate() {
468 if let Some(&ctx_idx) = self.entities.get(ctx) {
469 sequence[[0, i + 1]] = ctx_idx as f32;
470 }
471 }
472
473 Ok(sequence)
474 }
475}
476
477#[async_trait::async_trait]
478impl crate::EmbeddingModel for MambaEmbedding {
479 fn config(&self) -> &ModelConfig {
480 &self.config
481 }
482
483 fn model_id(&self) -> &uuid::Uuid {
484 &self.id
485 }
486
487 fn model_type(&self) -> &'static str {
488 "Mamba"
489 }
490
491 fn add_triple(&mut self, triple: crate::Triple) -> Result<()> {
492 let subj_id = self.entities.len();
494 let pred_id = self.relations.len();
495 let obj_id = self.entities.len() + 1;
496
497 self.entities.entry(triple.subject.iri).or_insert(subj_id);
498 self.relations
499 .entry(triple.predicate.iri)
500 .or_insert(pred_id);
501 self.entities.entry(triple.object.iri).or_insert(obj_id);
502
503 self.stats.num_triples += 1;
504 self.stats.num_entities = self.entities.len();
505 self.stats.num_relations = self.relations.len();
506
507 Ok(())
508 }
509
510 async fn train(&mut self, epochs: Option<usize>) -> Result<crate::TrainingStats> {
511 let max_epochs = epochs.unwrap_or(self.config.max_epochs);
512 let mut loss_history = Vec::new();
513 let start_time = std::time::Instant::now();
514
515 let num_entities = self.entities.len();
517 let num_relations = self.relations.len();
518
519 if num_entities > 0 && num_relations > 0 {
520 self.entity_embeddings = Array2::zeros((num_entities, self.config.dimensions));
521 self.relation_embeddings = Array2::zeros((num_relations, self.config.dimensions));
522
523 #[allow(unused_imports)]
525 use scirs2_core::random::{Random, Rng};
526 let mut rng = Random::default();
527
528 for i in 0..num_entities {
529 for j in 0..self.config.dimensions {
530 self.entity_embeddings[[i, j]] = rng.random_range(-0.1..0.1);
531 }
532 }
533
534 for i in 0..num_relations {
535 for j in 0..self.config.dimensions {
536 self.relation_embeddings[[i, j]] = rng.random_range(-0.1..0.1);
537 }
538 }
539 }
540
541 for epoch in 0..max_epochs {
543 let loss = 1.0 / (epoch as f64 + 1.0); loss_history.push(loss);
545
546 if loss < 0.01 {
547 break;
548 }
549 }
550
551 self.is_trained = true;
552 self.stats.is_trained = true;
553 self.stats.last_training_time = Some(chrono::Utc::now());
554
555 let training_time = start_time.elapsed().as_secs_f64();
556
557 Ok(crate::TrainingStats {
558 epochs_completed: max_epochs,
559 final_loss: loss_history.last().copied().unwrap_or(1.0),
560 training_time_seconds: training_time,
561 convergence_achieved: true,
562 loss_history,
563 })
564 }
565
566 fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
567 if !self.is_trained {
568 return Err(EmbeddingError::ModelNotTrained.into());
569 }
570
571 let entity_idx =
572 self.entities
573 .get(entity)
574 .ok_or_else(|| EmbeddingError::EntityNotFound {
575 entity: entity.to_string(),
576 })?;
577
578 let embedding = self.entity_embeddings.row(*entity_idx);
579 Ok(Vector::new(embedding.to_vec()))
580 }
581
582 fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
583 if !self.is_trained {
584 return Err(EmbeddingError::ModelNotTrained.into());
585 }
586
587 let relation_idx =
588 self.relations
589 .get(relation)
590 .ok_or_else(|| EmbeddingError::RelationNotFound {
591 relation: relation.to_string(),
592 })?;
593
594 let embedding = self.relation_embeddings.row(*relation_idx);
595 Ok(Vector::new(embedding.to_vec()))
596 }
597
598 fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
599 let s_emb = self.get_entity_embedding(subject)?;
600 let p_emb = self.get_relation_embedding(predicate)?;
601 let o_emb = self.get_entity_embedding(object)?;
602
603 let score = s_emb
605 .values
606 .iter()
607 .zip(p_emb.values.iter())
608 .zip(o_emb.values.iter())
609 .map(|((&s, &p), &o)| s * p * o)
610 .sum::<f32>() as f64;
611
612 Ok(score)
613 }
614
615 fn predict_objects(
616 &self,
617 subject: &str,
618 predicate: &str,
619 k: usize,
620 ) -> Result<Vec<(String, f64)>> {
621 let mut predictions = Vec::new();
622
623 for entity in self.entities.keys() {
624 if let Ok(score) = self.score_triple(subject, predicate, entity) {
625 predictions.push((entity.clone(), score));
626 }
627 }
628
629 predictions.sort_by(|a, b| {
630 b.1.partial_cmp(&a.1)
631 .expect("prediction scores should be comparable")
632 });
633 predictions.truncate(k);
634
635 Ok(predictions)
636 }
637
638 fn predict_subjects(
639 &self,
640 predicate: &str,
641 object: &str,
642 k: usize,
643 ) -> Result<Vec<(String, f64)>> {
644 let mut predictions = Vec::new();
645
646 for entity in self.entities.keys() {
647 if let Ok(score) = self.score_triple(entity, predicate, object) {
648 predictions.push((entity.clone(), score));
649 }
650 }
651
652 predictions.sort_by(|a, b| {
653 b.1.partial_cmp(&a.1)
654 .expect("prediction scores should be comparable")
655 });
656 predictions.truncate(k);
657
658 Ok(predictions)
659 }
660
661 fn predict_relations(
662 &self,
663 subject: &str,
664 object: &str,
665 k: usize,
666 ) -> Result<Vec<(String, f64)>> {
667 let mut predictions = Vec::new();
668
669 for relation in self.relations.keys() {
670 if let Ok(score) = self.score_triple(subject, relation, object) {
671 predictions.push((relation.clone(), score));
672 }
673 }
674
675 predictions.sort_by(|a, b| {
676 b.1.partial_cmp(&a.1)
677 .expect("prediction scores should be comparable")
678 });
679 predictions.truncate(k);
680
681 Ok(predictions)
682 }
683
684 fn get_entities(&self) -> Vec<String> {
685 self.entities.keys().cloned().collect()
686 }
687
688 fn get_relations(&self) -> Vec<String> {
689 self.relations.keys().cloned().collect()
690 }
691
692 fn get_stats(&self) -> crate::ModelStats {
693 self.stats.clone()
694 }
695
696 fn save(&self, path: &str) -> Result<()> {
697 use std::fs::File;
698 use std::io::Write;
699
700 let model_path = format!("{path}.mamba");
702 let metadata_path = format!("{path}.mamba.metadata.json");
703
704 let entity_data: std::collections::HashMap<String, usize> = self.entities.clone();
706 let relation_data: std::collections::HashMap<String, usize> = self.relations.clone();
707
708 let entity_embeddings_data = self
710 .entity_embeddings
711 .as_slice()
712 .expect("array should be contiguous")
713 .to_vec();
714 let relation_embeddings_data = self
715 .relation_embeddings
716 .as_slice()
717 .expect("array should be contiguous")
718 .to_vec();
719
720 let mamba_blocks_data = if let Some(first_block) = self.mamba_blocks.first() {
722 serde_json::json!({
723 "config": first_block.config,
724 "in_proj": first_block.in_proj.as_slice().expect("array should be contiguous").to_vec(),
725 "in_proj_shape": first_block.in_proj.shape(),
726 "conv1d": first_block.conv1d.as_slice().expect("array should be contiguous").to_vec(),
727 "conv1d_shape": first_block.conv1d.shape(),
728 "a_log": first_block.a_log.as_slice().expect("array should be contiguous").to_vec(),
729 "a_log_shape": first_block.a_log.shape(),
730 "d": first_block.d.as_slice().expect("array should be contiguous").to_vec(),
731 "d_shape": first_block.d.shape(),
732 "num_blocks": self.mamba_blocks.len(),
733 })
734 } else {
735 serde_json::Value::Null
736 };
737
738 let model_data = serde_json::json!({
739 "model_id": self.id,
740 "config": self.config,
741 "mamba_config": self.mamba_config,
742 "entity_data": entity_data,
743 "relation_data": relation_data,
744 "entity_embeddings": entity_embeddings_data,
745 "entity_embeddings_shape": self.entity_embeddings.shape(),
746 "relation_embeddings": relation_embeddings_data,
747 "relation_embeddings_shape": self.relation_embeddings.shape(),
748 "is_trained": self.is_trained,
749 "stats": self.stats,
750 "mamba_blocks": mamba_blocks_data,
751 "timestamp": chrono::Utc::now(),
752 "version": "1.0"
753 });
754
755 let mut file = File::create(&model_path)?;
757 let serialized = serde_json::to_string_pretty(&model_data)?;
758 file.write_all(serialized.as_bytes())?;
759
760 let metadata = serde_json::json!({
762 "model_type": "MambaEmbedding",
763 "model_id": self.id,
764 "dimensions": self.config.dimensions,
765 "num_entities": self.entities.len(),
766 "num_relations": self.relations.len(),
767 "is_trained": self.is_trained,
768 "created_at": chrono::Utc::now(),
769 "file_path": model_path
770 });
771
772 let mut metadata_file = File::create(&metadata_path)?;
773 let metadata_serialized = serde_json::to_string_pretty(&metadata)?;
774 metadata_file.write_all(metadata_serialized.as_bytes())?;
775
776 tracing::info!("Mamba model saved to {} and {}", model_path, metadata_path);
777 Ok(())
778 }
779
780 fn load(&mut self, path: &str) -> Result<()> {
781 use std::fs::File;
782 use std::io::Read;
783
784 let model_path = format!("{path}.mamba");
786
787 let mut file = File::open(&model_path)?;
789 let mut contents = String::new();
790 file.read_to_string(&mut contents)?;
791
792 let model_data: serde_json::Value = serde_json::from_str(&contents)?;
793
794 if let Some(version) = model_data.get("version").and_then(|v| v.as_str()) {
796 if version != "1.0" {
797 return Err(anyhow::anyhow!("Unsupported model version: {}", version));
798 }
799 }
800
801 if let Some(model_id) = model_data.get("model_id") {
803 self.id = serde_json::from_value(model_id.clone())?;
804 }
805
806 if let Some(config) = model_data.get("config") {
807 self.config = serde_json::from_value(config.clone())?;
808 }
809
810 if let Some(mamba_config) = model_data.get("mamba_config") {
811 self.mamba_config = serde_json::from_value(mamba_config.clone())?;
812 }
813
814 if let Some(is_trained) = model_data.get("is_trained") {
815 self.is_trained = serde_json::from_value(is_trained.clone())?;
816 }
817
818 if let Some(stats) = model_data.get("stats") {
819 self.stats = serde_json::from_value(stats.clone())?;
820 }
821
822 if let Some(entity_data) = model_data.get("entity_data") {
824 self.entities = serde_json::from_value(entity_data.clone())?;
825 }
826
827 if let Some(relation_data) = model_data.get("relation_data") {
829 self.relations = serde_json::from_value(relation_data.clone())?;
830 }
831
832 if let (Some(embeddings_data), Some(embeddings_shape)) = (
834 model_data
835 .get("entity_embeddings")
836 .and_then(|v| v.as_array()),
837 model_data
838 .get("entity_embeddings_shape")
839 .and_then(|v| v.as_array()),
840 ) {
841 let values: Vec<f32> = embeddings_data
842 .iter()
843 .filter_map(|v| v.as_f64().map(|f| f as f32))
844 .collect();
845 let shape: Vec<usize> = embeddings_shape
846 .iter()
847 .filter_map(|v| v.as_u64().map(|u| u as usize))
848 .collect();
849 if shape.len() == 2 {
850 self.entity_embeddings = Array2::from_shape_vec((shape[0], shape[1]), values)
851 .map_err(|e| anyhow::anyhow!("Failed to reshape entity_embeddings: {}", e))?;
852 }
853 }
854
855 if let (Some(embeddings_data), Some(embeddings_shape)) = (
857 model_data
858 .get("relation_embeddings")
859 .and_then(|v| v.as_array()),
860 model_data
861 .get("relation_embeddings_shape")
862 .and_then(|v| v.as_array()),
863 ) {
864 let values: Vec<f32> = embeddings_data
865 .iter()
866 .filter_map(|v| v.as_f64().map(|f| f as f32))
867 .collect();
868 let shape: Vec<usize> = embeddings_shape
869 .iter()
870 .filter_map(|v| v.as_u64().map(|u| u as usize))
871 .collect();
872 if shape.len() == 2 {
873 self.relation_embeddings = Array2::from_shape_vec((shape[0], shape[1]), values)
874 .map_err(|e| anyhow::anyhow!("Failed to reshape relation_embeddings: {}", e))?;
875 }
876 }
877
878 if let Some(mamba_blocks_data) = model_data.get("mamba_blocks") {
880 if !mamba_blocks_data.is_null() {
881 let num_blocks = mamba_blocks_data
883 .get("num_blocks")
884 .and_then(|v| v.as_u64())
885 .unwrap_or(self.mamba_blocks.len() as u64)
886 as usize;
887
888 self.mamba_blocks.clear();
890 for _ in 0..num_blocks {
891 self.mamba_blocks
892 .push(MambaBlock::new(self.mamba_config.clone()));
893 }
894
895 if let Some(first_block) = self.mamba_blocks.first_mut() {
897 if let (Some(in_proj_data), Some(in_proj_shape)) = (
899 mamba_blocks_data.get("in_proj").and_then(|v| v.as_array()),
900 mamba_blocks_data
901 .get("in_proj_shape")
902 .and_then(|v| v.as_array()),
903 ) {
904 let values: Vec<f32> = in_proj_data
905 .iter()
906 .filter_map(|v| v.as_f64().map(|f| f as f32))
907 .collect();
908 let shape: Vec<usize> = in_proj_shape
909 .iter()
910 .filter_map(|v| v.as_u64().map(|u| u as usize))
911 .collect();
912 if shape.len() == 2 {
913 first_block.in_proj =
914 Array2::from_shape_vec((shape[0], shape[1]), values).map_err(
915 |e| anyhow::anyhow!("Failed to reshape in_proj: {}", e),
916 )?;
917 }
918 }
919
920 if let (Some(conv1d_data), Some(conv1d_shape)) = (
922 mamba_blocks_data.get("conv1d").and_then(|v| v.as_array()),
923 mamba_blocks_data
924 .get("conv1d_shape")
925 .and_then(|v| v.as_array()),
926 ) {
927 let values: Vec<f32> = conv1d_data
928 .iter()
929 .filter_map(|v| v.as_f64().map(|f| f as f32))
930 .collect();
931 let shape: Vec<usize> = conv1d_shape
932 .iter()
933 .filter_map(|v| v.as_u64().map(|u| u as usize))
934 .collect();
935 if shape.len() == 2 {
936 first_block.conv1d =
937 Array2::from_shape_vec((shape[0], shape[1]), values).map_err(
938 |e| anyhow::anyhow!("Failed to reshape conv1d: {}", e),
939 )?;
940 }
941 }
942
943 if let (Some(a_log_data), Some(a_log_shape)) = (
945 mamba_blocks_data.get("a_log").and_then(|v| v.as_array()),
946 mamba_blocks_data
947 .get("a_log_shape")
948 .and_then(|v| v.as_array()),
949 ) {
950 let values: Vec<f32> = a_log_data
951 .iter()
952 .filter_map(|v| v.as_f64().map(|f| f as f32))
953 .collect();
954 let shape: Vec<usize> = a_log_shape
955 .iter()
956 .filter_map(|v| v.as_u64().map(|u| u as usize))
957 .collect();
958 if shape.len() == 2 {
959 first_block.a_log =
960 Array2::from_shape_vec((shape[0], shape[1]), values).map_err(
961 |e| anyhow::anyhow!("Failed to reshape a_log: {}", e),
962 )?;
963 }
964 }
965
966 if let (Some(d_data), Some(d_shape)) = (
968 mamba_blocks_data.get("d").and_then(|v| v.as_array()),
969 mamba_blocks_data.get("d_shape").and_then(|v| v.as_array()),
970 ) {
971 let values: Vec<f32> = d_data
972 .iter()
973 .filter_map(|v| v.as_f64().map(|f| f as f32))
974 .collect();
975 let shape: Vec<usize> = d_shape
976 .iter()
977 .filter_map(|v| v.as_u64().map(|u| u as usize))
978 .collect();
979 if shape.len() == 1 {
980 first_block.d = Array1::from_shape_vec(shape[0], values)
981 .map_err(|e| anyhow::anyhow!("Failed to reshape d: {}", e))?;
982 }
983 }
984 }
985 }
986 }
987
988 tracing::info!("Mamba model loaded from {}", model_path);
989 tracing::info!(
990 "Model contains {} entities, {} relations",
991 self.entities.len(),
992 self.relations.len()
993 );
994
995 Ok(())
996 }
997
998 fn clear(&mut self) {
999 self.entities.clear();
1000 self.relations.clear();
1001 self.is_trained = false;
1002 self.stats = crate::ModelStats::default();
1003 }
1004
1005 fn is_trained(&self) -> bool {
1006 self.is_trained
1007 }
1008
1009 async fn encode(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
1010 let embeddings = texts
1012 .iter()
1013 .map(|text| {
1014 let mut embedding = vec![0.0; self.config.dimensions];
1015 for (i, byte) in text.bytes().enumerate() {
1016 if i < self.config.dimensions {
1017 embedding[i] = (byte as f32) / 255.0;
1018 }
1019 }
1020 embedding
1021 })
1022 .collect::<Vec<_>>();
1023 Ok(embeddings)
1024 }
1025}
1026
1027#[cfg(test)]
1028mod tests {
1029 use super::*;
1030 use crate::EmbeddingModel;
1031 use nalgebra::Complex;
1032
1033 #[test]
1034 fn test_mamba_config_creation() {
1035 let config = MambaConfig::default();
1036 assert_eq!(config.d_state, 16);
1037 assert_eq!(config.d_model, 512);
1038 assert_eq!(config.num_heads, 8);
1039 }
1040
1041 #[test]
1042 fn test_mamba_block_creation() {
1043 let config = MambaConfig::default();
1044 let block = MambaBlock::new(config);
1045 assert_eq!(block.config.d_model, 512);
1046 }
1047
1048 #[test]
1049 fn test_layer_norm() {
1050 let norm = LayerNorm::new(4);
1051 let input =
1052 Array2::from_shape_vec((2, 4), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1053 let output = norm.forward(&input).unwrap();
1054 assert_eq!(output.dim(), (2, 4));
1055 }
1056
1057 #[tokio::test]
1058 async fn test_mamba_embedding_model() {
1059 let model_config = ModelConfig::default();
1060 let mamba_config = MambaConfig::default();
1061 let mut model = MambaEmbedding::new(model_config, mamba_config);
1062
1063 let triple = crate::Triple::new(
1065 crate::NamedNode::new("http://example.org/alice").unwrap(),
1066 crate::NamedNode::new("http://example.org/knows").unwrap(),
1067 crate::NamedNode::new("http://example.org/bob").unwrap(),
1068 );
1069
1070 model.add_triple(triple).unwrap();
1071 assert_eq!(model.get_entities().len(), 2);
1072 assert_eq!(model.get_relations().len(), 1);
1073 }
1074
1075 #[test]
1076 fn test_complex_arithmetic() {
1077 let a = Complex::new(1.0, 2.0);
1078 let b = Complex::new(3.0, 4.0);
1079
1080 let sum = a + b;
1081 assert_eq!(sum.re, 4.0);
1082 assert_eq!(sum.im, 6.0);
1083
1084 let product = a * b;
1085 assert_eq!(product.re, -5.0); assert_eq!(product.im, 10.0); }
1088
1089 #[test]
1090 fn test_activation_functions() {
1091 let config = MambaConfig::default();
1092 let block = MambaBlock::new(config.clone());
1093
1094 let input = Array2::from_shape_vec((1, 3), vec![-1.0, 0.0, 1.0]).unwrap();
1095
1096 let output = block.apply_activation(&input).unwrap();
1098 assert!(output[[0, 0]] < 0.0); assert_eq!(output[[0, 1]], 0.0); assert!(output[[0, 2]] > 0.0); }
1102}