use bincode::Error as BincodeError;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::convert::From;
use strip_markdown::strip_markdown;
use xorf::{HashProxy, Xor8};
use crate::{PostId, SearchIndex, Storage};
pub trait Post {
fn title(&self) -> &str;
fn url(&self) -> &str;
fn body(&self) -> Option<&str>;
fn meta(&self) -> HashMap<String, String>;
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct BasicPost {
pub title: String,
pub url: String,
pub body: Option<String>,
#[serde(default)]
pub meta: HashMap<String, String>,
}
impl Post for BasicPost {
fn title(&self) -> &str {
&self.title
}
fn url(&self) -> &str {
&self.url
}
fn body(&self) -> Option<&str> {
self.body.as_deref()
}
fn meta(&self) -> HashMap<String, String> {
self.meta.clone()
}
}
#[derive(Debug, Clone)]
pub struct TinySearch {
custom_stopwords: Option<HashSet<String>>,
}
impl TinySearch {
pub const fn new() -> Self {
Self {
custom_stopwords: None,
}
}
#[must_use]
pub fn with_stopwords<I>(mut self, stopwords: I) -> Self
where
I: IntoIterator<Item = String>,
{
self.custom_stopwords = Some(stopwords.into_iter().collect());
self
}
pub fn parse_posts_from_json(
&self,
json_str: &str,
) -> Result<Vec<BasicPost>, serde_json::Error> {
serde_json::from_str(json_str)
}
pub fn build_index<P: Post>(
&self,
posts: &[P],
) -> Result<SearchIndex, Box<dyn std::error::Error>> {
let prepared_posts = Self::prepare_posts(posts);
let stopwords = self.get_stopwords();
Ok(Self::generate_filters(prepared_posts, &stopwords))
}
pub fn search<'index>(
&self,
index: &'index SearchIndex,
query: &str,
num_results: usize,
) -> Vec<&'index PostId> {
crate::search(index, query, num_results)
}
pub fn build_and_serialize_index<P: Post>(
&self,
posts: &[P],
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
let filters = self.build_index(posts)?;
let storage = Storage::from(filters);
storage.to_bytes().map_err(std::convert::Into::into)
}
pub fn load_index_from_bytes(&self, bytes: &[u8]) -> Result<SearchIndex, BincodeError> {
let storage = Storage::from_bytes(bytes)?;
Ok(storage.filters)
}
}
impl Default for TinySearch {
fn default() -> Self {
Self::new()
}
}
impl TinySearch {
fn get_stopwords(&self) -> HashSet<String> {
self.custom_stopwords.clone().unwrap_or_else(|| {
include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/assets/stopwords"))
.split_whitespace()
.map(String::from)
.collect()
})
}
fn cleanup(s: &str) -> String {
s.replace(|c: char| !(c.is_alphabetic() || c == '\''), " ")
}
fn tokenize_with_stopwords(words: &str, stopwords: &HashSet<String>) -> HashSet<String> {
Self::cleanup(&strip_markdown(words))
.split_whitespace()
.filter(|&word| !word.trim().is_empty())
.map(str::to_lowercase)
.filter(|word| !stopwords.contains(word))
.collect()
}
fn generate_filters(
posts: HashMap<PostId, Option<String>>,
stopwords: &HashSet<String>,
) -> SearchIndex {
let split_posts: HashMap<PostId, Option<HashSet<String>>> = posts
.into_iter()
.map(|(post, content)| {
(
post,
content.map(|content| Self::tokenize_with_stopwords(&content, stopwords)),
)
})
.collect();
split_posts
.into_iter()
.map(|(post_id, body)| {
let title: HashSet<String> =
Self::tokenize_with_stopwords(&post_id.title, stopwords);
let metadata: HashSet<String> = if post_id.meta.is_empty() {
HashSet::new()
} else {
Self::tokenize_with_stopwords(&post_id.meta, stopwords)
};
let mut content: HashSet<String> = title;
content.extend(metadata);
if let Some(body) = body {
content.extend(body);
}
let content_vec: Vec<String> = content.into_iter().collect();
let filter =
HashProxy::<String, std::collections::hash_map::DefaultHasher, Xor8>::from(
&content_vec,
);
(post_id, filter)
})
.collect()
}
fn prepare_posts<P: Post>(posts: &[P]) -> HashMap<PostId, Option<String>> {
posts
.iter()
.map(|post| {
let meta_str = if post.meta().is_empty() {
String::new()
} else {
serde_json::to_string(&post.meta()).unwrap_or_default()
};
let post_id = PostId {
title: post.title().to_string(),
url: post.url().to_string(),
meta: meta_str,
};
let body = post.body().map(std::string::ToString::to_string);
(post_id, body)
})
.collect()
}
}