use std::net::{IpAddr, Ipv6Addr};
use chrono::{DateTime, NaiveDate, NaiveTime, TimeZone, Utc};
use futures_util::TryStreamExt;
use serde_json::Value;
use sqlx::SqlitePool;
use tantivy::query::QueryParserError;
use tantivy::schema::Field;
use tantivy::{Document, TantivyError};
use super::api_counterfeit::search::{
self, PaginatedHits, Pagination, Query, TantivyQueryBuilder, TopDocsPaginatedSearcher,
};
use super::errors::ApiError;
use super::types::RecordDocument;
const DATE_FORMAT: &str = "%Y-%m-%d";
const DOCUMENT_FIELD_CREATED_AT: &str = "__created_at";
const DOCUMENT_FIELD_UPDATED_AT: &str = "__updated_at";
impl From<TantivyError> for ApiError {
fn from(error: TantivyError) -> Self {
error!("TantivyError: {error}");
Self::ServerError
}
}
impl From<QueryParserError> for ApiError {
fn from(error: QueryParserError) -> Self {
error!("Tantivy QueryParserError: {error}");
Self::ServerError
}
}
pub struct Index<'a> {
inner: tantivy::Index,
schema: &'a search::Schema,
id_field: Field,
}
impl<'a> Index<'a> {
#[allow(clippy::needless_pass_by_value, clippy::cast_possible_truncation)]
pub fn search(&self, query: Query, pagination: Pagination) -> Result<PaginatedHits<String>, ApiError> {
trace!("Executing query: {query:?}");
let query = TantivyQueryBuilder::new(&self.inner, self.schema).build(query)?;
let searcher = TopDocsPaginatedSearcher {
searcher: self.inner.reader()?.searcher(),
query,
id_field: self.id_field,
pagination_limit: 1000,
};
let hits: PaginatedHits<Vec<u8>> = match pagination {
Pagination::Forward { first, after: None } => searcher.search_forward(first as usize)?,
Pagination::Forward {
first,
after: Some(after),
} => searcher.search_forward_after(first as usize, &after.try_into()?)?,
Pagination::Backward { last, before } => {
searcher.search_backward_before(last as usize, &before.try_into()?)?
}
};
Ok(hits.map_id(|id| String::from_utf8(id).unwrap()))
}
pub async fn build(
pool: &SqlitePool,
entity_type: &str,
config: &'a search::Config,
) -> Result<Index<'a>, ApiError> {
let schema = &config
.indices
.get(entity_type)
.ok_or_else(|| {
error!("Unknown index: {entity_type}");
ApiError::ServerError
})?
.schema;
trace!("Building index for {entity_type} and schema:\n{schema:?}");
let (index, fields) = search::open_index(schema)?;
let id_field = index.schema().get_field(search::ID_FIELD).unwrap();
let mut writer = index.writer_with_num_threads(1, 20_000_000)?;
let mut fut = sqlx::query_as(
r#"
SELECT pk AS id, document
FROM records WHERE entity_type = $1 AND pk = sk
"#,
)
.bind(entity_type)
.fetch(pool);
let mut record_count: usize = 0;
while let Some::<RecordDocument>(record) = fut.try_next().await? {
record_count += 1;
let mut doc = Document::default();
doc.add_bytes(id_field, record.id.as_bytes());
for field in &fields {
add_field(&mut doc, field, &record).map_err(|err| {
error!("{:?} for record '{}' on field '{}'", err, &record.id, &field.name);
ApiError::ServerError
})?;
}
writer.add_document(doc)?;
}
writer.commit()?;
trace!("Indexed {record_count} documents.");
Ok(Index {
inner: index,
schema,
id_field,
})
}
}
fn add_field(
doc: &mut Document,
search::IndexedField {
name,
doc_key,
tokenized_doc_key,
ty,
}: &search::IndexedField,
RecordDocument { document, .. }: &RecordDocument,
) -> Result<(), FieldError> {
let document_field_name = match name.as_str() {
"createdAt" => DOCUMENT_FIELD_CREATED_AT,
"updatedAt" => DOCUMENT_FIELD_UPDATED_AT,
name => name,
};
if let Some(value) = document.get(document_field_name) {
use search::FieldType::{
Boolean, Date, DateTime, Email, Float, IPAddress, Int, PhoneNumber, String, Timestamp, URL,
};
let field = *doc_key;
match ty {
URL { .. } | Email { .. } | String { .. } => {
let tokenized_doc_key = tokenized_doc_key.unwrap();
for value in DynamoItemExt::flatten(value) {
let value = DynamoItemExt::to_str(value)?;
doc.add_text(field, value);
doc.add_text(tokenized_doc_key, value);
}
}
PhoneNumber { .. } => {
for value in DynamoItemExt::flatten(value) {
doc.add_text(field, DynamoItemExt::to_str(value)?);
}
}
Int { .. } => {
for value in DynamoItemExt::flatten(value) {
doc.add_i64(field, DynamoItemExt::to_i64(value)?);
}
}
Float { .. } => {
for value in DynamoItemExt::flatten(value) {
doc.add_f64(field, DynamoItemExt::to_f64(value)?);
}
}
DateTime { .. } => {
for value in DynamoItemExt::flatten(value) {
doc.add_date(
field,
tantivy::DateTime::from_timestamp_millis(DynamoItemExt::to_datetime(value)?.timestamp_millis()),
);
}
}
Date { .. } => {
for value in DynamoItemExt::flatten(value) {
let value = Utc.from_utc_datetime(
&DynamoItemExt::to_date(value)?.and_time(NaiveTime::from_hms_opt(0, 0, 0).expect("Valid time")),
);
doc.add_date(
field,
tantivy::DateTime::from_timestamp_millis(value.timestamp_millis()),
);
}
}
Timestamp { .. } => {
for value in DynamoItemExt::flatten(value) {
doc.add_date(
field,
tantivy::DateTime::from_timestamp_millis(
DynamoItemExt::to_timestamp(value)?.timestamp_millis(),
),
);
}
}
Boolean { .. } => {
for value in DynamoItemExt::flatten(value) {
doc.add_bool(field, DynamoItemExt::to_bool(value)?);
}
}
IPAddress { .. } => {
for value in DynamoItemExt::flatten(value) {
doc.add_ip_addr(field, DynamoItemExt::to_ipaddr(value)?);
}
}
};
}
Ok(())
}
#[derive(Debug)]
pub enum FieldError {
MissingValue,
InvalidValue,
}
type FieldResult<T> = Result<T, FieldError>;
struct DynamoItemExt;
impl DynamoItemExt {
fn flatten(value: &Value) -> Vec<&Value> {
value
.get("L")
.and_then(serde_json::Value::as_array)
.map(|array| array.iter().flat_map(DynamoItemExt::flatten).collect::<Vec<_>>())
.unwrap_or(vec![value])
}
fn to_str(value: &Value) -> FieldResult<&str> {
value
.get("S")
.ok_or(FieldError::MissingValue)?
.as_str()
.ok_or(FieldError::InvalidValue)
}
fn to_i64(value: &Value) -> FieldResult<i64> {
value
.get("N")
.ok_or(FieldError::MissingValue)?
.as_str()
.ok_or(FieldError::InvalidValue)?
.parse::<i64>()
.map_err(|_| FieldError::InvalidValue)
}
fn to_u64(value: &Value) -> FieldResult<u64> {
value
.get("N")
.ok_or(FieldError::MissingValue)?
.as_str()
.ok_or(FieldError::InvalidValue)?
.parse::<u64>()
.map_err(|_| FieldError::InvalidValue)
}
fn to_f64(value: &Value) -> FieldResult<f64> {
value
.get("N")
.ok_or(FieldError::MissingValue)?
.as_str()
.ok_or(FieldError::InvalidValue)?
.parse::<f64>()
.map_err(|_| FieldError::InvalidValue)
}
fn to_bool(value: &Value) -> FieldResult<bool> {
value
.get("BOOL")
.ok_or(FieldError::MissingValue)?
.as_bool()
.ok_or(FieldError::InvalidValue)
}
fn to_datetime(value: &Value) -> FieldResult<DateTime<Utc>> {
DynamoItemExt::to_str(value)?
.parse::<DateTime<Utc>>()
.map_err(|_| FieldError::InvalidValue)
}
fn to_date(value: &Value) -> FieldResult<NaiveDate> {
NaiveDate::parse_from_str(DynamoItemExt::to_str(value)?, DATE_FORMAT).map_err(|_| FieldError::InvalidValue)
}
fn to_timestamp(value: &Value) -> FieldResult<DateTime<Utc>> {
i64::try_from(DynamoItemExt::to_u64(value)? * 1_000_000)
.map_err(|_| FieldError::InvalidValue)
.map(|nanos| Utc.timestamp_nanos(nanos))
}
fn to_ipaddr(value: &Value) -> FieldResult<Ipv6Addr> {
DynamoItemExt::to_str(value)?
.parse::<IpAddr>()
.map_err(|_| FieldError::InvalidValue)
.map(|ip| match ip {
IpAddr::V4(addr) => addr.to_ipv6_mapped(),
IpAddr::V6(addr) => addr,
})
}
}
#[cfg(test)]
mod tests {
use std::net::Ipv4Addr;
use super::*;
#[test]
fn test_dynamo_item_ext() {
let record: Value = serde_json::from_str(
r#"
{
"date": {"S": "2007-12-03"},
"float": {"N": "98.293"},
"text": {"S": "Dogs are the best!"},
"int": {"N": "8179"},
"timestamp": {"N": "1451653820000"},
"ip": {"S": "127.0.0.1"},
"datetime": {"S": "2016-01-01T13:10:20+00:00"},
"bool": {"BOOL": true}
}
"#,
)
.unwrap();
assert_eq!(
DynamoItemExt::to_str(record.get("text").unwrap()).unwrap(),
"Dogs are the best!"
);
assert_eq!(DynamoItemExt::to_i64(record.get("int").unwrap()).unwrap(), 8179);
assert!((DynamoItemExt::to_f64(record.get("float").unwrap()).unwrap() - 98.293).abs() < f64::EPSILON);
assert!(DynamoItemExt::to_bool(record.get("bool").unwrap()).unwrap());
assert_eq!(
DynamoItemExt::to_timestamp(record.get("timestamp").unwrap()).unwrap(),
Utc.timestamp_nanos(1_451_653_820_000_000_000i64)
);
assert_eq!(
DynamoItemExt::to_date(record.get("date").unwrap()).unwrap(),
NaiveDate::from_ymd_opt(2007, 12, 3).unwrap()
);
assert_eq!(
DynamoItemExt::to_datetime(record.get("datetime").unwrap()).unwrap(),
"2016-01-01T13:10:20.000Z".parse::<DateTime<Utc>>().unwrap()
);
assert_eq!(
DynamoItemExt::to_ipaddr(record.get("ip").unwrap()).unwrap().to_string(),
Ipv4Addr::new(127, 0, 0, 1).to_ipv6_mapped().to_string()
);
let value: Value = serde_json::from_str(
r#"
{"L": [{"L": [{"S": "first"}]}, {"L": [{"S": "second"}]}]}
"#,
)
.unwrap();
let result = DynamoItemExt::flatten(&value)
.iter()
.map(|value| DynamoItemExt::to_str(value).unwrap())
.collect::<Vec<_>>();
assert_eq!(result, vec!["first", "second"]);
}
}