oxirs_embed/multimodal/impl/
encoders.rs1use anyhow::Result;
4use scirs2_core::ndarray_ext::{Array1, Array2};
5use scirs2_core::random::{Random, Rng};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct TextEncoder {
12 pub encoder_type: String,
14 pub input_dim: usize,
16 pub output_dim: usize,
18 pub parameters: HashMap<String, Array2<f32>>,
20}
21
22impl TextEncoder {
23 pub fn new(encoder_type: String, input_dim: usize, output_dim: usize) -> Self {
24 let mut parameters = HashMap::new();
25
26 let mut random = Random::default();
28 parameters.insert(
29 "projection".to_string(),
30 Array2::from_shape_fn((output_dim, input_dim), |(_, _)| {
31 (random.random::<f32>() - 0.5) * 0.1
32 }),
33 );
34
35 let mut random = Random::default();
36 parameters.insert(
37 "attention".to_string(),
38 Array2::from_shape_fn((output_dim, output_dim), |(_, _)| {
39 (random.random::<f32>() - 0.5) * 0.1
40 }),
41 );
42
43 Self {
44 encoder_type,
45 input_dim,
46 output_dim,
47 parameters,
48 }
49 }
50
51 pub fn encode(&self, text: &str) -> Result<Array1<f32>> {
53 let input_features = self.extract_text_features(text);
54 let projection = self.parameters.get("projection").unwrap();
55
56 let encoded = projection.dot(&input_features);
58
59 let mean = encoded.mean().unwrap_or(0.0);
61 let var = encoded.var(0.0);
62 let normalized = encoded.mapv(|x| (x - mean) / (var + 1e-8).sqrt());
63
64 Ok(normalized)
65 }
66
67 fn extract_text_features(&self, text: &str) -> Array1<f32> {
69 let mut features = vec![0.0; self.input_dim];
70
71 let words: Vec<&str> = text.split_whitespace().collect();
73 for (i, word) in words.iter().enumerate() {
74 if i < self.input_dim {
75 features[i] = word.len() as f32 / 10.0; }
77 }
78
79 if self.input_dim > words.len() {
81 features[words.len()] = text.len() as f32 / 100.0; if self.input_dim > words.len() + 1 {
83 features[words.len() + 1] = words.len() as f32 / 20.0; }
85 }
86
87 Array1::from_vec(features)
88 }
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct KGEncoder {
94 pub architecture: String,
96 pub entity_dim: usize,
98 pub relation_dim: usize,
100 pub output_dim: usize,
102 pub parameters: HashMap<String, Array2<f32>>,
104}
105
106impl KGEncoder {
107 pub fn new(
108 architecture: String,
109 entity_dim: usize,
110 relation_dim: usize,
111 output_dim: usize,
112 ) -> Self {
113 let mut parameters = HashMap::new();
114
115 let mut random = Random::default();
117 parameters.insert(
118 "entity_projection".to_string(),
119 Array2::from_shape_fn((output_dim, entity_dim), |(_, _)| {
120 (random.random::<f32>() - 0.5) * 0.1
121 }),
122 );
123
124 let mut random = Random::default();
125 parameters.insert(
126 "relation_projection".to_string(),
127 Array2::from_shape_fn((output_dim, relation_dim), |(_, _)| {
128 (random.random::<f32>() - 0.5) * 0.1
129 }),
130 );
131
132 Self {
133 architecture,
134 entity_dim,
135 relation_dim,
136 output_dim,
137 parameters,
138 }
139 }
140
141 pub fn encode_entity(&self, entity_embedding: &Array1<f32>) -> Result<Array1<f32>> {
143 let projection = self.parameters.get("entity_projection").unwrap();
144
145 if projection.ncols() != entity_embedding.len() {
147 let target_dim = projection.ncols();
149 let mut adjusted_embedding = Array1::zeros(target_dim);
150
151 let copy_len = entity_embedding.len().min(target_dim);
152 adjusted_embedding
153 .slice_mut(scirs2_core::ndarray_ext::s![..copy_len])
154 .assign(&entity_embedding.slice(scirs2_core::ndarray_ext::s![..copy_len]));
155
156 Ok(projection.dot(&adjusted_embedding))
157 } else {
158 Ok(projection.dot(entity_embedding))
159 }
160 }
161
162 pub fn encode_relation(&self, relation_embedding: &Array1<f32>) -> Result<Array1<f32>> {
164 let projection = self.parameters.get("relation_projection").unwrap();
165
166 if projection.ncols() != relation_embedding.len() {
168 let target_dim = projection.ncols();
170 let mut adjusted_embedding = Array1::zeros(target_dim);
171
172 let copy_len = relation_embedding.len().min(target_dim);
173 adjusted_embedding
174 .slice_mut(scirs2_core::ndarray_ext::s![..copy_len])
175 .assign(&relation_embedding.slice(scirs2_core::ndarray_ext::s![..copy_len]));
176
177 Ok(projection.dot(&adjusted_embedding))
178 } else {
179 Ok(projection.dot(relation_embedding))
180 }
181 }
182
183 pub fn encode_structured(
185 &self,
186 entity: &Array1<f32>,
187 relations: &[Array1<f32>],
188 ) -> Result<Array1<f32>> {
189 let entity_encoded = self.encode_entity(entity)?;
190
191 let mut relation_agg = Array1::<f32>::zeros(self.output_dim);
193 for relation in relations {
194 let rel_encoded = self.encode_relation(relation)?;
195 relation_agg = &relation_agg + &rel_encoded;
196 }
197
198 if !relations.is_empty() {
199 relation_agg /= relations.len() as f32;
200 }
201
202 Ok(&entity_encoded + &relation_agg)
204 }
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct AlignmentNetwork {
210 pub architecture: String,
212 pub input_dims: (usize, usize),
214 pub hidden_dim: usize,
216 pub output_dim: usize,
218 pub parameters: HashMap<String, Array2<f32>>,
220}
221
222impl AlignmentNetwork {
223 pub fn new(
224 architecture: String,
225 text_dim: usize,
226 kg_dim: usize,
227 hidden_dim: usize,
228 output_dim: usize,
229 ) -> Self {
230 let mut parameters = HashMap::new();
231
232 let mut random = Random::default();
234 parameters.insert(
235 "text_hidden".to_string(),
236 Array2::from_shape_fn((hidden_dim, text_dim), |(_, _)| {
237 (random.random::<f32>() - 0.5) * 0.1
238 }),
239 );
240
241 let mut random = Random::default();
242 parameters.insert(
243 "text_output".to_string(),
244 Array2::from_shape_fn((output_dim, hidden_dim), |(_, _)| {
245 (random.random::<f32>() - 0.5) * 0.1
246 }),
247 );
248
249 let mut random = Random::default();
251 parameters.insert(
252 "kg_hidden".to_string(),
253 Array2::from_shape_fn((hidden_dim, kg_dim), |(_, _)| {
254 (random.random::<f32>() - 0.5) * 0.1
255 }),
256 );
257
258 let mut random = Random::default();
259 parameters.insert(
260 "kg_output".to_string(),
261 Array2::from_shape_fn((output_dim, hidden_dim), |(_, _)| {
262 (random.random::<f32>() - 0.5) * 0.1
263 }),
264 );
265
266 let mut random = Random::default();
268 parameters.insert(
269 "cross_attention".to_string(),
270 Array2::from_shape_fn((output_dim, output_dim), |(_, _)| {
271 (random.random::<f32>() - 0.5) * 0.1
272 }),
273 );
274
275 Self {
276 architecture,
277 input_dims: (text_dim, kg_dim),
278 hidden_dim,
279 output_dim,
280 parameters,
281 }
282 }
283
284 pub fn align(
286 &self,
287 text_emb: &Array1<f32>,
288 kg_emb: &Array1<f32>,
289 ) -> Result<(Array1<f32>, f32)> {
290 let text_hidden_matrix = self.parameters.get("text_hidden").unwrap();
292 let text_hidden = text_hidden_matrix.dot(text_emb);
293 let text_hidden = text_hidden.mapv(|x| x.max(0.0)); let text_output_matrix = self.parameters.get("text_output").unwrap();
295 let text_output = text_output_matrix.dot(&text_hidden);
296
297 let kg_hidden_matrix = self.parameters.get("kg_hidden").unwrap();
299 let kg_hidden = kg_hidden_matrix.dot(kg_emb);
300 let kg_hidden = kg_hidden.mapv(|x| x.max(0.0)); let kg_output_matrix = self.parameters.get("kg_output").unwrap();
302 let kg_output = kg_output_matrix.dot(&kg_hidden);
303
304 let attention_weights = self.compute_attention(&text_output, &kg_output)?;
306
307 let min_dim = text_output.len().min(kg_output.len());
309 let text_slice = text_output
310 .slice(scirs2_core::ndarray_ext::s![..min_dim])
311 .to_owned();
312 let kg_slice = kg_output
313 .slice(scirs2_core::ndarray_ext::s![..min_dim])
314 .to_owned();
315 let unified = &text_slice * attention_weights + &kg_slice * (1.0 - attention_weights);
316
317 let alignment_score = self.compute_alignment_score(&text_output, &kg_output);
319
320 Ok((unified, alignment_score))
321 }
322
323 fn compute_attention(&self, text_emb: &Array1<f32>, kg_emb: &Array1<f32>) -> Result<f32> {
325 let min_dim = text_emb.len().min(kg_emb.len());
327 let text_slice = text_emb.slice(scirs2_core::ndarray_ext::s![..min_dim]);
328 let kg_slice = kg_emb.slice(scirs2_core::ndarray_ext::s![..min_dim]);
329
330 let attention_score = text_slice.dot(&kg_slice);
332 let attention_weight = 1.0 / (1.0 + (-attention_score).exp()); Ok(attention_weight)
335 }
336
337 pub fn compute_alignment_score(&self, text_emb: &Array1<f32>, kg_emb: &Array1<f32>) -> f32 {
339 let min_dim = text_emb.len().min(kg_emb.len());
341 let text_slice = text_emb.slice(scirs2_core::ndarray_ext::s![..min_dim]);
342 let kg_slice = kg_emb.slice(scirs2_core::ndarray_ext::s![..min_dim]);
343
344 let dot_product = text_slice.dot(&kg_slice);
346 let text_norm = text_slice.dot(&text_slice).sqrt();
347 let kg_norm = kg_slice.dot(&kg_slice).sqrt();
348
349 if text_norm > 0.0 && kg_norm > 0.0 {
350 dot_product / (text_norm * kg_norm)
351 } else {
352 0.0
353 }
354 }
355}