1use std::sync::{Arc, Mutex};
2
3use serde::{Deserialize, Serialize};
4use tokio::{
5 sync::mpsc,
6 task::JoinHandle,
7 time::{sleep, Duration},
8};
9use tracing::{debug, error};
10
11use crate::{
12 bulk::{BulkAction, CreateAction, DeleteAction, IndexAction, UpdateAction, UpdateActionBody},
13 Error, OsClient,
14};
15
16#[derive(Debug, Clone, Default, Serialize, Deserialize)]
17pub struct BulkerStatistic {
18 pub delete_actions: u64,
20 pub create_actions: u64,
22 pub update_actions: u64,
24 pub index_actions: u64,
26 pub queue_size: usize,
28 pub running_reqwest_calls: usize,
30 pub total_reqwest_calls: usize,
32 pub finished_reqwest_calls: usize,
34 pub error_reqwest_calls: usize,
36 pub success_actions: usize,
38 pub error_actions: usize,
40 pub error_create_actions: usize,
42}
43
44#[derive(Debug, Clone)]
45struct Action {
46 action: BulkAction,
47 document: Option<String>,
48}
49#[derive(Clone)]
51pub struct BulkerBuilder {
52 os_client: Arc<OsClient>,
54 bulk_size: u32,
56 max_concurrent_connections: u32,
58}
59
60impl BulkerBuilder {
61 pub fn new(os_client: Arc<OsClient>, bulk_size: u32) -> Self {
63 BulkerBuilder {
64 os_client,
65 bulk_size,
66 max_concurrent_connections: 10,
67 }
68 }
69
70 pub fn bulk_size(mut self, bulk_size: u32) -> Self {
72 self.bulk_size = bulk_size;
73 self
74 }
75
76 pub fn max_concurrent_connections(mut self, max_concurrent_connections: u32) -> Self {
78 self.max_concurrent_connections = max_concurrent_connections;
79 self
80 }
81
82 pub fn build(self) -> Bulker {
84 let (handle, bulker) = Bulker::new(
85 self.os_client,
86 self.bulk_size,
87 self.max_concurrent_connections,
88 );
89 tokio::spawn(async move {
90 handle.await.unwrap();
91 });
92 bulker
93 }
94}
95
96#[derive(Clone)]
98pub struct Bulker {
99 sender: mpsc::Sender<Action>,
100 queue: Arc<Mutex<Vec<Action>>>,
102 os_client: Arc<OsClient>,
104 bulk_size: u32,
106 statistics: Arc<Mutex<BulkerStatistic>>,
110}
111
112impl Bulker {
113 pub fn new(
116 os_client: Arc<OsClient>,
117 bulk_size: u32,
118 max_concurrent_connections: u32,
119 ) -> (JoinHandle<()>, Bulker) {
120 let (sender, receiver) =
121 mpsc::channel::<Action>((bulk_size * (max_concurrent_connections + 1)) as usize);
122 let statistics = Arc::new(Mutex::new(BulkerStatistic::default()));
123 let service = Bulker {
124 bulk_size,
125 sender,
127 os_client: os_client.clone(),
128 queue: Arc::new(Mutex::new(Vec::new())),
129 statistics: statistics.clone(),
130 };
131 let queue = service.queue.clone();
132 let o_client = os_client.clone();
133 let handle = tokio::spawn(async move {
135 process_queue(
136 queue.clone(),
137 receiver,
138 o_client,
139 bulk_size as usize,
140 max_concurrent_connections as usize,
141 statistics.clone(),
142 )
143 .await
144 .unwrap();
145 });
146
147 (handle, service)
148 }
149
150 pub fn statistics(&self) -> BulkerStatistic {
151 self.statistics.lock().unwrap().clone()
152 }
153
154 pub async fn index<T: Serialize>(
167 &self,
168 index: &str,
169 body: &T,
170 id: Option<String>,
171 ) -> Result<(), Error> {
172 let action = BulkAction::Index(IndexAction {
173 index: index.to_owned(),
174 id: id.clone(),
175 pipeline: None,
176 });
177 self.sender
178 .send(Action {
179 action,
180 document: Some(serde_json::to_string(&body)?),
181 })
182 .await
183 .map_err(|e| Error::InternalError(format!("{}", e)))?;
184
185 self.statistics.lock().unwrap().index_actions += 1;
186 Ok(())
187 }
188
189 pub async fn create<T: Serialize>(&self, index: &str, id: &str, body: &T) -> Result<(), Error> {
203 let action = BulkAction::Create(CreateAction {
204 index: index.to_owned(),
205 id: id.to_owned(),
206 ..Default::default()
207 });
208 self.sender
209 .send(Action {
210 action,
211 document: Some(serde_json::to_string(&body)?),
212 })
213 .await
214 .map_err(|e| Error::InternalError(format!("{}", e)))?;
215 self.statistics.lock().unwrap().create_actions += 1;
216 Ok(())
217 }
218
219 pub async fn delete<T: Serialize>(&self, index: &str, id: &str) -> Result<(), Error> {
231 let action = BulkAction::Delete(DeleteAction {
232 index: index.to_owned(),
233 id: id.to_owned(),
234 ..Default::default()
235 });
236 self.sender
237 .send(Action {
238 action,
239 document: None,
240 })
241 .await
242 .map_err(|e| Error::InternalError(format!("{}", e)))?;
243 self.statistics.lock().unwrap().delete_actions += 1;
244 Ok(())
245 }
246
247 pub async fn update(
261 &self,
262 index: &str,
263 id: &str,
264 body: &UpdateActionBody,
265 ) -> Result<(), Error> {
266 let action = BulkAction::Update(UpdateAction {
267 index: index.to_owned(),
268 id: id.to_owned(),
269 ..Default::default()
270 });
271 self.sender
272 .send(Action {
273 action: action,
274 document: Some(serde_json::to_string(&body)?),
275 })
276 .await
277 .map_err(|e| Error::InternalError(format!("{}", e)))?;
278 self.statistics.lock().unwrap().update_actions += 1;
279 Ok(())
280 }
281
282 pub async fn flush(&self) {
284 loop {
285 self.refresh_queue_size();
286 let statistics = self.statistics.lock().unwrap();
287 let status=format!(
288 "Bulker: Finished reqwest calls: {}, Total reqwest calls: {}, Queue size: {}, Running reqwest calls: {}, Error reqwest calls: {}, Success actions: {}, Error actions: {}, Error create actions: {}",
289 statistics.finished_reqwest_calls,
290 statistics.total_reqwest_calls,
291 statistics.queue_size,
292 statistics.running_reqwest_calls,
293 statistics.error_reqwest_calls,
294 statistics.success_actions,
295 statistics.error_actions,
296 statistics.error_create_actions
297 );
298 println!("{}", status);
299 if statistics.finished_reqwest_calls == statistics.total_reqwest_calls
300 && statistics.queue_size == 0
301 {
302 break;
303 }
304 drop(statistics);
305 sleep(Duration::from_secs(1)).await;
306 }
307 }
308
309 fn refresh_queue_size(&self) {
310 let mut statistics = self.statistics.lock().unwrap();
312 statistics.queue_size = self.queue.lock().unwrap().len();
313 }
314}
315
316impl Drop for Bulker {
317 fn drop(&mut self) {
318 tokio::task::block_in_place(|| {
319 tokio::runtime::Handle::current().block_on(async {
320 let records_to_process: Vec<Action> = self.queue.lock().unwrap().clone();
322 if records_to_process.len() > 0 {
323 debug!(
324 "Bulker: Processing remaining records: {:?}",
325 records_to_process.len()
326 );
327 make_reqwest_calls(
328 self.os_client.clone(),
329 records_to_process,
330 self.statistics.clone(),
331 )
332 .await;
333 }
334 self.queue.lock().unwrap().clear();
336 });
337 });
338 }
339}
340
341async fn process_queue(
342 queue: Arc<Mutex<Vec<Action>>>,
343 mut receiver: mpsc::Receiver<Action>,
344 os_client: Arc<OsClient>,
345 bulk_size: usize,
346 max_concurrent_connections: usize,
347 statistics: Arc<Mutex<BulkerStatistic>>,
348) -> Result<(), Error> {
349 let mut reqwest_calls: Vec<tokio::task::JoinHandle<()>> = Vec::new();
350 let mut start = std::time::Instant::now();
351 loop {
352 tokio::select! {
353 Some(json_record) = receiver.recv() => {
355 queue.lock().unwrap().push(json_record.clone());
357
358 let queue_size = queue.lock().unwrap().len();
360 let running_reqwest_calls = reqwest_calls.iter().filter(|task| !task.is_finished()).count();
361 {
362 let mut statistics = statistics.lock().unwrap();
363 statistics.queue_size = queue_size;
364 statistics.running_reqwest_calls = running_reqwest_calls;
365 }
366 let end= std::time::Instant::now();
367
368 if (queue_size >= bulk_size && running_reqwest_calls <= max_concurrent_connections) || (queue_size > 0 && end.duration_since(start).as_secs() > 1 && running_reqwest_calls <= max_concurrent_connections) {
369 let records_to_process: Vec<Action> = queue.lock().unwrap().clone();
371
372 queue.lock().unwrap().clear();
374 {
375 let mut statistics = statistics.lock().unwrap();
376 statistics.total_reqwest_calls += 1;
377 statistics.queue_size = 0;
378 }
379
380 reqwest_calls.push(tokio::spawn(make_reqwest_calls(os_client.clone(), records_to_process, statistics.clone())));
382 start= std::time::Instant::now();
383 }
384 }
385 _ = sleep(Duration::from_secs(1)) => {
387 reqwest_calls.retain(|task| !task.is_finished());
388 }
389 }
390 {
392 let end = std::time::Instant::now();
393 let queue_size = queue.lock().unwrap().len();
394 let running_reqwest_calls = reqwest_calls
395 .iter()
396 .filter(|task| !task.is_finished())
397 .count();
398 if (queue_size >= bulk_size && running_reqwest_calls <= max_concurrent_connections)
399 || (queue_size > 0
400 && end.duration_since(start).as_secs() > 1
401 && running_reqwest_calls <= max_concurrent_connections)
402 {
403 let records_to_process: Vec<Action> = queue.lock().unwrap().clone();
405
406 queue.lock().unwrap().clear();
408 {
409 let mut statistics = statistics.lock().unwrap();
410 statistics.total_reqwest_calls += 1;
411 statistics.queue_size = 0;
412 }
413
414 reqwest_calls.push(tokio::spawn(make_reqwest_calls(
416 os_client.clone(),
417 records_to_process,
418 statistics.clone(),
419 )));
420 start = std::time::Instant::now();
421 }
422 }
423 }
424}
425
426async fn make_reqwest_calls(
427 os_client: Arc<OsClient>,
428 records: Vec<Action>,
429 statistics: Arc<Mutex<BulkerStatistic>>,
430) {
431 let mut bulker = String::new();
432 let total = &records.len();
433
434 for record in records {
435 let j = serde_json::to_string(&record.action).unwrap();
436 bulker.push_str(j.as_str());
437 bulker.push('\n');
438 if let Some(document) = record.document {
439 bulker.push_str(document.as_str());
440 bulker.push('\n');
441 }
442 }
443
444 match os_client.bulk().body(bulker).call().await {
445 Ok(bulk_response) => {
446 let mut statistics = statistics.lock().unwrap();
447 statistics.finished_reqwest_calls += 1;
448 statistics.success_actions += bulk_response.count_ok();
449 statistics.error_actions += bulk_response.count_errors();
450 statistics.error_create_actions += bulk_response.count_create_errors();
451 debug!(
452 "Request successful for record: {:?}",
453 &bulk_response.items.len()
454 );
455 }
456 Err(e) => {
457 let mut statistics = statistics.lock().unwrap();
458 statistics.total_reqwest_calls += 1;
459 statistics.finished_reqwest_calls += 1;
460 statistics.error_reqwest_calls += 1;
461 statistics.error_actions += total;
462 let message = format!("Error making Reqwest call: {:?}", e);
463 error!(message);
464 }
465 }
466}
467
468#[cfg(test)]
469mod tests {
470 use crate::{ConfigurationBuilder, OsClient};
471 use opensearch_testcontainer::*;
472 use serde_json::json;
473 use std::env;
474 use testcontainers::runners::AsyncRunner;
475 use tracing_test::traced_test;
476 use url::Url;
477
478 async fn get_client() -> OsClient {
479 if let Some(_) = env::var("OPENSEARCH_URL").ok() {
480 let client = OsClient::from_environment().unwrap();
481 return client;
482 } else {
483 let os_image = OpenSearch::default();
484 let opensearch = os_image.clone().start().await.unwrap();
485 let host_port = opensearch.get_host_port_ipv4(9200).await.unwrap();
486
487 let client = ConfigurationBuilder::new()
488 .accept_invalid_certificates(true)
489 .base_url(&format!("https://127.0.0.1:{host_port}"))
490 .basic_auth(os_image.username(), os_image.password())
491 .build();
492 return client;
493 }
494 }
495
496 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
497 #[traced_test]
498 async fn bulker_ingester() -> Result<(), Box<dyn std::error::Error>> {
499 let client = get_client().await;
500
501 let test_size: u32 = 100000;
502 let bulker = client
503 .bulker()
504 .bulk_size(1000)
505 .max_concurrent_connections(10)
506 .build();
507 for i in 0..test_size {
508 bulker
509 .index("test", &json!({"id":i}), Some(i.to_string()))
510 .await
511 .unwrap();
512 }
513 bulker.flush().await;
514 let statitics = bulker.statistics();
515 drop(bulker);
516
517 assert_eq!(statitics.index_actions, test_size as u64);
518 assert_eq!(statitics.create_actions, 0);
519 assert_eq!(statitics.delete_actions, 0);
520 assert_eq!(statitics.update_actions, 0);
521 client.indices().refresh().call().await.unwrap();
522
523 let count = client.count().index("test").call().await.unwrap();
524 assert_eq!(count.count, test_size);
525 Ok(())
526 }
527}