1use itertools::Itertools;
2use keepcalm::SharedMut;
3
4use tantivy::collector::TopDocs;
5use tantivy::query::{
6 AllQuery, BooleanQuery, BoostQuery, Occur, PhraseQuery, Query, QueryParser, TermQuery,
7};
8use tantivy::tokenizer::TokenizerManager;
9use tantivy::{schema::*, DocAddress, IndexWriter, Searcher, SegmentReader};
10
11use progscrape_scrapers::{ScrapeCollection, StoryDate, StoryUrl, TypedScrape};
12
13use std::borrow::Cow;
14use std::collections::{HashMap, HashSet};
15use std::panic::catch_unwind;
16use std::time::Duration;
17
18use crate::persist::index::indexshard::{StoryIndexShard, StoryLookup, StoryLookupId};
19use crate::persist::scrapestore::ScrapeStore;
20use crate::persist::shard::{ShardOrder, ShardRange};
21use crate::persist::{ScrapePersistResult, Shard, ShardSummary, StorageFetch, StoryQuery};
22use crate::story::{StoryCollector, TagSet};
23use crate::{
24 timer_end, timer_start, MemIndex, PersistError, PersistLocation, Storage, StorageSummary,
25 StorageWriter, Story, StoryEvaluator, StoryIdentifier,
26};
27
28use super::indexshard::StoryInsert;
29use super::schema::StorySchema;
30
31const STORY_INDEXING_CHUNK_SIZE: usize = 10000;
32const SCRAPE_PROCESSING_CHUNK_SIZE: usize = 1000;
33
34struct IndexCache {
35 cache: HashMap<Shard, SharedMut<StoryIndexShard>>,
36 location: PersistLocation,
37 range: ShardRange,
38 schema: StorySchema,
39 most_recent_story: Option<StoryDate>,
40}
41
42impl IndexCache {
43 fn get_shard(&mut self, shard: Shard) -> Result<SharedMut<StoryIndexShard>, PersistError> {
44 if let Some(shard) = self.cache.get(&shard) {
45 Ok(shard.clone())
46 } else {
47 tracing::info!("Creating shard {}", shard.to_string());
48 let new_shard =
49 StoryIndexShard::initialize(self.location.clone(), shard, self.schema.clone())?;
50 self.range.include(shard);
51 Ok(self
52 .cache
53 .entry(shard)
54 .or_insert(SharedMut::new(new_shard))
55 .clone())
56 }
57 }
58}
59
60pub struct StoryIndex {
61 index_cache: SharedMut<IndexCache>,
62 scrape_db: ScrapeStore,
63 schema: StorySchema,
64}
65
66struct WriterProvider {
67 writers: HashMap<Shard, IndexWriter>,
68 index: SharedMut<IndexCache>,
69}
70
71impl WriterProvider {
72 fn provide<T>(
73 &mut self,
74 shard: Shard,
75 f: impl FnOnce(Shard, &StoryIndexShard, &'_ mut IndexWriter) -> Result<T, PersistError>,
76 ) -> Result<T, PersistError> {
77 let shard_index = self.index.write().get_shard(shard)?;
78 let shard_index = shard_index.write();
79 let writer = if let Some(writer) = self.writers.get_mut(&shard) {
80 writer
81 } else {
82 let writer = shard_index.writer()?;
83 self.writers.entry(shard).or_insert(writer)
84 };
85
86 f(shard, &shard_index, writer)
87 }
88}
89
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
91enum ScoreAlgo {
92 Default,
93 Related,
94}
95
96impl StoryIndex {
97 pub fn new(location: PersistLocation) -> Result<Self, PersistError> {
98 let scrape_db = ScrapeStore::new(location.clone())?;
99 tracing::info!("Initialized StoryIndex at {:?}", location);
100
101 let mut range = ShardRange::default();
103 if let PersistLocation::Path(path) = &location {
104 for d in std::fs::read_dir(path)?.flatten() {
105 if let Some(s) = d.file_name().to_str() {
106 if let Some(shard) = Shard::from_string(s) {
107 range.include(shard);
108 }
109 }
110 }
111 }
112
113 tracing::info!("Found shards {:?}", range);
114 let schema = StorySchema::instantiate_global_schema();
115 let new = Self {
116 index_cache: SharedMut::new(IndexCache {
117 cache: HashMap::new(),
118 location,
119 range,
120 schema: schema.clone(),
121 most_recent_story: None,
122 }),
123 scrape_db,
124 schema,
125 };
126
127 Ok(new)
128 }
129
130 pub fn shards(&self) -> ShardRange {
131 self.index_cache.read().range
132 }
133
134 fn get_shard(&self, shard: Shard) -> Result<SharedMut<StoryIndexShard>, PersistError> {
135 let mut lock = self.index_cache.write();
136 lock.get_shard(shard)
137 }
138
139 #[inline(always)]
141 pub fn with_scrapes<F: FnOnce(&ScrapeStore) -> T, T>(&self, f: F) -> T {
142 f(&self.scrape_db)
143 }
144
145 #[inline(always)]
147 fn with_searcher<F: FnMut(Shard, &Searcher, &StorySchema) -> Result<T, PersistError>, T>(
148 &self,
149 shard: Shard,
150 mut f: F,
151 ) -> Result<T, PersistError> {
152 let shard_index = self.get_shard(shard)?;
153 let shard_index = shard_index.read();
154 shard_index.with_searcher(|searcher, schema| f(shard, searcher, schema))
155 }
156
157 #[inline(always)]
159 fn with_index<F: FnMut(Shard, &StoryIndexShard) -> Result<T, PersistError>, T>(
160 &self,
161 shard: Shard,
162 mut f: F,
163 ) -> Result<T, PersistError> {
164 let shard_index = self.get_shard(shard)?;
165 let shard_index = shard_index.read();
166 f(shard, &shard_index)
167 }
168
169 fn with_writers<
172 TOuter,
173 WriterOuterClosure: FnOnce(&mut WriterProvider) -> Result<TOuter, PersistError>,
174 >(
175 &self,
176 f: WriterOuterClosure,
177 ) -> Result<TOuter, PersistError> {
178 let mut provider = WriterProvider {
179 writers: Default::default(),
180 index: self.index_cache.clone(),
181 };
182 let res = f(&mut provider);
183 let WriterProvider { writers, .. } = provider;
184
185 let writer_count = writers.len();
186 if res.is_ok() {
187 tracing::info!("Commiting {} writer(s)", writer_count);
188 let commit_start = timer_start!();
189 for (shard, writer) in writers.into_iter().sorted_by_key(|(shard, _)| *shard) {
190 tracing::info!("Committing shard {:?}...", shard);
191 let shard = self.get_shard(shard)?;
192 let mut shard = shard.write();
193 shard.commit_writer(writer)?;
194 }
195 timer_end!(commit_start, "Committed {} writer(s).", writer_count);
196 self.index_cache.write().most_recent_story = None;
197 } else {
198 for mut writer in writers.into_values() {
200 if let Err(e) = writer.rollback() {
201 tracing::error!("Ignoring nested error in writer rollback: {:?}", e);
202 }
203 }
204 }
205 res
206 }
207
208 fn create_scrape_id_from_scrape(scrape: &TypedScrape) -> String {
209 format!(
210 "{}:{}",
211 Shard::from_date_time(scrape.date).to_string(),
212 scrape.id
213 )
214 }
215
216 fn create_story_insert(eval: &StoryEvaluator, story: &ScrapeCollection) -> StoryInsert {
217 let extracted = story.extract(&eval.extractor);
219 let score = eval.scorer.score(&extracted);
220 let scrape_ids = extracted
221 .scrapes
222 .values()
223 .map(|x| x.1)
224 .map(Self::create_scrape_id_from_scrape)
225 .collect_vec();
226 let title = extracted.title().to_owned();
227 let mut tags = TagSet::new();
228 eval.tagger.tag(&title, &mut tags);
229 for tag in extracted.tags() {
230 if let Some(tag) = eval.tagger.check_tag_search(&tag) {
231 tags.add(tag);
232 } else {
233 tags.add(tag);
234 }
235 }
236 let url = extracted.url();
237 let id = StoryIdentifier::new(story.earliest, extracted.url().normalization()).to_base64();
238 let doc = StoryInsert {
239 id,
240 host: url.host().to_owned(),
241 url: url.raw().to_owned(),
242 url_norm: url.normalization().string().to_owned(),
243 url_norm_hash: url.normalization().hash(),
244 score: score as f64,
245 date: story.earliest.timestamp(),
246 title,
247 scrape_ids,
248 tags,
249 };
250 doc
251 }
252
253 fn find_insert_position<'a, I: IntoIterator<Item = ScrapeCollection> + 'a>(
258 &self,
259 scrapes: I,
260 ) -> Result<Vec<(ScrapeCollection, Shard, Option<DocAddress>)>, PersistError> {
261 let one_month = Duration::from_secs(60 * 60 * 24 * 30).as_secs() as i64;
262 let mut res = vec![];
263
264 for story in scrapes {
267 let current_shard = Shard::from_date_time(story.earliest);
268 let mut shard = current_shard;
269 let mut i = 0;
270 let (shard, doc_address) = loop {
271 let doc_address = self.with_index(shard, |_, index| {
273 let lookup = StoryLookupId {
274 url_norm_hash: story.url().normalization().hash(),
275 date: story.earliest.timestamp(),
276 };
277 let lookup = HashSet::from_iter([lookup]);
278 let result = index.lookup_stories(lookup, (-one_month)..one_month)?;
279 Ok(match result.into_iter().next() {
280 Some(StoryLookup::Found(_, doc)) => Some(doc),
281 _ => None,
282 })
283 })?;
284
285 break if doc_address.is_some() {
287 (shard, doc_address)
288 } else if i == 0 {
289 i -= 1;
291 shard = shard.sub_months(1);
292 continue;
293 } else {
294 (current_shard, None)
296 };
297 };
298 res.push((story, shard, doc_address));
299 }
300 Ok(res)
301 }
302
303 fn insert_scrape_batch<'a, I: IntoIterator<Item = TypedScrape> + 'a>(
304 &mut self,
305 eval: &StoryEvaluator,
306 scrapes: I,
307 ) -> Result<Vec<ScrapePersistResult>, PersistError> {
308 let mut memindex = MemIndex::default();
309 memindex.insert_scrapes(scrapes)?;
310 let positions = self.find_insert_position(memindex.get_all_stories())?;
311
312 self.with_writers(|provider| {
313 let mut res = vec![];
314 for (story, shard, doc_address) in positions {
315 res.push(provider.provide(shard, |_, index, writer| {
316 if let Some(doc) = doc_address {
317 let doc = index.with_searcher(|searcher, _| Ok(searcher.doc(doc)?))?;
318 let ids = index.extract_scrape_ids_from_doc(&doc);
319 let scrapes = self.scrape_db.fetch_scrape_batch(ids)?;
320 let mut orig_story =
321 ScrapeCollection::new_from_iter(scrapes.into_values().flatten());
322 orig_story.merge_all(story);
323 let doc = Self::create_story_insert(eval, &orig_story);
324 index.reinsert_story_document(writer, doc)
325 } else {
326 let doc = Self::create_story_insert(eval, &story);
327 index.insert_story_document(writer, doc)
328 }
329 })?);
330 }
331 Ok(res)
332 })
333 }
334
335 fn insert_scrapes<I: IntoIterator<Item = TypedScrape>>(
337 &mut self,
338 eval: &StoryEvaluator,
339 scrapes: I,
340 ) -> Result<Vec<ScrapePersistResult>, PersistError> {
341 let v = scrapes.into_iter().collect_vec();
342
343 tracing::info!("Storing raw scrapes...");
344 self.scrape_db.insert_scrape_batch(v.iter())?;
345
346 tracing::info!("Indexing scrapes...");
347 self.insert_scrape_batch(eval, v)
348 }
349
350 fn insert_scrape_collections<I: IntoIterator<Item = ScrapeCollection>>(
351 &mut self,
352 eval: &StoryEvaluator,
353 scrape_collections: I,
354 ) -> Result<Vec<ScrapePersistResult>, PersistError> {
355 self.with_writers(|provider| {
356 let mut res = vec![];
357 let start = timer_start!();
358 let mut total = 0;
359 for scrape_collections in &scrape_collections
360 .into_iter()
361 .chunks(STORY_INDEXING_CHUNK_SIZE)
362 {
363 tracing::info!("Indexing chunk...");
364 let start_chunk = timer_start!();
365 let mut count = 0;
366 let mut scrapes_batch = vec![];
367
368 for story in scrape_collections {
369 count += 1;
370 res.push(ScrapePersistResult::NewStory);
371 let doc = Self::create_story_insert(eval, &story);
372 let scrapes = story.scrapes.into_values();
373 scrapes_batch.extend(scrapes);
374 provider.provide(
375 Shard::from_date_time(story.earliest),
376 move |_, index, writer| {
377 index.insert_story_document(writer, doc)?;
378 Ok(())
379 },
380 )?;
381
382 if scrapes_batch.len() > SCRAPE_PROCESSING_CHUNK_SIZE {
383 self.scrape_db.insert_scrape_batch(scrapes_batch.iter())?;
384 scrapes_batch.clear();
385 }
386 }
387 self.scrape_db.insert_scrape_batch(scrapes_batch.iter())?;
388 scrapes_batch.clear();
389 total += count;
390 timer_end!(start_chunk, "Indexed chunk of {} stories", count);
391 }
392 timer_end!(start, "Indexed total of {} stories", total);
393 Ok(res)
394 })
395 }
396
397 fn reinsert_stories<I: IntoIterator<Item = StoryIdentifier>>(
398 &mut self,
399 eval: &StoryEvaluator,
400 stories: I,
401 ) -> Result<Vec<ScrapePersistResult>, PersistError> {
402 self.with_writers(|provider| {
403 let mut res = vec![];
404 for id in stories {
405 let searcher = self.fetch_by_id(&id);
406 let docs = self.with_searcher(id.shard(), searcher)?;
407 if let Some((shard, doc)) = docs.first() {
408 provider.provide(*shard, |_, index, writer| {
409 let doc = index.with_searcher(|searcher, _| Ok(searcher.doc(*doc)?))?;
410 let ids = index.extract_scrape_ids_from_doc(&doc);
411 let scrapes = self.scrape_db.fetch_scrape_batch(ids)?;
412 let orig_story =
413 ScrapeCollection::new_from_iter(scrapes.into_values().flatten());
414 let doc = Self::create_story_insert(eval, &orig_story);
415 index.reinsert_story_document(writer, doc)?;
416 Ok(())
417 })?;
418 res.push(ScrapePersistResult::MergedWithExistingStory);
419 } else {
420 res.push(ScrapePersistResult::NotFound)
421 }
422 }
423 Ok(res)
424 })
425 }
426
427 fn fetch_by_segment(
428 &self,
429 ) -> impl FnMut(Shard, &Searcher, &StorySchema) -> Result<Vec<(Shard, DocAddress)>, PersistError>
430 {
431 move |shard, searcher, _schema| {
432 let mut v = vec![];
433 let now = timer_start!();
434 for (idx, segment_reader) in searcher.segment_readers().iter().enumerate() {
435 for doc_id in segment_reader.doc_ids_alive() {
436 let doc_address = DocAddress::new(idx as u32, doc_id);
437 v.push((shard, doc_address));
438 }
439 }
440 timer_end!(now, "Loaded {} stories from shard {:?}", v.len(), shard);
441 Ok(v)
442 }
443 }
444
445 fn fetch_by_id(
446 &self,
447 id: &StoryIdentifier,
448 ) -> impl FnMut(Shard, &Searcher, &StorySchema) -> Result<Vec<(Shard, DocAddress)>, PersistError>
449 {
450 let id = id.to_base64();
451 move |shard, searcher, schema| {
452 let query = TermQuery::new(
453 Term::from_field_text(schema.id_field, &id),
454 IndexRecordOption::Basic,
455 );
456 let docs = searcher.search(&query, &TopDocs::with_limit(1))?;
457 for (_, doc_address) in docs {
458 return Ok(vec![(shard, doc_address)]);
459 }
460 Ok(vec![])
461 }
462 }
463
464 fn fetch_search_query<Q: Query>(
466 &self,
467 query: Q,
468 max: usize,
469 score_algo: ScoreAlgo,
470 ) -> Result<Vec<(Shard, DocAddress)>, PersistError> {
471 let mut vec = vec![];
472 let mut remaining = max;
473 let now = self.most_recent_story()?.timestamp();
474 for shard in self.shards().iterate(ShardOrder::NewestFirst) {
475 if remaining == 0 {
476 break;
477 }
478 let docs = self.with_searcher(shard, |shard, searcher, schema| {
479 let schema = schema.clone();
480 let docs =
482 TopDocs::with_limit(remaining).tweak_score(move |reader: &SegmentReader| {
483 let score_field = reader
484 .fast_fields()
485 .f64(schema.score_field)
486 .expect("Failed to get fast fields");
487 let date_field = reader
488 .fast_fields()
489 .i64(schema.date_field)
490 .expect("Failed to get fast fields");
491 move |doc, score| {
492 if score_algo == ScoreAlgo::Related {
493 score
494 } else {
495 let doc_score = score_field.get_val(doc);
496 let doc_date = date_field.get_val(doc);
497 let age = now - doc_date;
498 score + doc_score as f32 + (age as f32) * -0.00001
499 }
500 }
501 });
502 let docs = searcher.search(&query, &docs)?;
503 Ok(docs.into_iter().map(move |x| (shard, x.1)))
504 })?;
505 vec.extend(docs);
506 remaining = max.saturating_sub(vec.len());
507 }
508 Ok(vec)
509 }
510
511 fn fetch_tag_search(
512 &self,
513 tag: &str,
514 alt: Option<&str>,
515 max: usize,
516 ) -> Result<Vec<(Shard, DocAddress)>, PersistError> {
517 let mut query_parser = QueryParser::new(
530 self.schema.schema.clone(),
531 vec![self.schema.title_field, self.schema.tags_field],
532 TokenizerManager::default(),
533 );
534 query_parser.set_field_boost(self.schema.tags_field, 10.0);
536 let query = if let Some(alt) = alt {
537 query_parser.parse_query(&format!("{tag} OR {alt}"))?
538 } else {
539 query_parser.parse_query(tag)?
540 };
541 tracing::debug!("Tag symbol query = {:?}", query);
542 self.fetch_search_query(query, max, ScoreAlgo::Default)
543 }
544
545 fn fetch_domain_search(
546 &self,
547 domain: &str,
548 max: usize,
549 ) -> Result<Vec<(Shard, DocAddress)>, PersistError> {
550 let host_field = self.schema.host_field;
551 let phrase = domain
552 .split('.')
553 .filter(|s| !s.is_empty())
554 .map(|s| Term::from_field_text(host_field, s))
555 .collect_vec();
556
557 if phrase.len() == 0 {
559 return Err(PersistError::UnexpectedError("Empty domain".to_string()));
560 }
561
562 if phrase.len() == 1 {
564 let query =
565 TermQuery::new(phrase.into_iter().next().unwrap(), IndexRecordOption::Basic);
566 tracing::debug!("Domain term query = {:?}", query);
567 self.fetch_search_query(query, max, ScoreAlgo::Default)
568 } else {
569 let query = PhraseQuery::new(phrase);
570 tracing::debug!("Domain phrase query = {:?}", query);
571 self.fetch_search_query(query, max, ScoreAlgo::Default)
572 }
573 }
574
575 fn fetch_url_search(
576 &self,
577 url: &StoryUrl,
578 max: usize,
579 ) -> Result<Vec<(Shard, DocAddress)>, PersistError> {
580 let hash = url.normalization().hash();
581 let hash_field = self.schema.url_norm_hash_field;
582 let query = TermQuery::new(
583 Term::from_field_i64(hash_field, hash),
584 IndexRecordOption::Basic,
585 );
586
587 tracing::debug!("URL hash query = {:?} (for {url})", query);
588 self.fetch_search_query(query, max, ScoreAlgo::Default)
589 }
590
591 fn fetch_text_search(
592 &self,
593 search: &str,
594 max: usize,
595 ) -> Result<Vec<(Shard, DocAddress)>, PersistError> {
596 let mut query_parser = QueryParser::new(
597 self.schema.schema.clone(),
598 vec![self.schema.title_field, self.schema.tags_field],
599 TokenizerManager::default(),
600 );
601 query_parser.set_field_boost(self.schema.tags_field, 3.0);
603
604 let search = if search.contains("http:") {
606 search.replace("http:", "http ").into()
607 } else {
608 Cow::Borrowed(search)
609 };
610 let search = if search.contains("https:") {
611 search.replace("https:", "https ").into()
612 } else {
613 search
614 };
615
616 let query = query_parser.parse_query(&search)?;
617 tracing::debug!("Term query = {:?}", query);
618 self.fetch_search_query(query, max, ScoreAlgo::Related)
619 }
620
621 fn fetch_related(
622 &self,
623 title: &str,
624 tags: &[String],
625 max: usize,
626 ) -> Result<Vec<(Shard, DocAddress)>, PersistError> {
627 let mut query_parser = QueryParser::new(
628 self.schema.schema.clone(),
629 vec![self.schema.title_field, self.schema.tags_field],
630 TokenizerManager::default(),
631 );
632 query_parser.set_field_boost(self.schema.title_field, 2.0);
633
634 let title = title.to_lowercase().replace(" the ", " ");
637 let title = title.replace(" a ", " ");
638 let title = title.trim_start_matches("the ");
639 let title = title.trim_start_matches("a ");
640 let title_query = query_parser
641 .parse_query(&title.replace(|c: char| c != ' ' && !c.is_alphanumeric(), " "))?;
642
643 let mut subqueries = vec![(Occur::Should, title_query)];
644 for tag in tags {
645 let query: Box<dyn Query> = if tag.contains('.') {
646 let phrase = tag
647 .split('.')
648 .filter(|s| !s.is_empty())
649 .map(|s| Term::from_field_text(self.schema.host_field, s))
650 .collect_vec();
651 Box::new(BoostQuery::new(Box::new(PhraseQuery::new(phrase)), 10.0))
652 } else {
653 Box::new(BoostQuery::new(
654 Box::new(TermQuery::new(
655 Term::from_field_text(self.schema.tags_field, tag),
656 IndexRecordOption::Basic,
657 )),
658 0.2,
659 ))
660 };
661 subqueries.push((Occur::Should, query));
662 }
663
664 let query = BooleanQuery::new(subqueries);
665 tracing::debug!("Related query = {:?}", query);
666 self.fetch_search_query(query, max, ScoreAlgo::Related)
667 }
668
669 fn fetch_front_page(&self, max_count: usize) -> Result<Vec<(Shard, DocAddress)>, PersistError> {
670 let mut story_collector: StoryCollector<(Shard, DocAddress)> =
671 StoryCollector::new(max_count);
672 let mut processed = 0;
673 let processing_target = max_count * 2;
674
675 for shard in self.shards().iterate(ShardOrder::NewestFirst).take(3) {
677 if processed >= processing_target {
679 break;
680 }
681
682 self.with_searcher(shard, |shard, searcher, _schema| {
683 let top = TopDocs::with_limit(processing_target - processed)
684 .order_by_fast_field::<i64>(self.schema.date_field);
685 let docs = searcher.search(&AllQuery {}, &top)?;
686 tracing::info!("Got {} doc(s) from shard {:?}", docs.len(), shard);
687
688 for (_, doc_address) in docs {
689 processed += 1;
690 let score = searcher
691 .segment_reader(doc_address.segment_ord)
692 .fast_fields()
693 .f64(self.schema.score_field)?
694 .get_val(doc_address.doc_id) as f32;
695 if story_collector.would_accept(score) {
696 story_collector.accept(score, (shard, doc_address));
697 }
698 }
699
700 Ok(())
701 })?;
702 }
703 tracing::info!(
704 "Got {}/{} docs for front page (processed {})",
705 story_collector.len(),
706 max_count,
707 processed
708 );
709 Ok(story_collector.to_sorted())
710 }
711
712 fn fetch_doc_addresses(
713 &self,
714 query: StoryQuery,
715 max: usize,
716 ) -> Result<Vec<(Shard, DocAddress)>, PersistError> {
717 catch_unwind(|| match query {
718 StoryQuery::ById(id) => self.with_searcher(id.shard(), self.fetch_by_id(&id)),
719 StoryQuery::ByShard(shard) => self.with_searcher(shard, self.fetch_by_segment()),
720 StoryQuery::FrontPage => self.fetch_front_page(max),
721 StoryQuery::TagSearch(tag, alt) => self.fetch_tag_search(&tag, alt.as_deref(), max),
722 StoryQuery::DomainSearch(domain) => self.fetch_domain_search(&domain, max),
723 StoryQuery::UrlSearch(url) => self.fetch_url_search(&url, max),
724 StoryQuery::TextSearch(text) => self.fetch_text_search(&text, max),
725 StoryQuery::Related(title, tags) => self.fetch_related(&title, tags.as_slice(), max),
726 })
727 .map_err(|e|
728 match e.downcast::<&'static str>() {
731 Ok(v) => PersistError::UnexpectedError(format!("Storage fetch panic: {v}")),
732 Err(_) => PersistError::UnexpectedError("Storage fetch panic".to_owned())
733 })?
734 }
735}
736
737impl StorageWriter for StoryIndex {
738 fn insert_scrapes<I: IntoIterator<Item = TypedScrape>>(
742 &mut self,
743 eval: &StoryEvaluator,
744 scrapes: I,
745 ) -> Result<Vec<ScrapePersistResult>, PersistError> {
746 self.insert_scrapes(eval, scrapes)
747 }
748
749 fn insert_scrape_collections<I: IntoIterator<Item = ScrapeCollection>>(
752 &mut self,
753 eval: &StoryEvaluator,
754 scrape_collections: I,
755 ) -> Result<Vec<ScrapePersistResult>, PersistError> {
756 self.insert_scrape_collections(eval, scrape_collections)
757 }
758
759 fn reinsert_stories<I: IntoIterator<Item = StoryIdentifier>>(
762 &mut self,
763 eval: &StoryEvaluator,
764 stories: I,
765 ) -> Result<Vec<ScrapePersistResult>, PersistError> {
766 self.reinsert_stories(eval, stories)
767 }
768}
769
770impl StorageFetch<Shard> for StoryIndex {
771 fn fetch_type(&self, query: StoryQuery, max: usize) -> Result<Vec<Story<Shard>>, PersistError> {
772 let mut v = vec![];
773 for (shard, doc) in self.fetch_doc_addresses(query, max)? {
774 let doc = self.with_index(shard, |_, index| {
775 let story = index.lookup_story(doc)?;
776 let url = StoryUrl::parse(story.url).expect("Failed to parse URL");
777 let date = StoryDate::from_seconds(story.date).expect("Failed to re-parse date");
778 let score = story.score as f32;
779 Ok(Story::new_from_parts(
780 story.title,
781 url,
782 date,
783 score,
784 story.tags,
785 story.scrape_ids,
786 ))
787 })?;
788
789 v.push(doc);
790 }
791 Ok(v)
792 }
793}
794
795impl StorageFetch<TypedScrape> for StoryIndex {
796 fn fetch_type(
797 &self,
798 query: StoryQuery,
799 max: usize,
800 ) -> Result<Vec<Story<TypedScrape>>, PersistError> {
801 let mut v = vec![];
802 for (shard, doc) in self.fetch_doc_addresses(query, max)? {
803 let doc = self.with_index(shard, |_, index| {
804 let story = index.lookup_story(doc)?;
805 let url = StoryUrl::parse(story.url).expect("Failed to parse URL");
806 let date = StoryDate::from_seconds(story.date).expect("Failed to re-parse date");
807 let score = story.score as f32;
808
809 let scrapes = self
810 .scrape_db
811 .fetch_scrape_batch(story.scrape_ids.clone())?;
812 let story = Story::new_from_parts(
813 story.title,
814 url,
815 date,
816 score,
817 story.tags,
818 scrapes.into_values().flatten(),
819 );
820
821 Ok(story)
822 })?;
823
824 v.push(doc);
825 }
826 Ok(v)
827 }
828}
829
830impl Storage for StoryIndex {
831 fn most_recent_story(&self) -> Result<StoryDate, PersistError> {
832 if let Some(most_recent_story) = self.index_cache.read().most_recent_story {
833 return Ok(most_recent_story);
834 }
835
836 if let Some(max) = self.shards().iterate(ShardOrder::NewestFirst).next() {
837 let shard = self.get_shard(max)?;
838 let index = shard.read();
839 let result = index.most_recent_story()?;
840 self.index_cache.write().most_recent_story = Some(result);
841 Ok(result)
842 } else {
843 Ok(StoryDate::MIN)
844 }
845 }
846
847 fn shard_range(&self) -> Result<ShardRange, PersistError> {
848 Ok(self.shards())
849 }
850
851 fn story_count(&self) -> Result<StorageSummary, PersistError> {
852 let mut summary = StorageSummary::default();
853 for shard in self.shards().iterate(ShardOrder::OldestFirst) {
854 let index = self.get_shard(shard)?;
855 let subtotal = index.read().total_docs()?;
856 let scrape_subtotal = self.scrape_db.stats(shard)?.count;
857 summary.by_shard.push((
858 shard.to_string(),
859 ShardSummary {
860 story_count: subtotal,
861 scrape_count: scrape_subtotal,
862 },
863 ));
864 summary.total.story_count += subtotal;
865 summary.total.scrape_count += scrape_subtotal;
866 }
867 Ok(summary)
868 }
869
870 fn fetch_count(&self, query: StoryQuery, max: usize) -> Result<usize, PersistError> {
871 Ok(self.fetch_doc_addresses(query, max)?.len())
872 }
873
874 fn fetch_detail_one(
875 &self,
876 query: StoryQuery,
877 ) -> Result<Option<HashMap<String, Vec<String>>>, PersistError> {
878 if let Some((shard, doc)) = self.fetch_doc_addresses(query, 1)?.first() {
879 let res = self.with_index(*shard, |_, index| {
880 let named_doc = index.doc_fields(*doc)?;
881 let mut map = HashMap::new();
882 for (key, value) in named_doc.0 {
883 map.insert(
884 key,
885 value
886 .into_iter()
887 .map(|v| serde_json::to_string(&v).unwrap_or_else(|e| e.to_string()))
888 .collect_vec(),
889 );
890 }
891 Ok(map)
892 })?;
893 Ok(Some(res))
894 } else {
895 Ok(None)
896 }
897 }
898}
899
900#[cfg(test)]
901mod test {
902
903 use super::*;
904 use progscrape_scrapers::{
905 hacker_news::*, lobsters::LobstersStory, reddit::*, ScrapeConfig, ScrapeSource, StoryUrl,
906 };
907
908 use crate::{story::TagSet, test::*, MemIndex};
909 use rstest::*;
910
911 fn populate_shard(
912 ids: impl Iterator<Item = (i64, i64)>,
913 ) -> Result<StoryIndexShard, PersistError> {
914 let mut shard = StoryIndexShard::initialize(
915 PersistLocation::Memory,
916 Shard::default(),
917 StorySchema::instantiate_global_schema(),
918 )?;
919 shard.with_writer(move |shard, writer, _| {
920 for (url_norm_hash, date) in ids {
921 shard.insert_story_document(
922 writer,
923 StoryInsert {
924 url_norm_hash,
925 date,
926 ..Default::default()
927 },
928 )?;
929 }
930 Ok(())
931 })?;
932 Ok(shard)
933 }
934
935 fn hn_story(id: &str, date: StoryDate, title: &str, url: &StoryUrl) -> TypedScrape {
936 HackerNewsStory::new_with_defaults(id, date, title, url.clone()).into()
937 }
938
939 fn reddit_story(
940 id: &str,
941 subreddit: &str,
942 date: StoryDate,
943 title: &str,
944 url: &StoryUrl,
945 ) -> TypedScrape {
946 RedditStory::new_subsource_with_defaults(id, subreddit, date, title, url.clone()).into()
947 }
948
949 fn lobsters_story(
950 id: &str,
951 date: StoryDate,
952 title: &str,
953 url: &StoryUrl,
954 tags: Vec<String>,
955 ) -> TypedScrape {
956 let mut lobsters = LobstersStory::new_with_defaults(id, date, title, url.clone());
957 lobsters.data.tags = tags;
958 lobsters.into()
959 }
960
961 fn rust_story_hn() -> TypedScrape {
962 let url = StoryUrl::parse("http://example.com").expect("URL");
963 let date = StoryDate::year_month_day(2020, 1, 1).expect("Date failed");
964 hn_story("story1", date, "I love Rust", &url)
965 }
966
967 fn rust_story_hn_prev_year() -> TypedScrape {
968 let url = StoryUrl::parse("http://example.com").expect("URL");
969 let date = StoryDate::year_month_day(2019, 12, 31).expect("Date failed");
970 hn_story("story1", date, "I love Rust", &url)
971 }
972
973 fn rust_story_reddit() -> TypedScrape {
974 let url = StoryUrl::parse("http://example.com").expect("URL");
975 let date = StoryDate::year_month_day(2020, 1, 1).expect("Date failed");
976 reddit_story("story1", "rust", date, "I love rust", &url)
977 }
978
979 fn rust_story_lobsters() -> TypedScrape {
980 let url = StoryUrl::parse("http://example.com").expect("URL");
981 let date = StoryDate::year_month_day(2020, 1, 1).expect("Date failed");
982 lobsters_story(
983 "story1",
984 date,
985 "Type inference in Rust",
986 &url,
987 vec!["plt".to_string(), "c++".to_string()],
988 )
989 }
990
991 #[rstest]
992 fn test_index_shard(_enable_tracing: &bool) {
993 let ids1 = (0..100).map(|x| (x, 0));
994 let ids2 = (100..200).map(|x| (x, 10));
995 let shard = populate_shard(ids1.chain(ids2)).expect("Failed to initialize shard");
996 let count_found = |vec: Vec<StoryLookup>| {
997 vec.iter()
998 .filter(|x| matches!(x, StoryLookup::Found(..)))
999 .count()
1000 };
1001 macro_rules! test_range {
1002 ($date:expr, $slop:expr, $expected:expr) => {
1003 let lookup = (95..110)
1004 .into_iter()
1005 .map(|n| StoryLookupId {
1006 url_norm_hash: n,
1007 date: $date,
1008 })
1009 .collect();
1010 let result = shard
1011 .lookup_stories(lookup, $slop)
1012 .expect("Failed to look up");
1013 assert_eq!($expected, count_found(result));
1014 };
1015 }
1016 test_range!(0, 0..=0, 5);
1018 test_range!(10, 0..=0, 10);
1020 test_range!(0, 0..=10, 15);
1022 }
1023
1024 #[rstest]
1025 fn test_index_scrapes(_enable_tracing: &bool) -> Result<(), Box<dyn std::error::Error>> {
1026 use ScrapeSource::*;
1027
1028 let mut index = StoryIndex::new(PersistLocation::Memory)?;
1029 let eval = StoryEvaluator::new_for_test();
1030 index.insert_scrapes(&eval, [rust_story_hn()])?;
1031
1032 let counts = index.story_count()?;
1033 assert_eq!(counts.total.story_count, 1);
1034
1035 index.insert_scrapes(&eval, [rust_story_reddit()])?;
1036
1037 let counts = index.story_count()?;
1038 assert_eq!(counts.total.story_count, 1);
1039
1040 let search = index.fetch::<Shard>(StoryQuery::from_search(&eval.tagger, "rust"), 10)?;
1041 assert_eq!(search.len(), 1);
1042
1043 let story = &search[0];
1044 assert_eq!("I love Rust", story.title);
1045 assert!(itertools::equal(
1046 [
1047 &HackerNews.id("story1"),
1048 &Reddit.subsource_id("rust", "story1")
1049 ],
1050 story.scrapes.keys().sorted()
1051 ),);
1052 assert_eq!(TagSet::from_iter(["rust"]), story.tags);
1053
1054 Ok(())
1055 }
1056
1057 #[rstest]
1058 fn test_index_scrapes_across_shard(
1059 _enable_tracing: &bool,
1060 ) -> Result<(), Box<dyn std::error::Error>> {
1061 use ScrapeSource::*;
1062
1063 let mut index = StoryIndex::new(PersistLocation::Memory)?;
1064 let eval = StoryEvaluator::new_for_test();
1065 index.insert_scrapes(&eval, [rust_story_hn_prev_year()])?;
1066
1067 let counts = index.story_count()?;
1068 assert_eq!(counts.total.story_count, 1);
1069
1070 index.insert_scrapes(&eval, [rust_story_reddit()])?;
1071
1072 let counts = index.story_count()?;
1073 assert_eq!(counts.total.story_count, 1);
1074
1075 let search = index.fetch::<Shard>(StoryQuery::from_search(&eval.tagger, "rust"), 10)?;
1076 assert_eq!(search.len(), 1);
1077
1078 let story = &search[0];
1079 assert_eq!("I love Rust", story.title);
1080 assert!(itertools::equal(
1081 [
1082 &HackerNews.id("story1"),
1083 &Reddit.subsource_id("rust", "story1")
1084 ],
1085 story.scrapes.keys().sorted()
1086 ),);
1087 assert_eq!(TagSet::from_iter(["rust"]), story.tags);
1088
1089 Ok(())
1090 }
1091
1092 #[rstest]
1093 fn test_index_scrape_collections(
1094 _enable_tracing: &bool,
1095 ) -> Result<(), Box<dyn std::error::Error>> {
1096 use ScrapeSource::*;
1097
1098 let mut memindex = MemIndex::default();
1099 let eval = StoryEvaluator::new_for_test();
1100 memindex.insert_scrapes([rust_story_hn(), rust_story_reddit()])?;
1101
1102 let mut index = StoryIndex::new(PersistLocation::Memory)?;
1103 index.insert_scrape_collections(&eval, memindex.get_all_stories())?;
1104
1105 let counts = index.story_count()?;
1106 assert_eq!(counts.total.story_count, 1);
1107
1108 let search = index.fetch::<Shard>(StoryQuery::from_search(&eval.tagger, "rust"), 10)?;
1109 assert_eq!(search.len(), 1);
1110
1111 let story = &search[0];
1112 assert_eq!("I love Rust", story.title);
1113 assert!(itertools::equal(
1114 [
1115 &HackerNews.id("story1"),
1116 &Reddit.subsource_id("rust", "story1")
1117 ],
1118 story.scrapes.keys().sorted()
1119 ),);
1120 assert_eq!(TagSet::from_iter(["rust"]), story.tags);
1121
1122 Ok(())
1123 }
1124
1125 #[test]
1127 fn test_reindex_story() -> Result<(), Box<dyn std::error::Error>> {
1128 let mut memindex = MemIndex::default();
1130 let eval = StoryEvaluator::new_for_test();
1131 memindex.insert_scrapes([rust_story_hn(), rust_story_reddit()])?;
1132 let mut index = StoryIndex::new(PersistLocation::Memory)?;
1133 index.insert_scrape_collections(&eval, memindex.get_all_stories())?;
1134
1135 let story = index
1137 .fetch_one::<Shard>(StoryQuery::from_search(&eval.tagger, "rust"))?
1138 .expect("Missing story");
1139
1140 assert_eq!(
1142 index.reinsert_stories(&eval, [story.id])?,
1143 vec![ScrapePersistResult::MergedWithExistingStory]
1144 );
1145 let story = index
1146 .fetch_one::<Shard>(StoryQuery::from_search(&eval.tagger, "rust"))?
1147 .expect("Missing story");
1148 assert_eq!(story.title, "I love Rust");
1149
1150 let counts = index.story_count()?;
1151 assert_eq!(counts.total.story_count, 1);
1152
1153 Ok(())
1154 }
1155
1156 #[rstest]
1157 fn test_insert_batch(_enable_tracing: &bool) -> Result<(), Box<dyn std::error::Error>> {
1158 let mut batch = vec![];
1159 let date = StoryDate::year_month_day(2020, 1, 1).expect("Date failed");
1160
1161 for i in 0..30 {
1162 let url = StoryUrl::parse(format!("http://domain-{}.com/", i)).expect("URL");
1163 batch.push(hn_story(
1164 &format!("story-{}", i),
1165 date,
1166 &format!("Title {}", i),
1167 &url,
1168 ));
1169 }
1170
1171 let mut index = StoryIndex::new(PersistLocation::Memory)?;
1172 let eval = StoryEvaluator::new_for_test();
1173
1174 index.insert_scrapes(&eval, batch.clone())?;
1175
1176 let url = StoryUrl::parse("http://domain-3.com/").expect("URL");
1178
1179 index.insert_scrapes(
1180 &eval,
1181 [reddit_story("story-3", "subreddit", date, "Title 3", &url)],
1182 )?;
1183
1184 index.insert_scrapes(&eval, batch.clone())?;
1185
1186 let front_page = index.fetch_count(StoryQuery::FrontPage, 100)?;
1187 assert_eq!(30, front_page);
1188
1189 Ok(())
1190 }
1191
1192 #[test]
1193 fn test_findable_by_extracted_tag() -> Result<(), Box<dyn std::error::Error>> {
1194 let mut index = StoryIndex::new(PersistLocation::Memory)?;
1195 let eval = StoryEvaluator::new_for_test();
1196 let scrape = rust_story_lobsters();
1197 let id = StoryIdentifier::new(scrape.date, scrape.url.normalization());
1198 index.insert_scrapes(&eval, [scrape.clone()])?;
1199
1200 let counts = index.story_count()?;
1201 assert_eq!(counts.total.story_count, 1);
1202
1203 let story = index
1205 .fetch_one::<Shard>(StoryQuery::ById(id.clone()))?
1206 .expect("Expected one story");
1207 assert_eq!(
1208 story.tags.into_iter().sorted().collect_vec(),
1209 vec!["cplusplus", "plt", "rust"]
1210 );
1211
1212 for term in ["plt", "c++", "cplusplus", "type", "inference"] {
1213 let search = index.fetch_count(StoryQuery::from_search(&eval.tagger, term), 10)?;
1214 let doc = index.fetch_detail_one(StoryQuery::ById(id.clone()))?;
1215 assert_eq!(
1216 1, search,
1217 "Expected one search result when querying '{}' for title={} url={} doc={:?}",
1218 term, scrape.raw_title, scrape.url, doc
1219 );
1220 }
1221 Ok(())
1222 }
1223
1224 #[test]
1226 fn test_torture() -> Result<(), Box<dyn std::error::Error>> {
1227 let mut index = StoryIndex::new(PersistLocation::Memory)?;
1228 let eval = StoryEvaluator::new_for_test();
1229 let url = StoryUrl::parse("http://example.com").expect("URL");
1230 let date = StoryDate::year_month_day(2020, 1, 1).expect("Date failed");
1231 index.insert_scrapes(&eval, [hn_story("story1", date, "title", &url)])?;
1232 for a in ["http://", "http:", " ", ""] {
1233 for b in ["x", "2", ".", " ", ""] {
1234 for c in [".1", ".x", ".", "x", "2", " ", ""] {
1235 for d in ["/", "/x", "/x.", "?", " ", ""] {
1236 let s = format!("{a}{b}{c}{d}");
1237 if s.trim().is_empty() || !s.contains(|c: char| c.is_alphanumeric()) {
1239 continue;
1240 }
1241 eprintln!("{s}");
1242 let res =
1244 index.fetch_count(StoryQuery::from_search(&eval.tagger, &s), 10)?;
1245 assert_eq!(res, 0, "Expected zero results for '{s}'");
1246 }
1247 }
1248 }
1249 }
1250 Ok(())
1251 }
1252
1253 #[rstest]
1255 #[case("http://example.com", "I love Rust", &["rust", "love", "example.com"])]
1256 #[case("http://medium.com", "The Pitfalls of C++", &["c++", "cplusplus", "pitfalls", "Pitfalls"])]
1257 #[case("http://www.att.com", "New AT&T plans", &["at&t", "atandt", "att.com", "http://att.com"])]
1258 #[case("http://example.com", "I love Go", &["golang", "love"])]
1259 #[case("http://example.com", "I love C", &["clanguage", "love"])]
1260 #[case("http://localhost", "About that special host", &["http://localhost", "localhost."])]
1262 #[case("http://www3.xyz.imperial.co.uk", "Why England is England", &["england", "www3.xyz.imperial.co.uk", "xyz.imperial.co.uk", "co.uk"])]
1263 fn test_findable(
1266 #[case] url: &str,
1267 #[case] title: &str,
1268 #[case] search_terms: &[&str],
1269 ) -> Result<(), Box<dyn std::error::Error>> {
1270 let mut index = StoryIndex::new(PersistLocation::Memory)?;
1271 let eval = StoryEvaluator::new_for_test();
1272 let url = StoryUrl::parse(url).expect("URL");
1273 let date = StoryDate::year_month_day(2020, 1, 1).expect("Date failed");
1274 let id = StoryIdentifier::new(date, url.normalization());
1275 index.insert_scrapes(&eval, [hn_story("story1", date, title, &url)])?;
1276
1277 let counts = index.story_count()?;
1278 assert_eq!(counts.total.story_count, 1);
1279
1280 let doc = index
1282 .fetch_detail_one(StoryQuery::ById(id))?
1283 .expect("Didn't find doc");
1284
1285 for term in search_terms {
1286 let search = index.fetch_count(StoryQuery::from_search(&eval.tagger, term), 10)?;
1287 assert_eq!(
1288 1, search,
1289 "Expected one search result when querying '{}' for title={} url={} doc={:?}",
1290 term, title, url, doc
1291 );
1292 }
1293
1294 Ok(())
1295 }
1296
1297 #[rstest]
1298 fn test_index_lots(_enable_tracing: &bool) -> Result<(), Box<dyn std::error::Error>> {
1299 let path = "/tmp/indextest";
1300 std::fs::create_dir_all(path)?;
1301 let mut index = StoryIndex::new(PersistLocation::Path(path.into()))?;
1302
1303 let scrapes = progscrape_scrapers::load_sample_scrapes(&ScrapeConfig::default());
1304 let eval = StoryEvaluator::new_for_test();
1305 let mut memindex = MemIndex::default();
1306
1307 memindex.insert_scrapes(scrapes)?;
1309
1310 index.insert_scrape_collections(&eval, memindex.get_all_stories())?;
1311
1312 for story in index.fetch::<Shard>(StoryQuery::from_search(&eval.tagger, "rust"), 10)? {
1314 println!("{:?}", story);
1315 }
1316
1317 Ok(())
1318 }
1319}