use std::fmt::Debug;
use std::ops::Deref;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::SystemTime;
use async_trait::async_trait;
use bytes::Bytes;
use futures_util::StreamExt;
use serde::{Deserialize, Deserializer, Serialize};
use tokio::io::AsyncRead;
use tokio_stream::Stream;
use utoipa::ToSchema;
use crate::constants;
use crate::constants::MAX_CONCURRENT_VERSION_PROBES;
use crate::error::OxenError;
use crate::storage::{LocalVersionStore, S3Opts, S3VersionStore};
use crate::util;
use crate::view::versions::CleanCorruptedVersionsResult;
pub enum LocalFilePath {
Stable(PathBuf),
Temp(async_tempfile::TempFile),
}
impl Deref for LocalFilePath {
type Target = Path;
fn deref(&self) -> &Path {
match self {
LocalFilePath::Stable(p) => p.as_path(),
LocalFilePath::Temp(t) => t.file_path(),
}
}
}
impl AsRef<Path> for LocalFilePath {
fn as_ref(&self) -> &Path {
self.deref()
}
}
impl std::fmt::Debug for LocalFilePath {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LocalFilePath::Stable(p) => write!(f, "Stable({p:?})"),
LocalFilePath::Temp(t) => write!(f, "Temp({:?})", t.file_path()),
}
}
}
impl std::fmt::Display for LocalFilePath {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.deref().display())
}
}
impl LocalFilePath {
pub fn to_pathbuf(&self) -> PathBuf {
self.deref().to_path_buf()
}
}
#[derive(Debug, Clone)]
pub enum VersionLocation {
Local(PathBuf),
S3 {
url: String,
region: String,
endpoint_url: Option<String>,
},
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash, Serialize, Deserialize, ToSchema)]
#[serde(rename_all = "lowercase")]
pub enum StorageKind {
#[default]
Local,
S3,
}
impl std::fmt::Display for StorageKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
StorageKind::Local => f.write_str("local"),
StorageKind::S3 => f.write_str("s3"),
}
}
}
impl std::str::FromStr for StorageKind {
type Err = OxenError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"local" => Ok(StorageKind::Local),
"s3" => Ok(StorageKind::S3),
other => Err(OxenError::UnsupportedStorageKind(other.to_string())),
}
}
}
#[derive(Serialize, Debug, Clone, Default)]
pub struct StorageConfig {
pub kind: StorageKind,
#[serde(skip_serializing_if = "Option::is_none")]
pub versions_path: Option<PathBuf>,
}
impl<'de> Deserialize<'de> for StorageConfig {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
struct LegacySettings {
#[serde(default)]
path: Option<PathBuf>,
}
#[derive(Deserialize)]
struct Raw {
#[serde(default)]
kind: StorageKind,
#[serde(default)]
versions_path: Option<PathBuf>,
#[serde(default)]
settings: Option<LegacySettings>,
}
let raw = Raw::deserialize(deserializer)?;
let versions_path = raw
.versions_path
.or_else(|| raw.settings.and_then(|s| s.path));
Ok(StorageConfig {
kind: raw.kind,
versions_path,
})
}
}
#[async_trait]
pub trait VersionStore: Debug + Send + Sync + 'static {
async fn init(&self) -> Result<(), OxenError>;
async fn store_version_from_reader(
&self,
hash: &str,
reader: Box<dyn AsyncRead + Send + Unpin>,
size: u64,
) -> Result<(), OxenError>;
async fn store_version(&self, hash: &str, data: Bytes) -> Result<(), OxenError>;
async fn store_version_chunk(
&self,
hash: &str,
offset: u64,
data: Bytes,
) -> Result<(), OxenError>;
async fn store_version_derived(
&self,
orig_hash: &str,
derived_filename: &str,
derived_data: Bytes,
) -> Result<(), OxenError>;
async fn get_version_chunk(
&self,
hash: &str,
offset: u64,
size: u64,
) -> Result<Vec<u8>, OxenError>;
async fn list_version_chunks(&self, hash: &str) -> Result<Vec<u64>, OxenError>;
async fn combine_version_chunks(&self, hash: &str) -> Result<(), OxenError>;
async fn get_version_size(&self, hash: &str) -> Result<u64, OxenError>;
async fn get_version(&self, hash: &str) -> Result<Vec<u8>, OxenError>;
async fn get_version_stream(
&self,
hash: &str,
) -> Result<Box<dyn Stream<Item = Result<Bytes, std::io::Error>> + Send + Unpin>, OxenError>;
async fn get_version_derived_size(
&self,
orig_hash: &str,
derived_filename: &str,
) -> Result<u64, OxenError>;
async fn get_version_derived_stream(
&self,
orig_hash: &str,
derived_filename: &str,
) -> Result<Box<dyn Stream<Item = Result<Bytes, std::io::Error>> + Send + Unpin>, OxenError>;
async fn derived_version_exists(
&self,
orig_hash: &str,
derived_filename: &str,
) -> Result<bool, OxenError>;
async fn version_location(&self, hash: &str) -> Result<VersionLocation, OxenError>;
async fn copy_version_to_path(
&self,
hash: &str,
dest_path: &Path,
mtime: SystemTime,
) -> Result<(), OxenError>;
async fn materialize(&self, hash: &str, dir: &Path) -> Result<LocalFilePath, OxenError> {
match self.version_location(hash).await? {
VersionLocation::Local(path) => Ok(LocalFilePath::Stable(path)),
VersionLocation::S3 { .. } => {
util::fs::create_dir_all(dir)?;
let temp = async_tempfile::TempFile::new_in(dir).await.map_err(|e| {
OxenError::basic_str(format!("Failed to create temp file: {e}"))
})?;
self.copy_version_to_path(hash, temp.file_path(), SystemTime::now())
.await?;
Ok(LocalFilePath::Temp(temp))
}
}
}
async fn version_exists(&self, hash: &str) -> Result<bool, OxenError>;
async fn find_missing_versions(&self, hashes: &[String]) -> Result<Vec<String>, OxenError> {
if hashes.is_empty() {
return Ok(Vec::new());
}
let max_concurrent = MAX_CONCURRENT_VERSION_PROBES.min(hashes.len());
let mut probes = futures_util::stream::iter(hashes.iter().cloned())
.map(|hash| async move {
let exists = self.version_exists(&hash).await?;
Ok::<_, OxenError>((hash, exists))
})
.buffer_unordered(max_concurrent);
let mut missing = Vec::new();
while let Some(result) = probes.next().await {
let (hash, exists) = result?;
if !exists {
missing.push(hash);
}
}
Ok(missing)
}
async fn delete_version(&self, hash: &str) -> Result<(), OxenError>;
async fn list_versions(&self) -> Result<Vec<String>, OxenError>;
async fn clean_corrupted_versions(
&self,
dry_run: bool,
) -> Result<CleanCorruptedVersionsResult, OxenError>;
fn storage_kind(&self) -> StorageKind;
}
pub fn create_version_store(
repo_dir: &Path,
config: &StorageConfig,
server_s3_opts: Option<&S3Opts>,
) -> Result<Arc<dyn VersionStore>, OxenError> {
match config.kind {
StorageKind::Local => {
let versions_dir = if let Some(path) = &config.versions_path {
if path.starts_with(".oxen") {
repo_dir.join(path)
} else {
path.clone()
}
} else {
util::fs::oxen_hidden_dir(repo_dir)
.join(constants::VERSIONS_DIR)
.join(constants::FILES_DIR)
};
let store = LocalVersionStore::new(versions_dir);
Ok(Arc::new(store))
}
StorageKind::S3 => {
let opts = server_s3_opts.ok_or(OxenError::S3BackendMissingServerOpts)?;
let name = repo_dir
.file_name()
.and_then(|s| s.to_str())
.ok_or_else(|| OxenError::S3PrefixUnresolvable(repo_dir.into()))?;
let namespace = repo_dir
.parent()
.and_then(|p| p.file_name())
.and_then(|s| s.to_str())
.ok_or_else(|| OxenError::S3PrefixUnresolvable(repo_dir.into()))?;
let prefix = format!("{namespace}/{name}");
let store = S3VersionStore::new(opts.bucket.clone(), prefix);
Ok(Arc::new(store))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
fn s3_config() -> StorageConfig {
StorageConfig {
kind: StorageKind::S3,
versions_path: None,
}
}
#[test]
fn create_version_store_s3_without_server_opts_errors() {
let repo_dir = PathBuf::from("/srv/oxen/test-ns/test-repo");
let result = create_version_store(&repo_dir, &s3_config(), None);
assert!(
matches!(result, Err(OxenError::S3BackendMissingServerOpts)),
"expected S3BackendMissingServerOpts, got {result:?}",
);
}
#[test]
fn create_version_store_s3_with_server_opts_builds_s3_store() {
let repo_dir = PathBuf::from("/srv/oxen/test-ns/test-repo");
let opts = S3Opts {
bucket: "my-bucket".to_string(),
};
let store = create_version_store(&repo_dir, &s3_config(), Some(&opts))
.expect("S3 store should construct when server opts are present");
assert_eq!(store.storage_kind(), StorageKind::S3);
}
}