use std::{
fmt, io,
sync::{Arc, Mutex},
};
use crate::Conn;
pub(crate) type LocalInfileInner =
Arc<Mutex<dyn for<'a> FnMut(&'a [u8], &'a mut LocalInfile<'_>) -> io::Result<()> + Send>>;
#[derive(Clone)]
pub struct LocalInfileHandler(pub(crate) LocalInfileInner);
impl LocalInfileHandler {
pub fn new<F>(f: F) -> Self
where
F: for<'a> FnMut(&'a [u8], &'a mut LocalInfile<'_>) -> io::Result<()> + Send + 'static,
{
LocalInfileHandler(Arc::new(Mutex::new(f)))
}
}
impl PartialEq for LocalInfileHandler {
fn eq(&self, other: &LocalInfileHandler) -> bool {
std::ptr::eq(&*self.0, &*other.0)
}
}
impl Eq for LocalInfileHandler {}
impl fmt::Debug for LocalInfileHandler {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(f, "LocalInfileHandler(...)")
}
}
#[derive(Debug)]
pub struct LocalInfile<'a> {
buffer: io::Cursor<&'a mut [u8]>,
conn: &'a mut Conn,
}
impl<'a> LocalInfile<'a> {
pub(crate) const BUFFER_SIZE: usize = 4096;
pub(crate) fn new(buffer: &'a mut [u8; LocalInfile::BUFFER_SIZE], conn: &'a mut Conn) -> Self {
Self {
buffer: io::Cursor::new(buffer),
conn,
}
}
}
impl io::Write for LocalInfile<'_> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.buffer.position() == Self::BUFFER_SIZE as u64 {
self.flush()?;
}
self.buffer.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
let n = self.buffer.position() as usize;
if n > 0 {
let mut range = &self.buffer.get_ref()[..n];
self.conn
.write_packet(&mut range)
.map_err(io::Error::other)?;
}
self.buffer.set_position(0);
Ok(())
}
}