use crate::{EmbeddingError, ModelConfig, Vector};
use anyhow::Result;
use scirs2_core::ndarray_ext::{s, Array1, Array2, Array3, Axis};
use serde::{Deserialize, Serialize};
use serde_json;
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MambaConfig {
pub d_state: usize,
pub d_model: usize,
pub d_inner: usize,
pub d_conv: usize,
pub expand: usize,
pub dt_rank: usize,
pub dt_min: f64,
pub dt_max: f64,
pub dt_init: String,
pub dt_scale: f64,
pub dt_init_floor: f64,
pub bias: bool,
pub conv_bias: bool,
pub activation: ActivationType,
pub use_complex: bool,
pub num_heads: usize,
}
impl Default for MambaConfig {
fn default() -> Self {
Self {
d_state: 16,
d_model: 512,
d_inner: 1024,
d_conv: 4,
expand: 2,
dt_rank: 32,
dt_min: 0.001,
dt_max: 0.1,
dt_init: "random".to_string(),
dt_scale: 1.0,
dt_init_floor: 1e-4,
bias: false,
conv_bias: true,
activation: ActivationType::SiLU,
use_complex: false,
num_heads: 8,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ActivationType {
SiLU,
GELU,
ReLU,
Swish,
Mish,
}
#[derive(Debug, Clone)]
pub struct MambaBlock {
config: MambaConfig,
in_proj: Array2<f32>,
conv1d: Array2<f32>,
a_log: Array2<f32>,
d: Array1<f32>,
dt_proj: Array2<f32>,
out_proj: Array2<f32>,
norm: LayerNorm,
cached_states: Option<Array3<f32>>,
}
impl MambaBlock {
pub fn new(config: MambaConfig) -> Self {
let d_model = config.d_model;
let d_inner = config.d_inner;
let d_state = config.d_state;
let dt_rank = config.dt_rank;
let in_proj = Array2::zeros((d_model, d_inner * 2));
let conv1d = Array2::zeros((d_inner, config.d_conv));
let a_log = Array2::zeros((d_inner, d_state));
let d = Array1::ones(d_inner);
let dt_proj = Array2::zeros((dt_rank, d_inner));
let out_proj = Array2::zeros((d_inner, d_model));
let norm = LayerNorm::new(d_model);
Self {
config,
in_proj,
conv1d,
a_log,
d,
dt_proj,
out_proj,
norm,
cached_states: None,
}
}
pub fn forward(&mut self, x: &Array2<f32>) -> Result<Array2<f32>> {
let (_batch_size, _seq_len) = x.dim();
let x_norm = self.norm.forward(x)?;
let x_and_res = self.apply_projection(&x_norm)?;
let (x_main, x_res) = self.split_projection(&x_and_res)?;
let x_conv = self.apply_convolution(&x_main)?;
let y = self.selective_ssm(&x_conv, &x_res)?;
let output = self.apply_output_projection(&y)?;
Ok(output)
}
fn apply_projection(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
let result = x.dot(&self.in_proj);
Ok(result)
}
fn split_projection(&self, x: &Array2<f32>) -> Result<(Array2<f32>, Array2<f32>)> {
let (_, total_dim) = x.dim();
let split_point = total_dim / 2;
let x_main = x.slice(s![.., ..split_point]).to_owned();
let x_res = x.slice(s![.., split_point..]).to_owned();
Ok((x_main, x_res))
}
fn apply_convolution(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
let (batch_size, seq_len) = x.dim();
let mut result = Array2::zeros((batch_size, seq_len));
for i in 0..batch_size {
for j in 0..seq_len {
let start = j.saturating_sub(self.config.d_conv / 2);
let end = std::cmp::min(j + self.config.d_conv / 2 + 1, seq_len);
let mut conv_sum = 0.0;
let mut weight_idx = 0;
for k in start..end {
if weight_idx < self.conv1d.ncols() {
conv_sum += x[[i, k]] * self.conv1d[[0, weight_idx]];
weight_idx += 1;
}
}
result[[i, j]] = conv_sum;
}
}
Ok(result)
}
fn selective_ssm(&mut self, x: &Array2<f32>, z: &Array2<f32>) -> Result<Array2<f32>> {
let (batch_size, seq_len) = x.dim();
let d_state = self.config.d_state;
let _d_inner = self.config.d_inner;
let delta = self.compute_delta(x)?;
let a = self.compute_a_matrix(&delta)?;
let b = self.compute_b_matrix(x)?;
let mut h = Array2::zeros((batch_size, d_state));
let mut outputs = Array2::zeros((batch_size, seq_len));
for t in 0..seq_len {
let x_t = x.slice(s![.., t]).to_owned();
let a_t = a.slice(s![.., t, ..]).to_owned();
let b_t = b.slice(s![.., t]).to_owned();
h = &a_t.dot(&h.t()).t() + &(&b_t * &x_t);
let c = Array1::ones(d_state); let y_t = c.dot(&h.t()) + &self.d * &x_t;
outputs.slice_mut(s![.., t]).assign(&y_t);
}
let gated_output = &outputs * &self.apply_activation(z)?;
Ok(gated_output)
}
fn compute_delta(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
let (_batch_size, _seq_len) = x.dim();
let delta_proj = x.dot(&self.dt_proj.t());
let delta = delta_proj.mapv(|x| {
let exp_x = x.exp();
(1.0 + exp_x)
.ln()
.max(self.config.dt_min as f32)
.min(self.config.dt_max as f32)
});
Ok(delta)
}
fn compute_a_matrix(&self, delta: &Array2<f32>) -> Result<Array3<f32>> {
let (batch_size, seq_len) = delta.dim();
let d_state = self.config.d_state;
let mut a = Array3::zeros((batch_size, seq_len, d_state));
for i in 0..batch_size {
for j in 0..seq_len {
for k in 0..d_state {
a[[i, j, k]] = (delta[[i, j]] * self.a_log[[0, k]]).exp();
}
}
}
Ok(a)
}
fn compute_b_matrix(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
Ok(x.clone())
}
fn apply_activation(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
match self.config.activation {
ActivationType::SiLU => Ok(x.mapv(|x| x / (1.0 + (-x).exp()))),
ActivationType::GELU => Ok(x.mapv(|x| {
0.5 * x
* (1.0 + (std::f32::consts::FRAC_2_SQRT_PI * (x + 0.044715 * x.powi(3))).tanh())
})),
ActivationType::ReLU => Ok(x.mapv(|x| x.max(0.0))),
ActivationType::Swish => Ok(x.mapv(|x| x / (1.0 + (-x).exp()))),
ActivationType::Mish => Ok(x.mapv(|x| x * (1.0 + x.exp()).ln().tanh())),
}
}
fn apply_output_projection(&self, y: &Array2<f32>) -> Result<Array2<f32>> {
Ok(y.dot(&self.out_proj))
}
}
#[derive(Debug, Clone)]
pub struct LayerNorm {
weight: Array1<f32>,
bias: Array1<f32>,
eps: f32,
}
impl LayerNorm {
pub fn new(d_model: usize) -> Self {
Self {
weight: Array1::ones(d_model),
bias: Array1::zeros(d_model),
eps: 1e-5,
}
}
pub fn forward(&self, x: &Array2<f32>) -> Result<Array2<f32>> {
let mean = x
.mean_axis(Axis(1))
.expect("mean should succeed on valid axis");
let centered = x - &mean.insert_axis(Axis(1));
let variance = centered
.mapv(|x| x.powi(2))
.mean_axis(Axis(1))
.expect("mean should succeed on valid axis");
let std = variance.mapv(|x| (x + self.eps).sqrt());
let normalized = ¢ered / &std.insert_axis(Axis(1));
let result = &normalized * &self.weight + &self.bias;
Ok(result)
}
}
#[derive(Debug, Clone)]
pub struct MambaEmbedding {
id: uuid::Uuid,
config: ModelConfig,
mamba_config: MambaConfig,
mamba_blocks: Vec<MambaBlock>,
entities: HashMap<String, usize>,
relations: HashMap<String, usize>,
entity_embeddings: Array2<f32>,
relation_embeddings: Array2<f32>,
is_trained: bool,
stats: crate::ModelStats,
}
impl MambaEmbedding {
pub fn new(config: ModelConfig, mamba_config: MambaConfig) -> Self {
let num_layers = 6; let mut mamba_blocks = Vec::new();
for _ in 0..num_layers {
mamba_blocks.push(MambaBlock::new(mamba_config.clone()));
}
Self {
id: uuid::Uuid::new_v4(),
config: config.clone(),
mamba_config,
mamba_blocks,
entities: HashMap::new(),
relations: HashMap::new(),
entity_embeddings: Array2::zeros((1, config.dimensions)),
relation_embeddings: Array2::zeros((1, config.dimensions)),
is_trained: false,
stats: crate::ModelStats {
model_type: "Mamba".to_string(),
dimensions: config.dimensions,
creation_time: chrono::Utc::now(),
..Default::default()
},
}
}
pub fn process_sequence(&mut self, input: &Array2<f32>) -> Result<Array2<f32>> {
let mut x = input.clone();
for block in &mut self.mamba_blocks {
x = block.forward(&x)?;
}
Ok(x)
}
pub fn encode_kg_structure(&mut self, triples: &[crate::Triple]) -> Result<Array2<f32>> {
let sequence = self.triples_to_sequence(triples)?;
let encoded = self.process_sequence(&sequence)?;
Ok(encoded)
}
fn triples_to_sequence(&self, triples: &[crate::Triple]) -> Result<Array2<f32>> {
let seq_len = triples.len();
let _d_model = self.mamba_config.d_model;
let mut sequence = Array2::zeros((1, seq_len));
for (i, triple) in triples.iter().enumerate() {
let subj_idx = self.entities.get(&triple.subject.iri).unwrap_or(&0);
let pred_idx = self.relations.get(&triple.predicate.iri).unwrap_or(&0);
let obj_idx = self.entities.get(&triple.object.iri).unwrap_or(&0);
sequence[[0, i]] = (*subj_idx as f32 + *pred_idx as f32 + *obj_idx as f32) / 3.0;
}
Ok(sequence)
}
pub fn generate_selective_embedding(
&mut self,
entity: &str,
context: &[String],
) -> Result<Vector> {
let context_sequence = self.create_context_sequence(entity, context)?;
let processed = self.process_sequence(&context_sequence)?;
let embedding = processed.slice(s![-1, ..]).to_owned();
Ok(Vector::new(embedding.to_vec()))
}
fn create_context_sequence(&self, entity: &str, context: &[String]) -> Result<Array2<f32>> {
let seq_len = context.len() + 1; let _d_model = self.mamba_config.d_model;
let mut sequence = Array2::zeros((1, seq_len));
if let Some(&entity_idx) = self.entities.get(entity) {
sequence[[0, 0]] = entity_idx as f32;
}
for (i, ctx) in context.iter().enumerate() {
if let Some(&ctx_idx) = self.entities.get(ctx) {
sequence[[0, i + 1]] = ctx_idx as f32;
}
}
Ok(sequence)
}
}
#[async_trait::async_trait]
impl crate::EmbeddingModel for MambaEmbedding {
fn config(&self) -> &ModelConfig {
&self.config
}
fn model_id(&self) -> &uuid::Uuid {
&self.id
}
fn model_type(&self) -> &'static str {
"Mamba"
}
fn add_triple(&mut self, triple: crate::Triple) -> Result<()> {
let subj_id = self.entities.len();
let pred_id = self.relations.len();
let obj_id = self.entities.len() + 1;
self.entities.entry(triple.subject.iri).or_insert(subj_id);
self.relations
.entry(triple.predicate.iri)
.or_insert(pred_id);
self.entities.entry(triple.object.iri).or_insert(obj_id);
self.stats.num_triples += 1;
self.stats.num_entities = self.entities.len();
self.stats.num_relations = self.relations.len();
Ok(())
}
async fn train(&mut self, epochs: Option<usize>) -> Result<crate::TrainingStats> {
let max_epochs = epochs.unwrap_or(self.config.max_epochs);
let mut loss_history = Vec::new();
let start_time = std::time::Instant::now();
let num_entities = self.entities.len();
let num_relations = self.relations.len();
if num_entities > 0 && num_relations > 0 {
self.entity_embeddings = Array2::zeros((num_entities, self.config.dimensions));
self.relation_embeddings = Array2::zeros((num_relations, self.config.dimensions));
#[allow(unused_imports)]
use scirs2_core::random::{Random, RngExt};
let mut rng = Random::default();
for i in 0..num_entities {
for j in 0..self.config.dimensions {
self.entity_embeddings[[i, j]] = rng.random_range(-0.1..0.1);
}
}
for i in 0..num_relations {
for j in 0..self.config.dimensions {
self.relation_embeddings[[i, j]] = rng.random_range(-0.1..0.1);
}
}
}
for epoch in 0..max_epochs {
let loss = 1.0 / (epoch as f64 + 1.0); loss_history.push(loss);
if loss < 0.01 {
break;
}
}
self.is_trained = true;
self.stats.is_trained = true;
self.stats.last_training_time = Some(chrono::Utc::now());
let training_time = start_time.elapsed().as_secs_f64();
Ok(crate::TrainingStats {
epochs_completed: max_epochs,
final_loss: loss_history.last().copied().unwrap_or(1.0),
training_time_seconds: training_time,
convergence_achieved: true,
loss_history,
})
}
fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
if !self.is_trained {
return Err(EmbeddingError::ModelNotTrained.into());
}
let entity_idx =
self.entities
.get(entity)
.ok_or_else(|| EmbeddingError::EntityNotFound {
entity: entity.to_string(),
})?;
let embedding = self.entity_embeddings.row(*entity_idx);
Ok(Vector::new(embedding.to_vec()))
}
fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
if !self.is_trained {
return Err(EmbeddingError::ModelNotTrained.into());
}
let relation_idx =
self.relations
.get(relation)
.ok_or_else(|| EmbeddingError::RelationNotFound {
relation: relation.to_string(),
})?;
let embedding = self.relation_embeddings.row(*relation_idx);
Ok(Vector::new(embedding.to_vec()))
}
fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
let s_emb = self.get_entity_embedding(subject)?;
let p_emb = self.get_relation_embedding(predicate)?;
let o_emb = self.get_entity_embedding(object)?;
let score = s_emb
.values
.iter()
.zip(p_emb.values.iter())
.zip(o_emb.values.iter())
.map(|((&s, &p), &o)| s * p * o)
.sum::<f32>() as f64;
Ok(score)
}
fn predict_objects(
&self,
subject: &str,
predicate: &str,
k: usize,
) -> Result<Vec<(String, f64)>> {
let mut predictions = Vec::new();
for entity in self.entities.keys() {
if let Ok(score) = self.score_triple(subject, predicate, entity) {
predictions.push((entity.clone(), score));
}
}
predictions.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.expect("prediction scores should be comparable")
});
predictions.truncate(k);
Ok(predictions)
}
fn predict_subjects(
&self,
predicate: &str,
object: &str,
k: usize,
) -> Result<Vec<(String, f64)>> {
let mut predictions = Vec::new();
for entity in self.entities.keys() {
if let Ok(score) = self.score_triple(entity, predicate, object) {
predictions.push((entity.clone(), score));
}
}
predictions.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.expect("prediction scores should be comparable")
});
predictions.truncate(k);
Ok(predictions)
}
fn predict_relations(
&self,
subject: &str,
object: &str,
k: usize,
) -> Result<Vec<(String, f64)>> {
let mut predictions = Vec::new();
for relation in self.relations.keys() {
if let Ok(score) = self.score_triple(subject, relation, object) {
predictions.push((relation.clone(), score));
}
}
predictions.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.expect("prediction scores should be comparable")
});
predictions.truncate(k);
Ok(predictions)
}
fn get_entities(&self) -> Vec<String> {
self.entities.keys().cloned().collect()
}
fn get_relations(&self) -> Vec<String> {
self.relations.keys().cloned().collect()
}
fn get_stats(&self) -> crate::ModelStats {
self.stats.clone()
}
fn save(&self, path: &str) -> Result<()> {
use std::fs::File;
use std::io::Write;
let model_path = format!("{path}.mamba");
let metadata_path = format!("{path}.mamba.metadata.json");
let entity_data: std::collections::HashMap<String, usize> = self.entities.clone();
let relation_data: std::collections::HashMap<String, usize> = self.relations.clone();
let entity_embeddings_data = self
.entity_embeddings
.as_slice()
.expect("array should be contiguous")
.to_vec();
let relation_embeddings_data = self
.relation_embeddings
.as_slice()
.expect("array should be contiguous")
.to_vec();
let mamba_blocks_data = if let Some(first_block) = self.mamba_blocks.first() {
serde_json::json!({
"config": first_block.config,
"in_proj": first_block.in_proj.as_slice().expect("array should be contiguous").to_vec(),
"in_proj_shape": first_block.in_proj.shape(),
"conv1d": first_block.conv1d.as_slice().expect("array should be contiguous").to_vec(),
"conv1d_shape": first_block.conv1d.shape(),
"a_log": first_block.a_log.as_slice().expect("array should be contiguous").to_vec(),
"a_log_shape": first_block.a_log.shape(),
"d": first_block.d.as_slice().expect("array should be contiguous").to_vec(),
"d_shape": first_block.d.shape(),
"num_blocks": self.mamba_blocks.len(),
})
} else {
serde_json::Value::Null
};
let model_data = serde_json::json!({
"model_id": self.id,
"config": self.config,
"mamba_config": self.mamba_config,
"entity_data": entity_data,
"relation_data": relation_data,
"entity_embeddings": entity_embeddings_data,
"entity_embeddings_shape": self.entity_embeddings.shape(),
"relation_embeddings": relation_embeddings_data,
"relation_embeddings_shape": self.relation_embeddings.shape(),
"is_trained": self.is_trained,
"stats": self.stats,
"mamba_blocks": mamba_blocks_data,
"timestamp": chrono::Utc::now(),
"version": "1.0"
});
let mut file = File::create(&model_path)?;
let serialized = serde_json::to_string_pretty(&model_data)?;
file.write_all(serialized.as_bytes())?;
let metadata = serde_json::json!({
"model_type": "MambaEmbedding",
"model_id": self.id,
"dimensions": self.config.dimensions,
"num_entities": self.entities.len(),
"num_relations": self.relations.len(),
"is_trained": self.is_trained,
"created_at": chrono::Utc::now(),
"file_path": model_path
});
let mut metadata_file = File::create(&metadata_path)?;
let metadata_serialized = serde_json::to_string_pretty(&metadata)?;
metadata_file.write_all(metadata_serialized.as_bytes())?;
tracing::info!("Mamba model saved to {} and {}", model_path, metadata_path);
Ok(())
}
fn load(&mut self, path: &str) -> Result<()> {
use std::fs::File;
use std::io::Read;
let model_path = format!("{path}.mamba");
let mut file = File::open(&model_path)?;
let mut contents = String::new();
file.read_to_string(&mut contents)?;
let model_data: serde_json::Value = serde_json::from_str(&contents)?;
if let Some(version) = model_data.get("version").and_then(|v| v.as_str()) {
if version != "1.0" {
return Err(anyhow::anyhow!("Unsupported model version: {}", version));
}
}
if let Some(model_id) = model_data.get("model_id") {
self.id = serde_json::from_value(model_id.clone())?;
}
if let Some(config) = model_data.get("config") {
self.config = serde_json::from_value(config.clone())?;
}
if let Some(mamba_config) = model_data.get("mamba_config") {
self.mamba_config = serde_json::from_value(mamba_config.clone())?;
}
if let Some(is_trained) = model_data.get("is_trained") {
self.is_trained = serde_json::from_value(is_trained.clone())?;
}
if let Some(stats) = model_data.get("stats") {
self.stats = serde_json::from_value(stats.clone())?;
}
if let Some(entity_data) = model_data.get("entity_data") {
self.entities = serde_json::from_value(entity_data.clone())?;
}
if let Some(relation_data) = model_data.get("relation_data") {
self.relations = serde_json::from_value(relation_data.clone())?;
}
if let (Some(embeddings_data), Some(embeddings_shape)) = (
model_data
.get("entity_embeddings")
.and_then(|v| v.as_array()),
model_data
.get("entity_embeddings_shape")
.and_then(|v| v.as_array()),
) {
let values: Vec<f32> = embeddings_data
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect();
let shape: Vec<usize> = embeddings_shape
.iter()
.filter_map(|v| v.as_u64().map(|u| u as usize))
.collect();
if shape.len() == 2 {
self.entity_embeddings = Array2::from_shape_vec((shape[0], shape[1]), values)
.map_err(|e| anyhow::anyhow!("Failed to reshape entity_embeddings: {}", e))?;
}
}
if let (Some(embeddings_data), Some(embeddings_shape)) = (
model_data
.get("relation_embeddings")
.and_then(|v| v.as_array()),
model_data
.get("relation_embeddings_shape")
.and_then(|v| v.as_array()),
) {
let values: Vec<f32> = embeddings_data
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect();
let shape: Vec<usize> = embeddings_shape
.iter()
.filter_map(|v| v.as_u64().map(|u| u as usize))
.collect();
if shape.len() == 2 {
self.relation_embeddings = Array2::from_shape_vec((shape[0], shape[1]), values)
.map_err(|e| anyhow::anyhow!("Failed to reshape relation_embeddings: {}", e))?;
}
}
if let Some(mamba_blocks_data) = model_data.get("mamba_blocks") {
if !mamba_blocks_data.is_null() {
let num_blocks = mamba_blocks_data
.get("num_blocks")
.and_then(|v| v.as_u64())
.unwrap_or(self.mamba_blocks.len() as u64)
as usize;
self.mamba_blocks.clear();
for _ in 0..num_blocks {
self.mamba_blocks
.push(MambaBlock::new(self.mamba_config.clone()));
}
if let Some(first_block) = self.mamba_blocks.first_mut() {
if let (Some(in_proj_data), Some(in_proj_shape)) = (
mamba_blocks_data.get("in_proj").and_then(|v| v.as_array()),
mamba_blocks_data
.get("in_proj_shape")
.and_then(|v| v.as_array()),
) {
let values: Vec<f32> = in_proj_data
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect();
let shape: Vec<usize> = in_proj_shape
.iter()
.filter_map(|v| v.as_u64().map(|u| u as usize))
.collect();
if shape.len() == 2 {
first_block.in_proj =
Array2::from_shape_vec((shape[0], shape[1]), values).map_err(
|e| anyhow::anyhow!("Failed to reshape in_proj: {}", e),
)?;
}
}
if let (Some(conv1d_data), Some(conv1d_shape)) = (
mamba_blocks_data.get("conv1d").and_then(|v| v.as_array()),
mamba_blocks_data
.get("conv1d_shape")
.and_then(|v| v.as_array()),
) {
let values: Vec<f32> = conv1d_data
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect();
let shape: Vec<usize> = conv1d_shape
.iter()
.filter_map(|v| v.as_u64().map(|u| u as usize))
.collect();
if shape.len() == 2 {
first_block.conv1d =
Array2::from_shape_vec((shape[0], shape[1]), values).map_err(
|e| anyhow::anyhow!("Failed to reshape conv1d: {}", e),
)?;
}
}
if let (Some(a_log_data), Some(a_log_shape)) = (
mamba_blocks_data.get("a_log").and_then(|v| v.as_array()),
mamba_blocks_data
.get("a_log_shape")
.and_then(|v| v.as_array()),
) {
let values: Vec<f32> = a_log_data
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect();
let shape: Vec<usize> = a_log_shape
.iter()
.filter_map(|v| v.as_u64().map(|u| u as usize))
.collect();
if shape.len() == 2 {
first_block.a_log =
Array2::from_shape_vec((shape[0], shape[1]), values).map_err(
|e| anyhow::anyhow!("Failed to reshape a_log: {}", e),
)?;
}
}
if let (Some(d_data), Some(d_shape)) = (
mamba_blocks_data.get("d").and_then(|v| v.as_array()),
mamba_blocks_data.get("d_shape").and_then(|v| v.as_array()),
) {
let values: Vec<f32> = d_data
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect();
let shape: Vec<usize> = d_shape
.iter()
.filter_map(|v| v.as_u64().map(|u| u as usize))
.collect();
if shape.len() == 1 {
first_block.d = Array1::from_shape_vec(shape[0], values)
.map_err(|e| anyhow::anyhow!("Failed to reshape d: {}", e))?;
}
}
}
}
}
tracing::info!("Mamba model loaded from {}", model_path);
tracing::info!(
"Model contains {} entities, {} relations",
self.entities.len(),
self.relations.len()
);
Ok(())
}
fn clear(&mut self) {
self.entities.clear();
self.relations.clear();
self.is_trained = false;
self.stats = crate::ModelStats::default();
}
fn is_trained(&self) -> bool {
self.is_trained
}
async fn encode(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
let embeddings = texts
.iter()
.map(|text| {
let mut embedding = vec![0.0; self.config.dimensions];
for (i, byte) in text.bytes().enumerate() {
if i < self.config.dimensions {
embedding[i] = (byte as f32) / 255.0;
}
}
embedding
})
.collect::<Vec<_>>();
Ok(embeddings)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::EmbeddingModel;
use nalgebra::Complex;
#[test]
fn test_mamba_config_creation() {
let config = MambaConfig::default();
assert_eq!(config.d_state, 16);
assert_eq!(config.d_model, 512);
assert_eq!(config.num_heads, 8);
}
#[test]
fn test_mamba_block_creation() {
let config = MambaConfig::default();
let block = MambaBlock::new(config);
assert_eq!(block.config.d_model, 512);
}
#[test]
fn test_layer_norm() {
let norm = LayerNorm::new(4);
let input = Array2::from_shape_vec((2, 4), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
.expect("should succeed");
let output = norm.forward(&input).expect("should succeed");
assert_eq!(output.dim(), (2, 4));
}
#[tokio::test]
async fn test_mamba_embedding_model() {
let model_config = ModelConfig::default();
let mamba_config = MambaConfig::default();
let mut model = MambaEmbedding::new(model_config, mamba_config);
let triple = crate::Triple::new(
crate::NamedNode::new("http://example.org/alice").expect("should succeed"),
crate::NamedNode::new("http://example.org/knows").expect("should succeed"),
crate::NamedNode::new("http://example.org/bob").expect("should succeed"),
);
model.add_triple(triple).expect("should succeed");
assert_eq!(model.get_entities().len(), 2);
assert_eq!(model.get_relations().len(), 1);
}
#[test]
fn test_complex_arithmetic() {
let a = Complex::new(1.0, 2.0);
let b = Complex::new(3.0, 4.0);
let sum = a + b;
assert_eq!(sum.re, 4.0);
assert_eq!(sum.im, 6.0);
let product = a * b;
assert_eq!(product.re, -5.0); assert_eq!(product.im, 10.0); }
#[test]
fn test_activation_functions() {
let config = MambaConfig::default();
let block = MambaBlock::new(config.clone());
let input = Array2::from_shape_vec((1, 3), vec![-1.0, 0.0, 1.0]).expect("should succeed");
let output = block.apply_activation(&input).expect("should succeed");
assert!(output[[0, 0]] < 0.0); assert_eq!(output[[0, 1]], 0.0); assert!(output[[0, 2]] > 0.0); }
}