1use std::borrow::Cow;
20
21#[derive(Debug, Clone)]
30pub enum ModalityInput<'a> {
31 Text(Cow<'a, str>),
33 Image {
35 data: Cow<'a, [u8]>,
37 format: ImageFormat,
39 },
40 Hybrid {
42 text: Cow<'a, str>,
44 visual_positions: Vec<VisualPosition>,
46 },
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
51pub enum ImageFormat {
52 #[default]
54 Png,
55 Jpeg,
57 Webp,
59 Unknown,
61}
62
63#[derive(Debug, Clone, Copy)]
65pub struct VisualPosition {
66 pub token_idx: u32,
68 pub x: f32,
70 pub y: f32,
72 pub width: f32,
74 pub height: f32,
76 pub page: u32,
78}
79
80pub mod registry;
83pub use registry::*;
84
85pub mod encoder;
86pub use encoder::*;
87
88pub mod traits;
89pub use traits::*;
90
91pub mod late_interaction;
92pub use late_interaction::*;
93
94pub mod span;
95pub use span::*;
96
97pub mod coref;
98pub use coref::*;
99
100pub mod relation_extraction;
101pub use relation_extraction::{
102 extract_relation_triples, extract_relations, RelationExtractionConfig,
103};
104
105pub mod binary_embeddings;
106pub use binary_embeddings::*;
107#[cfg(test)]
111mod tests {
112 use super::coref::{resolve_coreferences, CoreferenceConfig};
113 use super::late_interaction::DotProductInteraction;
114 use super::*;
115 use crate::{Entity, EntityType};
116
117 #[test]
118 fn test_semantic_registry_builder() {
119 let registry = SemanticRegistry::builder()
120 .add_entity("person", "A human being")
121 .add_entity("organization", "A company or group")
122 .add_relation("WORKS_FOR", "Employment relationship")
123 .build_placeholder(768);
124
125 assert_eq!(registry.len(), 3);
126 assert_eq!(registry.entity_labels().count(), 2);
127 assert_eq!(registry.relation_labels().count(), 1);
128 }
129
130 #[test]
131 fn test_standard_ner_registry() {
132 let registry = SemanticRegistry::standard_ner(768);
133 assert!(registry.len() >= 5);
134 assert!(registry.label_index.contains_key("person"));
135 assert!(registry.label_index.contains_key("organization"));
136 }
137
138 #[test]
139 fn test_dot_product_interaction() {
140 let interaction = DotProductInteraction::new();
141
142 let span_embs = vec![
144 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ];
147 let label_embs = vec![
148 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0, ];
152
153 let scores = interaction.compute_similarity(&span_embs, 2, &label_embs, 3, 4);
154
155 assert_eq!(scores.len(), 6); assert!((scores[0] - 1.0).abs() < 0.01); assert!((scores[4] - 1.0).abs() < 0.01); }
159
160 #[test]
161 fn test_cosine_similarity() {
162 let a = vec![1.0, 0.0, 0.0];
163 let b = vec![1.0, 0.0, 0.0];
164 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
165
166 let c = vec![0.0, 1.0, 0.0];
167 assert!(cosine_similarity(&a, &c).abs() < 0.001);
168
169 let d = vec![-1.0, 0.0, 0.0];
170 assert!((cosine_similarity(&a, &d) - (-1.0)).abs() < 0.001);
171 }
172
173 #[test]
174 fn test_coreference_string_match() {
175 let entities = vec![
176 Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.95),
177 Entity::new("Curie", EntityType::Person, 50, 55, 0.90),
178 ];
179
180 let embeddings = vec![0.0f32; 2 * 768]; let clusters =
182 resolve_coreferences(&entities, &embeddings, 768, &CoreferenceConfig::default());
183
184 assert_eq!(clusters.len(), 1);
185 assert_eq!(clusters[0].members.len(), 2);
186 assert_eq!(clusters[0].canonical_name, "Marie Curie");
187 }
188
189 #[test]
190 fn test_handshaking_matrix() {
191 let scores = vec![
193 0.9, 0.1, 0.2, 0.8, 0.1, 0.1, 0.0, 0.0, 0.7, 0.2, 0.3, 0.6, 0.0, 0.0, 0.0, 0.0, 0.1, 0.1, ];
206
207 let matrix = HandshakingMatrix::from_dense(&scores, 3, 2, 0.5);
208
209 assert!(matrix.cells.len() >= 4);
211 }
212
213 #[test]
214 fn test_relation_extraction() {
215 let entities = vec![
216 Entity::new("Steve Jobs", EntityType::Person, 0, 10, 0.95),
217 Entity::new("Apple", EntityType::Organization, 20, 25, 0.90),
218 ];
219
220 let text = "Steve Jobs founded Apple Inc in 1976";
221
222 let registry = SemanticRegistry::builder()
223 .add_relation("FOUNDED", "Founded an organization")
224 .build_placeholder(768);
225
226 let config = RelationExtractionConfig::default();
227 let relations = extract_relations(&entities, text, ®istry, &config);
228
229 assert!(!relations.is_empty());
230 assert_eq!(relations[0].relation_type, "FOUNDED");
231 }
232
233 #[test]
234 fn test_relation_extraction_uses_character_offsets_with_unicode_prefix() {
235 let text = "👋 Steve Jobs founded Apple Inc.";
237
238 let steve_start = text.find("Steve Jobs").expect("substring present");
240 let conv = crate::offset::SpanConverter::new(text);
242 let steve_start_char = conv.byte_to_char(steve_start);
243 let steve_end_char = steve_start_char + "Steve Jobs".chars().count();
244
245 let apple_start = text.find("Apple").expect("substring present");
246 let apple_start_char = conv.byte_to_char(apple_start);
247 let apple_end_char = apple_start_char + "Apple".chars().count();
248
249 let entities = vec![
250 Entity::new(
251 "Steve Jobs",
252 EntityType::Person,
253 steve_start_char,
254 steve_end_char,
255 0.95,
256 ),
257 Entity::new(
258 "Apple",
259 EntityType::Organization,
260 apple_start_char,
261 apple_end_char,
262 0.90,
263 ),
264 ];
265
266 let registry = SemanticRegistry::builder()
267 .add_relation("FOUNDED", "Founded an organization")
268 .build_placeholder(768);
269
270 let config = RelationExtractionConfig::default();
271 let relations = extract_relations(&entities, text, ®istry, &config);
272
273 assert!(
274 !relations.is_empty(),
275 "Expected FOUNDED relation to be detected"
276 );
277 assert_eq!(relations[0].relation_type, "FOUNDED");
278
279 let trigger = relations[0]
281 .trigger_span
282 .expect("expected trigger_span to be present");
283 let trigger_text: String = text
284 .chars()
285 .skip(trigger.0)
286 .take(trigger.1.saturating_sub(trigger.0))
287 .collect();
288 assert_eq!(trigger_text.to_ascii_lowercase(), "founded");
289 }
290
291 #[test]
296 fn test_binary_hash_creation() {
297 let embedding = vec![0.1, -0.2, 0.3, -0.4, 0.5, -0.6, 0.7, -0.8];
298 let hash = BinaryHash::from_embedding(&embedding);
299
300 assert_eq!(hash.dim, 8);
301 assert_eq!(hash.bits[0], 85);
304 }
305
306 #[test]
307 fn test_hamming_distance_identical() {
308 let embedding = vec![0.1; 64];
309 let hash1 = BinaryHash::from_embedding(&embedding);
310 let hash2 = BinaryHash::from_embedding(&embedding);
311
312 assert_eq!(hash1.hamming_distance(&hash2), 0);
313 }
314
315 #[test]
316 fn test_hamming_distance_opposite() {
317 let embedding1 = vec![0.1; 64];
318 let embedding2 = vec![-0.1; 64];
319 let hash1 = BinaryHash::from_embedding(&embedding1);
320 let hash2 = BinaryHash::from_embedding(&embedding2);
321
322 assert_eq!(hash1.hamming_distance(&hash2), 64);
323 }
324
325 #[test]
326 fn test_hamming_distance_half() {
327 let embedding1 = vec![0.1; 64];
328 let mut embedding2 = vec![0.1; 64];
329 embedding2[32..64].iter_mut().for_each(|x| *x = -0.1);
331
332 let hash1 = BinaryHash::from_embedding(&embedding1);
333 let hash2 = BinaryHash::from_embedding(&embedding2);
334
335 assert_eq!(hash1.hamming_distance(&hash2), 32);
336 }
337
338 #[test]
339 fn test_binary_blocker() {
340 let mut blocker = BinaryBlocker::new(5);
341
342 let base_embedding = vec![0.1; 64];
344 let similar_embedding = {
345 let mut e = vec![0.1; 64];
346 e[0] = -0.1; e[1] = -0.1; e
349 };
350 let different_embedding = vec![-0.1; 64];
351
352 blocker.add(0, BinaryHash::from_embedding(&base_embedding));
353 blocker.add(1, BinaryHash::from_embedding(&similar_embedding));
354 blocker.add(2, BinaryHash::from_embedding(&different_embedding));
355
356 let query = BinaryHash::from_embedding(&base_embedding);
358 let candidates = blocker.query(&query);
359
360 assert!(candidates.contains(&0), "Should find exact match");
361 assert!(
362 candidates.contains(&1),
363 "Should find similar (2 bits different)"
364 );
365 assert!(
366 !candidates.contains(&2),
367 "Should NOT find opposite (64 bits different)"
368 );
369 }
370
371 #[test]
372 fn test_two_stage_retrieval() {
373 let query = vec![1.0, 0.0, 0.0, 0.0];
375 let candidates = vec![
376 vec![1.0, 0.0, 0.0, 0.0], vec![0.9, 0.1, 0.0, 0.0], vec![-1.0, 0.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0], ];
381
382 let results = two_stage_retrieval(&query, &candidates, 4, 2);
384
385 assert!(!results.is_empty());
386 assert_eq!(results[0].0, 0);
388 assert!((results[0].1 - 1.0).abs() < 0.001);
389 }
390
391 #[test]
392 fn test_approximate_cosine() {
393 let embedding1 = vec![0.1; 768];
394 let embedding2 = vec![0.1; 768];
395 let hash1 = BinaryHash::from_embedding(&embedding1);
396 let hash2 = BinaryHash::from_embedding(&embedding2);
397
398 let approx = hash1.approximate_cosine(&hash2);
400 assert!((approx - 1.0).abs() < 0.001);
401
402 let embedding3 = vec![-0.1; 768];
404 let hash3 = BinaryHash::from_embedding(&embedding3);
405 let approx_opp = hash1.approximate_cosine(&hash3);
406 assert!((approx_opp - (-1.0)).abs() < 0.001);
407 }
408}