1use crate::{EmbeddingError, ModelConfig, Vector};
14use anyhow::Result;
15use scirs2_core::ndarray_ext::{s, Array1, Array2, Array3, Axis};
16use serde::{Deserialize, Serialize};
17use serde_json;
18use std::collections::HashMap;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct MambaConfig {
23 pub d_state: usize,
25 pub d_model: usize,
27 pub d_inner: usize,
29 pub d_conv: usize,
31 pub expand: usize,
33 pub dt_rank: usize,
35 pub dt_min: f64,
37 pub dt_max: f64,
39 pub dt_init: String,
41 pub dt_scale: f64,
43 pub dt_init_floor: f64,
45 pub bias: bool,
47 pub conv_bias: bool,
49 pub activation: ActivationType,
51 pub use_complex: bool,
53 pub num_heads: usize,
55}
56
57impl Default for MambaConfig {
58 fn default() -> Self {
59 Self {
60 d_state: 16,
61 d_model: 512,
62 d_inner: 1024,
63 d_conv: 4,
64 expand: 2,
65 dt_rank: 32,
66 dt_min: 0.001,
67 dt_max: 0.1,
68 dt_init: "random".to_string(),
69 dt_scale: 1.0,
70 dt_init_floor: 1e-4,
71 bias: false,
72 conv_bias: true,
73 activation: ActivationType::SiLU,
74 use_complex: false,
75 num_heads: 8,
76 }
77 }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub enum ActivationType {
83 SiLU,
84 GELU,
85 ReLU,
86 Swish,
87 Mish,
88}
89
90#[derive(Debug, Clone)]
92pub struct MambaBlock {
93 config: MambaConfig,
94 in_proj: Array2<f32>,
96 conv1d: Array2<f32>,
98 a_log: Array2<f32>,
100 d: Array1<f32>,
102 dt_proj: Array2<f32>,
104 out_proj: Array2<f32>,
106 norm: LayerNorm,
108 cached_states: Option<Array3<f32>>,
110}
111
112impl MambaBlock {
113 pub fn new(config: MambaConfig) -> Self {
115 let d_model = config.d_model;
116 let d_inner = config.d_inner;
117 let d_state = config.d_state;
118 let dt_rank = config.dt_rank;
119
120 let in_proj = Array2::zeros((d_model, d_inner * 2));
122 let conv1d = Array2::zeros((d_inner, config.d_conv));
123 let a_log = Array2::zeros((d_inner, d_state));
124 let d = Array1::ones(d_inner);
125 let dt_proj = Array2::zeros((dt_rank, d_inner));
126 let out_proj = Array2::zeros((d_inner, d_model));
127 let norm = LayerNorm::new(d_model);
128
129 Self {
130 config,
131 in_proj,
132 conv1d,
133 a_log,
134 d,
135 dt_proj,
136 out_proj,
137 norm,
138 cached_states: None,
139 }
140 }
141
142 pub fn forward(&mut self, x: &Array2<f32>) -> Result<Array2<f32>> {
144 let (_batch_size, _seq_len) = x.dim();
145
146 let x_norm = self.norm.forward(x)?;
148 let x_and_res = self.apply_projection(&x_norm)?;
149
150 let (x_main, x_res) = self.split_projection(&x_and_res)?;
152
153 let x_conv = self.apply_convolution(&x_main)?;
155
156 let y = self.selective_ssm(&x_conv, &x_res)?;
158
159 let output = self.apply_output_projection(&y)?;
161
162 Ok(output)
163 }
164
165 fn apply_projection(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
167 let result = x.dot(&self.in_proj);
169 Ok(result)
170 }
171
172 fn split_projection(&self, x: &Array2<f32>) -> Result<(Array2<f32>, Array2<f32>)> {
174 let (_, total_dim) = x.dim();
175 let split_point = total_dim / 2;
176
177 let x_main = x.slice(s![.., ..split_point]).to_owned();
178 let x_res = x.slice(s![.., split_point..]).to_owned();
179
180 Ok((x_main, x_res))
181 }
182
183 fn apply_convolution(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
185 let (batch_size, seq_len) = x.dim();
188 let mut result = Array2::zeros((batch_size, seq_len));
189
190 for i in 0..batch_size {
191 for j in 0..seq_len {
192 let start = j.saturating_sub(self.config.d_conv / 2);
193 let end = std::cmp::min(j + self.config.d_conv / 2 + 1, seq_len);
194
195 let mut conv_sum = 0.0;
196 let mut weight_idx = 0;
197
198 for k in start..end {
199 if weight_idx < self.conv1d.ncols() {
200 conv_sum += x[[i, k]] * self.conv1d[[0, weight_idx]];
201 weight_idx += 1;
202 }
203 }
204
205 result[[i, j]] = conv_sum;
206 }
207 }
208
209 Ok(result)
210 }
211
212 fn selective_ssm(&mut self, x: &Array2<f32>, z: &Array2<f32>) -> Result<Array2<f32>> {
214 let (batch_size, seq_len) = x.dim();
215 let d_state = self.config.d_state;
216 let _d_inner = self.config.d_inner;
217
218 let delta = self.compute_delta(x)?;
220
221 let a = self.compute_a_matrix(&delta)?;
223 let b = self.compute_b_matrix(x)?;
224
225 let mut h = Array2::zeros((batch_size, d_state));
227 let mut outputs = Array2::zeros((batch_size, seq_len));
228
229 for t in 0..seq_len {
231 let x_t = x.slice(s![.., t]).to_owned();
232 let a_t = a.slice(s![.., t, ..]).to_owned();
233 let b_t = b.slice(s![.., t]).to_owned();
234
235 h = &a_t.dot(&h.t()).t() + &(&b_t * &x_t);
237
238 let c = Array1::ones(d_state); let y_t = c.dot(&h.t()) + &self.d * &x_t;
241 outputs.slice_mut(s![.., t]).assign(&y_t);
242 }
243
244 let gated_output = &outputs * &self.apply_activation(z)?;
246
247 Ok(gated_output)
248 }
249
250 fn compute_delta(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
252 let (_batch_size, _seq_len) = x.dim();
253
254 let delta_proj = x.dot(&self.dt_proj.t());
256
257 let delta = delta_proj.mapv(|x| {
259 let exp_x = x.exp();
260 (1.0 + exp_x)
261 .ln()
262 .max(self.config.dt_min as f32)
263 .min(self.config.dt_max as f32)
264 });
265
266 Ok(delta)
267 }
268
269 fn compute_a_matrix(&self, delta: &Array2<f32>) -> Result<Array3<f32>> {
271 let (batch_size, seq_len) = delta.dim();
272 let d_state = self.config.d_state;
273
274 let mut a = Array3::zeros((batch_size, seq_len, d_state));
275
276 for i in 0..batch_size {
277 for j in 0..seq_len {
278 for k in 0..d_state {
279 a[[i, j, k]] = (delta[[i, j]] * self.a_log[[0, k]]).exp();
281 }
282 }
283 }
284
285 Ok(a)
286 }
287
288 fn compute_b_matrix(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
290 Ok(x.clone())
293 }
294
295 fn apply_activation(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
297 match self.config.activation {
298 ActivationType::SiLU => Ok(x.mapv(|x| x / (1.0 + (-x).exp()))),
299 ActivationType::GELU => Ok(x.mapv(|x| {
300 0.5 * x
301 * (1.0 + (std::f32::consts::FRAC_2_SQRT_PI * (x + 0.044715 * x.powi(3))).tanh())
302 })),
303 ActivationType::ReLU => Ok(x.mapv(|x| x.max(0.0))),
304 ActivationType::Swish => Ok(x.mapv(|x| x / (1.0 + (-x).exp()))),
305 ActivationType::Mish => Ok(x.mapv(|x| x * (1.0 + x.exp()).ln().tanh())),
306 }
307 }
308
309 fn apply_output_projection(&self, y: &Array2<f32>) -> Result<Array2<f32>> {
311 Ok(y.dot(&self.out_proj))
312 }
313}
314
315#[derive(Debug, Clone)]
317pub struct LayerNorm {
318 weight: Array1<f32>,
319 bias: Array1<f32>,
320 eps: f32,
321}
322
323impl LayerNorm {
324 pub fn new(d_model: usize) -> Self {
325 Self {
326 weight: Array1::ones(d_model),
327 bias: Array1::zeros(d_model),
328 eps: 1e-5,
329 }
330 }
331
332 pub fn forward(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
333 let mean = x.mean_axis(Axis(1)).unwrap();
334 let centered = x - &mean.insert_axis(Axis(1));
335 let variance = centered.mapv(|x| x.powi(2)).mean_axis(Axis(1)).unwrap();
336 let std = variance.mapv(|x| (x + self.eps).sqrt());
337
338 let normalized = ¢ered / &std.insert_axis(Axis(1));
339 let result = &normalized * &self.weight + &self.bias;
340
341 Ok(result)
342 }
343}
344
345#[derive(Debug, Clone)]
347pub struct MambaEmbedding {
348 id: uuid::Uuid,
349 config: ModelConfig,
350 mamba_config: MambaConfig,
351 mamba_blocks: Vec<MambaBlock>,
352 entities: HashMap<String, usize>,
353 relations: HashMap<String, usize>,
354 entity_embeddings: Array2<f32>,
355 relation_embeddings: Array2<f32>,
356 is_trained: bool,
357 stats: crate::ModelStats,
358}
359
360impl MambaEmbedding {
361 pub fn new(config: ModelConfig, mamba_config: MambaConfig) -> Self {
363 let num_layers = 6; let mut mamba_blocks = Vec::new();
365
366 for _ in 0..num_layers {
367 mamba_blocks.push(MambaBlock::new(mamba_config.clone()));
368 }
369
370 Self {
371 id: uuid::Uuid::new_v4(),
372 config: config.clone(),
373 mamba_config,
374 mamba_blocks,
375 entities: HashMap::new(),
376 relations: HashMap::new(),
377 entity_embeddings: Array2::zeros((1, config.dimensions)),
378 relation_embeddings: Array2::zeros((1, config.dimensions)),
379 is_trained: false,
380 stats: crate::ModelStats {
381 model_type: "Mamba".to_string(),
382 dimensions: config.dimensions,
383 creation_time: chrono::Utc::now(),
384 ..Default::default()
385 },
386 }
387 }
388
389 pub fn process_sequence(&mut self, input: &Array2<f32>) -> Result<Array2<f32>> {
391 let mut x = input.clone();
392
393 for block in &mut self.mamba_blocks {
394 x = block.forward(&x)?;
395 }
396
397 Ok(x)
398 }
399
400 pub fn encode_kg_structure(&mut self, triples: &[crate::Triple]) -> Result<Array2<f32>> {
402 let sequence = self.triples_to_sequence(triples)?;
404
405 let encoded = self.process_sequence(&sequence)?;
407
408 Ok(encoded)
409 }
410
411 fn triples_to_sequence(&self, triples: &[crate::Triple]) -> Result<Array2<f32>> {
413 let seq_len = triples.len();
414 let _d_model = self.mamba_config.d_model;
415
416 let mut sequence = Array2::zeros((1, seq_len));
417
418 for (i, triple) in triples.iter().enumerate() {
420 let subj_idx = self.entities.get(&triple.subject.iri).unwrap_or(&0);
421 let pred_idx = self.relations.get(&triple.predicate.iri).unwrap_or(&0);
422 let obj_idx = self.entities.get(&triple.object.iri).unwrap_or(&0);
423
424 sequence[[0, i]] = (*subj_idx as f32 + *pred_idx as f32 + *obj_idx as f32) / 3.0;
426 }
427
428 Ok(sequence)
429 }
430
431 pub fn generate_selective_embedding(
433 &mut self,
434 entity: &str,
435 context: &[String],
436 ) -> Result<Vector> {
437 let context_sequence = self.create_context_sequence(entity, context)?;
439
440 let processed = self.process_sequence(&context_sequence)?;
442
443 let embedding = processed.slice(s![-1, ..]).to_owned();
445
446 Ok(Vector::new(embedding.to_vec()))
447 }
448
449 fn create_context_sequence(&self, entity: &str, context: &[String]) -> Result<Array2<f32>> {
451 let seq_len = context.len() + 1; let _d_model = self.mamba_config.d_model;
453
454 let mut sequence = Array2::zeros((1, seq_len));
455
456 if let Some(&entity_idx) = self.entities.get(entity) {
458 sequence[[0, 0]] = entity_idx as f32;
459 }
460
461 for (i, ctx) in context.iter().enumerate() {
463 if let Some(&ctx_idx) = self.entities.get(ctx) {
464 sequence[[0, i + 1]] = ctx_idx as f32;
465 }
466 }
467
468 Ok(sequence)
469 }
470}
471
472#[async_trait::async_trait]
473impl crate::EmbeddingModel for MambaEmbedding {
474 fn config(&self) -> &ModelConfig {
475 &self.config
476 }
477
478 fn model_id(&self) -> &uuid::Uuid {
479 &self.id
480 }
481
482 fn model_type(&self) -> &'static str {
483 "Mamba"
484 }
485
486 fn add_triple(&mut self, triple: crate::Triple) -> Result<()> {
487 let subj_id = self.entities.len();
489 let pred_id = self.relations.len();
490 let obj_id = self.entities.len() + 1;
491
492 self.entities.entry(triple.subject.iri).or_insert(subj_id);
493 self.relations
494 .entry(triple.predicate.iri)
495 .or_insert(pred_id);
496 self.entities.entry(triple.object.iri).or_insert(obj_id);
497
498 self.stats.num_triples += 1;
499 self.stats.num_entities = self.entities.len();
500 self.stats.num_relations = self.relations.len();
501
502 Ok(())
503 }
504
505 async fn train(&mut self, epochs: Option<usize>) -> Result<crate::TrainingStats> {
506 let max_epochs = epochs.unwrap_or(self.config.max_epochs);
507 let mut loss_history = Vec::new();
508 let start_time = std::time::Instant::now();
509
510 let num_entities = self.entities.len();
512 let num_relations = self.relations.len();
513
514 if num_entities > 0 && num_relations > 0 {
515 self.entity_embeddings = Array2::zeros((num_entities, self.config.dimensions));
516 self.relation_embeddings = Array2::zeros((num_relations, self.config.dimensions));
517
518 #[allow(unused_imports)]
520 use scirs2_core::random::{Random, Rng};
521 let mut rng = Random::default();
522
523 for i in 0..num_entities {
524 for j in 0..self.config.dimensions {
525 self.entity_embeddings[[i, j]] = rng.random_range(-0.1, 0.1);
526 }
527 }
528
529 for i in 0..num_relations {
530 for j in 0..self.config.dimensions {
531 self.relation_embeddings[[i, j]] = rng.random_range(-0.1, 0.1);
532 }
533 }
534 }
535
536 for epoch in 0..max_epochs {
538 let loss = 1.0 / (epoch as f64 + 1.0); loss_history.push(loss);
540
541 if loss < 0.01 {
542 break;
543 }
544 }
545
546 self.is_trained = true;
547 self.stats.is_trained = true;
548 self.stats.last_training_time = Some(chrono::Utc::now());
549
550 let training_time = start_time.elapsed().as_secs_f64();
551
552 Ok(crate::TrainingStats {
553 epochs_completed: max_epochs,
554 final_loss: loss_history.last().copied().unwrap_or(1.0),
555 training_time_seconds: training_time,
556 convergence_achieved: true,
557 loss_history,
558 })
559 }
560
561 fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
562 if !self.is_trained {
563 return Err(EmbeddingError::ModelNotTrained.into());
564 }
565
566 let entity_idx =
567 self.entities
568 .get(entity)
569 .ok_or_else(|| EmbeddingError::EntityNotFound {
570 entity: entity.to_string(),
571 })?;
572
573 let embedding = self.entity_embeddings.row(*entity_idx);
574 Ok(Vector::new(embedding.to_vec()))
575 }
576
577 fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
578 if !self.is_trained {
579 return Err(EmbeddingError::ModelNotTrained.into());
580 }
581
582 let relation_idx =
583 self.relations
584 .get(relation)
585 .ok_or_else(|| EmbeddingError::RelationNotFound {
586 relation: relation.to_string(),
587 })?;
588
589 let embedding = self.relation_embeddings.row(*relation_idx);
590 Ok(Vector::new(embedding.to_vec()))
591 }
592
593 fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
594 let s_emb = self.get_entity_embedding(subject)?;
595 let p_emb = self.get_relation_embedding(predicate)?;
596 let o_emb = self.get_entity_embedding(object)?;
597
598 let score = s_emb
600 .values
601 .iter()
602 .zip(p_emb.values.iter())
603 .zip(o_emb.values.iter())
604 .map(|((&s, &p), &o)| s * p * o)
605 .sum::<f32>() as f64;
606
607 Ok(score)
608 }
609
610 fn predict_objects(
611 &self,
612 subject: &str,
613 predicate: &str,
614 k: usize,
615 ) -> Result<Vec<(String, f64)>> {
616 let mut predictions = Vec::new();
617
618 for entity in self.entities.keys() {
619 if let Ok(score) = self.score_triple(subject, predicate, entity) {
620 predictions.push((entity.clone(), score));
621 }
622 }
623
624 predictions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
625 predictions.truncate(k);
626
627 Ok(predictions)
628 }
629
630 fn predict_subjects(
631 &self,
632 predicate: &str,
633 object: &str,
634 k: usize,
635 ) -> Result<Vec<(String, f64)>> {
636 let mut predictions = Vec::new();
637
638 for entity in self.entities.keys() {
639 if let Ok(score) = self.score_triple(entity, predicate, object) {
640 predictions.push((entity.clone(), score));
641 }
642 }
643
644 predictions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
645 predictions.truncate(k);
646
647 Ok(predictions)
648 }
649
650 fn predict_relations(
651 &self,
652 subject: &str,
653 object: &str,
654 k: usize,
655 ) -> Result<Vec<(String, f64)>> {
656 let mut predictions = Vec::new();
657
658 for relation in self.relations.keys() {
659 if let Ok(score) = self.score_triple(subject, relation, object) {
660 predictions.push((relation.clone(), score));
661 }
662 }
663
664 predictions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
665 predictions.truncate(k);
666
667 Ok(predictions)
668 }
669
670 fn get_entities(&self) -> Vec<String> {
671 self.entities.keys().cloned().collect()
672 }
673
674 fn get_relations(&self) -> Vec<String> {
675 self.relations.keys().cloned().collect()
676 }
677
678 fn get_stats(&self) -> crate::ModelStats {
679 self.stats.clone()
680 }
681
682 fn save(&self, path: &str) -> Result<()> {
683 use std::fs::File;
684 use std::io::Write;
685
686 let model_path = format!("{path}.mamba");
688 let metadata_path = format!("{path}.mamba.metadata.json");
689
690 let entity_data: std::collections::HashMap<String, usize> = self.entities.clone();
692 let relation_data: std::collections::HashMap<String, usize> = self.relations.clone();
693
694 let entity_embeddings_data = self.entity_embeddings.as_slice().unwrap().to_vec();
696 let relation_embeddings_data = self.relation_embeddings.as_slice().unwrap().to_vec();
697
698 let mamba_blocks_data = if let Some(first_block) = self.mamba_blocks.first() {
700 serde_json::json!({
701 "config": first_block.config,
702 "in_proj": first_block.in_proj.as_slice().unwrap().to_vec(),
703 "in_proj_shape": first_block.in_proj.shape(),
704 "conv1d": first_block.conv1d.as_slice().unwrap().to_vec(),
705 "conv1d_shape": first_block.conv1d.shape(),
706 "a_log": first_block.a_log.as_slice().unwrap().to_vec(),
707 "a_log_shape": first_block.a_log.shape(),
708 "d": first_block.d.as_slice().unwrap().to_vec(),
709 "d_shape": first_block.d.shape(),
710 "num_blocks": self.mamba_blocks.len(),
711 })
712 } else {
713 serde_json::Value::Null
714 };
715
716 let model_data = serde_json::json!({
717 "model_id": self.id,
718 "config": self.config,
719 "mamba_config": self.mamba_config,
720 "entity_data": entity_data,
721 "relation_data": relation_data,
722 "entity_embeddings": entity_embeddings_data,
723 "entity_embeddings_shape": self.entity_embeddings.shape(),
724 "relation_embeddings": relation_embeddings_data,
725 "relation_embeddings_shape": self.relation_embeddings.shape(),
726 "is_trained": self.is_trained,
727 "stats": self.stats,
728 "mamba_blocks": mamba_blocks_data,
729 "timestamp": chrono::Utc::now(),
730 "version": "1.0"
731 });
732
733 let mut file = File::create(&model_path)?;
735 let serialized = serde_json::to_string_pretty(&model_data)?;
736 file.write_all(serialized.as_bytes())?;
737
738 let metadata = serde_json::json!({
740 "model_type": "MambaEmbedding",
741 "model_id": self.id,
742 "dimensions": self.config.dimensions,
743 "num_entities": self.entities.len(),
744 "num_relations": self.relations.len(),
745 "is_trained": self.is_trained,
746 "created_at": chrono::Utc::now(),
747 "file_path": model_path
748 });
749
750 let mut metadata_file = File::create(&metadata_path)?;
751 let metadata_serialized = serde_json::to_string_pretty(&metadata)?;
752 metadata_file.write_all(metadata_serialized.as_bytes())?;
753
754 tracing::info!("Mamba model saved to {} and {}", model_path, metadata_path);
755 Ok(())
756 }
757
758 fn load(&mut self, path: &str) -> Result<()> {
759 use std::fs::File;
760 use std::io::Read;
761
762 let model_path = format!("{path}.mamba");
764
765 let mut file = File::open(&model_path)?;
767 let mut contents = String::new();
768 file.read_to_string(&mut contents)?;
769
770 let model_data: serde_json::Value = serde_json::from_str(&contents)?;
771
772 if let Some(version) = model_data.get("version").and_then(|v| v.as_str()) {
774 if version != "1.0" {
775 return Err(anyhow::anyhow!("Unsupported model version: {}", version));
776 }
777 }
778
779 if let Some(model_id) = model_data.get("model_id") {
781 self.id = serde_json::from_value(model_id.clone())?;
782 }
783
784 if let Some(config) = model_data.get("config") {
785 self.config = serde_json::from_value(config.clone())?;
786 }
787
788 if let Some(mamba_config) = model_data.get("mamba_config") {
789 self.mamba_config = serde_json::from_value(mamba_config.clone())?;
790 }
791
792 if let Some(is_trained) = model_data.get("is_trained") {
793 self.is_trained = serde_json::from_value(is_trained.clone())?;
794 }
795
796 if let Some(stats) = model_data.get("stats") {
797 self.stats = serde_json::from_value(stats.clone())?;
798 }
799
800 if let Some(entity_data) = model_data.get("entity_data") {
802 self.entities = serde_json::from_value(entity_data.clone())?;
803 }
804
805 if let Some(relation_data) = model_data.get("relation_data") {
807 self.relations = serde_json::from_value(relation_data.clone())?;
808 }
809
810 if let (Some(embeddings_data), Some(embeddings_shape)) = (
812 model_data
813 .get("entity_embeddings")
814 .and_then(|v| v.as_array()),
815 model_data
816 .get("entity_embeddings_shape")
817 .and_then(|v| v.as_array()),
818 ) {
819 let values: Vec<f32> = embeddings_data
820 .iter()
821 .filter_map(|v| v.as_f64().map(|f| f as f32))
822 .collect();
823 let shape: Vec<usize> = embeddings_shape
824 .iter()
825 .filter_map(|v| v.as_u64().map(|u| u as usize))
826 .collect();
827 if shape.len() == 2 {
828 self.entity_embeddings = Array2::from_shape_vec((shape[0], shape[1]), values)
829 .map_err(|e| anyhow::anyhow!("Failed to reshape entity_embeddings: {}", e))?;
830 }
831 }
832
833 if let (Some(embeddings_data), Some(embeddings_shape)) = (
835 model_data
836 .get("relation_embeddings")
837 .and_then(|v| v.as_array()),
838 model_data
839 .get("relation_embeddings_shape")
840 .and_then(|v| v.as_array()),
841 ) {
842 let values: Vec<f32> = embeddings_data
843 .iter()
844 .filter_map(|v| v.as_f64().map(|f| f as f32))
845 .collect();
846 let shape: Vec<usize> = embeddings_shape
847 .iter()
848 .filter_map(|v| v.as_u64().map(|u| u as usize))
849 .collect();
850 if shape.len() == 2 {
851 self.relation_embeddings = Array2::from_shape_vec((shape[0], shape[1]), values)
852 .map_err(|e| anyhow::anyhow!("Failed to reshape relation_embeddings: {}", e))?;
853 }
854 }
855
856 if let Some(mamba_blocks_data) = model_data.get("mamba_blocks") {
858 if !mamba_blocks_data.is_null() {
859 let num_blocks = mamba_blocks_data
861 .get("num_blocks")
862 .and_then(|v| v.as_u64())
863 .unwrap_or(self.mamba_blocks.len() as u64)
864 as usize;
865
866 self.mamba_blocks.clear();
868 for _ in 0..num_blocks {
869 self.mamba_blocks
870 .push(MambaBlock::new(self.mamba_config.clone()));
871 }
872
873 if let Some(first_block) = self.mamba_blocks.first_mut() {
875 if let (Some(in_proj_data), Some(in_proj_shape)) = (
877 mamba_blocks_data.get("in_proj").and_then(|v| v.as_array()),
878 mamba_blocks_data
879 .get("in_proj_shape")
880 .and_then(|v| v.as_array()),
881 ) {
882 let values: Vec<f32> = in_proj_data
883 .iter()
884 .filter_map(|v| v.as_f64().map(|f| f as f32))
885 .collect();
886 let shape: Vec<usize> = in_proj_shape
887 .iter()
888 .filter_map(|v| v.as_u64().map(|u| u as usize))
889 .collect();
890 if shape.len() == 2 {
891 first_block.in_proj =
892 Array2::from_shape_vec((shape[0], shape[1]), values).map_err(
893 |e| anyhow::anyhow!("Failed to reshape in_proj: {}", e),
894 )?;
895 }
896 }
897
898 if let (Some(conv1d_data), Some(conv1d_shape)) = (
900 mamba_blocks_data.get("conv1d").and_then(|v| v.as_array()),
901 mamba_blocks_data
902 .get("conv1d_shape")
903 .and_then(|v| v.as_array()),
904 ) {
905 let values: Vec<f32> = conv1d_data
906 .iter()
907 .filter_map(|v| v.as_f64().map(|f| f as f32))
908 .collect();
909 let shape: Vec<usize> = conv1d_shape
910 .iter()
911 .filter_map(|v| v.as_u64().map(|u| u as usize))
912 .collect();
913 if shape.len() == 2 {
914 first_block.conv1d =
915 Array2::from_shape_vec((shape[0], shape[1]), values).map_err(
916 |e| anyhow::anyhow!("Failed to reshape conv1d: {}", e),
917 )?;
918 }
919 }
920
921 if let (Some(a_log_data), Some(a_log_shape)) = (
923 mamba_blocks_data.get("a_log").and_then(|v| v.as_array()),
924 mamba_blocks_data
925 .get("a_log_shape")
926 .and_then(|v| v.as_array()),
927 ) {
928 let values: Vec<f32> = a_log_data
929 .iter()
930 .filter_map(|v| v.as_f64().map(|f| f as f32))
931 .collect();
932 let shape: Vec<usize> = a_log_shape
933 .iter()
934 .filter_map(|v| v.as_u64().map(|u| u as usize))
935 .collect();
936 if shape.len() == 2 {
937 first_block.a_log =
938 Array2::from_shape_vec((shape[0], shape[1]), values).map_err(
939 |e| anyhow::anyhow!("Failed to reshape a_log: {}", e),
940 )?;
941 }
942 }
943
944 if let (Some(d_data), Some(d_shape)) = (
946 mamba_blocks_data.get("d").and_then(|v| v.as_array()),
947 mamba_blocks_data.get("d_shape").and_then(|v| v.as_array()),
948 ) {
949 let values: Vec<f32> = d_data
950 .iter()
951 .filter_map(|v| v.as_f64().map(|f| f as f32))
952 .collect();
953 let shape: Vec<usize> = d_shape
954 .iter()
955 .filter_map(|v| v.as_u64().map(|u| u as usize))
956 .collect();
957 if shape.len() == 1 {
958 first_block.d = Array1::from_shape_vec(shape[0], values)
959 .map_err(|e| anyhow::anyhow!("Failed to reshape d: {}", e))?;
960 }
961 }
962 }
963 }
964 }
965
966 tracing::info!("Mamba model loaded from {}", model_path);
967 tracing::info!(
968 "Model contains {} entities, {} relations",
969 self.entities.len(),
970 self.relations.len()
971 );
972
973 Ok(())
974 }
975
976 fn clear(&mut self) {
977 self.entities.clear();
978 self.relations.clear();
979 self.is_trained = false;
980 self.stats = crate::ModelStats::default();
981 }
982
983 fn is_trained(&self) -> bool {
984 self.is_trained
985 }
986
987 async fn encode(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
988 let embeddings = texts
990 .iter()
991 .map(|text| {
992 let mut embedding = vec![0.0; self.config.dimensions];
993 for (i, byte) in text.bytes().enumerate() {
994 if i < self.config.dimensions {
995 embedding[i] = (byte as f32) / 255.0;
996 }
997 }
998 embedding
999 })
1000 .collect::<Vec<_>>();
1001 Ok(embeddings)
1002 }
1003}
1004
1005#[cfg(test)]
1006mod tests {
1007 use super::*;
1008 use crate::EmbeddingModel;
1009 use nalgebra::Complex;
1010
1011 #[test]
1012 fn test_mamba_config_creation() {
1013 let config = MambaConfig::default();
1014 assert_eq!(config.d_state, 16);
1015 assert_eq!(config.d_model, 512);
1016 assert_eq!(config.num_heads, 8);
1017 }
1018
1019 #[test]
1020 fn test_mamba_block_creation() {
1021 let config = MambaConfig::default();
1022 let block = MambaBlock::new(config);
1023 assert_eq!(block.config.d_model, 512);
1024 }
1025
1026 #[test]
1027 fn test_layer_norm() {
1028 let norm = LayerNorm::new(4);
1029 let input =
1030 Array2::from_shape_vec((2, 4), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1031 let output = norm.forward(&input).unwrap();
1032 assert_eq!(output.dim(), (2, 4));
1033 }
1034
1035 #[tokio::test]
1036 async fn test_mamba_embedding_model() {
1037 let model_config = ModelConfig::default();
1038 let mamba_config = MambaConfig::default();
1039 let mut model = MambaEmbedding::new(model_config, mamba_config);
1040
1041 let triple = crate::Triple::new(
1043 crate::NamedNode::new("http://example.org/alice").unwrap(),
1044 crate::NamedNode::new("http://example.org/knows").unwrap(),
1045 crate::NamedNode::new("http://example.org/bob").unwrap(),
1046 );
1047
1048 model.add_triple(triple).unwrap();
1049 assert_eq!(model.get_entities().len(), 2);
1050 assert_eq!(model.get_relations().len(), 1);
1051 }
1052
1053 #[test]
1054 fn test_complex_arithmetic() {
1055 let a = Complex::new(1.0, 2.0);
1056 let b = Complex::new(3.0, 4.0);
1057
1058 let sum = a + b;
1059 assert_eq!(sum.re, 4.0);
1060 assert_eq!(sum.im, 6.0);
1061
1062 let product = a * b;
1063 assert_eq!(product.re, -5.0); assert_eq!(product.im, 10.0); }
1066
1067 #[test]
1068 fn test_activation_functions() {
1069 let config = MambaConfig::default();
1070 let block = MambaBlock::new(config.clone());
1071
1072 let input = Array2::from_shape_vec((1, 3), vec![-1.0, 0.0, 1.0]).unwrap();
1073
1074 let output = block.apply_activation(&input).unwrap();
1076 assert!(output[[0, 0]] < 0.0); assert_eq!(output[[0, 1]], 0.0); assert!(output[[0, 2]] > 0.0); }
1080}