1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
use crate::indexing::{combine_json_objects, index_manager::IndexManager, KNOWN_DESCRIPTORS};
use crate::rest_api::api::{
    BulkRequest, BulkRequestDoc, PostIndexBulkResponseError, PostIndexBulkResponseOk,
    PostIndexBulkResponseOkStatus, PostIndexesBulkIndexResponse,
};
use crate::search::compound_processing::process_cpd;
use crate::search::scaffold_search::{scaffold_search, PARSED_SCAFFOLDS};
use poem_openapi::payload::Json;
use rayon::prelude::*;
use serde_json::{Map, Value};
use std::collections::HashMap;
use tantivy::schema::Field;

pub async fn v1_post_index_bulk(
    index_manager: &IndexManager,
    index: String,
    bulk_request: BulkRequest,
) -> PostIndexesBulkIndexResponse {
    let index = match index_manager.open(&index) {
        Ok(index) => index,
        Err(e) => {
            return PostIndexesBulkIndexResponse::Err(Json(PostIndexBulkResponseError {
                error: e.to_string(),
            }))
        }
    };

    let mut writer = match index.writer(16 * 1024 * 1024) {
        Ok(writer) => writer,
        Err(e) => {
            return PostIndexesBulkIndexResponse::Err(Json(PostIndexBulkResponseError {
                error: e.to_string(),
            }))
        }
    };

    let schema = index.schema();

    let smiles_field = schema.get_field("smiles").unwrap();
    let fingerprint_field = schema.get_field("fingerprint").unwrap();
    let extra_data_field = schema.get_field("extra_data").unwrap();

    let descriptors_fields = KNOWN_DESCRIPTORS
        .iter()
        .map(|kd| (*kd, schema.get_field(kd).unwrap()))
        .collect::<HashMap<&str, Field>>();

    let tantivy_docs_conversion_operation = tokio::task::spawn_blocking(move || {
        bulk_request
            .docs
            .into_par_iter()
            .map(|doc| {
                bulk_request_doc_to_tantivy_doc(
                    doc,
                    smiles_field,
                    fingerprint_field,
                    &descriptors_fields,
                    extra_data_field,
                )
            })
            .collect::<Vec<_>>()
    })
    .await;

    let tantivy_docs = match tantivy_docs_conversion_operation {
        Ok(docs) => docs,
        Err(e) => {
            return PostIndexesBulkIndexResponse::Err(Json(PostIndexBulkResponseError {
                error: e.to_string(),
            }))
        }
    };

    let mut document_insert_statuses = Vec::with_capacity(tantivy_docs.len());

    for doc_conversion_result in tantivy_docs {
        let tantivy_doc = match doc_conversion_result {
            Ok(doc) => doc,
            Err(e) => {
                document_insert_statuses.push(PostIndexBulkResponseOkStatus {
                    opcode: None,
                    error: Some(e.to_string()),
                });
                continue;
            }
        };

        let write_operation = writer.add_document(tantivy_doc);

        let status = match write_operation {
            Ok(opstamp) => PostIndexBulkResponseOkStatus {
                opcode: Some(opstamp),
                error: None,
            },
            Err(e) => PostIndexBulkResponseOkStatus {
                opcode: None,
                error: Some(e.to_string()),
            },
        };
        document_insert_statuses.push(status);
    }

    match writer.commit() {
        Ok(_) => (),
        Err(e) => {
            return PostIndexesBulkIndexResponse::Err(Json(PostIndexBulkResponseError {
                error: e.to_string(),
            }))
        }
    }

    PostIndexesBulkIndexResponse::Ok(Json(PostIndexBulkResponseOk {
        statuses: document_insert_statuses,
    }))
}

fn bulk_request_doc_to_tantivy_doc(
    bulk_request_doc: BulkRequestDoc,
    smiles_field: Field,
    fingerprint_field: Field,
    descriptors_fields: &HashMap<&str, Field>,
    extra_data_field: Field,
) -> eyre::Result<tantivy::Document> {
    // By default, do not attempt to fix problematic molecules
    let (tautomer, fingerprint, descriptors) = process_cpd(&bulk_request_doc.smiles, false)?;

    let json: Value = serde_json::to_value(descriptors)?;
    let jsonified_compound_descriptors: Map<String, Value> = if let Value::Object(map) = json {
        map
    } else {
        return Err(eyre::eyre!("not an object"));
    };

    let mut doc = tantivy::doc!(
        smiles_field => tautomer.as_smiles(),
        fingerprint_field => fingerprint.0.into_vec()
    );

    let scaffolds = &PARSED_SCAFFOLDS;
    let scaffold_matches = scaffold_search(&tautomer, scaffolds)?;

    let scaffold_json = match scaffold_matches.is_empty() {
        true => serde_json::json!({"scaffolds": vec![-1]}),
        false => serde_json::json!({"scaffolds": scaffold_matches}),
    };

    let extra_data_json = combine_json_objects(Some(scaffold_json), bulk_request_doc.extra_data);
    if let Some(extra_data_json) = extra_data_json {
        doc.add_field_value(extra_data_field, extra_data_json);
    }

    for field in KNOWN_DESCRIPTORS {
        if let Some(Value::Number(val)) = jsonified_compound_descriptors.get(field) {
            if field.starts_with("Num") || field.starts_with("lipinski") {
                let int = val.as_f64().unwrap() as i64;
                doc.add_field_value(*descriptors_fields.get(field).unwrap(), int);
            } else {
                doc.add_field_value(
                    *descriptors_fields.get(field).unwrap(),
                    val.as_f64().unwrap(),
                );
            };
        }
    }

    Ok(doc)
}