stash-cli 0.8.0

A local store for pipeline output and ad hoc file snapshots
Documentation
use signal_hook::SigId;
use signal_hook::consts::signal::{SIGINT, SIGTERM};
use signal_hook::low_level;
use std::collections::BTreeMap;
use std::fs::{self, File};
use std::io::{self, Read, Write};
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicI32, Ordering};

use super::PartialSavedError;

struct PartialSaveOptions {
    save_on_error: bool,
    save_empty: bool,
    signal: Option<i32>,
}

pub fn push_from_reader<R: Read>(
    reader: &mut R,
    attrs: BTreeMap<String, String>,
) -> io::Result<String> {
    super::init()?;
    let interrupted = Arc::new(AtomicBool::new(false));
    let signal = Arc::new(AtomicI32::new(0));
    let _signal_guard = SignalGuard::new(&interrupted, &signal)?;
    let id = super::new_ulid()?;
    let data_path = super::tmp_dir()?.join(format!("{id}.data"));
    let data = File::create(&data_path)?;
    run_read_loop(
        reader,
        None,
        data,
        data_path,
        id,
        attrs,
        &interrupted,
        &signal,
        true,
    )
}

pub fn tee_from_reader_partial<R: Read, W: Write>(
    reader: &mut R,
    stdout: &mut W,
    attrs: BTreeMap<String, String>,
    save_on_error: bool,
) -> io::Result<String> {
    super::init()?;
    let interrupted = Arc::new(AtomicBool::new(false));
    let signal = Arc::new(AtomicI32::new(0));
    let _signal_guard = SignalGuard::new(&interrupted, &signal)?;
    let id = super::new_ulid()?;
    let data_path = super::tmp_dir()?.join(format!("{id}.data"));
    let data = File::create(&data_path)?;
    run_read_loop(
        reader,
        Some(stdout as &mut dyn Write),
        data,
        data_path,
        id,
        attrs,
        &interrupted,
        &signal,
        save_on_error,
    )
}

#[inline]
fn check_interrupted(interrupted: &AtomicBool, signal: &AtomicI32) -> Option<(io::Error, i32)> {
    if interrupted.load(Ordering::Relaxed) {
        let signo = signal.load(Ordering::Relaxed);
        let msg = match signo {
            SIGTERM => "terminated by signal",
            _ => "interrupted by signal",
        };
        Some((io::Error::new(io::ErrorKind::Interrupted, msg), signo))
    } else {
        None
    }
}

fn run_read_loop<R: Read>(
    reader: &mut R,
    mut tee: Option<&mut dyn Write>,
    mut data: File,
    data_path: PathBuf,
    id: String,
    attrs: BTreeMap<String, String>,
    interrupted: &Arc<AtomicBool>,
    signal: &Arc<AtomicI32>,
    save_on_error: bool,
) -> io::Result<String> {
    let mut sample = Vec::with_capacity(512);
    let mut total = 0i64;
    let mut buf = [0u8; 65536];
    loop {
        if let Some((err, signo)) = check_interrupted(interrupted, signal) {
            return save_or_abort_partial(
                id,
                data_path,
                &sample,
                total,
                attrs,
                err,
                PartialSaveOptions {
                    save_on_error,
                    save_empty: true,
                    signal: Some(signo),
                },
            );
        }
        let n = match reader.read(&mut buf) {
            Ok(n) => n,
            Err(err) => {
                return save_or_abort_partial(
                    id,
                    data_path,
                    &sample,
                    total,
                    attrs,
                    err,
                    PartialSaveOptions {
                        save_on_error,
                        save_empty: false,
                        signal: None,
                    },
                );
            }
        };
        if n == 0 {
            if let Some((err, signo)) = check_interrupted(interrupted, signal) {
                return save_or_abort_partial(
                    id,
                    data_path,
                    &sample,
                    total,
                    attrs,
                    err,
                    PartialSaveOptions {
                        save_on_error,
                        save_empty: true,
                        signal: Some(signo),
                    },
                );
            }
            break;
        }
        let sample_len = sample.len();
        if sample_len < 512 {
            let need = (512 - sample_len).min(n);
            sample.extend_from_slice(&buf[..need]);
        }
        if let Err(err) = data.write_all(&buf[..n]) {
            let _ = fs::remove_file(&data_path);
            return Err(err);
        }
        total += n as i64;
        if let Some(ref mut out) = tee {
            if let Err(err) = out.write_all(&buf[..n]) {
                drop(data);
                if err.kind() == io::ErrorKind::BrokenPipe {
                    return super::finalize_saved_entry(id, data_path, &sample, total, attrs);
                }
                return save_or_abort_partial(
                    id,
                    data_path,
                    &sample,
                    total,
                    attrs,
                    err,
                    PartialSaveOptions {
                        save_on_error,
                        save_empty: false,
                        signal: None,
                    },
                );
            }
        }
    }
    drop(data);
    super::finalize_saved_entry(id, data_path, &sample, total, attrs)
}

fn save_or_abort_partial(
    id: String,
    data_path: PathBuf,
    sample: &[u8],
    total: i64,
    mut attrs: BTreeMap<String, String>,
    err: io::Error,
    options: PartialSaveOptions,
) -> io::Result<String> {
    if !options.save_on_error || (total == 0 && !options.save_empty) {
        let _ = fs::remove_file(&data_path);
        return Err(err);
    }
    attrs.insert("partial".into(), "true".into());
    super::finalize_saved_entry(id.clone(), data_path, sample, total, attrs)?;
    Err(io::Error::other(PartialSavedError {
        id,
        cause: err,
        signal: options.signal,
    }))
}

struct SignalGuard {
    ids: [SigId; 2],
}

impl SignalGuard {
    fn new(flag: &Arc<AtomicBool>, signal: &Arc<AtomicI32>) -> io::Result<Self> {
        let id0 = register_signal(SIGINT, flag, signal)?;
        let id1 = register_signal(SIGTERM, flag, signal)?;
        Ok(Self { ids: [id0, id1] })
    }
}

impl Drop for SignalGuard {
    fn drop(&mut self) {
        for id in &self.ids {
            low_level::unregister(*id);
        }
    }
}

fn register_signal(
    signo: i32,
    flag: &Arc<AtomicBool>,
    signal: &Arc<AtomicI32>,
) -> io::Result<SigId> {
    let flag = Arc::clone(flag);
    let signal = Arc::clone(signal);
    unsafe {
        low_level::register(signo, move || {
            signal.store(signo, Ordering::Relaxed);
            flag.store(true, Ordering::Relaxed);
        })
    }
    .map_err(io::Error::other)
}