use axum::{
body::Body,
extract::{Multipart, Path, Query, State},
http::{header, StatusCode},
response::{
sse::{Event, KeepAlive, Sse},
IntoResponse, Response,
},
Json,
};
use bytes::Bytes;
use futures::stream::{self, Stream, StreamExt};
use ipfrs_core::{Block, Cid};
use ipfrs_storage::BlockStoreTrait;
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use std::time::Duration;
use tokio::sync::broadcast;
use tracing::info;
use uuid::Uuid;
use crate::gateway::GatewayState;
#[derive(Debug, Clone)]
pub struct ConcurrencyConfig {
pub max_concurrent_tasks: usize,
pub parallel_enabled: bool,
}
impl Default for ConcurrencyConfig {
fn default() -> Self {
Self {
max_concurrent_tasks: 100, parallel_enabled: true,
}
}
}
impl ConcurrencyConfig {
pub fn conservative() -> Self {
Self {
max_concurrent_tasks: 50,
parallel_enabled: true,
}
}
pub fn aggressive() -> Self {
Self {
max_concurrent_tasks: 200,
parallel_enabled: true,
}
}
pub fn sequential() -> Self {
Self {
max_concurrent_tasks: 1,
parallel_enabled: false,
}
}
pub fn validate(&self) -> Result<(), String> {
if self.max_concurrent_tasks == 0 && self.parallel_enabled {
return Err(
"max_concurrent_tasks cannot be 0 when parallel_enabled is true".to_string(),
);
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct FlowControlConfig {
pub max_bytes_per_second: u64,
pub initial_window_size: usize,
pub max_window_size: usize,
pub min_window_size: usize,
pub dynamic_adjustment: bool,
}
impl Default for FlowControlConfig {
fn default() -> Self {
Self {
max_bytes_per_second: 0, initial_window_size: 256 * 1024, max_window_size: 1024 * 1024, min_window_size: 64 * 1024, dynamic_adjustment: true,
}
}
}
impl FlowControlConfig {
pub fn with_rate_limit(bytes_per_second: u64) -> Self {
Self {
max_bytes_per_second: bytes_per_second,
..Default::default()
}
}
pub fn conservative() -> Self {
Self {
initial_window_size: 64 * 1024,
max_window_size: 256 * 1024,
min_window_size: 32 * 1024,
..Default::default()
}
}
pub fn aggressive() -> Self {
Self {
initial_window_size: 512 * 1024,
max_window_size: 2 * 1024 * 1024,
min_window_size: 128 * 1024,
..Default::default()
}
}
pub fn validate(&self) -> Result<(), String> {
if self.min_window_size > self.initial_window_size {
return Err(format!(
"Minimum window size ({}) cannot exceed initial window size ({})",
self.min_window_size, self.initial_window_size
));
}
if self.initial_window_size > self.max_window_size {
return Err(format!(
"Initial window size ({}) cannot exceed maximum window size ({})",
self.initial_window_size, self.max_window_size
));
}
if self.max_bytes_per_second > 0 {
validation::validate_rate_limit(self.max_bytes_per_second)?;
}
Ok(())
}
}
#[derive(Debug)]
pub struct FlowController {
config: FlowControlConfig,
current_window_size: usize,
bytes_sent: u64,
start_time: std::time::Instant,
last_adjustment: std::time::Instant,
}
impl FlowController {
pub fn new(config: FlowControlConfig) -> Self {
Self {
current_window_size: config.initial_window_size,
config,
bytes_sent: 0,
start_time: std::time::Instant::now(),
last_adjustment: std::time::Instant::now(),
}
}
pub fn window_size(&self) -> usize {
self.current_window_size
}
pub fn calculate_delay(&self, bytes_to_send: usize) -> std::time::Duration {
if self.config.max_bytes_per_second == 0 {
return std::time::Duration::from_secs(0);
}
let elapsed = self.start_time.elapsed();
let elapsed_secs = elapsed.as_secs_f64();
if elapsed_secs == 0.0 {
return std::time::Duration::from_secs(0);
}
let current_rate = self.bytes_sent as f64 / elapsed_secs;
let target_rate = self.config.max_bytes_per_second as f64;
if current_rate + (bytes_to_send as f64 / elapsed_secs) > target_rate {
let delay_secs = (bytes_to_send as f64 / target_rate).max(0.0);
std::time::Duration::from_secs_f64(delay_secs)
} else {
std::time::Duration::from_secs(0)
}
}
pub fn on_data_sent(&mut self, bytes: usize) {
self.bytes_sent += bytes as u64;
if self.config.dynamic_adjustment {
self.adjust_window();
}
}
fn adjust_window(&mut self) {
let elapsed = self.last_adjustment.elapsed();
if elapsed < std::time::Duration::from_millis(100) {
return;
}
self.last_adjustment = std::time::Instant::now();
let new_size = (self.current_window_size as f64 * 1.1)
.min(self.config.max_window_size as f64) as usize;
self.current_window_size =
new_size.clamp(self.config.min_window_size, self.config.max_window_size);
}
#[allow(dead_code)]
pub fn on_congestion(&mut self) {
let new_size = self.current_window_size / 2;
self.current_window_size = new_size.max(self.config.min_window_size);
self.last_adjustment = std::time::Instant::now();
}
pub fn current_throughput(&self) -> f64 {
let elapsed = self.start_time.elapsed().as_secs_f64();
if elapsed > 0.0 {
self.bytes_sent as f64 / elapsed
} else {
0.0
}
}
}
#[derive(Debug, Clone)]
pub struct OperationState {
pub operation_id: String,
pub offset: u64,
pub total_size: Option<u64>,
pub operation_type: OperationType,
pub status: OperationStatus,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum OperationType {
Upload,
Download,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum OperationStatus {
InProgress,
Paused,
Cancelled,
Completed,
Failed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResumeToken {
pub operation_id: String,
pub offset: u64,
pub cid: Option<String>,
}
impl ResumeToken {
pub fn new(operation_id: String, offset: u64, cid: Option<String>) -> Self {
Self {
operation_id,
offset,
cid,
}
}
pub fn encode(&self) -> Result<String, String> {
let json = serde_json::to_string(self).map_err(|e| e.to_string())?;
Ok(base64::Engine::encode(
&base64::engine::general_purpose::URL_SAFE_NO_PAD,
json.as_bytes(),
))
}
pub fn decode(encoded: &str) -> Result<Self, String> {
let bytes =
base64::Engine::decode(&base64::engine::general_purpose::URL_SAFE_NO_PAD, encoded)
.map_err(|e| e.to_string())?;
let json = String::from_utf8(bytes).map_err(|e| e.to_string())?;
serde_json::from_str(&json).map_err(|e| e.to_string())
}
}
#[derive(Debug, Deserialize)]
pub struct CancelRequest {
pub operation_id: String,
}
#[derive(Debug, Serialize)]
pub struct CancelResponse {
pub operation_id: String,
pub cancelled: bool,
pub resume_token: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ProgressEvent {
pub operation_id: String,
pub bytes_processed: u64,
pub total_bytes: Option<u64>,
pub progress_percent: Option<f32>,
pub status: ProgressStatus,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum ProgressStatus {
Started,
InProgress,
Completed,
Failed,
}
#[derive(Clone)]
pub struct ProgressTracker {
sender: broadcast::Sender<ProgressEvent>,
}
impl ProgressTracker {
pub fn new() -> Self {
let (sender, _) = broadcast::channel(100);
Self { sender }
}
pub fn send(&self, event: ProgressEvent) {
let _ = self.sender.send(event);
}
pub fn subscribe(&self) -> broadcast::Receiver<ProgressEvent> {
self.sender.subscribe()
}
}
impl Default for ProgressTracker {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Deserialize)]
pub struct StreamDownloadQuery {
pub chunk_size: Option<usize>,
pub max_bytes_per_second: Option<u64>,
pub flow_control: Option<bool>,
pub resume_token: Option<String>,
pub offset: Option<u64>,
}
pub async fn stream_download(
State(state): State<GatewayState>,
Path(cid_str): Path<String>,
Query(query): Query<StreamDownloadQuery>,
) -> Result<Response, StreamingError> {
let cid: Cid = cid_str
.parse()
.map_err(|_| StreamingError::InvalidCid(cid_str.clone()))?;
let block = state
.store
.get(&cid)
.await
.map_err(|e| StreamingError::Storage(e.to_string()))?
.ok_or_else(|| StreamingError::NotFound(cid_str.clone()))?;
let data = block.data().to_vec();
let total_size = data.len();
let start_offset = if let Some(resume_token) = &query.resume_token {
let token = ResumeToken::decode(resume_token)
.map_err(|e| StreamingError::Upload(format!("Invalid resume token: {}", e)))?;
if let Some(token_cid) = &token.cid {
if token_cid != &cid_str {
return Err(StreamingError::Upload(
"Resume token CID mismatch".to_string(),
));
}
}
token.offset as usize
} else {
query.offset.unwrap_or(0) as usize
};
if start_offset >= total_size {
return Err(StreamingError::Upload(format!(
"Invalid offset: {} (total size: {})",
start_offset, total_size
)));
}
let enable_flow_control = query.flow_control.unwrap_or(false);
let flow_controller = if enable_flow_control {
let mut config = FlowControlConfig::default();
if let Some(rate) = query.max_bytes_per_second {
config.max_bytes_per_second = rate;
}
Some(FlowController::new(config))
} else {
None
};
let chunk_size = query.chunk_size.unwrap_or_else(|| {
flow_controller
.as_ref()
.map(|fc| fc.window_size())
.unwrap_or(64 * 1024)
});
let chunks: Vec<Vec<u8>> = data[start_offset..]
.chunks(chunk_size)
.map(|chunk| chunk.to_vec())
.collect();
let remaining_size = total_size - start_offset;
let stream = if let Some(mut fc) = flow_controller {
let stream = async_stream::stream! {
for chunk in chunks {
let chunk_len = chunk.len();
let delay = fc.calculate_delay(chunk_len);
if !delay.is_zero() {
tokio::time::sleep(delay).await;
}
fc.on_data_sent(chunk_len);
yield Ok::<_, Infallible>(Bytes::from(chunk));
}
};
Body::from_stream(stream)
} else {
let stream = stream::iter(chunks).map(|chunk| Ok::<_, Infallible>(Bytes::from(chunk)));
Body::from_stream(stream)
};
let mut response_builder = Response::builder();
if start_offset > 0 {
response_builder = response_builder.status(StatusCode::PARTIAL_CONTENT);
let end_offset = total_size - 1;
response_builder = response_builder.header(
header::CONTENT_RANGE,
format!("bytes {}-{}/{}", start_offset, end_offset, total_size),
);
} else {
response_builder = response_builder.status(StatusCode::OK);
}
Ok(response_builder
.header(header::CONTENT_TYPE, "application/octet-stream")
.header(header::CONTENT_LENGTH, remaining_size.to_string())
.header("X-Chunk-Size", chunk_size.to_string())
.header("Accept-Ranges", "bytes")
.body(stream)
.unwrap())
}
#[derive(Debug, Serialize)]
pub struct StreamUploadResponse {
pub cid: String,
pub size: u64,
pub chunks_received: usize,
}
pub async fn stream_upload(
State(state): State<GatewayState>,
mut multipart: Multipart,
) -> Result<Json<StreamUploadResponse>, StreamingError> {
let mut total_data = Vec::new();
let mut chunks_received = 0;
while let Some(field) = multipart
.next_field()
.await
.map_err(|e| StreamingError::Upload(format!("Failed to read field: {}", e)))?
{
let data = field
.bytes()
.await
.map_err(|e| StreamingError::Upload(format!("Failed to read data: {}", e)))?;
total_data.extend_from_slice(&data);
chunks_received += 1;
}
if total_data.is_empty() {
return Err(StreamingError::Upload("No data received".to_string()));
}
let block = Block::new(Bytes::from(total_data))
.map_err(|e| StreamingError::Upload(format!("Failed to create block: {}", e)))?;
let cid = *block.cid();
let size = block.size();
state
.store
.put(&block)
.await
.map_err(|e| StreamingError::Storage(e.to_string()))?;
info!("Stream upload completed: {} ({} bytes)", cid, size);
Ok(Json(StreamUploadResponse {
cid: cid.to_string(),
size,
chunks_received,
}))
}
#[derive(Debug, Deserialize)]
pub struct BatchGetRequest {
pub cids: Vec<String>,
}
#[derive(Debug, Serialize)]
pub struct BatchGetResponse {
pub blocks: Vec<BatchBlockResult>,
pub errors: Vec<BatchError>,
}
#[derive(Debug, Serialize)]
pub struct BatchBlockResult {
pub cid: String,
pub data: String, pub size: u64,
}
#[derive(Debug, Serialize)]
pub struct BatchError {
pub cid: String,
pub error: String,
}
pub async fn batch_get(
State(state): State<GatewayState>,
Json(req): Json<BatchGetRequest>,
) -> Result<Json<BatchGetResponse>, StreamingError> {
validation::validate_batch_size(req.cids.len()).map_err(StreamingError::Validation)?;
let tasks: Vec<_> = req
.cids
.into_iter()
.map(|cid_str| {
let state = state.clone();
tokio::spawn(async move {
match cid_str.parse::<Cid>() {
Ok(cid) => match state.store.get(&cid).await {
Ok(Some(block)) => {
let data_base64 = base64::Engine::encode(
&base64::engine::general_purpose::STANDARD,
block.data(),
);
Ok(BatchBlockResult {
cid: cid_str,
data: data_base64,
size: block.size(),
})
}
Ok(None) => Err(BatchError {
cid: cid_str,
error: "Block not found".to_string(),
}),
Err(e) => Err(BatchError {
cid: cid_str,
error: e.to_string(),
}),
},
Err(_) => Err(BatchError {
cid: cid_str,
error: "Invalid CID".to_string(),
}),
}
})
})
.collect();
let mut blocks = Vec::new();
let mut errors = Vec::new();
for task in tasks {
match task.await {
Ok(Ok(block)) => blocks.push(block),
Ok(Err(error)) => errors.push(error),
Err(e) => {
errors.push(BatchError {
cid: "unknown".to_string(),
error: format!("Task execution error: {}", e),
});
}
}
}
Ok(Json(BatchGetResponse { blocks, errors }))
}
#[derive(Debug, Deserialize)]
pub struct BatchPutItem {
pub data: String,
}
#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum TransactionMode {
Atomic,
#[default]
BestEffort,
}
#[derive(Debug, Deserialize)]
pub struct BatchPutRequest {
pub blocks: Vec<BatchPutItem>,
#[serde(default)]
pub transaction_mode: TransactionMode,
}
#[derive(Debug, Serialize)]
pub struct BatchPutResponse {
pub stored: Vec<BatchStoredResult>,
pub errors: Vec<BatchPutError>,
pub transaction_id: Option<String>,
pub transaction_status: TransactionStatus,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum TransactionStatus {
Committed,
PartialSuccess,
RolledBack,
}
#[derive(Debug, Serialize)]
pub struct BatchStoredResult {
pub cid: String,
pub size: u64,
pub index: usize,
}
#[derive(Debug, Serialize)]
pub struct BatchPutError {
pub index: usize,
pub error: String,
}
pub async fn batch_put(
State(state): State<GatewayState>,
Json(req): Json<BatchPutRequest>,
) -> Result<Json<BatchPutResponse>, StreamingError> {
let transaction_id = Uuid::new_v4().to_string();
match req.transaction_mode {
TransactionMode::Atomic => batch_put_atomic(state, req.blocks, transaction_id).await,
TransactionMode::BestEffort => {
batch_put_best_effort(state, req.blocks, transaction_id).await
}
}
}
async fn batch_put_atomic(
state: GatewayState,
items: Vec<BatchPutItem>,
transaction_id: String,
) -> Result<Json<BatchPutResponse>, StreamingError> {
validation::validate_batch_size(items.len()).map_err(StreamingError::Validation)?;
let mut prepared_blocks = Vec::new();
let mut errors = Vec::new();
for (index, item) in items.into_iter().enumerate() {
match base64::Engine::decode(&base64::engine::general_purpose::STANDARD, &item.data) {
Ok(data) => match Block::new(Bytes::from(data)) {
Ok(block) => {
prepared_blocks.push((index, block));
}
Err(e) => {
errors.push(BatchPutError {
index,
error: format!("Block creation error: {}", e),
});
}
},
Err(e) => {
errors.push(BatchPutError {
index,
error: format!("Base64 decode error: {}", e),
});
}
}
}
if !errors.is_empty() {
info!(
"Atomic batch put [{}] rolled back: {} validation errors",
transaction_id,
errors.len()
);
return Ok(Json(BatchPutResponse {
stored: vec![],
errors,
transaction_id: Some(transaction_id),
transaction_status: TransactionStatus::RolledBack,
}));
}
let mut stored = Vec::new();
let mut stored_cids = Vec::new();
for (index, block) in prepared_blocks {
let cid = *block.cid();
let size = block.size();
match state.store.put(&block).await {
Ok(_) => {
stored_cids.push(cid);
stored.push(BatchStoredResult {
cid: cid.to_string(),
size,
index,
});
}
Err(e) => {
info!(
"Atomic batch put [{}] rolling back: storage error at index {}",
transaction_id, index
);
for stored_cid in stored_cids {
let _ = state.store.delete(&stored_cid).await; }
return Ok(Json(BatchPutResponse {
stored: vec![],
errors: vec![BatchPutError {
index,
error: format!("Storage error (transaction rolled back): {}", e),
}],
transaction_id: Some(transaction_id),
transaction_status: TransactionStatus::RolledBack,
}));
}
}
}
info!(
"Atomic batch put [{}] committed: {} blocks stored",
transaction_id,
stored.len()
);
Ok(Json(BatchPutResponse {
stored,
errors: vec![],
transaction_id: Some(transaction_id),
transaction_status: TransactionStatus::Committed,
}))
}
async fn batch_put_best_effort(
state: GatewayState,
items: Vec<BatchPutItem>,
transaction_id: String,
) -> Result<Json<BatchPutResponse>, StreamingError> {
validation::validate_batch_size(items.len()).map_err(StreamingError::Validation)?;
let mut stored = Vec::new();
let mut errors = Vec::new();
for (index, item) in items.into_iter().enumerate() {
match base64::Engine::decode(&base64::engine::general_purpose::STANDARD, &item.data) {
Ok(data) => match Block::new(Bytes::from(data)) {
Ok(block) => {
let cid = *block.cid();
let size = block.size();
match state.store.put(&block).await {
Ok(_) => {
stored.push(BatchStoredResult {
cid: cid.to_string(),
size,
index,
});
}
Err(e) => {
errors.push(BatchPutError {
index,
error: format!("Storage error: {}", e),
});
}
}
}
Err(e) => {
errors.push(BatchPutError {
index,
error: format!("Block creation error: {}", e),
});
}
},
Err(e) => {
errors.push(BatchPutError {
index,
error: format!("Base64 decode error: {}", e),
});
}
}
}
let status = if errors.is_empty() {
TransactionStatus::Committed
} else {
TransactionStatus::PartialSuccess
};
info!(
"Best-effort batch put [{}] completed: {} stored, {} errors",
transaction_id,
stored.len(),
errors.len()
);
Ok(Json(BatchPutResponse {
stored,
errors,
transaction_id: Some(transaction_id),
transaction_status: status,
}))
}
#[derive(Debug, Deserialize)]
pub struct BatchHasRequest {
pub cids: Vec<String>,
}
#[derive(Debug, Serialize)]
pub struct BatchHasResponse {
pub results: Vec<BatchHasResult>,
}
#[derive(Debug, Serialize)]
pub struct BatchHasResult {
pub cid: String,
pub exists: bool,
}
pub async fn batch_has(
State(state): State<GatewayState>,
Json(req): Json<BatchHasRequest>,
) -> Result<Json<BatchHasResponse>, StreamingError> {
validation::validate_batch_size(req.cids.len()).map_err(StreamingError::Validation)?;
let tasks: Vec<_> = req
.cids
.into_iter()
.map(|cid_str| {
let state = state.clone();
tokio::spawn(async move {
let exists = if let Ok(cid) = cid_str.parse::<Cid>() {
state.store.has(&cid).await.unwrap_or(false)
} else {
false
};
BatchHasResult {
cid: cid_str,
exists,
}
})
})
.collect();
let mut results = Vec::new();
for task in tasks {
match task.await {
Ok(result) => results.push(result),
Err(e) => {
results.push(BatchHasResult {
cid: format!("task_error_{}", e),
exists: false,
});
}
}
}
Ok(Json(BatchHasResponse { results }))
}
pub async fn progress_stream(
State(_state): State<GatewayState>,
Path(operation_id): Path<String>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let tracker = ProgressTracker::new();
let mut receiver = tracker.subscribe();
let stream = async_stream::stream! {
let initial = ProgressEvent {
operation_id: operation_id.clone(),
bytes_processed: 0,
total_bytes: None,
progress_percent: Some(0.0),
status: ProgressStatus::Started,
};
yield Ok(Event::default()
.event("progress")
.json_data(initial)
.unwrap());
loop {
match tokio::time::timeout(Duration::from_secs(30), receiver.recv()).await {
Ok(Ok(event)) => {
let is_complete = matches!(event.status, ProgressStatus::Completed | ProgressStatus::Failed);
yield Ok(Event::default()
.event("progress")
.json_data(event)
.unwrap());
if is_complete {
break;
}
}
Ok(Err(_)) => {
break;
}
Err(_) => {
yield Ok(Event::default().comment("keepalive"));
}
}
}
};
Sse::new(stream).keep_alive(KeepAlive::default())
}
pub mod validation {
pub fn validate_cid(cid: &str) -> Result<(), String> {
if cid.is_empty() {
return Err("CID cannot be empty".to_string());
}
if !cid.starts_with("Qm") && !cid.starts_with("bafy") && !cid.starts_with("baf") {
return Err(format!("Invalid CID format: {}", cid));
}
if cid.len() < 10 {
return Err(format!("CID too short: {}", cid));
}
Ok(())
}
pub fn validate_offset(offset: u64, total_size: usize) -> Result<(), String> {
if offset as usize >= total_size {
return Err(format!(
"Offset {} exceeds total size {}",
offset, total_size
));
}
Ok(())
}
pub fn validate_chunk_size(chunk_size: usize) -> Result<(), String> {
const MIN_CHUNK_SIZE: usize = 1024; const MAX_CHUNK_SIZE: usize = 10 * 1024 * 1024;
if chunk_size < MIN_CHUNK_SIZE {
return Err(format!(
"Chunk size {} is too small (minimum: {})",
chunk_size, MIN_CHUNK_SIZE
));
}
if chunk_size > MAX_CHUNK_SIZE {
return Err(format!(
"Chunk size {} is too large (maximum: {})",
chunk_size, MAX_CHUNK_SIZE
));
}
Ok(())
}
pub fn validate_rate_limit(bytes_per_second: u64) -> Result<(), String> {
const MAX_RATE: u64 = 10 * 1024 * 1024 * 1024;
if bytes_per_second > MAX_RATE {
return Err(format!(
"Rate limit {} exceeds maximum {}",
bytes_per_second, MAX_RATE
));
}
Ok(())
}
pub fn validate_batch_size(count: usize) -> Result<(), String> {
const MAX_BATCH_SIZE: usize = 1000;
if count == 0 {
return Err("Batch cannot be empty".to_string());
}
if count > MAX_BATCH_SIZE {
return Err(format!(
"Batch size {} exceeds maximum {}",
count, MAX_BATCH_SIZE
));
}
Ok(())
}
}
#[derive(Debug)]
pub enum StreamingError {
InvalidCid(String),
NotFound(String),
Upload(String),
Storage(String),
Validation(String),
}
impl IntoResponse for StreamingError {
fn into_response(self) -> Response {
let (status, message) = match self {
StreamingError::InvalidCid(cid) => {
(StatusCode::BAD_REQUEST, format!("Invalid CID: {}", cid))
}
StreamingError::NotFound(cid) => {
(StatusCode::NOT_FOUND, format!("Block not found: {}", cid))
}
StreamingError::Upload(msg) => {
(StatusCode::BAD_REQUEST, format!("Upload error: {}", msg))
}
StreamingError::Storage(msg) => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Storage error: {}", msg),
),
StreamingError::Validation(msg) => (
StatusCode::BAD_REQUEST,
format!("Validation error: {}", msg),
),
};
(status, message).into_response()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_progress_event_serialization() {
let event = ProgressEvent {
operation_id: "test-123".to_string(),
bytes_processed: 1024,
total_bytes: Some(2048),
progress_percent: Some(50.0),
status: ProgressStatus::InProgress,
};
let json = serde_json::to_string(&event).unwrap();
assert!(json.contains("test-123"));
assert!(json.contains("1024"));
assert!(json.contains("inprogress"));
}
#[test]
fn test_progress_tracker() {
let tracker = ProgressTracker::new();
let _receiver = tracker.subscribe();
let event = ProgressEvent {
operation_id: "test".to_string(),
bytes_processed: 100,
total_bytes: Some(200),
progress_percent: Some(50.0),
status: ProgressStatus::InProgress,
};
tracker.send(event);
}
#[test]
fn test_batch_request_deserialization() {
let json = r#"{"cids": ["QmTest1", "QmTest2"]}"#;
let req: BatchGetRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.cids.len(), 2);
assert_eq!(req.cids[0], "QmTest1");
}
#[test]
fn test_batch_put_request_deserialization() {
let json = r#"{"blocks": [{"data": "SGVsbG8="}]}"#;
let req: BatchPutRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.blocks.len(), 1);
assert_eq!(req.blocks[0].data, "SGVsbG8=");
assert_eq!(req.transaction_mode, TransactionMode::BestEffort); }
#[test]
fn test_batch_put_request_atomic_mode() {
let json = r#"{"blocks": [{"data": "SGVsbG8="}], "transaction_mode": "atomic"}"#;
let req: BatchPutRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.transaction_mode, TransactionMode::Atomic);
}
#[test]
fn test_transaction_mode_default() {
let mode = TransactionMode::default();
assert_eq!(mode, TransactionMode::BestEffort);
}
#[test]
fn test_transaction_status_serialization() {
let status = TransactionStatus::Committed;
let json = serde_json::to_string(&status).unwrap();
assert_eq!(json, r#""committed""#);
let status = TransactionStatus::RolledBack;
let json = serde_json::to_string(&status).unwrap();
assert_eq!(json, r#""rolledback""#);
}
#[test]
fn test_batch_put_response_with_transaction() {
let response = BatchPutResponse {
stored: vec![],
errors: vec![],
transaction_id: Some("test-txn-123".to_string()),
transaction_status: TransactionStatus::Committed,
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("test-txn-123"));
assert!(json.contains("committed"));
}
#[test]
fn test_flow_control_config_default() {
let config = FlowControlConfig::default();
assert_eq!(config.max_bytes_per_second, 0);
assert_eq!(config.initial_window_size, 256 * 1024);
assert_eq!(config.max_window_size, 1024 * 1024);
assert_eq!(config.min_window_size, 64 * 1024);
assert!(config.dynamic_adjustment);
}
#[test]
fn test_flow_control_config_with_rate_limit() {
let config = FlowControlConfig::with_rate_limit(1_000_000); assert_eq!(config.max_bytes_per_second, 1_000_000);
assert!(config.dynamic_adjustment);
}
#[test]
fn test_flow_control_config_conservative() {
let config = FlowControlConfig::conservative();
assert_eq!(config.initial_window_size, 64 * 1024);
assert_eq!(config.max_window_size, 256 * 1024);
assert_eq!(config.min_window_size, 32 * 1024);
}
#[test]
fn test_flow_control_config_aggressive() {
let config = FlowControlConfig::aggressive();
assert_eq!(config.initial_window_size, 512 * 1024);
assert_eq!(config.max_window_size, 2 * 1024 * 1024);
assert_eq!(config.min_window_size, 128 * 1024);
}
#[test]
fn test_flow_controller_window_size() {
let config = FlowControlConfig::default();
let controller = FlowController::new(config.clone());
assert_eq!(controller.window_size(), config.initial_window_size);
}
#[test]
fn test_flow_controller_no_rate_limit() {
let config = FlowControlConfig::default(); let controller = FlowController::new(config);
let delay = controller.calculate_delay(1024);
assert_eq!(delay, std::time::Duration::from_secs(0));
}
#[test]
fn test_flow_controller_on_data_sent() {
let config = FlowControlConfig::default();
let mut controller = FlowController::new(config);
controller.on_data_sent(1024);
assert_eq!(controller.bytes_sent, 1024);
controller.on_data_sent(2048);
assert_eq!(controller.bytes_sent, 3072);
}
#[test]
fn test_flow_controller_on_congestion() {
let config = FlowControlConfig::default();
let mut controller = FlowController::new(config.clone());
let initial_window = controller.window_size();
controller.on_congestion();
assert!(controller.window_size() < initial_window);
assert!(controller.window_size() >= config.min_window_size);
}
#[test]
fn test_flow_controller_throughput() {
let config = FlowControlConfig::default();
let mut controller = FlowController::new(config);
controller.on_data_sent(1024);
let throughput = controller.current_throughput();
assert!(throughput >= 0.0);
}
#[test]
fn test_resume_token_encode_decode() {
let token = ResumeToken::new("op-123".to_string(), 4096, Some("QmTest123".to_string()));
let encoded = token.encode().unwrap();
assert!(!encoded.is_empty());
let decoded = ResumeToken::decode(&encoded).unwrap();
assert_eq!(decoded.operation_id, "op-123");
assert_eq!(decoded.offset, 4096);
assert_eq!(decoded.cid, Some("QmTest123".to_string()));
}
#[test]
fn test_resume_token_invalid() {
let result = ResumeToken::decode("invalid!!!base64");
assert!(result.is_err());
let invalid_json = base64::Engine::encode(
&base64::engine::general_purpose::URL_SAFE_NO_PAD,
b"not json",
);
let result = ResumeToken::decode(&invalid_json);
assert!(result.is_err());
}
#[test]
fn test_operation_type_serialization() {
let upload = OperationType::Upload;
let json = serde_json::to_string(&upload).unwrap();
assert_eq!(json, r#""upload""#);
let download = OperationType::Download;
let json = serde_json::to_string(&download).unwrap();
assert_eq!(json, r#""download""#);
}
#[test]
fn test_operation_status_serialization() {
let status = OperationStatus::InProgress;
let json = serde_json::to_string(&status).unwrap();
assert_eq!(json, r#""inprogress""#);
let status = OperationStatus::Cancelled;
let json = serde_json::to_string(&status).unwrap();
assert_eq!(json, r#""cancelled""#);
}
#[test]
fn test_cancel_response_serialization() {
let response = CancelResponse {
operation_id: "op-456".to_string(),
cancelled: true,
resume_token: Some("token123".to_string()),
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("op-456"));
assert!(json.contains("true"));
assert!(json.contains("token123"));
}
#[test]
fn test_validate_cid_valid() {
assert!(validation::validate_cid("QmTest123456").is_ok());
assert!(validation::validate_cid(
"bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
)
.is_ok());
assert!(validation::validate_cid(
"bafkreigh2akiscaildcqabsyg3dfr6chu3fgpregiymsck7e7aqa4s52zy"
)
.is_ok());
}
#[test]
fn test_validate_cid_invalid() {
assert!(validation::validate_cid("").is_err());
assert!(validation::validate_cid("invalid").is_err());
assert!(validation::validate_cid("Qm").is_err());
}
#[test]
fn test_validate_offset_valid() {
assert!(validation::validate_offset(0, 1000).is_ok());
assert!(validation::validate_offset(500, 1000).is_ok());
assert!(validation::validate_offset(999, 1000).is_ok());
}
#[test]
fn test_validate_offset_invalid() {
assert!(validation::validate_offset(1000, 1000).is_err());
assert!(validation::validate_offset(2000, 1000).is_err());
}
#[test]
fn test_validate_chunk_size_valid() {
assert!(validation::validate_chunk_size(1024).is_ok()); assert!(validation::validate_chunk_size(64 * 1024).is_ok()); assert!(validation::validate_chunk_size(10 * 1024 * 1024).is_ok()); }
#[test]
fn test_validate_chunk_size_invalid() {
assert!(validation::validate_chunk_size(512).is_err()); assert!(validation::validate_chunk_size(20 * 1024 * 1024).is_err()); }
#[test]
fn test_validate_rate_limit_valid() {
assert!(validation::validate_rate_limit(0).is_ok()); assert!(validation::validate_rate_limit(1_000_000).is_ok()); assert!(validation::validate_rate_limit(10 * 1024 * 1024 * 1024).is_ok());
}
#[test]
fn test_validate_rate_limit_invalid() {
assert!(validation::validate_rate_limit(20 * 1024 * 1024 * 1024).is_err());
}
#[test]
fn test_validate_batch_size_valid() {
assert!(validation::validate_batch_size(1).is_ok());
assert!(validation::validate_batch_size(100).is_ok());
assert!(validation::validate_batch_size(1000).is_ok()); }
#[test]
fn test_validate_batch_size_invalid() {
assert!(validation::validate_batch_size(0).is_err()); assert!(validation::validate_batch_size(1001).is_err()); assert!(validation::validate_batch_size(5000).is_err()); }
#[test]
fn test_flow_control_config_validation_valid() {
let config = FlowControlConfig::default();
assert!(config.validate().is_ok());
let config = FlowControlConfig::conservative();
assert!(config.validate().is_ok());
let config = FlowControlConfig::aggressive();
assert!(config.validate().is_ok());
}
#[test]
fn test_flow_control_config_validation_invalid() {
let config = FlowControlConfig {
max_bytes_per_second: 0,
initial_window_size: 64 * 1024,
max_window_size: 1024 * 1024,
min_window_size: 128 * 1024, dynamic_adjustment: true,
};
assert!(config.validate().is_err());
let config = FlowControlConfig {
max_bytes_per_second: 0,
initial_window_size: 2 * 1024 * 1024,
max_window_size: 1024 * 1024, min_window_size: 64 * 1024,
dynamic_adjustment: true,
};
assert!(config.validate().is_err());
let config = FlowControlConfig {
max_bytes_per_second: 20 * 1024 * 1024 * 1024, initial_window_size: 256 * 1024,
max_window_size: 1024 * 1024,
min_window_size: 64 * 1024,
dynamic_adjustment: true,
};
assert!(config.validate().is_err());
}
#[test]
fn test_concurrency_config_default() {
let config = ConcurrencyConfig::default();
assert_eq!(config.max_concurrent_tasks, 100);
assert!(config.parallel_enabled);
assert!(config.validate().is_ok());
}
#[test]
fn test_concurrency_config_conservative() {
let config = ConcurrencyConfig::conservative();
assert_eq!(config.max_concurrent_tasks, 50);
assert!(config.parallel_enabled);
assert!(config.validate().is_ok());
}
#[test]
fn test_concurrency_config_aggressive() {
let config = ConcurrencyConfig::aggressive();
assert_eq!(config.max_concurrent_tasks, 200);
assert!(config.parallel_enabled);
assert!(config.validate().is_ok());
}
#[test]
fn test_concurrency_config_sequential() {
let config = ConcurrencyConfig::sequential();
assert_eq!(config.max_concurrent_tasks, 1);
assert!(!config.parallel_enabled);
assert!(config.validate().is_ok());
}
#[test]
fn test_concurrency_config_validation_invalid() {
let config = ConcurrencyConfig {
max_concurrent_tasks: 0,
parallel_enabled: true,
};
assert!(config.validate().is_err());
}
#[test]
fn test_concurrency_config_validation_valid() {
let config = ConcurrencyConfig {
max_concurrent_tasks: 0,
parallel_enabled: false,
};
assert!(config.validate().is_ok());
let config = ConcurrencyConfig {
max_concurrent_tasks: 100,
parallel_enabled: true,
};
assert!(config.validate().is_ok());
}
}