espipe 0.2.0

A minimalist command-line utility to pipe documents from a file or I/O stream into an Elasticsearch cluster.
mod bulk_response;

use super::{BulkAction, Sender};
use bulk_response::BulkResponse;
use elasticsearch::{
    Elasticsearch,
    http::{Method, StatusCode, headers::HeaderMap, headers::HeaderValue},
};
use eyre::{OptionExt, Result, eyre};
use futures::{StreamExt, stream::FuturesUnordered};
use serde_json::{Value, json, value::RawValue};
use std::{sync::Arc, time::Duration};
use tokio::{
    sync::mpsc,
    task::JoinHandle,
    time::sleep,
};
use url::Url;

const DEFAULT_BATCH_SIZE: usize = 5_000;
const DEFAULT_MAX_INFLIGHT_REQUESTS: usize = 16;

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct ElasticsearchOutputConfig {
    batch_size: usize,
    max_inflight_requests: usize,
}

impl ElasticsearchOutputConfig {
    pub const DEFAULT_BATCH_SIZE: usize = DEFAULT_BATCH_SIZE;
    pub const DEFAULT_MAX_INFLIGHT_REQUESTS: usize = DEFAULT_MAX_INFLIGHT_REQUESTS;

    pub fn try_new(batch_size: usize, max_inflight_requests: usize) -> Result<Self> {
        if batch_size == 0 {
            return Err(eyre!("batch size must be greater than zero"));
        }
        if max_inflight_requests == 0 {
            return Err(eyre!("max requests must be greater than zero"));
        }

        Ok(Self {
            batch_size,
            max_inflight_requests,
        })
    }

    fn channel_capacity(self) -> usize {
        self.batch_size
    }
}

impl Default for ElasticsearchOutputConfig {
    fn default() -> Self {
        Self {
            batch_size: DEFAULT_BATCH_SIZE,
            max_inflight_requests: DEFAULT_MAX_INFLIGHT_REQUESTS,
        }
    }
}

#[derive(Debug)]
pub struct ElasticsearchOutput {
    hostname: String,
    index: String,
    sender: Option<mpsc::Sender<Box<RawValue>>>,
    worker: JoinHandle<Result<usize>>,
}

impl ElasticsearchOutput {
    pub fn try_new(
        client: Elasticsearch,
        url: Url,
        action: BulkAction,
        config: ElasticsearchOutputConfig,
    ) -> Result<Self> {
        let hostname = url
            .host_str()
            .ok_or_eyre("Url missing host_str")?
            .to_string();
        let index = url.path().trim_start_matches('/').to_string();
        log::debug!("Elasticsearch output to {hostname}/{index}");

        let client = Arc::new(client);
        let (sender, receiver) = mpsc::channel(config.channel_capacity());
        let worker = tokio::spawn(run_bulk_worker(
            Arc::clone(&client),
            hostname.clone(),
            index.clone(),
            action,
            config,
            receiver,
        ));

        Ok(Self {
            hostname,
            index,
            sender: Some(sender),
            worker,
        })
    }
}

impl Sender for ElasticsearchOutput {
    async fn send(&mut self, value: Box<RawValue>) -> Result<usize> {
        let sender = self
            .sender
            .as_ref()
            .ok_or_eyre("Elasticsearch output already closed")?;
        sender
            .send(value)
            .await
            .map_err(|_| eyre!("Elasticsearch output worker closed unexpectedly"))?;
        Ok(0)
    }

    async fn close(mut self) -> Result<usize> {
        self.sender.take();
        self.worker.await.map_err(eyre::Report::new)?
    }
}

impl std::fmt::Display for ElasticsearchOutput {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
        write!(f, "{}:{}", self.hostname, self.index)
    }
}

async fn run_bulk_worker(
    client: Arc<Elasticsearch>,
    hostname: String,
    index: String,
    action: BulkAction,
    config: ElasticsearchOutputConfig,
    mut receiver: mpsc::Receiver<Box<RawValue>>,
) -> Result<usize> {
    let mut batch = Vec::with_capacity(config.batch_size);
    let mut docs_sent = 0usize;
    let mut inflight = FuturesUnordered::<JoinHandle<Result<usize>>>::new();

    while let Some(doc) = receiver.recv().await {
        batch.push(doc);
        if batch.len() >= config.batch_size {
            spawn_flush(&mut inflight, &client, &hostname, &index, action, config, &mut batch)?;
            docs_sent += reap_inflight_if_needed(&mut inflight, config.max_inflight_requests).await?;
        }
    }

    if !batch.is_empty() {
        spawn_flush(&mut inflight, &client, &hostname, &index, action, config, &mut batch)?;
    }

    while let Some(result) = inflight.next().await {
        docs_sent += result.map_err(eyre::Report::new)??;
    }

    Ok(docs_sent)
}

fn spawn_flush(
    inflight: &mut FuturesUnordered<JoinHandle<Result<usize>>>,
    client: &Arc<Elasticsearch>,
    hostname: &str,
    index: &str,
    action: BulkAction,
    config: ElasticsearchOutputConfig,
    batch: &mut Vec<Box<RawValue>>,
) -> Result<()> {
    let docs = std::mem::replace(batch, Vec::with_capacity(config.batch_size));
    let body = build_bulk_body(action, &docs)?;
    log::debug!("Bulk sending {} docs to {hostname}/{index}", docs.len());
    let client = Arc::clone(client);
    let index = index.to_string();

    inflight.push(tokio::spawn(async move {
        let mut headers = HeaderMap::new();
        headers.insert("content-type", HeaderValue::from_static("application/x-ndjson"));

        let mut attempt = 0u64;
        let mut backoff = Duration::from_secs(1);
        let max_backoff = Duration::from_secs(30);

        loop {
            attempt += 1;
            let response = client
                .send(
                    Method::Post,
                    &format!("/{index}/_bulk"),
                    headers.clone(),
                    Option::<&()>::None,
                    Some(body.clone()),
                    None,
                )
                .await?;

            let status_code = response.status_code();
            let bulk_response = response.json::<BulkResponse>().await?;
            match status_code {
                StatusCode::BAD_REQUEST => {
                    log::error!(
                        "Bulk response: 400 - Bad request ({})",
                        bulk_response.error_cause()
                    );
                    return Ok(0);
                }
                StatusCode::TOO_MANY_REQUESTS => {
                    log::warn!(
                        "Bulk response: 429 - Too many requests (attempt {attempt}, backoff {:?}): {}",
                        backoff,
                        bulk_response.error_cause()
                    );
                    sleep(backoff).await;
                    if backoff < max_backoff {
                        backoff = std::cmp::min(backoff * 2, max_backoff);
                    }
                }
                _ => {
                    log::debug!("Bulk response status: {status_code}");
                    if bulk_response.has_errors() {
                        log::warn!(
                            "Bulk response contained errors: {}",
                            bulk_response.error_counts()
                        );
                    }
                    return Ok(bulk_response.success_count());
                }
            }
        }
    }));

    Ok(())
}

async fn reap_inflight_if_needed(
    inflight: &mut FuturesUnordered<JoinHandle<Result<usize>>>,
    max_inflight_requests: usize,
) -> Result<usize> {
    let mut docs_sent = 0usize;
    while inflight.len() >= max_inflight_requests {
        if let Some(result) = inflight.next().await {
            docs_sent += result.map_err(eyre::Report::new)??;
        }
    }
    Ok(docs_sent)
}

fn build_bulk_body(action: BulkAction, batch: &[Box<RawValue>]) -> Result<Vec<u8>> {
    let mut body = Vec::with_capacity(batch.len() * 64);
    for doc in batch {
        match action {
            BulkAction::Create => {
                body.extend_from_slice(b"{\"create\":{}}\n");
                body.extend_from_slice(doc.get().as_bytes());
                body.push(b'\n');
            }
            BulkAction::Index => {
                body.extend_from_slice(b"{\"index\":{}}\n");
                body.extend_from_slice(doc.get().as_bytes());
                body.push(b'\n');
            }
            BulkAction::Update => append_update_operation(&mut body, doc)?,
        }
    }
    Ok(body)
}

fn append_update_operation(body: &mut Vec<u8>, doc: &RawValue) -> Result<()> {
    let (id, doc) = extract_update_id(doc)?;
    body.extend_from_slice(b"{\"update\":{\"_id\":");
    serde_json::to_writer(&mut *body, &id)?;
    body.extend_from_slice(b"}}\n");
    serde_json::to_writer(&mut *body, &json!({ "doc": doc }))?;
    body.push(b'\n');
    Ok(())
}

fn extract_update_id(doc: &RawValue) -> Result<(String, Value)> {
    match serde_json::from_str::<Value>(doc.get())? {
        Value::Object(mut map) => {
            let id_value = map
                .remove("_id")
                .ok_or_eyre("Update action requires an _id field on each document")?;
            let id = id_value
                .as_str()
                .ok_or_eyre("Update action requires _id to be a string")?
                .to_string();
            Ok((id, Value::Object(map)))
        }
        _ => Err(eyre!(
            "Update action requires each document to be a JSON object"
        )),
    }
}

#[cfg(test)]
mod tests {
    use super::{
        DEFAULT_BATCH_SIZE, DEFAULT_MAX_INFLIGHT_REQUESTS, ElasticsearchOutputConfig,
        build_bulk_body, extract_update_id,
    };
    use crate::output::BulkAction;
    use serde_json::{Value, json, value::RawValue};

    #[test]
    fn build_bulk_body_uses_create_ndjson() {
        let docs = vec![
            RawValue::from_string("{\"a\":1}".to_string()).unwrap(),
            RawValue::from_string("{\"b\":2}".to_string()).unwrap(),
        ];

        let body = build_bulk_body(BulkAction::Create, &docs).unwrap();
        assert_eq!(
            String::from_utf8(body).unwrap(),
            "{\"create\":{}}\n{\"a\":1}\n{\"create\":{}}\n{\"b\":2}\n"
        );
    }

    #[test]
    fn build_bulk_body_uses_index_ndjson() {
        let docs = vec![RawValue::from_string("{\"a\":1}".to_string()).unwrap()];
        let body = build_bulk_body(BulkAction::Index, &docs).unwrap();
        assert_eq!(
            String::from_utf8(body).unwrap(),
            "{\"index\":{}}\n{\"a\":1}\n"
        );
    }

    #[test]
    fn build_bulk_body_wraps_update_docs() {
        let docs = vec![
            RawValue::from_string("{\"_id\":\"1\",\"a\":1}".to_string()).unwrap(),
        ];
        let body = build_bulk_body(BulkAction::Update, &docs).unwrap();
        let lines: Vec<Value> = String::from_utf8(body)
            .unwrap()
            .lines()
            .map(|line| serde_json::from_str(line).unwrap())
            .collect();
        assert_eq!(lines[0]["update"]["_id"], "1");
        assert_eq!(lines[1], json!({ "doc": { "a": 1 } }));
    }

    #[test]
    fn extract_update_id_requires_id() {
        let doc = RawValue::from_string("{\"message\":\"hello\"}".to_string()).unwrap();
        let err = extract_update_id(&doc).err().expect("expected error");
        assert!(err.to_string().contains("_id"));
    }

    #[test]
    fn default_worker_limits_are_bounded() {
        let config = ElasticsearchOutputConfig::default();
        assert_eq!(config.batch_size, DEFAULT_BATCH_SIZE);
        assert_eq!(config.channel_capacity(), DEFAULT_BATCH_SIZE);
        assert_eq!(config.max_inflight_requests, DEFAULT_MAX_INFLIGHT_REQUESTS);
    }

    #[test]
    fn config_rejects_zero_limits() {
        let batch_err = ElasticsearchOutputConfig::try_new(0, 1).unwrap_err();
        assert!(batch_err.to_string().contains("batch size"));

        let requests_err = ElasticsearchOutputConfig::try_new(1, 0).unwrap_err();
        assert!(requests_err.to_string().contains("max requests"));
    }
}