use std::sync::{Arc, Mutex};
use serde::{Deserialize, Serialize};
use tokio::{
sync::mpsc,
task::JoinHandle,
time::{sleep, Duration},
};
use tracing::{debug, error};
use crate::{
bulk::{BulkAction, CreateAction, DeleteAction, IndexAction, UpdateAction, UpdateActionBody},
Error, OsClient,
};
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct BulkerStatistic {
pub delete_actions: u64,
pub create_actions: u64,
pub update_actions: u64,
pub index_actions: u64,
pub queue_size: usize,
pub running_reqwest_calls: usize,
pub total_reqwest_calls: usize,
pub finished_reqwest_calls: usize,
pub error_reqwest_calls: usize,
pub success_actions: usize,
pub error_actions: usize,
pub error_create_actions: usize,
}
#[derive(Debug, Clone)]
struct Action {
action: BulkAction,
document: Option<String>,
}
#[derive(Clone)]
pub struct BulkerBuilder {
os_client: Arc<OsClient>,
bulk_size: u32,
max_concurrent_connections: u32,
}
impl BulkerBuilder {
pub fn new(os_client: Arc<OsClient>, bulk_size: u32) -> Self {
BulkerBuilder {
os_client,
bulk_size,
max_concurrent_connections: 10,
}
}
pub fn bulk_size(mut self, bulk_size: u32) -> Self {
self.bulk_size = bulk_size;
self
}
pub fn max_concurrent_connections(mut self, max_concurrent_connections: u32) -> Self {
self.max_concurrent_connections = max_concurrent_connections;
self
}
pub fn build(self) -> Bulker {
let (handle, bulker) = Bulker::new(
self.os_client,
self.bulk_size,
self.max_concurrent_connections,
);
tokio::spawn(async move {
handle.await.unwrap();
});
bulker
}
}
#[derive(Clone)]
pub struct Bulker {
sender: mpsc::Sender<Action>,
queue: Arc<Mutex<Vec<Action>>>,
os_client: Arc<OsClient>,
bulk_size: u32,
statistics: Arc<Mutex<BulkerStatistic>>,
}
impl Bulker {
pub fn new(
os_client: Arc<OsClient>,
bulk_size: u32,
max_concurrent_connections: u32,
) -> (JoinHandle<()>, Bulker) {
let (sender, receiver) =
mpsc::channel::<Action>((bulk_size * (max_concurrent_connections + 1)) as usize);
let statistics = Arc::new(Mutex::new(BulkerStatistic::default()));
let service = Bulker {
bulk_size,
sender,
os_client: os_client.clone(),
queue: Arc::new(Mutex::new(Vec::new())),
statistics: statistics.clone(),
};
let queue = service.queue.clone();
let o_client = os_client.clone();
let handle = tokio::spawn(async move {
process_queue(
queue.clone(),
receiver,
o_client,
bulk_size as usize,
max_concurrent_connections as usize,
statistics.clone(),
)
.await
.unwrap();
});
(handle, service)
}
pub fn statistics(&self) -> BulkerStatistic {
self.statistics.lock().unwrap().clone()
}
pub async fn index<T: Serialize>(
&self,
index: &str,
body: &T,
id: Option<String>,
) -> Result<(), Error> {
let action = BulkAction::Index(IndexAction {
index: index.to_owned(),
id: id.clone(),
pipeline: None,
});
self.sender
.send(Action {
action,
document: Some(serde_json::to_string(&body)?),
})
.await
.map_err(|e| Error::InternalError(format!("{}", e)))?;
self.statistics.lock().unwrap().index_actions += 1;
Ok(())
}
pub async fn create<T: Serialize>(&self, index: &str, id: &str, body: &T) -> Result<(), Error> {
let action = BulkAction::Create(CreateAction {
index: index.to_owned(),
id: id.to_owned(),
..Default::default()
});
self.sender
.send(Action {
action,
document: Some(serde_json::to_string(&body)?),
})
.await
.map_err(|e| Error::InternalError(format!("{}", e)))?;
self.statistics.lock().unwrap().create_actions += 1;
Ok(())
}
pub async fn delete<T: Serialize>(&self, index: &str, id: &str) -> Result<(), Error> {
let action = BulkAction::Delete(DeleteAction {
index: index.to_owned(),
id: id.to_owned(),
..Default::default()
});
self.sender
.send(Action {
action,
document: None,
})
.await
.map_err(|e| Error::InternalError(format!("{}", e)))?;
self.statistics.lock().unwrap().delete_actions += 1;
Ok(())
}
pub async fn update(
&self,
index: &str,
id: &str,
body: &UpdateActionBody,
) -> Result<(), Error> {
let action = BulkAction::Update(UpdateAction {
index: index.to_owned(),
id: id.to_owned(),
..Default::default()
});
self.sender
.send(Action {
action: action,
document: Some(serde_json::to_string(&body)?),
})
.await
.map_err(|e| Error::InternalError(format!("{}", e)))?;
self.statistics.lock().unwrap().update_actions += 1;
Ok(())
}
pub async fn flush(&self) {
loop {
self.refresh_queue_size();
let statistics = self.statistics.lock().unwrap();
let status=format!(
"Bulker: Finished reqwest calls: {}, Total reqwest calls: {}, Queue size: {}, Running reqwest calls: {}, Error reqwest calls: {}, Success actions: {}, Error actions: {}, Error create actions: {}",
statistics.finished_reqwest_calls,
statistics.total_reqwest_calls,
statistics.queue_size,
statistics.running_reqwest_calls,
statistics.error_reqwest_calls,
statistics.success_actions,
statistics.error_actions,
statistics.error_create_actions
);
println!("{}", status);
if statistics.finished_reqwest_calls == statistics.total_reqwest_calls
&& statistics.queue_size == 0
{
break;
}
drop(statistics);
sleep(Duration::from_secs(1)).await;
}
}
fn refresh_queue_size(&self) {
let mut statistics = self.statistics.lock().unwrap();
statistics.queue_size = self.queue.lock().unwrap().len();
}
}
impl Drop for Bulker {
fn drop(&mut self) {
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async {
let records_to_process: Vec<Action> = self.queue.lock().unwrap().clone();
if records_to_process.len() > 0 {
debug!(
"Bulker: Processing remaining records: {:?}",
records_to_process.len()
);
make_reqwest_calls(
self.os_client.clone(),
records_to_process,
self.statistics.clone(),
)
.await;
}
self.queue.lock().unwrap().clear();
});
});
}
}
async fn process_queue(
queue: Arc<Mutex<Vec<Action>>>,
mut receiver: mpsc::Receiver<Action>,
os_client: Arc<OsClient>,
bulk_size: usize,
max_concurrent_connections: usize,
statistics: Arc<Mutex<BulkerStatistic>>,
) -> Result<(), Error> {
let mut reqwest_calls: Vec<tokio::task::JoinHandle<()>> = Vec::new();
let mut start = std::time::Instant::now();
loop {
tokio::select! {
Some(json_record) = receiver.recv() => {
queue.lock().unwrap().push(json_record.clone());
let queue_size = queue.lock().unwrap().len();
let running_reqwest_calls = reqwest_calls.iter().filter(|task| !task.is_finished()).count();
{
let mut statistics = statistics.lock().unwrap();
statistics.queue_size = queue_size;
statistics.running_reqwest_calls = running_reqwest_calls;
}
let end= std::time::Instant::now();
if (queue_size >= bulk_size && running_reqwest_calls <= max_concurrent_connections) || (queue_size > 0 && end.duration_since(start).as_secs() > 1 && running_reqwest_calls <= max_concurrent_connections) {
let records_to_process: Vec<Action> = queue.lock().unwrap().clone();
queue.lock().unwrap().clear();
{
let mut statistics = statistics.lock().unwrap();
statistics.total_reqwest_calls += 1;
statistics.queue_size = 0;
}
reqwest_calls.push(tokio::spawn(make_reqwest_calls(os_client.clone(), records_to_process, statistics.clone())));
start= std::time::Instant::now();
}
}
_ = sleep(Duration::from_secs(1)) => {
reqwest_calls.retain(|task| !task.is_finished());
}
}
{
let end = std::time::Instant::now();
let queue_size = queue.lock().unwrap().len();
let running_reqwest_calls = reqwest_calls
.iter()
.filter(|task| !task.is_finished())
.count();
if (queue_size >= bulk_size && running_reqwest_calls <= max_concurrent_connections)
|| (queue_size > 0
&& end.duration_since(start).as_secs() > 1
&& running_reqwest_calls <= max_concurrent_connections)
{
let records_to_process: Vec<Action> = queue.lock().unwrap().clone();
queue.lock().unwrap().clear();
{
let mut statistics = statistics.lock().unwrap();
statistics.total_reqwest_calls += 1;
statistics.queue_size = 0;
}
reqwest_calls.push(tokio::spawn(make_reqwest_calls(
os_client.clone(),
records_to_process,
statistics.clone(),
)));
start = std::time::Instant::now();
}
}
}
}
async fn make_reqwest_calls(
os_client: Arc<OsClient>,
records: Vec<Action>,
statistics: Arc<Mutex<BulkerStatistic>>,
) {
let mut bulker = String::new();
let total = &records.len();
for record in records {
let j = serde_json::to_string(&record.action).unwrap();
bulker.push_str(j.as_str());
bulker.push('\n');
if let Some(document) = record.document {
bulker.push_str(document.as_str());
bulker.push('\n');
}
}
match os_client.bulk().body(bulker).call().await {
Ok(bulk_response) => {
let mut statistics = statistics.lock().unwrap();
statistics.finished_reqwest_calls += 1;
statistics.success_actions += bulk_response.count_ok();
statistics.error_actions += bulk_response.count_errors();
statistics.error_create_actions += bulk_response.count_create_errors();
debug!(
"Request successful for record: {:?}",
&bulk_response.items.len()
);
}
Err(e) => {
let mut statistics = statistics.lock().unwrap();
statistics.total_reqwest_calls += 1;
statistics.finished_reqwest_calls += 1;
statistics.error_reqwest_calls += 1;
statistics.error_actions += total;
let message = format!("Error making Reqwest call: {:?}", e);
error!(message);
}
}
}
#[cfg(test)]
mod tests {
use crate::{ConfigurationBuilder, OsClient};
use opensearch_testcontainer::*;
use serde_json::json;
use std::env;
use testcontainers::runners::AsyncRunner;
use tracing_test::traced_test;
use url::Url;
async fn get_client() -> OsClient {
if let Some(_) = env::var("OPENSEARCH_URL").ok() {
let client = OsClient::from_environment().unwrap();
return client;
} else {
let os_image = OpenSearch::default();
let opensearch = os_image.clone().start().await.unwrap();
let host_port = opensearch.get_host_port_ipv4(9200).await.unwrap();
let client = ConfigurationBuilder::new()
.accept_invalid_certificates(true)
.base_url(&format!("https://127.0.0.1:{host_port}"))
.basic_auth(os_image.username(), os_image.password())
.build();
return client;
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[traced_test]
async fn bulker_ingester() -> Result<(), Box<dyn std::error::Error>> {
let client = get_client().await;
let test_size: u32 = 100000;
let bulker = client
.bulker()
.bulk_size(1000)
.max_concurrent_connections(10)
.build();
for i in 0..test_size {
bulker
.index("test", &json!({"id":i}), Some(i.to_string()))
.await
.unwrap();
}
bulker.flush().await;
let statitics = bulker.statistics();
drop(bulker);
assert_eq!(statitics.index_actions, test_size as u64);
assert_eq!(statitics.create_actions, 0);
assert_eq!(statitics.delete_actions, 0);
assert_eq!(statitics.update_actions, 0);
client.indices().refresh().call().await.unwrap();
let count = client.count().index("test").call().await.unwrap();
assert_eq!(count.count, test_size);
Ok(())
}
}