use std::collections::VecDeque;
use std::io::{self, Write};
use std::process::ChildStdin;
use std::sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc, Mutex,
};
#[cfg(not(feature = "unstable"))]
use itertools::Itertools;
#[cfg(not(feature = "unstable"))]
use std::ops::Deref;
use crossbeam_channel::{bounded, Receiver, Sender};
use hashbrown::HashMap;
use crate::common::{CompileError, FileSpec, Notification, Result};
use crate::runtime::{command::command_for_write, Str};
const IO_CHAN_SIZE: usize = 8;
const BUFFER_SIZE: usize = 64 << 10;
pub trait FileFactory: Clone + 'static + Send + Sync {
type Output: io::Write;
type Stdout: io::Write;
fn cmd(&self, cmd: &[u8]) -> io::Result<ChildStdin> {
command_for_write(cmd)
}
fn build(&self, path: &str, spec: FileSpec) -> io::Result<Self::Output>;
fn stdout(&self) -> Self::Stdout;
}
impl<W: io::Write, T: Fn(&str, FileSpec) -> io::Result<W> + Clone + 'static + Send + Sync>
FileFactory for T
{
type Output = W;
type Stdout = grep_cli::StandardStream;
fn build(&self, path: &str, spec: FileSpec) -> io::Result<W> {
(self)(path, spec)
}
fn stdout(&self) -> Self::Stdout {
grep_cli::stdout(termcolor::ColorChoice::Auto)
}
}
type FileWriter = std::fs::File;
fn open_file(path: &str, spec: FileSpec) -> io::Result<FileWriter> {
let file = std::fs::OpenOptions::new()
.write(true)
.create(true)
.append(matches!(spec, FileSpec::Append))
.open(path)?;
Ok(file)
}
pub fn default_factory() -> impl FileFactory {
open_file
}
pub fn factory_from_file(fname: &str) -> io::Result<impl FileFactory> {
let _file = open_file(fname, FileSpec::Trunc)?;
#[derive(Clone)]
struct FileStdout(String);
impl FileFactory for FileStdout {
type Output = FileWriter;
type Stdout = FileWriter;
fn build(&self, path: &str, spec: FileSpec) -> io::Result<Self::Output> {
open_file(path, spec)
}
fn stdout(&self) -> Self::Stdout {
open_file(self.0.as_str(), FileSpec::Append).expect("failed to open stdout")
}
}
Ok(FileStdout(fname.into()))
}
fn build_handle<W: io::Write, F: Fn(FileSpec) -> io::Result<W> + Send + 'static>(
f: F,
is_stdout: bool,
) -> RawHandle {
let (sender, receiver) = bounded(IO_CHAN_SIZE);
let error = Arc::new(Mutex::new(None));
let receiver_error = error.clone();
std::thread::spawn(move || receive_thread(receiver, receiver_error, f));
RawHandle {
error,
sender,
line_buffer: is_stdout && grep_cli::is_tty_stdout(),
}
}
pub struct Registry {
global: Arc<dyn Root>,
files: HashMap<Str<'static>, FileHandle>,
cmds: HashMap<Str<'static>, FileHandle>,
stdout: FileHandle,
}
impl Registry {
pub fn from_factory(f: impl FileFactory) -> Registry {
let root_impl = RootImpl::from_factory(f);
let stdout = root_impl.get_stdout().into_handle();
Registry {
global: Arc::new(root_impl),
files: Default::default(),
cmds: Default::default(),
stdout,
}
}
pub fn get_handle<'a>(
&mut self,
name: Option<&Str<'a>>,
fspec: FileSpec,
) -> Result<&mut FileHandle> {
let name = if let Some(s) = name {
s
} else {
return self.get_file(None);
};
match fspec {
FileSpec::Cmd => self.get_cmd(name),
FileSpec::Trunc | FileSpec::Append => self.get_file(Some(name)),
}
}
pub fn close<'a>(&mut self, path_or_cmd: &Str<'a>) -> Result<()> {
if let Some(fh) = self.files.get_mut(&path_or_cmd.clone().unmoor()) {
fh.close()?;
return Ok(());
}
if let Some(ch) = self.cmds.get_mut(&path_or_cmd.clone().unmoor()) {
ch.close()?;
return Ok(());
}
path_or_cmd.with_bytes(|bs| self.global.close(bs))
}
pub fn get_cmd<'a>(&mut self, cmd: &Str<'a>) -> Result<&mut FileHandle> {
use hashbrown::hash_map::Entry;
let global = &self.global;
match self.cmds.entry(cmd.clone().unmoor()) {
Entry::Occupied(o) => Ok(o.into_mut()),
Entry::Vacant(v) => {
Ok(v.insert(cmd.with_bytes(|bs| global.get_command(bs)).into_handle()))
}
}
}
pub fn get_file<'a>(&mut self, name: Option<&Str<'a>>) -> Result<&mut FileHandle> {
match name {
Some(path) => {
use hashbrown::hash_map::Entry;
let global = &self.global;
match self.files.entry(path.clone().unmoor()) {
Entry::Occupied(o) => Ok(o.into_mut()),
Entry::Vacant(v) => {
let raw = path.with_bytes(|bs| match std::str::from_utf8(bs) {
Ok(s) => Ok(global.get_handle(s)),
Err(e) => err!("invalid UTF8 in filename: {}", e),
})?;
Ok(v.insert(raw.into_handle()))
}
}
}
None => Ok(&mut self.stdout),
}
}
pub fn destroy_and_flush_all_files(&mut self) -> Result<()> {
let mut last_error = Ok(());
for (_, mut fh) in self.files.drain().chain(self.cmds.drain()) {
let res = fh.flush();
if res.is_err() {
last_error = res;
}
}
last_error
}
}
impl Clone for Registry {
fn clone(&self) -> Registry {
Registry {
global: self.global.clone(),
files: Default::default(),
cmds: Default::default(),
stdout: self.stdout.raw().into_handle(),
}
}
}
trait Root: 'static + Send + Sync {
fn get_command(&self, cmd: &[u8]) -> RawHandle;
fn get_handle(&self, fname: &str) -> RawHandle;
fn get_stdout(&self) -> RawHandle;
fn close(&self, fname: &[u8]) -> Result<()>;
}
struct RootImpl<F> {
handles: Mutex<HashMap<String, RawHandle>>,
commands: Mutex<HashMap<Box<[u8]>, RawHandle>>,
stdout_raw: RawHandle,
file_factory: F,
}
impl<F: FileFactory> RootImpl<F> {
fn from_factory(file_factory: F) -> RootImpl<F> {
let local_factory = file_factory.clone();
let stdout_raw = build_handle(
move |_append| Ok(local_factory.stdout()),
true,
);
RootImpl {
handles: Default::default(),
commands: Default::default(),
stdout_raw,
file_factory,
}
}
}
impl<F: FileFactory> Root for RootImpl<F> {
fn close(&self, fname: &[u8]) -> Result<()> {
let mut handle = None;
{
let cmds = self.commands.lock().unwrap();
if let Some(h) = cmds.get(fname) {
handle = Some(h.clone());
}
}
if let Some(h) = handle.take() {
h.into_handle().close()?;
return Ok(());
}
{
let fname = if let Ok(s) = std::str::from_utf8(fname) {
s
} else {
return Ok(());
};
let files = self.handles.lock().unwrap();
if let Some(h) = files.get(fname) {
handle = Some(h.clone());
}
}
if let Some(h) = handle.take() {
h.into_handle().close()?;
return Ok(());
}
Ok(())
}
fn get_command(&self, cmd: &[u8]) -> RawHandle {
let mut cmds = self.commands.lock().unwrap();
if let Some(h) = cmds.get(cmd) {
return h.clone();
}
let local_factory = self.file_factory.clone();
let local_name = Box::<[u8]>::from(cmd);
let global_name = local_name.clone();
let handle = build_handle(
move |_| local_factory.cmd(&*local_name),
true,
);
let _old = cmds.insert(global_name, handle.clone());
debug_assert!(
_old.is_none(),
"duplicate insertion of command: {}",
String::from_utf8_lossy(cmd)
);
handle
}
fn get_handle(&self, fname: &str) -> RawHandle {
let mut handles = self.handles.lock().unwrap();
if let Some(h) = handles.get(fname) {
return h.clone();
}
let local_factory = self.file_factory.clone();
let local_name = String::from(fname);
let global_name = local_name.clone();
let handle = build_handle(
move |append| local_factory.build(local_name.as_str(), append),
false,
);
handles.insert(global_name, handle.clone());
handle
}
fn get_stdout(&self) -> RawHandle {
self.stdout_raw.clone()
}
}
pub struct FileHandle {
raw: RawHandle,
#[allow(clippy::vec_box)]
old_guards: Vec<Box<WriteGuard>>,
guards: VecDeque<Box<WriteGuard>>,
cur_batch: Box<WriteGuard>,
}
impl FileHandle {
fn raw(&self) -> RawHandle {
self.raw.clone()
}
fn clear_guards(&mut self) -> Result<()> {
let mut done_count = 0;
for (i, guard) in self.guards.iter().enumerate() {
match guard.status() {
RequestStatus::Ongoing => break,
RequestStatus::Ok => done_count = i,
RequestStatus::Error => return Err(self.read_error()),
}
}
for _ in 0..done_count {
let old = self.guards.pop_front().unwrap();
if self.old_guards.len() < IO_CHAN_SIZE {
self.old_guards.push(old);
}
}
Ok(())
}
fn guard(&mut self) -> Box<WriteGuard> {
if let Some(mut g) = self.old_guards.pop() {
g.activate();
g
} else {
Box::new(WriteGuard::default())
}
}
fn read_error(&self) -> CompileError {
const BAD_SHUTDOWN_MSG: &str = "internal error: (writer?) thread did not shut down cleanly";
if let Ok(lock) = self.raw.error.lock() {
match &*lock {
Some(err) => err.clone(),
None => CompileError(BAD_SHUTDOWN_MSG.into()),
}
} else {
CompileError(BAD_SHUTDOWN_MSG.into())
}
}
fn clear_batch(&mut self, upto: Option<usize>) -> Result<()> {
if self.cur_batch.data.is_empty() {
return Ok(());
}
let (flush, upto) = if let Some(ix) = upto {
(true, ix)
} else {
(false, self.cur_batch.data.len())
};
self.clear_guards()?;
let mut next_batch = self.guard();
self.cur_batch.peel(upto, &mut *next_batch);
let req = self.cur_batch.request(flush);
self.raw.sender.send(req).unwrap();
std::mem::swap(&mut next_batch, &mut self.cur_batch);
self.guards.push_back(next_batch);
Ok(())
}
pub fn write_all<'a>(&mut self, ss: &[&Str<'a>], spec: FileSpec) -> Result<()> {
let cur_len = self.cur_batch.data.len();
let mut added_bytes = 0;
let mut last_line = None;
for s in ss.iter() {
let bs = unsafe { &*s.get_bytes() };
self.cur_batch.extend(&*bs, spec);
if self.raw.line_buffer {
if let Some(ix) = memchr::memchr(b'\n', bs) {
last_line = Some(cur_len + added_bytes + ix + 1);
}
}
added_bytes += bs.len();
}
if (self.raw.line_buffer && last_line.is_some()) || (added_bytes + cur_len > BUFFER_SIZE) {
self.clear_batch(last_line)?;
}
Ok(())
}
pub fn write<'a>(&mut self, s: &Str<'a>, spec: FileSpec) -> Result<()> {
self.write_all(&[s], spec)
}
pub fn flush(&mut self) -> Result<()> {
self.clear_batch(None)?;
let (n, req) = Request::flush();
self.raw.sender.send(req).unwrap();
n.1.wait();
self.guards.clear();
if let RequestStatus::Error = n.0.read() {
Err(self.read_error())
} else {
Ok(())
}
}
pub fn close(&mut self) -> Result<()> {
self.clear_batch(None)?;
self.raw.sender.send(Request::Close).unwrap();
Ok(())
}
}
impl Drop for FileHandle {
fn drop(&mut self) {
let _ = self.flush();
self.cur_batch.status.set_ok();
}
}
#[derive(Default)]
struct ErrorCode(AtomicUsize);
#[derive(Debug)]
enum RequestStatus {
Ongoing = 0,
Ok = 1,
Error = 2,
}
impl ErrorCode {
fn read(&self) -> RequestStatus {
match self.0.load(Ordering::Acquire) {
0 => RequestStatus::Ongoing,
1 => RequestStatus::Ok,
2 => RequestStatus::Error,
_ => unreachable!(),
}
}
fn set_ok(&self) {
self.0.store(RequestStatus::Ok as usize, Ordering::Release);
}
fn set_error(&self) {
self.0
.store(RequestStatus::Error as usize, Ordering::Release);
}
}
enum Request {
Write {
data: *const [u8],
status: *const ErrorCode,
spec: FileSpec,
flush: bool,
},
Flush(Arc<(ErrorCode, Notification)>),
Close,
}
unsafe impl Send for Request {}
impl Request {
fn flush() -> (Arc<(ErrorCode, Notification)>, Request) {
let notify = Arc::new((ErrorCode::default(), Notification::default()));
let req = Request::Flush(notify.clone());
(notify, req)
}
fn size(&self) -> usize {
match self {
Request::Write { data, .. } => unsafe { &**data }.len(),
Request::Flush(_) | Request::Close => 0,
}
}
fn set_code(&self, mut f: impl FnMut(&ErrorCode)) {
match self {
Request::Write { status, .. } => f(unsafe { &**status }),
Request::Flush(n) => {
f(&n.0);
n.1.notify();
}
Request::Close => {}
}
}
}
impl Drop for Request {
fn drop(&mut self) {
match self {
Request::Write { status, .. } => {
let status = unsafe { &**status }.read();
assert!(!matches!(status, RequestStatus::Ongoing));
}
Request::Flush(n) => {
assert!(n.1.has_been_notified());
}
Request::Close => {}
}
}
}
#[derive(Default)]
struct WriteGuard {
data: Vec<u8>,
status: ErrorCode,
spec: FileSpec,
}
impl WriteGuard {
fn extend(&mut self, bs: &[u8], spec: FileSpec) {
self.data.extend(bs);
self.spec = spec;
}
fn peel(&mut self, bytes: usize, next: &mut WriteGuard) {
if bytes < self.data.len() {
next.data.extend(self.data[bytes..].iter().cloned());
self.data.truncate(bytes);
}
}
fn request(&self, flush: bool) -> Request {
Request::Write {
data: &self.data[..],
status: &self.status,
spec: self.spec,
flush,
}
}
fn status(&self) -> RequestStatus {
self.status.read()
}
fn activate(&mut self) {
self.status = ErrorCode::default();
self.spec = FileSpec::Trunc;
self.data.clear();
}
}
impl Drop for WriteGuard {
fn drop(&mut self) {
let status = self.status();
assert!(!matches!(status, RequestStatus::Ongoing));
}
}
#[derive(Clone)]
struct RawHandle {
error: Arc<Mutex<Option<CompileError>>>,
sender: Sender<Request>,
line_buffer: bool,
}
impl RawHandle {
fn into_handle(self) -> FileHandle {
FileHandle {
cur_batch: Default::default(),
raw: self,
guards: Default::default(),
old_guards: Default::default(),
}
}
}
#[derive(Default)]
struct WriteBatch {
io_vec: Vec<io::IoSlice<'static>>,
requests: Vec<Request>,
n_writes: usize,
write_bytes: usize,
flush: bool,
close: bool,
}
#[cfg(feature = "unstable")]
fn write_all(batch: &mut WriteBatch, w: &mut impl Write) -> io::Result< bool> {
w.write_all_vectored(&mut batch.io_vec[..]).map(|_| true)
}
#[cfg(not(feature = "unstable"))]
fn write_all(batch: &mut WriteBatch, w: &mut impl Write) -> io::Result< bool> {
batch
.io_vec
.iter()
.map(|slice| {
let buf: &[u8] = slice.deref();
w.write_all(buf)
})
.fold_results(true, |a, _| a)
}
impl WriteBatch {
fn n_writes(&self) -> usize {
self.n_writes
}
fn issue(&mut self, w: &mut impl Write) -> io::Result< bool> {
if self.write_bytes > 0 {
write_all(self, w)?;
}
if self.flush || self.close {
w.flush()?;
}
let close = self.close;
self.clear();
Ok(close)
}
fn get_spec(&self) -> FileSpec {
for req in self.requests.iter() {
if let Request::Write { spec, .. } = req {
return *spec;
}
}
Default::default()
}
fn push(&mut self, req: Request) -> bool {
match &req {
Request::Write { data, flush, .. } => {
let io_slice = io::IoSlice::new(unsafe { &**data });
self.write_bytes += io_slice.len();
self.io_vec.push(io_slice);
self.n_writes += 1;
self.flush |= *flush;
}
Request::Flush(_) => self.flush = true,
Request::Close => self.close = true,
};
self.requests.push(req);
self.flush || self.close
}
fn clear_batch(&mut self, mut f: impl FnMut(&ErrorCode)) {
self.io_vec.clear();
for req in self.requests.drain(..) {
req.set_code(&mut f)
}
self.close = false;
self.flush = false;
self.n_writes = 0;
self.write_bytes = 0;
}
fn clear_error(&mut self) {
self.clear_batch(ErrorCode::set_error)
}
fn clear(&mut self) {
self.clear_batch(ErrorCode::set_ok)
}
}
fn receive_thread<W: io::Write>(
receiver: Receiver<Request>,
error: Arc<Mutex<Option<CompileError>>>,
f: impl Fn(FileSpec) -> io::Result<W>,
) {
let mut batch = WriteBatch::default();
if let Err(e) = receive_loop(&receiver, &mut batch, f) {
{
let mut err = error.lock().unwrap();
*err = Some(CompileError(format!("{}", e)));
}
batch.clear_error();
while let Ok(req) = receiver.recv() {
req.set_code(ErrorCode::set_error)
}
}
}
fn receive_loop<W: io::Write>(
receiver: &Receiver<Request>,
batch: &mut WriteBatch,
f: impl Fn(FileSpec) -> io::Result<W>,
) -> io::Result<()> {
const MAX_BATCH_BYTES: usize = 1 << 20;
const MAX_BATCH_SIZE: usize = 1 << 10;
let mut writer = None;
while let Ok(req) = receiver.recv() {
let mut batch_bytes = req.size();
if !batch.push(req) {
while let Ok(req) = receiver.try_recv() {
batch_bytes += req.size();
if batch.push(req)
|| batch.n_writes() >= MAX_BATCH_SIZE
|| batch_bytes >= MAX_BATCH_BYTES
{
break;
}
}
}
if writer.is_none() {
if batch.n_writes() == 0 {
batch.clear();
continue;
}
writer = Some(f(batch.get_spec())?);
}
if batch.issue(writer.as_mut().unwrap())? {
writer = None;
}
}
Ok(())
}
pub mod testing {
use super::*;
#[derive(Clone, Default)]
pub struct FakeFs {
pub stdout: FakeFile,
named: Arc<Mutex<HashMap<String, FakeFile>>>,
}
impl FakeFs {
pub fn get_handle(&self, path: &str) -> Option<FakeFile> {
self.named.lock().unwrap().get(path).cloned()
}
}
impl FileFactory for FakeFs {
type Output = FakeFile;
type Stdout = FakeFile;
fn build(&self, path: &str, spec: FileSpec) -> io::Result<Self::Output> {
let mut named = self.named.lock().unwrap();
if let Some(file) = named.get(path) {
file.reopen(spec);
return Ok(file.clone());
}
let new_file = FakeFile::default();
named.insert(path.into(), new_file.clone());
Ok(new_file)
}
fn stdout(&self) -> Self::Stdout {
self.stdout.clone()
}
}
#[derive(Default)]
struct FakeFileInner {
data: Mutex<Vec<u8>>,
poison: AtomicBool,
}
impl FakeFileInner {
fn result(&self) -> io::Result<()> {
if self.poison.load(Ordering::Acquire) {
Err(io::Error::new(io::ErrorKind::Other, "poisoned fake file!"))
} else {
Ok(())
}
}
}
#[derive(Clone, Default)]
pub struct FakeFile(Arc<FakeFileInner>);
impl FakeFile {
pub fn set_poison(&self, p: bool) {
self.0.poison.store(p, Ordering::Release);
}
pub fn read_data(&self) -> Vec<u8> {
(*self.0.data.lock().unwrap()).clone()
}
pub fn reopen(&self, spec: FileSpec) {
if !matches!(spec, FileSpec::Append) {
self.clear();
}
}
pub fn clear(&self) {
self.0.data.lock().unwrap().clear();
}
}
impl Write for FakeFile {
fn write(&mut self, bytes: &[u8]) -> io::Result<usize> {
self.0.result()?;
self.0.data.lock().unwrap().extend(bytes);
Ok(bytes.len())
}
fn flush(&mut self) -> io::Result<()> {
self.0.result()?;
Ok(())
}
fn write_vectored(&mut self, bufs: &[io::IoSlice]) -> io::Result<usize> {
self.0.result()?;
let mut written = 0;
let mut data = self.0.data.lock().unwrap();
for b in bufs {
let bytes: &[u8] = &*b;
data.extend(bytes);
written += bytes.len();
}
Ok(written)
}
}
}
#[cfg(test)]
mod tests {
use super::testing::*;
use super::*;
#[test]
fn basic_writing() {
let s1 = Str::from("hello");
let s2 = Str::from(" there");
let fs = FakeFs::default();
let mut reg = Registry::from_factory(fs.clone());
{
let handle = reg
.get_handle( None, FileSpec::default())
.unwrap();
handle.write(&s1, FileSpec::Append).unwrap();
handle.write(&s2, FileSpec::Append).unwrap();
handle.flush().unwrap();
handle.write(&s1, FileSpec::Append).unwrap();
handle.write(&s2, FileSpec::Append).unwrap();
handle.flush().unwrap();
}
let data = fs.stdout.read_data();
assert_eq!(&data[..], "hello therehello there".as_bytes());
}
#[test]
fn reopen_named_file() {
let fname_str = "/fake";
let fname = Str::from(fname_str);
let s1 = Str::from("hello");
let s2 = Str::from(" there");
let fs = FakeFs::default();
let mut reg = Registry::from_factory(fs.clone());
{
let handle = reg.get_handle(Some(&fname), FileSpec::default()).unwrap();
handle.write(&s1, FileSpec::Append).unwrap();
handle.write(&s2, FileSpec::Append).unwrap();
handle.flush().unwrap();
handle.write(&s1, FileSpec::Append).unwrap();
handle.write(&s2, FileSpec::Append).unwrap();
}
{
let handle = reg.get_handle(Some(&fname), FileSpec::default()).unwrap();
handle.close().unwrap();
handle.write(&s1, FileSpec::Trunc).unwrap();
handle.write(&s2, FileSpec::Trunc).unwrap();
handle.flush().unwrap();
}
let data = fs.get_handle(fname_str).unwrap().read_data();
assert_eq!(&data[..], "hello there".as_bytes());
}
#[test]
fn multithreaded_write() {
const N_THREADS: usize = 100;
const WRITES_PER_THREAD: usize = 1000;
let fs = FakeFs::default();
fs.build("/fake/BAD", FileSpec::Trunc)
.unwrap()
.set_poison(true);
let mut threads = Vec::with_capacity(N_THREADS);
{
let reg = Registry::from_factory(fs.clone());
for t in 0..N_THREADS {
let mut treg = reg.clone();
threads.push(std::thread::spawn(move || {
let a = Str::from("A");
let b = Str::from("B");
let fa = Str::from("/fake/A");
let fb = Str::from("/fake/B");
let fbad = Str::from("/fake/BAD");
for i in 0..WRITES_PER_THREAD {
{
let h1 = treg.get_handle(Some(&fa), FileSpec::default()).unwrap();
h1.write(&a, FileSpec::Append).unwrap();
if (t + i) % 100 == 0 {
h1.close().unwrap();
}
}
{
let h2 = treg.get_handle(Some(&fb), FileSpec::default()).unwrap();
h2.write(&b, FileSpec::Trunc).unwrap();
if (t + i) % 105 == 0 {
h2.flush().unwrap();
}
}
{
let h3 = treg.get_handle(Some(&fbad), FileSpec::default()).unwrap();
let _ = h3.write(&a, FileSpec::Append);
if (t + i) % 103 == 0 {
assert!(h3.flush().is_err())
}
}
}
}));
}
}
for t in threads.into_iter() {
t.join().unwrap();
}
let expected_a = vec![b'A'; N_THREADS * WRITES_PER_THREAD];
let expected_b = vec![b'B'; N_THREADS * WRITES_PER_THREAD];
assert_eq!(fs.get_handle("/fake/A").unwrap().read_data(), expected_a);
assert_eq!(fs.get_handle("/fake/B").unwrap().read_data(), expected_b);
}
}