use anyhow::Result;
use scirs2_core::ndarray_ext::{Array1, Array2};
use scirs2_core::random::{Random, RngExt};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextEncoder {
pub encoder_type: String,
pub input_dim: usize,
pub output_dim: usize,
pub parameters: HashMap<String, Array2<f32>>,
}
impl TextEncoder {
pub fn new(encoder_type: String, input_dim: usize, output_dim: usize) -> Self {
let mut parameters = HashMap::new();
let mut random = Random::default();
parameters.insert(
"projection".to_string(),
Array2::from_shape_fn((output_dim, input_dim), |(_, _)| {
(random.random::<f32>() - 0.5) * 0.1
}),
);
let mut random = Random::default();
parameters.insert(
"attention".to_string(),
Array2::from_shape_fn((output_dim, output_dim), |(_, _)| {
(random.random::<f32>() - 0.5) * 0.1
}),
);
Self {
encoder_type,
input_dim,
output_dim,
parameters,
}
}
pub fn encode(&self, text: &str) -> Result<Array1<f32>> {
let input_features = self.extract_text_features(text);
let projection = self
.parameters
.get("projection")
.expect("parameter 'projection' should be initialized");
let encoded = projection.dot(&input_features);
let mean = encoded.mean().unwrap_or(0.0);
let var = encoded.var(0.0);
let normalized = encoded.mapv(|x| (x - mean) / (var + 1e-8).sqrt());
Ok(normalized)
}
fn extract_text_features(&self, text: &str) -> Array1<f32> {
let mut features = vec![0.0; self.input_dim];
let words: Vec<&str> = text.split_whitespace().collect();
for (i, word) in words.iter().enumerate() {
if i < self.input_dim {
features[i] = word.len() as f32 / 10.0; }
}
if self.input_dim > words.len() {
features[words.len()] = text.len() as f32 / 100.0; if self.input_dim > words.len() + 1 {
features[words.len() + 1] = words.len() as f32 / 20.0; }
}
Array1::from_vec(features)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KGEncoder {
pub architecture: String,
pub entity_dim: usize,
pub relation_dim: usize,
pub output_dim: usize,
pub parameters: HashMap<String, Array2<f32>>,
}
impl KGEncoder {
pub fn new(
architecture: String,
entity_dim: usize,
relation_dim: usize,
output_dim: usize,
) -> Self {
let mut parameters = HashMap::new();
let mut random = Random::default();
parameters.insert(
"entity_projection".to_string(),
Array2::from_shape_fn((output_dim, entity_dim), |(_, _)| {
(random.random::<f32>() - 0.5) * 0.1
}),
);
let mut random = Random::default();
parameters.insert(
"relation_projection".to_string(),
Array2::from_shape_fn((output_dim, relation_dim), |(_, _)| {
(random.random::<f32>() - 0.5) * 0.1
}),
);
Self {
architecture,
entity_dim,
relation_dim,
output_dim,
parameters,
}
}
pub fn encode_entity(&self, entity_embedding: &Array1<f32>) -> Result<Array1<f32>> {
let projection = self
.parameters
.get("entity_projection")
.expect("parameter 'entity_projection' should be initialized");
if projection.ncols() != entity_embedding.len() {
let target_dim = projection.ncols();
let mut adjusted_embedding = Array1::zeros(target_dim);
let copy_len = entity_embedding.len().min(target_dim);
adjusted_embedding
.slice_mut(scirs2_core::ndarray_ext::s![..copy_len])
.assign(&entity_embedding.slice(scirs2_core::ndarray_ext::s![..copy_len]));
Ok(projection.dot(&adjusted_embedding))
} else {
Ok(projection.dot(entity_embedding))
}
}
pub fn encode_relation(&self, relation_embedding: &Array1<f32>) -> Result<Array1<f32>> {
let projection = self
.parameters
.get("relation_projection")
.expect("parameter 'relation_projection' should be initialized");
if projection.ncols() != relation_embedding.len() {
let target_dim = projection.ncols();
let mut adjusted_embedding = Array1::zeros(target_dim);
let copy_len = relation_embedding.len().min(target_dim);
adjusted_embedding
.slice_mut(scirs2_core::ndarray_ext::s![..copy_len])
.assign(&relation_embedding.slice(scirs2_core::ndarray_ext::s![..copy_len]));
Ok(projection.dot(&adjusted_embedding))
} else {
Ok(projection.dot(relation_embedding))
}
}
pub fn encode_structured(
&self,
entity: &Array1<f32>,
relations: &[Array1<f32>],
) -> Result<Array1<f32>> {
let entity_encoded = self.encode_entity(entity)?;
let mut relation_agg = Array1::<f32>::zeros(self.output_dim);
for relation in relations {
let rel_encoded = self.encode_relation(relation)?;
relation_agg = &relation_agg + &rel_encoded;
}
if !relations.is_empty() {
relation_agg /= relations.len() as f32;
}
Ok(&entity_encoded + &relation_agg)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AlignmentNetwork {
pub architecture: String,
pub input_dims: (usize, usize),
pub hidden_dim: usize,
pub output_dim: usize,
pub parameters: HashMap<String, Array2<f32>>,
}
impl AlignmentNetwork {
pub fn new(
architecture: String,
text_dim: usize,
kg_dim: usize,
hidden_dim: usize,
output_dim: usize,
) -> Self {
let mut parameters = HashMap::new();
let mut random = Random::default();
parameters.insert(
"text_hidden".to_string(),
Array2::from_shape_fn((hidden_dim, text_dim), |(_, _)| {
(random.random::<f32>() - 0.5) * 0.1
}),
);
let mut random = Random::default();
parameters.insert(
"text_output".to_string(),
Array2::from_shape_fn((output_dim, hidden_dim), |(_, _)| {
(random.random::<f32>() - 0.5) * 0.1
}),
);
let mut random = Random::default();
parameters.insert(
"kg_hidden".to_string(),
Array2::from_shape_fn((hidden_dim, kg_dim), |(_, _)| {
(random.random::<f32>() - 0.5) * 0.1
}),
);
let mut random = Random::default();
parameters.insert(
"kg_output".to_string(),
Array2::from_shape_fn((output_dim, hidden_dim), |(_, _)| {
(random.random::<f32>() - 0.5) * 0.1
}),
);
let mut random = Random::default();
parameters.insert(
"cross_attention".to_string(),
Array2::from_shape_fn((output_dim, output_dim), |(_, _)| {
(random.random::<f32>() - 0.5) * 0.1
}),
);
Self {
architecture,
input_dims: (text_dim, kg_dim),
hidden_dim,
output_dim,
parameters,
}
}
pub fn align(
&self,
text_emb: &Array1<f32>,
kg_emb: &Array1<f32>,
) -> Result<(Array1<f32>, f32)> {
let text_hidden_matrix = self
.parameters
.get("text_hidden")
.expect("parameter 'text_hidden' should be initialized");
let text_hidden = text_hidden_matrix.dot(text_emb);
let text_hidden = text_hidden.mapv(|x| x.max(0.0)); let text_output_matrix = self
.parameters
.get("text_output")
.expect("parameter 'text_output' should be initialized");
let text_output = text_output_matrix.dot(&text_hidden);
let kg_hidden_matrix = self
.parameters
.get("kg_hidden")
.expect("parameter 'kg_hidden' should be initialized");
let kg_hidden = kg_hidden_matrix.dot(kg_emb);
let kg_hidden = kg_hidden.mapv(|x| x.max(0.0)); let kg_output_matrix = self
.parameters
.get("kg_output")
.expect("parameter 'kg_output' should be initialized");
let kg_output = kg_output_matrix.dot(&kg_hidden);
let attention_weights = self.compute_attention(&text_output, &kg_output)?;
let min_dim = text_output.len().min(kg_output.len());
let text_slice = text_output
.slice(scirs2_core::ndarray_ext::s![..min_dim])
.to_owned();
let kg_slice = kg_output
.slice(scirs2_core::ndarray_ext::s![..min_dim])
.to_owned();
let unified = &text_slice * attention_weights + &kg_slice * (1.0 - attention_weights);
let alignment_score = self.compute_alignment_score(&text_output, &kg_output);
Ok((unified, alignment_score))
}
fn compute_attention(&self, text_emb: &Array1<f32>, kg_emb: &Array1<f32>) -> Result<f32> {
let min_dim = text_emb.len().min(kg_emb.len());
let text_slice = text_emb.slice(scirs2_core::ndarray_ext::s![..min_dim]);
let kg_slice = kg_emb.slice(scirs2_core::ndarray_ext::s![..min_dim]);
let attention_score = text_slice.dot(&kg_slice);
let attention_weight = 1.0 / (1.0 + (-attention_score).exp());
Ok(attention_weight)
}
pub fn compute_alignment_score(&self, text_emb: &Array1<f32>, kg_emb: &Array1<f32>) -> f32 {
let min_dim = text_emb.len().min(kg_emb.len());
let text_slice = text_emb.slice(scirs2_core::ndarray_ext::s![..min_dim]);
let kg_slice = kg_emb.slice(scirs2_core::ndarray_ext::s![..min_dim]);
let dot_product = text_slice.dot(&kg_slice);
let text_norm = text_slice.dot(&text_slice).sqrt();
let kg_norm = kg_slice.dot(&kg_slice).sqrt();
if text_norm > 0.0 && kg_norm > 0.0 {
dot_product / (text_norm * kg_norm)
} else {
0.0
}
}
}