oxirs_embed/multimodal/impl/
encoders.rs1use anyhow::Result;
4use scirs2_core::ndarray_ext::{Array1, Array2};
5use scirs2_core::random::{Random, Rng};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct TextEncoder {
12 pub encoder_type: String,
14 pub input_dim: usize,
16 pub output_dim: usize,
18 pub parameters: HashMap<String, Array2<f32>>,
20}
21
22impl TextEncoder {
23 pub fn new(encoder_type: String, input_dim: usize, output_dim: usize) -> Self {
24 let mut parameters = HashMap::new();
25
26 let mut random = Random::default();
28 parameters.insert(
29 "projection".to_string(),
30 Array2::from_shape_fn((output_dim, input_dim), |(_, _)| {
31 (random.random::<f32>() - 0.5) * 0.1
32 }),
33 );
34
35 let mut random = Random::default();
36 parameters.insert(
37 "attention".to_string(),
38 Array2::from_shape_fn((output_dim, output_dim), |(_, _)| {
39 (random.random::<f32>() - 0.5) * 0.1
40 }),
41 );
42
43 Self {
44 encoder_type,
45 input_dim,
46 output_dim,
47 parameters,
48 }
49 }
50
51 pub fn encode(&self, text: &str) -> Result<Array1<f32>> {
53 let input_features = self.extract_text_features(text);
54 let projection = self
55 .parameters
56 .get("projection")
57 .expect("parameter 'projection' should be initialized");
58
59 let encoded = projection.dot(&input_features);
61
62 let mean = encoded.mean().unwrap_or(0.0);
64 let var = encoded.var(0.0);
65 let normalized = encoded.mapv(|x| (x - mean) / (var + 1e-8).sqrt());
66
67 Ok(normalized)
68 }
69
70 fn extract_text_features(&self, text: &str) -> Array1<f32> {
72 let mut features = vec![0.0; self.input_dim];
73
74 let words: Vec<&str> = text.split_whitespace().collect();
76 for (i, word) in words.iter().enumerate() {
77 if i < self.input_dim {
78 features[i] = word.len() as f32 / 10.0; }
80 }
81
82 if self.input_dim > words.len() {
84 features[words.len()] = text.len() as f32 / 100.0; if self.input_dim > words.len() + 1 {
86 features[words.len() + 1] = words.len() as f32 / 20.0; }
88 }
89
90 Array1::from_vec(features)
91 }
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct KGEncoder {
97 pub architecture: String,
99 pub entity_dim: usize,
101 pub relation_dim: usize,
103 pub output_dim: usize,
105 pub parameters: HashMap<String, Array2<f32>>,
107}
108
109impl KGEncoder {
110 pub fn new(
111 architecture: String,
112 entity_dim: usize,
113 relation_dim: usize,
114 output_dim: usize,
115 ) -> Self {
116 let mut parameters = HashMap::new();
117
118 let mut random = Random::default();
120 parameters.insert(
121 "entity_projection".to_string(),
122 Array2::from_shape_fn((output_dim, entity_dim), |(_, _)| {
123 (random.random::<f32>() - 0.5) * 0.1
124 }),
125 );
126
127 let mut random = Random::default();
128 parameters.insert(
129 "relation_projection".to_string(),
130 Array2::from_shape_fn((output_dim, relation_dim), |(_, _)| {
131 (random.random::<f32>() - 0.5) * 0.1
132 }),
133 );
134
135 Self {
136 architecture,
137 entity_dim,
138 relation_dim,
139 output_dim,
140 parameters,
141 }
142 }
143
144 pub fn encode_entity(&self, entity_embedding: &Array1<f32>) -> Result<Array1<f32>> {
146 let projection = self
147 .parameters
148 .get("entity_projection")
149 .expect("parameter 'entity_projection' should be initialized");
150
151 if projection.ncols() != entity_embedding.len() {
153 let target_dim = projection.ncols();
155 let mut adjusted_embedding = Array1::zeros(target_dim);
156
157 let copy_len = entity_embedding.len().min(target_dim);
158 adjusted_embedding
159 .slice_mut(scirs2_core::ndarray_ext::s![..copy_len])
160 .assign(&entity_embedding.slice(scirs2_core::ndarray_ext::s![..copy_len]));
161
162 Ok(projection.dot(&adjusted_embedding))
163 } else {
164 Ok(projection.dot(entity_embedding))
165 }
166 }
167
168 pub fn encode_relation(&self, relation_embedding: &Array1<f32>) -> Result<Array1<f32>> {
170 let projection = self
171 .parameters
172 .get("relation_projection")
173 .expect("parameter 'relation_projection' should be initialized");
174
175 if projection.ncols() != relation_embedding.len() {
177 let target_dim = projection.ncols();
179 let mut adjusted_embedding = Array1::zeros(target_dim);
180
181 let copy_len = relation_embedding.len().min(target_dim);
182 adjusted_embedding
183 .slice_mut(scirs2_core::ndarray_ext::s![..copy_len])
184 .assign(&relation_embedding.slice(scirs2_core::ndarray_ext::s![..copy_len]));
185
186 Ok(projection.dot(&adjusted_embedding))
187 } else {
188 Ok(projection.dot(relation_embedding))
189 }
190 }
191
192 pub fn encode_structured(
194 &self,
195 entity: &Array1<f32>,
196 relations: &[Array1<f32>],
197 ) -> Result<Array1<f32>> {
198 let entity_encoded = self.encode_entity(entity)?;
199
200 let mut relation_agg = Array1::<f32>::zeros(self.output_dim);
202 for relation in relations {
203 let rel_encoded = self.encode_relation(relation)?;
204 relation_agg = &relation_agg + &rel_encoded;
205 }
206
207 if !relations.is_empty() {
208 relation_agg /= relations.len() as f32;
209 }
210
211 Ok(&entity_encoded + &relation_agg)
213 }
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct AlignmentNetwork {
219 pub architecture: String,
221 pub input_dims: (usize, usize),
223 pub hidden_dim: usize,
225 pub output_dim: usize,
227 pub parameters: HashMap<String, Array2<f32>>,
229}
230
231impl AlignmentNetwork {
232 pub fn new(
233 architecture: String,
234 text_dim: usize,
235 kg_dim: usize,
236 hidden_dim: usize,
237 output_dim: usize,
238 ) -> Self {
239 let mut parameters = HashMap::new();
240
241 let mut random = Random::default();
243 parameters.insert(
244 "text_hidden".to_string(),
245 Array2::from_shape_fn((hidden_dim, text_dim), |(_, _)| {
246 (random.random::<f32>() - 0.5) * 0.1
247 }),
248 );
249
250 let mut random = Random::default();
251 parameters.insert(
252 "text_output".to_string(),
253 Array2::from_shape_fn((output_dim, hidden_dim), |(_, _)| {
254 (random.random::<f32>() - 0.5) * 0.1
255 }),
256 );
257
258 let mut random = Random::default();
260 parameters.insert(
261 "kg_hidden".to_string(),
262 Array2::from_shape_fn((hidden_dim, kg_dim), |(_, _)| {
263 (random.random::<f32>() - 0.5) * 0.1
264 }),
265 );
266
267 let mut random = Random::default();
268 parameters.insert(
269 "kg_output".to_string(),
270 Array2::from_shape_fn((output_dim, hidden_dim), |(_, _)| {
271 (random.random::<f32>() - 0.5) * 0.1
272 }),
273 );
274
275 let mut random = Random::default();
277 parameters.insert(
278 "cross_attention".to_string(),
279 Array2::from_shape_fn((output_dim, output_dim), |(_, _)| {
280 (random.random::<f32>() - 0.5) * 0.1
281 }),
282 );
283
284 Self {
285 architecture,
286 input_dims: (text_dim, kg_dim),
287 hidden_dim,
288 output_dim,
289 parameters,
290 }
291 }
292
293 pub fn align(
295 &self,
296 text_emb: &Array1<f32>,
297 kg_emb: &Array1<f32>,
298 ) -> Result<(Array1<f32>, f32)> {
299 let text_hidden_matrix = self
301 .parameters
302 .get("text_hidden")
303 .expect("parameter 'text_hidden' should be initialized");
304 let text_hidden = text_hidden_matrix.dot(text_emb);
305 let text_hidden = text_hidden.mapv(|x| x.max(0.0)); let text_output_matrix = self
307 .parameters
308 .get("text_output")
309 .expect("parameter 'text_output' should be initialized");
310 let text_output = text_output_matrix.dot(&text_hidden);
311
312 let kg_hidden_matrix = self
314 .parameters
315 .get("kg_hidden")
316 .expect("parameter 'kg_hidden' should be initialized");
317 let kg_hidden = kg_hidden_matrix.dot(kg_emb);
318 let kg_hidden = kg_hidden.mapv(|x| x.max(0.0)); let kg_output_matrix = self
320 .parameters
321 .get("kg_output")
322 .expect("parameter 'kg_output' should be initialized");
323 let kg_output = kg_output_matrix.dot(&kg_hidden);
324
325 let attention_weights = self.compute_attention(&text_output, &kg_output)?;
327
328 let min_dim = text_output.len().min(kg_output.len());
330 let text_slice = text_output
331 .slice(scirs2_core::ndarray_ext::s![..min_dim])
332 .to_owned();
333 let kg_slice = kg_output
334 .slice(scirs2_core::ndarray_ext::s![..min_dim])
335 .to_owned();
336 let unified = &text_slice * attention_weights + &kg_slice * (1.0 - attention_weights);
337
338 let alignment_score = self.compute_alignment_score(&text_output, &kg_output);
340
341 Ok((unified, alignment_score))
342 }
343
344 fn compute_attention(&self, text_emb: &Array1<f32>, kg_emb: &Array1<f32>) -> Result<f32> {
346 let min_dim = text_emb.len().min(kg_emb.len());
348 let text_slice = text_emb.slice(scirs2_core::ndarray_ext::s![..min_dim]);
349 let kg_slice = kg_emb.slice(scirs2_core::ndarray_ext::s![..min_dim]);
350
351 let attention_score = text_slice.dot(&kg_slice);
353 let attention_weight = 1.0 / (1.0 + (-attention_score).exp()); Ok(attention_weight)
356 }
357
358 pub fn compute_alignment_score(&self, text_emb: &Array1<f32>, kg_emb: &Array1<f32>) -> f32 {
360 let min_dim = text_emb.len().min(kg_emb.len());
362 let text_slice = text_emb.slice(scirs2_core::ndarray_ext::s![..min_dim]);
363 let kg_slice = kg_emb.slice(scirs2_core::ndarray_ext::s![..min_dim]);
364
365 let dot_product = text_slice.dot(&kg_slice);
367 let text_norm = text_slice.dot(&text_slice).sqrt();
368 let kg_norm = kg_slice.dot(&kg_slice).sqrt();
369
370 if text_norm > 0.0 && kg_norm > 0.0 {
371 dot_product / (text_norm * kg_norm)
372 } else {
373 0.0
374 }
375 }
376}