use std::{
collections::HashSet,
fs::File,
io::Read,
path::{Path, PathBuf},
sync::{Arc, RwLock},
};
use base64::prelude::*;
use futures::StreamExt;
use log::info;
use mime_guess::from_path;
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use regex::Regex;
use reqwest::{header, Client};
use wax::{Glob, WalkEntry, WalkError};
pub struct PostConfig {
pub concurrency: usize,
pub host: String,
pub port: u16,
pub collection: String,
pub directory_path: PathBuf,
pub file_extensions: Vec<String>,
pub update_url: Option<String>,
pub exclued_regex: Option<Regex>,
pub include_regex: Option<Regex>,
pub basic_auth_creds: Option<String>,
}
impl Default for PostConfig {
fn default() -> Self {
PostConfig {
concurrency: 8,
host: String::from("localhost"),
port: 8983,
collection: String::from("collection1"),
directory_path: PathBuf::from("./"),
file_extensions: vec![
String::from("xml"),
String::from("json"),
String::from("jsonl"),
String::from("csv"),
String::from("pdf"),
String::from("doc"),
String::from("docx"),
String::from("ppt"),
String::from("pptx"),
String::from("xls"),
String::from("xlsx"),
String::from("odt"),
String::from("odp"),
String::from("ods"),
String::from("ott"),
String::from("otp"),
String::from("ots"),
String::from("rtf"),
String::from("htm"),
String::from("html"),
String::from("txt"),
String::from("log"),
],
update_url: None,
exclued_regex: None,
include_regex: None,
basic_auth_creds: None,
}
}
}
#[allow(clippy::redundant_clone)]
pub async fn solr_post(
config: PostConfig,
mut on_start: Option<Box<dyn FnMut(u64)>>,
mut on_next: Option<Box<dyn FnMut(u64)>>,
mut on_finish: Option<Box<dyn FnMut()>>,
) -> usize {
let file_extensions_joined = config.file_extensions.join(",");
let glob_expression = format!("**/*.{{{}}}", file_extensions_joined);
let glob = Glob::new(glob_expression.as_str()).unwrap();
let files: Vec<Result<WalkEntry, WalkError>> = glob.walk(config.directory_path).collect();
let files_to_index_set: HashSet<String>;
let mut default_headers = header::HeaderMap::new();
if let Some(creds) = &config.basic_auth_creds {
let auth_value = BASE64_STANDARD.encode(creds);
default_headers.insert(
header::AUTHORIZATION,
header::HeaderValue::from_str(&format!("Basic {}", auth_value)).unwrap(),
);
}
let client = Client::builder()
.default_headers(default_headers)
.build()
.unwrap();
let solr_collection_update_endpoint = match &config.update_url {
Some(url) => url.clone(),
None => format!(
"http://{0}:{1}/solr/{2}/update/extract",
config.host, config.port, config.collection
),
};
{
let files_to_index = Arc::new(RwLock::new(HashSet::<String>::new()));
let files_to_index_ref = files_to_index.clone();
files.par_iter().for_each(|file| match file {
Ok(entry) => {
let path = entry.path();
let path_str = path.to_str().unwrap();
let mut file = File::open(path_str).unwrap();
let mut contents = String::new();
file.read_to_string(&mut contents).unwrap();
if let Some(exclude_regex) = config.exclued_regex.as_ref() {
if exclude_regex.is_match(&contents) {
return;
}
}
if let Some(include_regex) = config.include_regex.as_ref() {
if !include_regex.is_match(&contents) {
return;
}
}
let mut files_to_index_set = files_to_index.write().expect("rwlock poisoned");
files_to_index_set.insert(path_str.to_string());
}
Err(e) => println!("error: {:?}", e),
});
let rw_lock_files_set = files_to_index_ref.read().expect("rwlock poisoned");
files_to_index_set = rw_lock_files_set.clone();
}
let total_files_to_index = files_to_index_set.len();
let mut posts = futures::stream::iter(files_to_index_set.into_iter().map(|file| async {
let file_path = Path::new(&file);
let file_path_absolute = file_path.canonicalize().unwrap();
let file_path_encoded = urlencoding::encode(file_path_absolute.to_str().unwrap());
let mut file = File::open(file).unwrap();
let mut contents = String::new();
file.read_to_string(&mut contents).unwrap();
let solr_post_url = format!(
"{0}?resource.name={1}&literal.id={1}",
solr_collection_update_endpoint, file_path_encoded
);
let mime_type = from_path(&file_path_absolute).first_or_octet_stream();
(
client
.post(solr_post_url)
.header(header::CONTENT_TYPE, mime_type.to_string())
.body(contents)
.send()
.await,
file_path_absolute,
)
}))
.buffer_unordered(config.concurrency);
info!("indexing {} files", total_files_to_index);
let mut indexed_count = 0;
if let Some(ref mut on_start) = on_start {
on_start(total_files_to_index as u64);
}
while let Some((res, file_path)) = posts.next().await {
match res {
Ok(response) => {
if response.status().is_success() {
info!("indexed: {}", file_path.to_str().unwrap());
} else {
eprintln!(
"POST {} {}\nIs collection correct?\nfailed to index file: {}",
response.url(),
response.status(),
file_path.to_str().unwrap(),
);
}
indexed_count += 1;
if let Some(ref mut on_next) = on_next {
on_next(indexed_count as u64);
}
}
Err(e) => {
eprintln!("{}\nIs Solr server running and collection available?", e)
}
}
}
let response = client
.get("http://localhost:8983/solr/portal/update?commit=true")
.send()
.await;
match response {
Ok(response) => {
if response.status().is_success() {
info!("commit successful");
} else {
info!("commit failed");
}
}
Err(e) => {
eprintln!("{}\nIs Solr server running and collection available?", e);
}
}
info!("indexing complete");
if let Some(ref mut on_finish) = on_finish {
on_finish();
}
total_files_to_index
}