use std::path::{Path, PathBuf};
use bytes::BytesMut;
use tokio::fs::File;
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWriteExt};
use crate::error::{Error, Result};
use crate::smtp::codec::LineReader;
#[derive(Debug)]
pub enum DataOutcome {
Complete(CapturedSource),
SizeExceeded,
Eof,
}
#[derive(Debug)]
pub enum CapturedSource {
InMemory(BytesMut),
OnDisk(PathBuf, u64),
}
impl CapturedSource {
pub fn size_bytes(&self) -> u64 {
match self {
CapturedSource::InMemory(b) => b.len() as u64,
CapturedSource::OnDisk(_, n) => *n,
}
}
}
pub struct DataReadCfg {
pub max_line: usize,
pub max_bytes: u64,
pub spill_at: usize,
pub spill_dir: PathBuf,
}
pub async fn read_data<R: AsyncRead + Unpin>(
reader: &mut LineReader<R>,
cfg: &DataReadCfg,
) -> Result<DataOutcome> {
let mut mem = BytesMut::with_capacity(cfg.spill_at.min(64 * 1024));
let mut spill: Option<(File, PathBuf)> = None;
let mut total: u64 = 0;
loop {
let line = read_data_line(reader, cfg.max_line).await?;
let line = match line {
Some(l) => l,
None => return Ok(DataOutcome::Eof),
};
if line.as_slice() == b"." {
return Ok(DataOutcome::Complete(finalize(mem, spill, total).await?));
}
let body: &[u8] = if line.first() == Some(&b'.') {
&line[1..]
} else {
&line
};
let line_len = body.len() as u64 + 2; if total + line_len > cfg.max_bytes {
return Ok(DataOutcome::SizeExceeded);
}
total += line_len;
if let Some((file, _)) = spill.as_mut() {
file.write_all(body).await?;
file.write_all(b"\r\n").await?;
} else if mem.len() + body.len() + 2 > cfg.spill_at {
tokio::fs::create_dir_all(&cfg.spill_dir).await?;
let path = cfg
.spill_dir
.join(format!("{}.tmp", uuid::Uuid::new_v4()));
let mut f = File::create(&path).await?;
f.write_all(&mem).await?;
f.write_all(body).await?;
f.write_all(b"\r\n").await?;
spill = Some((f, path));
mem = BytesMut::new();
} else {
mem.extend_from_slice(body);
mem.extend_from_slice(b"\r\n");
}
}
}
async fn finalize(
mem: BytesMut,
spill: Option<(File, PathBuf)>,
total: u64,
) -> Result<CapturedSource> {
if let Some((mut f, path)) = spill {
f.flush().await?;
drop(f);
Ok(CapturedSource::OnDisk(path, total))
} else {
Ok(CapturedSource::InMemory(mem))
}
}
async fn read_data_line<R: AsyncRead + Unpin>(
reader: &mut LineReader<R>,
max_line: usize,
) -> Result<Option<Vec<u8>>> {
let inner = reader.as_buf_mut();
let mut buf = Vec::with_capacity(128);
loop {
let chunk = inner.fill_buf().await?;
if chunk.is_empty() {
return if buf.is_empty() {
Ok(None)
} else {
Err(Error::SmtpProto("DATA unterminated at EOF".into()))
};
}
let mut consumed = 0;
for (i, b) in chunk.iter().enumerate() {
consumed = i + 1;
if *b == b'\n' {
if let Some(prev) = buf.last() {
if *prev == b'\r' {
buf.pop();
}
}
inner.consume(consumed);
if buf.len() > max_line {
return Err(Error::SmtpProto("DATA line too long".into()));
}
return Ok(Some(buf));
}
buf.push(*b);
if buf.len() > max_line {
inner.consume(consumed);
return Err(Error::SmtpProto("DATA line too long".into()));
}
}
inner.consume(consumed);
}
}
pub async fn finalize_to_blob(
captured: &CapturedSource,
blob_dir: &Path,
email_id: &str,
) -> Result<PathBuf> {
tokio::fs::create_dir_all(blob_dir).await?;
let final_path = blob_dir.join(format!("{email_id}.eml"));
match captured {
CapturedSource::InMemory(bytes) => {
tokio::fs::write(&final_path, bytes).await?;
}
CapturedSource::OnDisk(temp_path, _) => match tokio::fs::rename(temp_path, &final_path).await {
Ok(()) => {}
Err(_) => {
tokio::fs::copy(temp_path, &final_path).await?;
let _ = tokio::fs::remove_file(temp_path).await;
}
},
}
Ok(final_path)
}
pub async fn load_bytes(captured: &CapturedSource) -> Result<Vec<u8>> {
match captured {
CapturedSource::InMemory(b) => Ok(b.to_vec()),
CapturedSource::OnDisk(path, _) => Ok(tokio::fs::read(path).await?),
}
}