1use std::collections::HashMap;
40
41pub type EntityAnnotation<'a> = (&'a str, &'a str, usize, usize);
43
44pub type DemoBatch<'a> = Vec<(&'a str, Vec<EntityAnnotation<'a>>)>;
46
47#[derive(Debug, Clone)]
49pub struct HelpfulnessConfig {
50 pub similarity_weight: f64,
52 pub type_overlap_weight: f64,
54 pub density_weight: f64,
56 pub min_score: f64,
58}
59
60impl Default for HelpfulnessConfig {
61 fn default() -> Self {
62 Self {
63 similarity_weight: 0.4,
64 type_overlap_weight: 0.4,
65 density_weight: 0.2,
66 min_score: 0.1,
67 }
68 }
69}
70
71#[derive(Debug, Clone)]
73pub struct DemonstrationExample {
74 pub text: String,
76 pub entities: Vec<(String, String, usize, usize)>,
78 features: ExampleFeatures,
80}
81
82#[derive(Debug, Clone, Default)]
84struct ExampleFeatures {
85 tokens: Vec<String>,
87 entity_types: Vec<String>,
89 entity_density: f64,
91}
92
93impl DemonstrationExample {
94 pub fn new(text: &str, entities: Vec<(&str, &str, usize, usize)>) -> Self {
96 let entities: Vec<_> = entities
97 .into_iter()
98 .map(|(t, ty, s, e)| (t.to_string(), ty.to_string(), s, e))
99 .collect();
100
101 let features = Self::compute_features(text, &entities);
102
103 Self {
104 text: text.to_string(),
105 entities,
106 features,
107 }
108 }
109
110 fn compute_features(
111 text: &str,
112 entities: &[(String, String, usize, usize)],
113 ) -> ExampleFeatures {
114 let tokens: Vec<String> = text.split_whitespace().map(|w| w.to_lowercase()).collect();
115
116 let entity_types: Vec<String> = entities.iter().map(|(_, ty, _, _)| ty.clone()).collect();
117
118 let entity_density = if tokens.is_empty() {
119 0.0
120 } else {
121 (entities.len() as f64 / tokens.len() as f64) * 100.0
122 };
123
124 ExampleFeatures {
125 tokens,
126 entity_types,
127 entity_density,
128 }
129 }
130}
131
132#[derive(Debug, Clone, Default)]
136pub struct DemonstrationBank {
137 examples: Vec<DemonstrationExample>,
138 config: HelpfulnessConfig,
139}
140
141impl DemonstrationBank {
142 #[must_use]
144 pub fn new() -> Self {
145 Self::default()
146 }
147
148 #[must_use]
150 pub fn with_config(config: HelpfulnessConfig) -> Self {
151 Self {
152 examples: vec![],
153 config,
154 }
155 }
156
157 pub fn add(&mut self, text: &str, entities: Vec<(&str, &str, usize, usize)>) {
159 self.examples
160 .push(DemonstrationExample::new(text, entities));
161 }
162
163 pub fn add_all(&mut self, demos: DemoBatch<'_>) {
165 for (text, entities) in demos {
166 self.add(text, entities);
167 }
168 }
169
170 #[must_use]
172 pub fn len(&self) -> usize {
173 self.examples.len()
174 }
175
176 #[must_use]
178 pub fn is_empty(&self) -> bool {
179 self.examples.is_empty()
180 }
181
182 #[must_use]
193 pub fn select(&self, query: &str, k: usize) -> Vec<&DemonstrationExample> {
194 if self.examples.is_empty() || k == 0 {
195 return vec![];
196 }
197
198 let query_features = DemonstrationExample::compute_features(query, &[]);
199
200 let mut scored: Vec<_> = Vec::with_capacity(self.examples.len().min(k * 2));
203 scored.extend(
204 self.examples
205 .iter()
206 .map(|ex| {
207 let score = self.helpfulness_score(&query_features, ex);
208 (ex, score)
209 })
210 .filter(|(_, score)| *score >= self.config.min_score),
211 );
212
213 scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
216
217 scored.into_iter().take(k).map(|(ex, _)| ex).collect()
219 }
220
221 #[must_use]
223 pub fn select_with_scores(&self, query: &str, k: usize) -> Vec<(&DemonstrationExample, f64)> {
224 if self.examples.is_empty() || k == 0 {
225 return vec![];
226 }
227
228 let query_features = DemonstrationExample::compute_features(query, &[]);
229
230 let mut scored: Vec<_> = Vec::with_capacity(self.examples.len().min(k * 2));
232 scored.extend(
233 self.examples
234 .iter()
235 .map(|ex| {
236 let score = self.helpfulness_score(&query_features, ex);
237 (ex, score)
238 })
239 .filter(|(_, score)| *score >= self.config.min_score),
240 );
241
242 scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
244
245 scored.into_iter().take(k).collect()
246 }
247
248 fn helpfulness_score(&self, query: &ExampleFeatures, demo: &DemonstrationExample) -> f64 {
255 let sim = self.token_similarity(&query.tokens, &demo.features.tokens);
256 let type_overlap = self.type_overlap(&query.entity_types, &demo.features.entity_types);
257 let density_sim =
258 self.density_similarity(query.entity_density, demo.features.entity_density);
259
260 self.config.similarity_weight * sim
261 + self.config.type_overlap_weight * type_overlap
262 + self.config.density_weight * density_sim
263 }
264
265 fn token_similarity(&self, a: &[String], b: &[String]) -> f64 {
267 if a.is_empty() && b.is_empty() {
268 return 1.0;
269 }
270 if a.is_empty() || b.is_empty() {
271 return 0.0;
272 }
273
274 let set_a: std::collections::HashSet<_> = a.iter().collect();
275 let set_b: std::collections::HashSet<_> = b.iter().collect();
276
277 let intersection = set_a.intersection(&set_b).count();
278 let union = set_a.union(&set_b).count();
279
280 if union == 0 {
281 0.0
282 } else {
283 intersection as f64 / union as f64
284 }
285 }
286
287 fn type_overlap(&self, query_types: &[String], demo_types: &[String]) -> f64 {
289 if query_types.is_empty() {
291 return 1.0;
292 }
293 if demo_types.is_empty() {
294 return 0.0;
295 }
296
297 let query_set: std::collections::HashSet<_> = query_types.iter().collect();
298 let demo_set: std::collections::HashSet<_> = demo_types.iter().collect();
299
300 let overlap = query_set.intersection(&demo_set).count();
301 overlap as f64 / query_set.len() as f64
302 }
303
304 fn density_similarity(&self, query_density: f64, demo_density: f64) -> f64 {
306 let diff = (query_density - demo_density).abs();
308 (-diff / 5.0).exp() }
310}
311
312#[derive(Debug, Clone, Default)]
316pub struct TRFExtractor {
317 window_size: usize,
318}
319
320impl TRFExtractor {
321 #[must_use]
323 pub fn new() -> Self {
324 Self { window_size: 3 }
325 }
326
327 #[must_use]
329 pub fn with_window(size: usize) -> Self {
330 Self { window_size: size }
331 }
332
333 #[must_use]
337 pub fn extract(
338 &self,
339 text: &str,
340 entities: &[(String, String, usize, usize)],
341 ) -> HashMap<String, Vec<String>> {
342 let mut features: HashMap<String, Vec<String>> = HashMap::new();
343 let tokens: Vec<&str> = text.split_whitespace().collect();
344
345 for (entity_text, entity_type, start, _end) in entities {
346 let mut char_pos = 0;
348 let mut token_idx = None;
349
350 for (i, token) in tokens.iter().enumerate() {
351 if char_pos == *start || (char_pos <= *start && char_pos + token.len() > *start) {
352 token_idx = Some(i);
353 break;
354 }
355 char_pos += token.len() + 1; }
357
358 if let Some(idx) = token_idx {
359 let start_idx = idx.saturating_sub(self.window_size);
361 let end_idx = (idx + self.window_size + 1).min(tokens.len());
362
363 let context: Vec<String> = tokens[start_idx..end_idx]
364 .iter()
365 .enumerate()
366 .filter(|(i, _)| *i + start_idx != idx) .map(|(_, &t)| t.to_lowercase())
368 .collect();
369
370 features
371 .entry(entity_type.clone())
372 .or_default()
373 .extend(context);
374 }
375
376 features
378 .entry(format!("{}_text", entity_type))
379 .or_default()
380 .push(entity_text.to_lowercase());
381 }
382
383 features
384 }
385}
386
387#[cfg(test)]
392mod tests {
393 use super::*;
394
395 #[test]
396 fn test_demonstration_example_creation() {
397 let demo = DemonstrationExample::new(
398 "Steve Jobs founded Apple.",
399 vec![("Steve Jobs", "PER", 0, 10), ("Apple", "ORG", 19, 24)],
400 );
401
402 assert_eq!(demo.entities.len(), 2);
403 assert!(demo.features.entity_types.contains(&"PER".to_string()));
404 assert!(demo.features.entity_types.contains(&"ORG".to_string()));
405 }
406
407 #[test]
408 fn test_bank_add_and_len() {
409 let mut bank = DemonstrationBank::new();
410 assert!(bank.is_empty());
411
412 bank.add("Test text.", vec![("Test", "MISC", 0, 4)]);
413 assert_eq!(bank.len(), 1);
414 }
415
416 #[test]
417 fn test_select_demonstrations() {
418 let mut bank = DemonstrationBank::new();
419
420 bank.add(
421 "Steve Jobs founded Apple in California.",
422 vec![
423 ("Steve Jobs", "PER", 0, 10),
424 ("Apple", "ORG", 19, 24),
425 ("California", "LOC", 28, 38),
426 ],
427 );
428
429 bank.add(
430 "The weather in New York is nice today.",
431 vec![("New York", "LOC", 15, 23)],
432 );
433
434 bank.add(
435 "Bill Gates started Microsoft in Seattle.",
436 vec![
437 ("Bill Gates", "PER", 0, 10),
438 ("Microsoft", "ORG", 19, 28),
439 ("Seattle", "LOC", 32, 39),
440 ],
441 );
442
443 let demos = bank.select("Steve Jobs founded Apple in Silicon Valley.", 3);
445
446 assert_eq!(demos.len(), 3);
448
449 let demo_texts: Vec<_> = demos.iter().map(|d| d.text.as_str()).collect();
451 assert!(demo_texts.contains(&"Steve Jobs founded Apple in California."));
452 assert!(demo_texts.contains(&"Bill Gates started Microsoft in Seattle."));
453 assert!(demo_texts.contains(&"The weather in New York is nice today."));
454 }
455
456 #[test]
457 fn test_select_with_scores() {
458 let mut bank = DemonstrationBank::new();
459
460 bank.add("Apple is in Cupertino.", vec![("Apple", "ORG", 0, 5)]);
461 bank.add("Google is in Mountain View.", vec![("Google", "ORG", 0, 6)]);
462
463 let demos = bank.select_with_scores("Microsoft is in Redmond.", 2);
464
465 assert_eq!(demos.len(), 2);
466 for (_, score) in &demos {
468 assert!(*score > 0.0);
469 }
470 }
471
472 #[test]
473 fn test_select_empty_bank() {
474 let bank = DemonstrationBank::new();
475 let demos = bank.select("Test query.", 5);
476 assert!(demos.is_empty());
477 }
478
479 #[test]
480 fn test_trf_extractor() {
481 let extractor = TRFExtractor::new();
482
483 let features = extractor.extract(
484 "The CEO Steve Jobs announced the new iPhone.",
485 &[("Steve Jobs".to_string(), "PER".to_string(), 8, 18)],
486 );
487
488 assert!(features.contains_key("PER"));
489 let per_context = features.get("PER").unwrap();
490 assert!(per_context.iter().any(|w| w == "ceo" || w == "announced"));
492 }
493
494 #[test]
495 fn test_helpfulness_config() {
496 let config = HelpfulnessConfig {
497 similarity_weight: 0.5,
498 type_overlap_weight: 0.3,
499 density_weight: 0.2,
500 min_score: 0.2,
501 };
502
503 let bank = DemonstrationBank::with_config(config);
504 assert!(!bank.config.min_score.is_nan());
505 }
506}