Skip to main content

opensearch_client/
bulker.rs

1use std::sync::{Arc, Mutex};
2
3use serde::{Deserialize, Serialize};
4use tokio::{
5    sync::mpsc,
6    task::JoinHandle,
7    time::{sleep, Duration},
8};
9use tracing::{debug, error};
10
11use crate::{
12    bulk::{BulkAction, CreateAction, DeleteAction, IndexAction, UpdateAction, UpdateActionBody},
13    Error, OsClient,
14};
15
16#[derive(Debug, Clone, Default, Serialize, Deserialize)]
17pub struct BulkerStatistic {
18    // Number of deleted actions
19    pub delete_actions: u64,
20    // Number of created actions
21    pub create_actions: u64,
22    // Number of updated actions
23    pub update_actions: u64,
24    // Number of index actions
25    pub index_actions: u64,
26    // Number of records in the queue
27    pub queue_size: usize,
28    // Number of running Reqwest calls
29    pub running_reqwest_calls: usize,
30    // Number of total Reqwest calls
31    pub total_reqwest_calls: usize,
32    // Number of finished Reqwest calls
33    pub finished_reqwest_calls: usize,
34    // Number of error calls
35    pub error_reqwest_calls: usize,
36    // Number of action without errors
37    pub success_actions: usize,
38    // Number of action with errors
39    pub error_actions: usize,
40    // Number of creation action with errors
41    pub error_create_actions: usize,
42}
43
44#[derive(Debug, Clone)]
45struct Action {
46    action: BulkAction,
47    document: Option<String>,
48}
49// BulkerBuilder is a helper struct to build a Bulker instance.
50#[derive(Clone)]
51pub struct BulkerBuilder {
52    // OpenSearch Client
53    os_client: Arc<OsClient>,
54    // Max items for bulk
55    bulk_size: u32,
56    // Max parallel bulks
57    max_concurrent_connections: u32,
58}
59
60impl BulkerBuilder {
61    /// Create a new BulkerBuilder instance.
62    pub fn new(os_client: Arc<OsClient>, bulk_size: u32) -> Self {
63        BulkerBuilder {
64            os_client,
65            bulk_size,
66            max_concurrent_connections: 10,
67        }
68    }
69
70    /// Set the bulk size.
71    pub fn bulk_size(mut self, bulk_size: u32) -> Self {
72        self.bulk_size = bulk_size;
73        self
74    }
75
76    /// Set the max concurrent connections.
77    pub fn max_concurrent_connections(mut self, max_concurrent_connections: u32) -> Self {
78        self.max_concurrent_connections = max_concurrent_connections;
79        self
80    }
81
82    /// Build a Bulker instance.
83    pub fn build(self) -> Bulker {
84        let (handle, bulker) = Bulker::new(
85            self.os_client,
86            self.bulk_size,
87            self.max_concurrent_connections,
88        );
89        tokio::spawn(async move {
90            handle.await.unwrap();
91        });
92        bulker
93    }
94}
95
96/// Bulker is a helper struct to make bulk requests to OpenSearch.
97#[derive(Clone)]
98pub struct Bulker {
99    sender: mpsc::Sender<Action>,
100    // Create a shared queue using Arc and Mutex
101    queue: Arc<Mutex<Vec<Action>>>,
102    // OPenSearch Client
103    os_client: Arc<OsClient>,
104    // Max items for bulk
105    bulk_size: u32,
106    // Max parallel bulks
107    // max_concurrent_connections: u32,
108    // statistics
109    statistics: Arc<Mutex<BulkerStatistic>>,
110}
111
112impl Bulker {
113    /// Spawn an extraction service on a separate thread and return an extraction
114    /// instance to interact with it
115    pub fn new(
116        os_client: Arc<OsClient>,
117        bulk_size: u32,
118        max_concurrent_connections: u32,
119    ) -> (JoinHandle<()>, Bulker) {
120        let (sender, receiver) =
121            mpsc::channel::<Action>((bulk_size * (max_concurrent_connections + 1)) as usize);
122        let statistics = Arc::new(Mutex::new(BulkerStatistic::default()));
123        let service = Bulker {
124            bulk_size,
125            // max_concurrent_connections,
126            sender,
127            os_client: os_client.clone(),
128            queue: Arc::new(Mutex::new(Vec::new())),
129            statistics: statistics.clone(),
130        };
131        let queue = service.queue.clone();
132        let o_client = os_client.clone();
133        // Spawn a background task to process the queue and make async Reqwest calls
134        let handle = tokio::spawn(async move {
135            process_queue(
136                queue.clone(),
137                receiver,
138                o_client,
139                bulk_size as usize,
140                max_concurrent_connections as usize,
141                statistics.clone(),
142            )
143            .await
144            .unwrap();
145        });
146
147        (handle, service)
148    }
149
150    pub fn statistics(&self) -> BulkerStatistic {
151        self.statistics.lock().unwrap().clone()
152    }
153
154    /// Sends a bulk index request to OpenSearch with the specified index, id and
155    /// document body.
156    ///
157    /// # Arguments
158    ///
159    /// * `index` - A string slice that holds the name of the index.
160    /// * `id` - An optional string slice that holds the id of the document.
161    /// * `body` - A reference to a serializable document body.
162    ///
163    /// # Returns
164    ///
165    /// Returns () on success, or an `Error` on failure.
166    pub async fn index<T: Serialize>(
167        &self,
168        index: &str,
169        body: &T,
170        id: Option<String>,
171    ) -> Result<(), Error> {
172        let action = BulkAction::Index(IndexAction {
173            index: index.to_owned(),
174            id: id.clone(),
175            pipeline: None,
176        });
177        self.sender
178            .send(Action {
179                action,
180                document: Some(serde_json::to_string(&body)?),
181            })
182            .await
183            .map_err(|e| Error::InternalError(format!("{}", e)))?;
184
185        self.statistics.lock().unwrap().index_actions += 1;
186        Ok(())
187    }
188
189    /// Sends a bulk create request to the OpenSearch cluster with the specified
190    /// index, id and body.
191    ///
192    /// # Arguments
193    ///
194    /// * `index` - A string slice that holds the name of the index.
195    /// * `id` - A string slice that holds the id of the document.
196    /// * `body` - A generic type `T` that holds the body of the document to be
197    ///   created.
198    ///
199    /// # Returns
200    ///
201    /// Returns () on success, or an `Error` on failure.
202    pub async fn create<T: Serialize>(&self, index: &str, id: &str, body: &T) -> Result<(), Error> {
203        let action = BulkAction::Create(CreateAction {
204            index: index.to_owned(),
205            id: id.to_owned(),
206            ..Default::default()
207        });
208        self.sender
209            .send(Action {
210                action,
211                document: Some(serde_json::to_string(&body)?),
212            })
213            .await
214            .map_err(|e| Error::InternalError(format!("{}", e)))?;
215        self.statistics.lock().unwrap().create_actions += 1;
216        Ok(())
217    }
218
219    /// Sends a bulk delete request to the OpenSearch cluster with the specified
220    /// index and id.
221    ///
222    /// # Arguments
223    ///
224    /// * `index` - A string slice that holds the name of the index.
225    /// * `id` - A string slice that holds the id of the document.
226    ///
227    /// # Returns
228    ///
229    /// Returns () on success, or an `Error` on failure.
230    pub async fn delete<T: Serialize>(&self, index: &str, id: &str) -> Result<(), Error> {
231        let action = BulkAction::Delete(DeleteAction {
232            index: index.to_owned(),
233            id: id.to_owned(),
234            ..Default::default()
235        });
236        self.sender
237            .send(Action {
238                action,
239                document: None,
240            })
241            .await
242            .map_err(|e| Error::InternalError(format!("{}", e)))?;
243        self.statistics.lock().unwrap().delete_actions += 1;
244        Ok(())
245    }
246
247    /// Asynchronously updates a document in bulk.
248    ///
249    /// # Arguments
250    ///
251    /// * `index` - A string slice that holds the name of the index.
252    /// * `id` - A string slice that holds the ID of the document to update.
253    /// * `body` - An `UpdateAction` struct that holds the update action to
254    ///   perform.
255    ///
256    /// # Returns
257    ///
258    /// Returns a `Result` containing a `serde_json::Value` on success, or an
259    /// `Error` on failure.
260    pub async fn update(
261        &self,
262        index: &str,
263        id: &str,
264        body: &UpdateActionBody,
265    ) -> Result<(), Error> {
266        let action = BulkAction::Update(UpdateAction {
267            index: index.to_owned(),
268            id: id.to_owned(),
269            ..Default::default()
270        });
271        self.sender
272            .send(Action {
273                action: action,
274                document: Some(serde_json::to_string(&body)?),
275            })
276            .await
277            .map_err(|e| Error::InternalError(format!("{}", e)))?;
278        self.statistics.lock().unwrap().update_actions += 1;
279        Ok(())
280    }
281
282    // wait that every reqwest is completed
283    pub async fn flush(&self) {
284        loop {
285            self.refresh_queue_size();
286            let statistics = self.statistics.lock().unwrap();
287            let status=format!(
288                "Bulker: Finished reqwest calls: {}, Total reqwest calls: {}, Queue size: {}, Running reqwest calls: {}, Error reqwest calls: {}, Success actions: {}, Error actions: {}, Error create actions: {}",
289                statistics.finished_reqwest_calls,
290                statistics.total_reqwest_calls,
291                statistics.queue_size,
292                statistics.running_reqwest_calls,
293                statistics.error_reqwest_calls,
294                statistics.success_actions,
295                statistics.error_actions,
296                statistics.error_create_actions
297            );
298            println!("{}", status);
299            if statistics.finished_reqwest_calls == statistics.total_reqwest_calls
300                && statistics.queue_size == 0
301            {
302                break;
303            }
304            drop(statistics);
305            sleep(Duration::from_secs(1)).await;
306        }
307    }
308
309    fn refresh_queue_size(&self) {
310        // we fresh the queue size
311        let mut statistics = self.statistics.lock().unwrap();
312        statistics.queue_size = self.queue.lock().unwrap().len();
313    }
314}
315
316impl Drop for Bulker {
317    fn drop(&mut self) {
318        tokio::task::block_in_place(|| {
319            tokio::runtime::Handle::current().block_on(async {
320                // Clone the records from the queue to process synchronously
321                let records_to_process: Vec<Action> = self.queue.lock().unwrap().clone();
322                if records_to_process.len() > 0 {
323                    debug!(
324                        "Bulker: Processing remaining records: {:?}",
325                        records_to_process.len()
326                    );
327                    make_reqwest_calls(
328                        self.os_client.clone(),
329                        records_to_process,
330                        self.statistics.clone(),
331                    )
332                    .await;
333                }
334                // Clear the queue
335                self.queue.lock().unwrap().clear();
336            });
337        });
338    }
339}
340
341async fn process_queue(
342    queue: Arc<Mutex<Vec<Action>>>,
343    mut receiver: mpsc::Receiver<Action>,
344    os_client: Arc<OsClient>,
345    bulk_size: usize,
346    max_concurrent_connections: usize,
347    statistics: Arc<Mutex<BulkerStatistic>>,
348) -> Result<(), Error> {
349    let mut reqwest_calls: Vec<tokio::task::JoinHandle<()>> = Vec::new();
350    let mut start = std::time::Instant::now();
351    loop {
352        tokio::select! {
353            // Handle incoming JSON records
354            Some(json_record) = receiver.recv() => {
355                // Put the JSON record in the queue
356                queue.lock().unwrap().push(json_record.clone());
357
358                // Check conditions for making async Reqwest calls
359                let queue_size = queue.lock().unwrap().len();
360                let running_reqwest_calls = reqwest_calls.iter().filter(|task| !task.is_finished()).count();
361                {
362                    let mut statistics = statistics.lock().unwrap();
363                    statistics.queue_size = queue_size;
364                    statistics.running_reqwest_calls = running_reqwest_calls;
365                }
366                let end= std::time::Instant::now();
367
368                if (queue_size >= bulk_size && running_reqwest_calls <= max_concurrent_connections) || (queue_size > 0 && end.duration_since(start).as_secs() > 1 && running_reqwest_calls <= max_concurrent_connections) {
369                    // Clone the records from the queue to process asynchronously
370                    let records_to_process: Vec<Action> = queue.lock().unwrap().clone();
371
372                    // Clear the queue
373                    queue.lock().unwrap().clear();
374                    {
375                      let mut statistics = statistics.lock().unwrap();
376                      statistics.total_reqwest_calls += 1;
377                      statistics.queue_size = 0;
378                    }
379
380                    // Spawn an async task to make the Reqwest calls
381                    reqwest_calls.push(tokio::spawn(make_reqwest_calls(os_client.clone(), records_to_process, statistics.clone())));
382                    start= std::time::Instant::now();
383                }
384            }
385            // Handle elapsed time for Reqwest calls
386            _ = sleep(Duration::from_secs(1)) => {
387                reqwest_calls.retain(|task| !task.is_finished());
388            }
389        }
390        // Check if all the records have been processed or timeout to send data
391        {
392            let end = std::time::Instant::now();
393            let queue_size = queue.lock().unwrap().len();
394            let running_reqwest_calls = reqwest_calls
395                .iter()
396                .filter(|task| !task.is_finished())
397                .count();
398            if (queue_size >= bulk_size && running_reqwest_calls <= max_concurrent_connections)
399                || (queue_size > 0
400                    && end.duration_since(start).as_secs() > 1
401                    && running_reqwest_calls <= max_concurrent_connections)
402            {
403                // Clone the records from the queue to process asynchronously
404                let records_to_process: Vec<Action> = queue.lock().unwrap().clone();
405
406                // Clear the queue
407                queue.lock().unwrap().clear();
408                {
409                    let mut statistics = statistics.lock().unwrap();
410                    statistics.total_reqwest_calls += 1;
411                    statistics.queue_size = 0;
412                }
413
414                // Spawn an async task to make the Reqwest calls
415                reqwest_calls.push(tokio::spawn(make_reqwest_calls(
416                    os_client.clone(),
417                    records_to_process,
418                    statistics.clone(),
419                )));
420                start = std::time::Instant::now();
421            }
422        }
423    }
424}
425
426async fn make_reqwest_calls(
427    os_client: Arc<OsClient>,
428    records: Vec<Action>,
429    statistics: Arc<Mutex<BulkerStatistic>>,
430) {
431    let mut bulker = String::new();
432    let total = &records.len();
433
434    for record in records {
435        let j = serde_json::to_string(&record.action).unwrap();
436        bulker.push_str(j.as_str());
437        bulker.push('\n');
438        if let Some(document) = record.document {
439            bulker.push_str(document.as_str());
440            bulker.push('\n');
441        }
442    }
443
444    match os_client.bulk().body(bulker).call().await {
445        Ok(bulk_response) => {
446            let mut statistics = statistics.lock().unwrap();
447            statistics.finished_reqwest_calls += 1;
448            statistics.success_actions += bulk_response.count_ok();
449            statistics.error_actions += bulk_response.count_errors();
450            statistics.error_create_actions += bulk_response.count_create_errors();
451            debug!(
452                "Request successful for record: {:?}",
453                &bulk_response.items.len()
454            );
455        }
456        Err(e) => {
457            let mut statistics = statistics.lock().unwrap();
458            statistics.total_reqwest_calls += 1;
459            statistics.finished_reqwest_calls += 1;
460            statistics.error_reqwest_calls += 1;
461            statistics.error_actions += total;
462            let message = format!("Error making Reqwest call: {:?}", e);
463            error!(message);
464        }
465    }
466}
467
468#[cfg(test)]
469mod tests {
470    use crate::{ConfigurationBuilder, OsClient};
471    use opensearch_testcontainer::*;
472    use serde_json::json;
473    use std::env;
474    use testcontainers::runners::AsyncRunner;
475    use tracing_test::traced_test;
476    use url::Url;
477
478    async fn get_client() -> OsClient {
479        if let Some(_) = env::var("OPENSEARCH_URL").ok() {
480            let client = OsClient::from_environment().unwrap();
481            return client;
482        } else {
483            let os_image = OpenSearch::default();
484            let opensearch = os_image.clone().start().await.unwrap();
485            let host_port = opensearch.get_host_port_ipv4(9200).await.unwrap();
486
487            let client = ConfigurationBuilder::new()
488                .accept_invalid_certificates(true)
489                .base_url(&format!("https://127.0.0.1:{host_port}"))
490                .basic_auth(os_image.username(), os_image.password())
491                .build();
492            return client;
493        }
494    }
495
496    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
497    #[traced_test]
498    async fn bulker_ingester() -> Result<(), Box<dyn std::error::Error>> {
499        let client = get_client().await;
500
501        let test_size: u32 = 100000;
502        let bulker = client
503            .bulker()
504            .bulk_size(1000)
505            .max_concurrent_connections(10)
506            .build();
507        for i in 0..test_size {
508            bulker
509                .index("test", &json!({"id":i}), Some(i.to_string()))
510                .await
511                .unwrap();
512        }
513        bulker.flush().await;
514        let statitics = bulker.statistics();
515        drop(bulker);
516
517        assert_eq!(statitics.index_actions, test_size as u64);
518        assert_eq!(statitics.create_actions, 0);
519        assert_eq!(statitics.delete_actions, 0);
520        assert_eq!(statitics.update_actions, 0);
521        client.indices().refresh().call().await.unwrap();
522
523        let count = client.count().index("test").call().await.unwrap();
524        assert_eq!(count.count, test_size);
525        Ok(())
526    }
527}