garage-sdk 0.1.1

Async Rust SDK for Garage S3-compatible storage with uploads and public URL generation
Documentation
use crate::error::{Error, Result};
use bytes::Bytes;
use futures_util::{StreamExt, TryStreamExt};
use std::error::Error as StdError;
use std::future::Future;
use std::path::Path;
use std::sync::{
    Arc,
    atomic::{AtomicU64, Ordering},
};
use url::Url;

pub const STREAM_LIMIT_ERROR: &str = "download size exceeds maximum";
pub type BodyError = Box<dyn StdError + Send + Sync>;

#[derive(Copy, Clone, Debug, PartialEq)]
enum DownloadStrategy {
    Buffer,
    Stream,
}

/// Result of a download operation.
pub struct DownloadResult<S> {
    /// The detected content type.
    pub content_type: String,
    /// The derived file extension.
    pub extension: String,
    /// Downloaded content body.
    pub body: DownloadBody<S>,
}

/// Body payload for a download.
pub enum DownloadBody<S> {
    /// Fully buffered content.
    Buffered { bytes: Bytes },
    /// Streamed content with a size counter.
    Streamed {
        stream: S,
        size_counter: Arc<AtomicU64>,
    },
}

/// Downloader abstraction for fetching remote content.
pub trait Downloader {
    /// Stream type returned when streaming content.
    type Stream: futures_util::Stream<Item = std::result::Result<Bytes, BodyError>>
        + Send
        + Sync
        + 'static;
    /// Future returned by `fetch`.
    type Future<'a>: Future<Output = Result<DownloadResult<Self::Stream>>> + Send
    where
        Self: 'a;

    /// Fetch content at `url` with buffering and size limits.
    fn fetch<'a>(
        &'a self,
        url: Url,
        max_buffered_bytes: u64,
        max_file_size: u64,
    ) -> Self::Future<'a>;
}

/// Default downloader using reqwest.
#[derive(Clone, Debug)]
pub struct ReqwestDownloader {
    client: reqwest::Client,
}

impl ReqwestDownloader {
    /// Create a downloader from a reqwest client.
    pub fn new(client: reqwest::Client) -> Self {
        Self { client }
    }
}

impl Downloader for ReqwestDownloader {
    type Stream =
        SyncStream<futures_util::stream::BoxStream<'static, std::result::Result<Bytes, BodyError>>>;
    type Future<'a> =
        std::pin::Pin<Box<dyn Future<Output = Result<DownloadResult<Self::Stream>>> + Send + 'a>>;

    fn fetch<'a>(
        &'a self,
        url: Url,
        max_buffered_bytes: u64,
        max_file_size: u64,
    ) -> Self::Future<'a> {
        Box::pin(async move {
            let response =
                self.client
                    .get(url.clone())
                    .send()
                    .await
                    .map_err(|e| Error::Download {
                        url: url.to_string(),
                        reason: e.to_string(),
                    })?;

            if !response.status().is_success() {
                return Err(Error::Download {
                    url: url.to_string(),
                    reason: format!("HTTP {}", response.status()),
                });
            }

            let strategy = select_download_strategy(
                response.content_length(),
                max_buffered_bytes,
                max_file_size,
            )?;

            let content_type = response
                .headers()
                .get(reqwest::header::CONTENT_TYPE)
                .and_then(|v| v.to_str().ok())
                .map(|s| s.split(';').next().unwrap_or(s).trim().to_string())
                .unwrap_or_else(|| "application/octet-stream".to_string());

            let ext = url
                .path_segments()
                .and_then(|mut segments| segments.next_back())
                .and_then(|name| Path::new(name).extension().and_then(|e| e.to_str()))
                .or_else(|| {
                    mime_guess::get_mime_extensions_str(&content_type)
                        .and_then(|arr| arr.first().copied())
                })
                .unwrap_or("bin")
                .to_string();

            match strategy {
                DownloadStrategy::Buffer => {
                    let bytes = response.bytes().await.map_err(|e| Error::Download {
                        url: url.to_string(),
                        reason: format!("failed to read response body: {}", e),
                    })?;

                    if bytes.len() as u64 > max_file_size {
                        return Err(Error::Config {
                            message: format!(
                                "downloaded content size {} exceeds maximum {}",
                                bytes.len(),
                                max_file_size
                            ),
                        });
                    }

                    Ok(DownloadResult {
                        content_type,
                        extension: ext,
                        body: DownloadBody::Buffered { bytes },
                    })
                }
                DownloadStrategy::Stream => {
                    let stream = response
                        .bytes_stream()
                        .map_err(|e| -> BodyError { Box::new(e) });
                    let (limited_stream, counter) = limit_stream(stream, max_file_size);
                    let boxed_stream = limited_stream.boxed();
                    let sync_stream = SyncStream::new(boxed_stream);

                    Ok(DownloadResult {
                        content_type,
                        extension: ext,
                        body: DownloadBody::Streamed {
                            stream: sync_stream,
                            size_counter: counter,
                        },
                    })
                }
            }
        })
    }
}

/// Sync wrapper for streams to satisfy the S3 body requirement.
pub struct SyncStream<S> {
    inner: std::sync::Mutex<std::pin::Pin<Box<S>>>,
}

impl<S> SyncStream<S> {
    fn new(stream: S) -> Self {
        Self {
            inner: std::sync::Mutex::new(Box::pin(stream)),
        }
    }
}

impl<S> futures_util::Stream for SyncStream<S>
where
    S: futures_util::Stream<Item = std::result::Result<Bytes, BodyError>>,
{
    type Item = std::result::Result<Bytes, BodyError>;

    fn poll_next(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Option<Self::Item>> {
        let mut guard = match self.inner.lock() {
            Ok(guard) => guard,
            Err(_) => {
                return std::task::Poll::Ready(Some(Err(Box::new(std::io::Error::other(
                    "stream lock poisoned",
                )))));
            }
        };

        guard.as_mut().poll_next(cx)
    }
}

fn select_download_strategy(
    content_length: Option<u64>,
    max_buffered_bytes: u64,
    max_file_size: u64,
) -> Result<DownloadStrategy> {
    if max_buffered_bytes == 0 {
        return Err(Error::Config {
            message: "max_buffered_bytes must be greater than zero".into(),
        });
    }

    if max_buffered_bytes > max_file_size {
        return Err(Error::Config {
            message: "max_buffered_bytes must not exceed max_file_size".into(),
        });
    }

    match content_length {
        Some(size) if size > max_file_size => Err(Error::Config {
            message: format!("content size {} exceeds maximum {}", size, max_file_size),
        }),
        Some(size) if size <= max_buffered_bytes => Ok(DownloadStrategy::Buffer),
        _ => Ok(DownloadStrategy::Stream),
    }
}

fn limit_stream<S>(
    stream: S,
    max: u64,
) -> (
    impl futures_util::Stream<Item = std::result::Result<Bytes, BodyError>>,
    Arc<AtomicU64>,
)
where
    S: futures_util::Stream<Item = std::result::Result<Bytes, BodyError>>,
{
    let counter = Arc::new(AtomicU64::new(0));
    let counter_inner = Arc::clone(&counter);
    let limited = stream.map(move |result| match result {
        Ok(bytes) => {
            let read = bytes.len() as u64;
            let total = counter_inner.fetch_add(read, Ordering::Relaxed) + read;
            if total > max {
                return Err(Box::new(std::io::Error::other(STREAM_LIMIT_ERROR)) as BodyError);
            }
            Ok(bytes)
        }
        Err(err) => Err(err),
    });

    (limited, counter)
}

#[cfg(test)]
mod tests {
    use super::*;
    use futures_util::TryStreamExt;

    #[test]
    fn test_select_download_strategy_buffers_small_content() {
        let strategy = select_download_strategy(Some(1024), 8 * 1024, 64 * 1024)
            .expect("expected a valid strategy");
        assert_eq!(strategy, DownloadStrategy::Buffer);
    }

    #[test]
    fn test_select_download_strategy_streams_large_content() {
        let strategy = select_download_strategy(Some(16 * 1024), 8 * 1024, 64 * 1024)
            .expect("expected a valid strategy");
        assert_eq!(strategy, DownloadStrategy::Stream);
    }

    #[test]
    fn test_select_download_strategy_streams_when_length_unknown() {
        let strategy =
            select_download_strategy(None, 8 * 1024, 64 * 1024).expect("expected a valid strategy");
        assert_eq!(strategy, DownloadStrategy::Stream);
    }

    #[test]
    fn test_select_download_strategy_rejects_oversize() {
        let result = select_download_strategy(Some(128 * 1024), 8 * 1024, 64 * 1024);
        assert!(result.is_err());
    }

    #[tokio::test]
    async fn test_limit_stream_enforces_max_size() {
        let data = vec![0_u8; 16];
        let stream = futures_util::stream::iter(vec![Ok::<Bytes, BodyError>(Bytes::from(data))]);
        let (limited, _) = limit_stream(stream, 8);
        let result: std::result::Result<Vec<Bytes>, BodyError> = limited.try_collect().await;

        let err = result.expect_err("expected size limit error");
        assert!(err.to_string().contains(STREAM_LIMIT_ERROR));
    }

    #[tokio::test]
    async fn test_limit_stream_tracks_size() {
        let data = vec![1_u8; 5];
        let stream =
            futures_util::stream::iter(vec![Ok::<Bytes, BodyError>(Bytes::from(data.clone()))]);
        let (limited, counter) = limit_stream(stream, 8);
        let result: std::result::Result<Vec<Bytes>, BodyError> = limited.try_collect().await;

        let chunks = result.expect("expected stream to succeed");
        let total: usize = chunks.iter().map(|chunk| chunk.len()).sum();

        assert_eq!(total, data.len());
        assert_eq!(counter.load(Ordering::Relaxed), data.len() as u64);
    }
}