use std::fs;
use std::io::{Read, Write};
use std::path::{Path, PathBuf};
use flate2::read::ZlibDecoder;
use flate2::write::ZlibEncoder;
use flate2::Compression;
use sha1::{Digest, Sha1};
use crate::error::{Error, Result};
use crate::objects::{Object, ObjectId, ObjectKind};
use crate::pack;
#[derive(Debug, Clone)]
pub struct Odb {
objects_dir: PathBuf,
}
impl Odb {
#[must_use]
pub fn new(objects_dir: &Path) -> Self {
Self {
objects_dir: objects_dir.to_path_buf(),
}
}
#[must_use]
pub fn objects_dir(&self) -> &Path {
&self.objects_dir
}
#[must_use]
pub fn object_path(&self, oid: &ObjectId) -> PathBuf {
self.objects_dir
.join(oid.loose_prefix())
.join(oid.loose_suffix())
}
#[must_use]
pub fn exists(&self, oid: &ObjectId) -> bool {
if self.object_path(oid).exists() {
return true;
}
if let Ok(indexes) = pack::read_local_pack_indexes(&self.objects_dir) {
for idx in &indexes {
if idx.entries.iter().any(|e| e.oid == *oid) {
return true;
}
}
}
false
}
pub fn read(&self, oid: &ObjectId) -> Result<Object> {
let path = self.object_path(oid);
match fs::File::open(&path) {
Ok(file) => {
let mut decoder = ZlibDecoder::new(file);
let mut raw = Vec::new();
decoder
.read_to_end(&mut raw)
.map_err(|e| Error::Zlib(e.to_string()))?;
return parse_object_bytes(&raw);
}
Err(_) => {
}
}
pack::read_object_from_packs(&self.objects_dir, oid)
}
#[must_use]
pub fn hash_object_data(kind: ObjectKind, data: &[u8]) -> ObjectId {
let header = format!("{} {}\0", kind, data.len());
let mut hasher = Sha1::new();
hasher.update(header.as_bytes());
hasher.update(data);
let digest = hasher.finalize();
ObjectId::from_bytes(digest.as_slice()).unwrap_or_else(|_| unreachable!("SHA-1 is 20 bytes"))
}
pub fn write(&self, kind: ObjectKind, data: &[u8]) -> Result<ObjectId> {
let store_bytes = build_store_bytes(kind, data);
let oid = hash_bytes(&store_bytes);
let path = self.object_path(&oid);
if path.exists() {
return Ok(oid);
}
let prefix_dir = path
.parent()
.ok_or_else(|| Error::PathError("object path has no parent".to_owned()))?;
fs::create_dir_all(prefix_dir)?;
let tmp_path = prefix_dir.join(format!("tmp_{}", oid.loose_suffix()));
{
let tmp_file = fs::File::create(&tmp_path)?;
let mut encoder = ZlibEncoder::new(tmp_file, Compression::default());
encoder
.write_all(&store_bytes)
.map_err(|e| Error::Zlib(e.to_string()))?;
encoder.finish().map_err(|e| Error::Zlib(e.to_string()))?;
}
fs::rename(&tmp_path, &path)?;
Ok(oid)
}
pub fn write_raw(&self, store_bytes: &[u8]) -> Result<ObjectId> {
parse_object_bytes(store_bytes)?;
let oid = hash_bytes(store_bytes);
let path = self.object_path(&oid);
if path.exists() {
return Ok(oid);
}
let prefix_dir = path
.parent()
.ok_or_else(|| Error::PathError("object path has no parent".to_owned()))?;
fs::create_dir_all(prefix_dir)?;
let tmp_path = prefix_dir.join(format!("tmp_{}", oid.loose_suffix()));
{
let tmp_file = fs::File::create(&tmp_path)?;
let mut encoder = ZlibEncoder::new(tmp_file, Compression::default());
encoder
.write_all(store_bytes)
.map_err(|e| Error::Zlib(e.to_string()))?;
encoder.finish().map_err(|e| Error::Zlib(e.to_string()))?;
}
fs::rename(&tmp_path, &path)?;
Ok(oid)
}
}
fn hash_bytes(data: &[u8]) -> ObjectId {
let mut hasher = Sha1::new();
hasher.update(data);
let digest = hasher.finalize();
ObjectId::from_bytes(digest.as_slice()).unwrap_or_else(|_| unreachable!("SHA-1 is 20 bytes"))
}
fn build_store_bytes(kind: ObjectKind, data: &[u8]) -> Vec<u8> {
let header = format!("{} {}\0", kind, data.len());
let mut out = Vec::with_capacity(header.len() + data.len());
out.extend_from_slice(header.as_bytes());
out.extend_from_slice(data);
out
}
pub(crate) fn parse_object_bytes(raw: &[u8]) -> Result<Object> {
let nul = raw
.iter()
.position(|&b| b == 0)
.ok_or_else(|| Error::CorruptObject("missing NUL in object header".to_owned()))?;
let header = &raw[..nul];
let data = raw[nul + 1..].to_vec();
let sp = header
.iter()
.position(|&b| b == b' ')
.ok_or_else(|| Error::CorruptObject("missing space in object header".to_owned()))?;
let kind = ObjectKind::from_bytes(&header[..sp])?;
let size_str = std::str::from_utf8(&header[sp + 1..])
.map_err(|_| Error::CorruptObject("non-UTF-8 object size".to_owned()))?;
let size: usize = size_str
.parse()
.map_err(|_| Error::CorruptObject(format!("invalid object size: {size_str}")))?;
if data.len() != size {
return Err(Error::CorruptObject(format!(
"object size mismatch: header says {size} but got {}",
data.len()
)));
}
Ok(Object::new(kind, data))
}
#[cfg(test)]
mod tests {
#![allow(clippy::expect_used, clippy::unwrap_used)]
use super::*;
use tempfile::TempDir;
#[test]
fn round_trip_blob() {
let dir = TempDir::new().unwrap();
let odb = Odb::new(dir.path());
let data = b"hello world";
let oid = odb.write(ObjectKind::Blob, data).unwrap();
let obj = odb.read(&oid).unwrap();
assert_eq!(obj.kind, ObjectKind::Blob);
assert_eq!(obj.data, data);
}
#[test]
fn known_blob_hash() {
let oid = Odb::hash_object_data(ObjectKind::Blob, b"hello");
assert_eq!(oid.to_hex(), "b6fc4c620b67d95f953a5c1c1230aaab5db5a1b0");
}
}