use std::collections::BTreeMap;
use std::io::{self, Write};
use std::path::{Path, PathBuf};
pub const MIN_PART_SIZE: u64 = 5 * 1024 * 1024;
pub const DEFAULT_PART_SIZE: u64 = 8 * 1024 * 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct PartId(pub u32);
impl std::fmt::Display for PartId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "part-{:04}", self.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PartStatus {
Pending,
InProgress,
Complete(u64),
Failed(String),
}
#[derive(Debug, Clone)]
pub struct PartInfo {
pub id: PartId,
pub offset: u64,
pub size: u64,
pub status: PartStatus,
pub etag: Option<String>,
}
#[derive(Debug, Clone)]
pub struct MultipartConfig {
pub part_size: u64,
pub max_concurrency: usize,
pub temp_dir: Option<PathBuf>,
pub cleanup_on_drop: bool,
}
impl Default for MultipartConfig {
fn default() -> Self {
Self {
part_size: DEFAULT_PART_SIZE,
max_concurrency: 4,
temp_dir: None,
cleanup_on_drop: true,
}
}
}
#[derive(Debug, Clone)]
pub struct MultipartResult {
pub total_bytes: u64,
pub part_count: usize,
pub etags: Vec<Option<String>>,
}
#[derive(Debug)]
pub enum MultipartError {
Io(io::Error),
PartOutOfRange { id: PartId, max: u32 },
SizeUnknown,
Config(String),
Assembly(String),
}
impl std::fmt::Display for MultipartError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Io(e) => write!(f, "multipart I/O error: {e}"),
Self::PartOutOfRange { id, max } => {
write!(f, "part {id} out of range (max part-{max:04})")
}
Self::SizeUnknown => write!(f, "total size unknown; call set_total_size first"),
Self::Config(msg) => write!(f, "multipart config error: {msg}"),
Self::Assembly(msg) => write!(f, "multipart assembly error: {msg}"),
}
}
}
impl std::error::Error for MultipartError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
if let Self::Io(e) = self {
Some(e)
} else {
None
}
}
}
impl From<io::Error> for MultipartError {
fn from(e: io::Error) -> Self {
Self::Io(e)
}
}
pub type MultipartWriterResult<T> = Result<T, MultipartError>;
pub struct MultipartWriter {
config: MultipartConfig,
destination: PathBuf,
total_size: Option<u64>,
parts: BTreeMap<PartId, PartInfo>,
temp_paths: BTreeMap<PartId, PathBuf>,
}
impl MultipartWriter {
pub fn new(
destination: impl AsRef<Path>,
config: MultipartConfig,
) -> MultipartWriterResult<Self> {
if config.part_size < MIN_PART_SIZE {
return Err(MultipartError::Config(format!(
"part_size {} is below minimum {} bytes",
config.part_size, MIN_PART_SIZE
)));
}
Ok(Self {
config,
destination: destination.as_ref().to_path_buf(),
total_size: None,
parts: BTreeMap::new(),
temp_paths: BTreeMap::new(),
})
}
pub fn set_total_size(&mut self, size: u64) -> MultipartWriterResult<()> {
if size == 0 {
return Err(MultipartError::Config(
"total size must be greater than zero".to_string(),
));
}
self.total_size = Some(size);
Ok(())
}
#[must_use]
pub fn plan_parts(&mut self) -> Vec<PartId> {
let total = self.total_size.unwrap_or(0);
if total == 0 {
return Vec::new();
}
if !self.parts.is_empty() {
return self.parts.keys().copied().collect();
}
let part_size = self.config.part_size;
let mut offset = 0u64;
let mut part_num = 1u32;
while offset < total {
let size = (total - offset).min(part_size);
let id = PartId(part_num);
self.parts.insert(
id,
PartInfo {
id,
offset,
size,
status: PartStatus::Pending,
etag: None,
},
);
offset += size;
part_num += 1;
}
self.parts.keys().copied().collect()
}
pub fn part_range(&self, id: PartId) -> MultipartWriterResult<(u64, u64)> {
self.parts
.get(&id)
.map(|p| (p.offset, p.size))
.ok_or_else(|| {
let max = self.parts.keys().next_back().map(|p| p.0).unwrap_or(0);
MultipartError::PartOutOfRange { id, max }
})
}
pub fn write_part(&mut self, id: PartId, data: &[u8]) -> MultipartWriterResult<()> {
if !self.parts.contains_key(&id) {
let max = self.parts.keys().next_back().map(|p| p.0).unwrap_or(0);
return Err(MultipartError::PartOutOfRange { id, max });
}
let temp_dir = self
.config
.temp_dir
.clone()
.unwrap_or_else(std::env::temp_dir);
let temp_path = temp_dir.join(format!(
"oximedia_mpart_{}.bin",
id
));
let mut file = std::fs::File::create(&temp_path)?;
file.write_all(data)?;
file.flush()?;
let written = data.len() as u64;
self.temp_paths.insert(id, temp_path);
if let Some(info) = self.parts.get_mut(&id) {
info.status = PartStatus::Complete(written);
}
Ok(())
}
pub fn set_part_etag(&mut self, id: PartId, etag: impl Into<String>) {
if let Some(info) = self.parts.get_mut(&id) {
info.etag = Some(etag.into());
}
}
#[must_use]
pub fn all_parts_complete(&self) -> bool {
self.parts
.values()
.all(|p| matches!(p.status, PartStatus::Complete(_)))
}
#[must_use]
pub fn completed_count(&self) -> usize {
self.parts
.values()
.filter(|p| matches!(p.status, PartStatus::Complete(_)))
.count()
}
pub fn finalize(&mut self, destination: impl AsRef<Path>) -> MultipartWriterResult<MultipartResult> {
if !self.all_parts_complete() {
return Err(MultipartError::Assembly(
"not all parts are complete".to_string(),
));
}
let dest = destination.as_ref();
let mut out = std::fs::File::create(dest)?;
let mut total_bytes = 0u64;
let mut etags = Vec::with_capacity(self.parts.len());
for (id, info) in &self.parts {
etags.push(info.etag.clone());
if let Some(temp_path) = self.temp_paths.get(id) {
let chunk = std::fs::read(temp_path).map_err(|e| {
MultipartError::Assembly(format!("reading part {id}: {e}"))
})?;
out.write_all(&chunk)?;
total_bytes += chunk.len() as u64;
}
}
out.flush()?;
if self.config.cleanup_on_drop {
for temp_path in self.temp_paths.values() {
let _ = std::fs::remove_file(temp_path);
}
}
Ok(MultipartResult {
total_bytes,
part_count: self.parts.len(),
etags,
})
}
#[must_use]
pub fn destination(&self) -> &Path {
&self.destination
}
pub fn parts(&self) -> impl Iterator<Item = &PartInfo> {
self.parts.values()
}
}
impl Drop for MultipartWriter {
fn drop(&mut self) {
if self.config.cleanup_on_drop {
for temp_path in self.temp_paths.values() {
let _ = std::fs::remove_file(temp_path);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_plan_parts_divides_correctly() {
let config = MultipartConfig {
part_size: MIN_PART_SIZE,
..Default::default()
};
let mut writer = MultipartWriter::new("/tmp/test.bin", config).unwrap();
writer.set_total_size(MIN_PART_SIZE * 3 + 100).unwrap();
let parts = writer.plan_parts();
assert_eq!(parts.len(), 4);
}
#[test]
fn test_plan_parts_exact_multiple() {
let config = MultipartConfig {
part_size: MIN_PART_SIZE,
..Default::default()
};
let mut writer = MultipartWriter::new("/tmp/test.bin", config).unwrap();
writer.set_total_size(MIN_PART_SIZE * 2).unwrap();
let parts = writer.plan_parts();
assert_eq!(parts.len(), 2);
}
#[test]
fn test_part_range() {
let config = MultipartConfig {
part_size: MIN_PART_SIZE,
..Default::default()
};
let mut writer = MultipartWriter::new("/tmp/test.bin", config).unwrap();
writer.set_total_size(MIN_PART_SIZE + 1000).unwrap();
writer.plan_parts();
let (off1, sz1) = writer.part_range(PartId(1)).unwrap();
assert_eq!(off1, 0);
assert_eq!(sz1, MIN_PART_SIZE);
let (off2, sz2) = writer.part_range(PartId(2)).unwrap();
assert_eq!(off2, MIN_PART_SIZE);
assert_eq!(sz2, 1000);
}
#[test]
fn test_write_and_finalize() {
let temp_dir = std::env::temp_dir();
let dest = temp_dir.join("oximedia_mpart_final_test.bin");
let config = MultipartConfig {
part_size: MIN_PART_SIZE,
temp_dir: Some(temp_dir.clone()),
..Default::default()
};
let total = MIN_PART_SIZE + 512;
let mut writer = MultipartWriter::new(&dest, config).unwrap();
writer.set_total_size(total).unwrap();
let parts = writer.plan_parts();
for id in &parts {
let (_, size) = writer.part_range(*id).unwrap();
let data = vec![id.0 as u8; size as usize];
writer.write_part(*id, &data).unwrap();
}
assert!(writer.all_parts_complete());
let result = writer.finalize(&dest).unwrap();
assert_eq!(result.total_bytes, total);
assert_eq!(result.part_count, 2);
let written = std::fs::read(&dest).unwrap();
assert_eq!(written.len() as u64, total);
let _ = std::fs::remove_file(&dest);
}
#[test]
fn test_config_too_small_part_size() {
let config = MultipartConfig {
part_size: 1024, ..Default::default()
};
assert!(MultipartWriter::new("/tmp/test.bin", config).is_err());
}
#[test]
fn test_part_out_of_range() {
let mut writer = MultipartWriter::new("/tmp/test.bin", MultipartConfig::default()).unwrap();
writer.set_total_size(MIN_PART_SIZE).unwrap();
writer.plan_parts();
assert!(writer.part_range(PartId(999)).is_err());
}
}