ipfrs_interface/
tensor.rs

1//! Zero-Copy Tensor API
2//!
3//! Provides high-performance tensor access with:
4//! - Zero-copy streaming
5//! - Memory-mapped responses
6//! - Partial tensor retrieval
7//! - Range request support
8
9use axum::{
10    body::Body,
11    extract::{Path, Query, State},
12    http::{header, HeaderMap, StatusCode},
13    response::{IntoResponse, Response},
14    Json,
15};
16use ipfrs_core::Cid;
17use ipfrs_storage::BlockStoreTrait;
18use serde::{Deserialize, Serialize};
19use std::path::PathBuf;
20use std::sync::Arc;
21
22use crate::gateway::GatewayState;
23use crate::middleware::{
24    add_caching_headers, check_etag_match, not_modified_response, CacheConfig,
25};
26use crate::mmap::{MmapCache, MmapError};
27
28// ============================================================================
29// Tensor Metadata
30// ============================================================================
31
32/// Tensor shape and type information
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct TensorMetadata {
35    /// Tensor shape (dimensions)
36    pub shape: Vec<usize>,
37    /// Data type (e.g., "f32", "f64", "i32", "u8")
38    pub dtype: String,
39    /// Total number of elements
40    pub num_elements: usize,
41    /// Size in bytes
42    pub size_bytes: usize,
43    /// Layout (row-major or column-major)
44    pub layout: TensorLayout,
45}
46
47/// Tensor memory layout
48#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
49#[serde(rename_all = "lowercase")]
50pub enum TensorLayout {
51    /// Row-major (C-style)
52    RowMajor,
53    /// Column-major (Fortran-style)
54    ColumnMajor,
55}
56
57impl TensorMetadata {
58    /// Create metadata from safetensors format
59    pub fn from_safetensors_data(data: &[u8]) -> Result<Self, String> {
60        // Safetensors format: first 8 bytes = header length
61        if data.len() < 8 {
62            return Err("Data too short for safetensors format".to_string());
63        }
64
65        let header_len = u64::from_le_bytes(data[0..8].try_into().unwrap()) as usize;
66        if data.len() < 8 + header_len {
67            return Err("Incomplete safetensors header".to_string());
68        }
69
70        // Parse JSON header
71        let header_bytes = &data[8..8 + header_len];
72        let header: serde_json::Value = serde_json::from_slice(header_bytes)
73            .map_err(|e| format!("Failed to parse safetensors header: {}", e))?;
74
75        // Extract first tensor metadata
76        if let Some(tensors) = header.as_object() {
77            if let Some((_name, tensor_info)) =
78                tensors.iter().find(|(k, _)| k.as_str() != "__metadata__")
79            {
80                if let Some(shape) = tensor_info.get("shape").and_then(|s| s.as_array()) {
81                    let shape: Vec<usize> = shape
82                        .iter()
83                        .filter_map(|v| v.as_u64().map(|n| n as usize))
84                        .collect();
85
86                    let dtype = tensor_info
87                        .get("dtype")
88                        .and_then(|d| d.as_str())
89                        .unwrap_or("f32")
90                        .to_string();
91
92                    let num_elements = shape.iter().product();
93                    let element_size = Self::dtype_size(&dtype);
94                    let size_bytes = num_elements * element_size;
95
96                    return Ok(TensorMetadata {
97                        shape,
98                        dtype,
99                        num_elements,
100                        size_bytes,
101                        layout: TensorLayout::RowMajor, // Default
102                    });
103                }
104            }
105        }
106
107        Err("No tensor found in safetensors data".to_string())
108    }
109
110    /// Get size of a data type in bytes
111    fn dtype_size(dtype: &str) -> usize {
112        match dtype {
113            "f16" | "bf16" => 2,
114            "f32" | "i32" | "u32" => 4,
115            "f64" | "i64" | "u64" => 8,
116            "i8" | "u8" => 1,
117            "i16" | "u16" => 2,
118            _ => 4, // Default to 4 bytes
119        }
120    }
121
122    /// Create metadata from raw tensor data
123    pub fn from_raw(shape: Vec<usize>, dtype: String) -> Self {
124        let num_elements = shape.iter().product();
125        let element_size = Self::dtype_size(&dtype);
126        let size_bytes = num_elements * element_size;
127
128        TensorMetadata {
129            shape,
130            dtype,
131            num_elements,
132            size_bytes,
133            layout: TensorLayout::RowMajor,
134        }
135    }
136}
137
138// ============================================================================
139// Tensor Query Parameters
140// ============================================================================
141
142/// Query parameters for tensor retrieval
143#[derive(Debug, Deserialize)]
144pub struct TensorQuery {
145    /// Retrieve only metadata (no data)
146    pub metadata_only: Option<bool>,
147    /// Slice specification (e.g., "0:10,5:15" for 2D tensor)
148    pub slice: Option<String>,
149    /// Format: "raw" or "safetensors" (default: auto-detect)
150    pub format: Option<String>,
151}
152
153/// Tensor slice specification
154#[derive(Debug)]
155pub struct TensorSlice {
156    /// Slice ranges for each dimension (start, end)
157    pub ranges: Vec<(usize, Option<usize>)>,
158}
159
160impl TensorSlice {
161    /// Extract a slice from tensor data
162    ///
163    /// This performs actual data slicing for row-major tensors.
164    /// For multi-dimensional tensors, this extracts a contiguous region.
165    pub fn extract_data(&self, data: &[u8], metadata: &TensorMetadata) -> Result<Vec<u8>, String> {
166        if self.ranges.len() != metadata.shape.len() {
167            return Err(format!(
168                "Slice dimensions ({}) don't match tensor dimensions ({})",
169                self.ranges.len(),
170                metadata.shape.len()
171            ));
172        }
173
174        let element_size = TensorMetadata::dtype_size(&metadata.dtype);
175
176        // For simplicity, implement 1D and 2D slicing
177        match metadata.shape.len() {
178            1 => self.extract_1d(data, &metadata.shape, element_size),
179            2 => self.extract_2d(data, &metadata.shape, element_size),
180            _ => Err("Tensor slicing for dimensions > 2 not yet implemented".to_string()),
181        }
182    }
183
184    /// Extract 1D slice
185    fn extract_1d(
186        &self,
187        data: &[u8],
188        shape: &[usize],
189        element_size: usize,
190    ) -> Result<Vec<u8>, String> {
191        let (start, end) = (self.ranges[0].0, self.ranges[0].1.unwrap_or(shape[0]));
192
193        if start >= shape[0] || end > shape[0] || start >= end {
194            return Err(format!(
195                "Invalid 1D slice range [{}:{}] for shape [{}]",
196                start, end, shape[0]
197            ));
198        }
199
200        let byte_start = start * element_size;
201        let byte_end = end * element_size;
202
203        if byte_end > data.len() {
204            return Err(format!(
205                "Slice range {}..{} exceeds data length {}",
206                byte_start,
207                byte_end,
208                data.len()
209            ));
210        }
211
212        Ok(data[byte_start..byte_end].to_vec())
213    }
214
215    /// Extract 2D slice (row-major layout)
216    fn extract_2d(
217        &self,
218        data: &[u8],
219        shape: &[usize],
220        element_size: usize,
221    ) -> Result<Vec<u8>, String> {
222        let rows = shape[0];
223        let cols = shape[1];
224
225        let (row_start, row_end) = (self.ranges[0].0, self.ranges[0].1.unwrap_or(rows));
226        let (col_start, col_end) = (self.ranges[1].0, self.ranges[1].1.unwrap_or(cols));
227
228        if row_start >= rows || row_end > rows || row_start >= row_end {
229            return Err(format!(
230                "Invalid row range [{}:{}] for shape [{}, {}]",
231                row_start, row_end, rows, cols
232            ));
233        }
234
235        if col_start >= cols || col_end > cols || col_start >= col_end {
236            return Err(format!(
237                "Invalid column range [{}:{}] for shape [{}, {}]",
238                col_start, col_end, rows, cols
239            ));
240        }
241
242        let mut result = Vec::new();
243        let row_size = cols * element_size;
244
245        for row in row_start..row_end {
246            let row_offset = row * row_size;
247            let slice_start = row_offset + col_start * element_size;
248            let slice_end = row_offset + col_end * element_size;
249
250            if slice_end > data.len() {
251                return Err(format!(
252                    "Row {} slice range {}..{} exceeds data length {}",
253                    row,
254                    slice_start,
255                    slice_end,
256                    data.len()
257                ));
258            }
259
260            result.extend_from_slice(&data[slice_start..slice_end]);
261        }
262
263        Ok(result)
264    }
265
266    /// Parse slice string (e.g., "0:10,5:15")
267    pub fn parse(slice_str: &str) -> Result<Self, String> {
268        let ranges: Result<Vec<_>, String> = slice_str
269            .split(',')
270            .map(|part| {
271                let parts: Vec<&str> = part.split(':').collect();
272                match parts.len() {
273                    1 => {
274                        let idx = parts[0]
275                            .parse::<usize>()
276                            .map_err(|e| format!("Invalid slice index: {}", e))?;
277                        Ok((idx, Some(idx + 1)))
278                    }
279                    2 => {
280                        let start = parts[0]
281                            .parse::<usize>()
282                            .map_err(|e| format!("Invalid slice start: {}", e))?;
283                        let end = if parts[1].is_empty() {
284                            None
285                        } else {
286                            Some(
287                                parts[1]
288                                    .parse::<usize>()
289                                    .map_err(|e| format!("Invalid slice end: {}", e))?,
290                            )
291                        };
292                        Ok((start, end))
293                    }
294                    _ => Err(format!("Invalid slice format: {}", part)),
295                }
296            })
297            .collect();
298
299        Ok(TensorSlice { ranges: ranges? })
300    }
301
302    /// Calculate the slice size in bytes
303    pub fn calculate_size(&self, metadata: &TensorMetadata) -> Result<usize, String> {
304        if self.ranges.len() != metadata.shape.len() {
305            return Err(format!(
306                "Slice dimensions ({}) don't match tensor dimensions ({})",
307                self.ranges.len(),
308                metadata.shape.len()
309            ));
310        }
311
312        let mut slice_elements = 1;
313        for (i, (start, end)) in self.ranges.iter().enumerate() {
314            let dim_size = metadata.shape[i];
315            let actual_end = end.unwrap_or(dim_size);
316
317            if *start >= dim_size || actual_end > dim_size || *start >= actual_end {
318                return Err(format!(
319                    "Invalid slice range [{}:{}] for dimension {} of size {}",
320                    start, actual_end, i, dim_size
321                ));
322            }
323
324            slice_elements *= actual_end - start;
325        }
326
327        let element_size = TensorMetadata::dtype_size(&metadata.dtype);
328        Ok(slice_elements * element_size)
329    }
330}
331
332// ============================================================================
333// Tensor Responses
334// ============================================================================
335
336/// Tensor metadata response
337#[derive(Debug, Serialize)]
338pub struct TensorInfoResponse {
339    pub cid: String,
340    pub metadata: TensorMetadata,
341}
342
343// ============================================================================
344// Tensor Endpoints
345// ============================================================================
346
347/// Get tensor with zero-copy streaming
348///
349/// GET /v1/tensor/{cid}
350///
351/// Retrieves tensor data with optional range requests for partial loading.
352/// Supports both safetensors and raw binary formats.
353pub async fn get_tensor(
354    State(state): State<GatewayState>,
355    Path(cid_str): Path<String>,
356    Query(query): Query<TensorQuery>,
357    headers: HeaderMap,
358) -> Result<Response, TensorError> {
359    let cid: Cid = cid_str
360        .parse()
361        .map_err(|_| TensorError::InvalidCid(cid_str.clone()))?;
362
363    // Check if cached (ETag)
364    let cache_config = CacheConfig::default();
365    if check_etag_match(&headers, &cid_str) {
366        return Ok(not_modified_response(&cid_str, &cache_config));
367    }
368
369    // Get the block
370    let block = state
371        .store
372        .get(&cid)
373        .await
374        .map_err(|e| TensorError::Storage(e.to_string()))?
375        .ok_or_else(|| TensorError::NotFound(cid_str.clone()))?;
376
377    let data = block.data();
378
379    // Try to parse metadata (safetensors format or assume raw)
380    let metadata = TensorMetadata::from_safetensors_data(data).ok();
381
382    // If metadata_only requested, return just metadata
383    if query.metadata_only.unwrap_or(false) {
384        if let Some(metadata) = metadata {
385            return Ok(Json(TensorInfoResponse {
386                cid: cid_str,
387                metadata,
388            })
389            .into_response());
390        } else {
391            return Err(TensorError::InvalidFormat(
392                "Cannot extract metadata from tensor".to_string(),
393            ));
394        }
395    }
396
397    // Handle partial retrieval (slicing)
398    let (response_data, is_partial, metadata_for_response) = if let Some(slice_str) = query.slice {
399        let meta = metadata.ok_or_else(|| {
400            TensorError::InvalidFormat("Metadata required for slicing".to_string())
401        })?;
402
403        let slice = TensorSlice::parse(&slice_str)?;
404
405        // Extract the sliced data
406        let sliced_data = slice.extract_data(data, &meta)?;
407
408        (sliced_data, true, Some(meta))
409    } else {
410        // Return full tensor
411        (data.to_vec(), false, metadata)
412    };
413
414    // Build response
415    let mut response_builder = Response::builder();
416
417    if is_partial {
418        response_builder = response_builder.status(StatusCode::PARTIAL_CONTENT);
419    } else {
420        response_builder = response_builder.status(StatusCode::OK);
421    }
422
423    // Determine content type based on format
424    let content_type = match query.format.as_deref() {
425        Some("safetensors") | None if metadata_for_response.is_some() => {
426            "application/vnd.safetensors"
427        }
428        _ => "application/octet-stream",
429    };
430
431    let mut response = response_builder
432        .header(header::CONTENT_TYPE, content_type)
433        .header(header::CONTENT_LENGTH, response_data.len())
434        .header(
435            "X-Tensor-Format",
436            if metadata_for_response.is_some() {
437                "safetensors"
438            } else {
439                "raw"
440            },
441        )
442        .body(Body::from(response_data))
443        .unwrap();
444
445    // Add caching headers
446    add_caching_headers(response.headers_mut(), &cid_str, &cache_config);
447
448    // Add tensor metadata as headers if available
449    if let Some(ref meta) = metadata_for_response {
450        if let Ok(shape_json) = serde_json::to_string(&meta.shape) {
451            if let Ok(header_value) = header::HeaderValue::from_str(&shape_json) {
452                response
453                    .headers_mut()
454                    .insert("X-Tensor-Shape", header_value);
455            }
456        }
457        if let Ok(header_value) = header::HeaderValue::from_str(&meta.dtype) {
458            response
459                .headers_mut()
460                .insert("X-Tensor-Dtype", header_value);
461        }
462    }
463
464    Ok(response)
465}
466
467/// Get tensor metadata only
468///
469/// GET /v1/tensor/{cid}/info
470///
471/// Retrieves only tensor metadata without downloading the full data.
472pub async fn get_tensor_info(
473    State(state): State<GatewayState>,
474    Path(cid_str): Path<String>,
475) -> Result<Json<TensorInfoResponse>, TensorError> {
476    let cid: Cid = cid_str
477        .parse()
478        .map_err(|_| TensorError::InvalidCid(cid_str.clone()))?;
479
480    // Get the block
481    let block = state
482        .store
483        .get(&cid)
484        .await
485        .map_err(|e| TensorError::Storage(e.to_string()))?
486        .ok_or_else(|| TensorError::NotFound(cid_str.clone()))?;
487
488    let data = block.data();
489
490    // Parse metadata
491    let metadata = TensorMetadata::from_safetensors_data(data).map_err(|e| {
492        TensorError::InvalidFormat(format!("Failed to parse tensor metadata: {}", e))
493    })?;
494
495    Ok(Json(TensorInfoResponse {
496        cid: cid_str,
497        metadata,
498    }))
499}
500
501/// Get tensor in Apache Arrow IPC format
502///
503/// GET /v1/tensor/{cid}/arrow
504///
505/// Retrieves tensor data in Apache Arrow IPC Stream format for efficient
506/// data exchange with Arrow-compatible systems (Pandas, Polars, PyArrow, etc.)
507pub async fn get_tensor_arrow(
508    State(state): State<GatewayState>,
509    Path(cid_str): Path<String>,
510    Query(query): Query<TensorQuery>,
511) -> Result<Response, TensorError> {
512    let cid: Cid = cid_str
513        .parse()
514        .map_err(|_| TensorError::InvalidCid(cid_str.clone()))?;
515
516    // Get the block
517    let block = state
518        .store
519        .get(&cid)
520        .await
521        .map_err(|e| TensorError::Storage(e.to_string()))?
522        .ok_or_else(|| TensorError::NotFound(cid_str.clone()))?;
523
524    let data = block.data();
525
526    // Try to parse metadata (safetensors format)
527    let metadata = TensorMetadata::from_safetensors_data(data)
528        .map_err(|e| TensorError::InvalidFormat(format!("Cannot parse tensor metadata: {}", e)))?;
529
530    // Handle partial retrieval (slicing) if requested
531    let response_data = if let Some(slice_str) = query.slice {
532        let slice = TensorSlice::parse(&slice_str)?;
533        slice.extract_data(data, &metadata)?
534    } else {
535        // Return full tensor data (skip safetensors header)
536        let header_len = u64::from_le_bytes(data[0..8].try_into().unwrap()) as usize;
537        data[8 + header_len..].to_vec()
538    };
539
540    // Convert to Arrow RecordBatch and serialize
541    let batch = crate::arrow::tensor_to_record_batch(&metadata, &response_data)
542        .map_err(|e| TensorError::Storage(format!("Failed to create Arrow batch: {}", e)))?;
543
544    let ipc_bytes = crate::arrow::record_batch_to_ipc_bytes(&batch)
545        .map_err(|e| TensorError::Storage(format!("Failed to serialize Arrow IPC: {}", e)))?;
546
547    // Build response
548    Response::builder()
549        .status(StatusCode::OK)
550        .header(header::CONTENT_TYPE, "application/vnd.apache.arrow.stream")
551        .header("X-Tensor-Shape", format!("{:?}", metadata.shape))
552        .header("X-Tensor-Dtype", &metadata.dtype)
553        .header("X-Tensor-Elements", metadata.num_elements.to_string())
554        .body(Body::from(ipc_bytes))
555        .map_err(|e| TensorError::Storage(format!("Failed to build response: {}", e)))
556}
557
558// ============================================================================
559// Memory-Mapped Tensor Serving
560// ============================================================================
561
562/// Get tensor using memory-mapped I/O (zero-copy from disk)
563///
564/// GET /v1/tensor/{cid}/mmap
565///
566/// Retrieves tensor data using memory-mapped I/O for maximum performance.
567/// This endpoint is optimized for serving large tensors directly from disk
568/// without loading them into memory.
569///
570/// # Performance
571///
572/// - **Zero-copy**: Data is served directly from disk via OS page cache
573/// - **Lazy loading**: Only requested pages are loaded into memory
574/// - **OS optimizations**: Leverages sendfile and similar system calls
575///
576/// # Limitations
577///
578/// - Only works for tensors stored on local filesystem
579/// - Requires tensor file path to be available
580/// - Not suitable for tensors stored in distributed storage
581pub async fn get_tensor_mmap(
582    Path(cid_str): Path<String>,
583    Query(query): Query<TensorQuery>,
584    headers: HeaderMap,
585    mmap_cache: Arc<MmapCache>,
586    tensor_storage_path: PathBuf,
587) -> Result<Response, TensorError> {
588    let _cid: Cid = cid_str
589        .parse()
590        .map_err(|_| TensorError::InvalidCid(cid_str.clone()))?;
591
592    // Check if cached (ETag)
593    let cache_config = CacheConfig::default();
594    if check_etag_match(&headers, &cid_str) {
595        return Ok(not_modified_response(&cid_str, &cache_config));
596    }
597
598    // Construct file path from CID
599    // In production, this would use the actual storage backend's file path
600    let file_path = tensor_storage_path.join(format!("{}.tensor", cid_str));
601
602    // Get or create memory-mapped file
603    let mmap_file = mmap_cache.get_or_create(&file_path).map_err(|e| match e {
604        MmapError::FileNotFound(_) => TensorError::NotFound(cid_str.clone()),
605        _ => TensorError::Storage(e.to_string()),
606    })?;
607
608    // Get file data
609    let data = mmap_file.bytes();
610
611    // Try to parse metadata (safetensors format)
612    let metadata = TensorMetadata::from_safetensors_data(&data).ok();
613
614    // If metadata_only requested, return just metadata
615    if query.metadata_only.unwrap_or(false) {
616        if let Some(metadata) = metadata {
617            return Ok(Json(TensorInfoResponse {
618                cid: cid_str,
619                metadata,
620            })
621            .into_response());
622        } else {
623            return Err(TensorError::InvalidFormat(
624                "Cannot extract metadata from tensor".to_string(),
625            ));
626        }
627    }
628
629    // Handle partial retrieval (slicing)
630    let (response_data, is_partial, metadata_for_response) = if let Some(slice_str) = query.slice {
631        let meta = metadata.ok_or_else(|| {
632            TensorError::InvalidFormat("Metadata required for slicing".to_string())
633        })?;
634
635        let slice = TensorSlice::parse(&slice_str)?;
636
637        // For mmap, we can efficiently retrieve just the slice
638        // by calculating the byte range
639        let sliced_data = slice.extract_data(&data, &meta)?;
640
641        (sliced_data, true, Some(meta))
642    } else {
643        // Return full tensor
644        (data.to_vec(), false, metadata)
645    };
646
647    // Build response
648    let mut response_builder = Response::builder();
649
650    if is_partial {
651        response_builder = response_builder.status(StatusCode::PARTIAL_CONTENT);
652    } else {
653        response_builder = response_builder.status(StatusCode::OK);
654    }
655
656    // Determine content type
657    let content_type = match query.format.as_deref() {
658        Some("safetensors") | None if metadata_for_response.is_some() => {
659            "application/vnd.safetensors"
660        }
661        _ => "application/octet-stream",
662    };
663
664    let mut response = response_builder
665        .header(header::CONTENT_TYPE, content_type)
666        .header(header::CONTENT_LENGTH, response_data.len())
667        .header("X-Served-By", "mmap")
668        .header(
669            "X-Tensor-Format",
670            if metadata_for_response.is_some() {
671                "safetensors"
672            } else {
673                "raw"
674            },
675        )
676        .body(Body::from(response_data))
677        .unwrap();
678
679    // Add caching headers
680    add_caching_headers(response.headers_mut(), &cid_str, &cache_config);
681
682    // Add tensor metadata as headers
683    if let Some(ref meta) = metadata_for_response {
684        if let Ok(shape_json) = serde_json::to_string(&meta.shape) {
685            if let Ok(header_value) = header::HeaderValue::from_str(&shape_json) {
686                response
687                    .headers_mut()
688                    .insert("X-Tensor-Shape", header_value);
689            }
690        }
691        if let Ok(header_value) = header::HeaderValue::from_str(&meta.dtype) {
692            response
693                .headers_mut()
694                .insert("X-Tensor-Dtype", header_value);
695        }
696    }
697
698    Ok(response)
699}
700
701/// Mmap-based tensor range request
702///
703/// Efficiently serves byte ranges from memory-mapped tensor files.
704/// Optimized for HTTP 206 Partial Content responses.
705#[allow(dead_code)]
706pub async fn get_tensor_mmap_range(
707    cid_str: String,
708    range: std::ops::Range<usize>,
709    mmap_cache: Arc<MmapCache>,
710    tensor_storage_path: PathBuf,
711) -> Result<Response, TensorError> {
712    let _cid: Cid = cid_str
713        .parse()
714        .map_err(|_| TensorError::InvalidCid(cid_str.clone()))?;
715
716    // Construct file path
717    let file_path = tensor_storage_path.join(format!("{}.tensor", cid_str));
718
719    // Get memory-mapped file
720    let mmap_file = mmap_cache.get_or_create(&file_path).map_err(|e| match e {
721        MmapError::FileNotFound(_) => TensorError::NotFound(cid_str.clone()),
722        _ => TensorError::Storage(e.to_string()),
723    })?;
724
725    // Get the requested range (zero-copy)
726    let range_data = mmap_file
727        .range(range.clone())
728        .map_err(|e| TensorError::Storage(e.to_string()))?;
729
730    // Build partial content response
731    let response = Response::builder()
732        .status(StatusCode::PARTIAL_CONTENT)
733        .header(header::CONTENT_TYPE, "application/octet-stream")
734        .header(header::CONTENT_LENGTH, range_data.len())
735        .header(
736            header::CONTENT_RANGE,
737            format!(
738                "bytes {}-{}/{}",
739                range.start,
740                range.end - 1,
741                mmap_file.size()
742            ),
743        )
744        .header("X-Served-By", "mmap")
745        .body(Body::from(range_data))
746        .unwrap();
747
748    Ok(response)
749}
750
751// ============================================================================
752// Error Types
753// ============================================================================
754
755/// Tensor operation errors
756#[derive(Debug)]
757pub enum TensorError {
758    InvalidCid(String),
759    NotFound(String),
760    InvalidFormat(String),
761    Storage(String),
762    NotImplemented(String),
763}
764
765impl IntoResponse for TensorError {
766    fn into_response(self) -> Response {
767        let (status, message) = match self {
768            TensorError::InvalidCid(cid) => {
769                (StatusCode::BAD_REQUEST, format!("Invalid CID: {}", cid))
770            }
771            TensorError::NotFound(cid) => {
772                (StatusCode::NOT_FOUND, format!("Tensor not found: {}", cid))
773            }
774            TensorError::InvalidFormat(msg) => (
775                StatusCode::BAD_REQUEST,
776                format!("Invalid tensor format: {}", msg),
777            ),
778            TensorError::Storage(msg) => (
779                StatusCode::INTERNAL_SERVER_ERROR,
780                format!("Storage error: {}", msg),
781            ),
782            TensorError::NotImplemented(msg) => (
783                StatusCode::NOT_IMPLEMENTED,
784                format!("Not implemented: {}", msg),
785            ),
786        };
787
788        (status, message).into_response()
789    }
790}
791
792impl From<String> for TensorError {
793    fn from(s: String) -> Self {
794        TensorError::InvalidFormat(s)
795    }
796}
797
798#[cfg(test)]
799mod tests {
800    use super::*;
801
802    #[test]
803    fn test_tensor_metadata_dtype_size() {
804        assert_eq!(TensorMetadata::dtype_size("f32"), 4);
805        assert_eq!(TensorMetadata::dtype_size("f64"), 8);
806        assert_eq!(TensorMetadata::dtype_size("i32"), 4);
807        assert_eq!(TensorMetadata::dtype_size("u8"), 1);
808        assert_eq!(TensorMetadata::dtype_size("f16"), 2);
809    }
810
811    #[test]
812    fn test_tensor_metadata_from_raw() {
813        let meta = TensorMetadata::from_raw(vec![10, 20, 30], "f32".to_string());
814        assert_eq!(meta.shape, vec![10, 20, 30]);
815        assert_eq!(meta.dtype, "f32");
816        assert_eq!(meta.num_elements, 6000);
817        assert_eq!(meta.size_bytes, 24000);
818    }
819
820    #[test]
821    fn test_tensor_slice_parse_single() {
822        let slice = TensorSlice::parse("5").unwrap();
823        assert_eq!(slice.ranges, vec![(5, Some(6))]);
824    }
825
826    #[test]
827    fn test_tensor_slice_parse_range() {
828        let slice = TensorSlice::parse("10:20").unwrap();
829        assert_eq!(slice.ranges, vec![(10, Some(20))]);
830    }
831
832    #[test]
833    fn test_tensor_slice_parse_open_end() {
834        let slice = TensorSlice::parse("10:").unwrap();
835        assert_eq!(slice.ranges, vec![(10, None)]);
836    }
837
838    #[test]
839    fn test_tensor_slice_parse_multi_dim() {
840        let slice = TensorSlice::parse("0:10,5:15,2:8").unwrap();
841        assert_eq!(
842            slice.ranges,
843            vec![(0, Some(10)), (5, Some(15)), (2, Some(8))]
844        );
845    }
846
847    #[test]
848    fn test_tensor_slice_calculate_size() {
849        let meta = TensorMetadata::from_raw(vec![100, 100], "f32".to_string());
850        let slice = TensorSlice::parse("0:10,0:10").unwrap();
851
852        let size = slice.calculate_size(&meta).unwrap();
853        assert_eq!(size, 10 * 10 * 4); // 10x10 elements * 4 bytes
854    }
855
856    #[test]
857    fn test_tensor_slice_invalid_dimensions() {
858        let meta = TensorMetadata::from_raw(vec![100, 100], "f32".to_string());
859        let slice = TensorSlice::parse("0:10").unwrap(); // Only 1 dimension
860
861        let result = slice.calculate_size(&meta);
862        assert!(result.is_err());
863    }
864
865    #[test]
866    fn test_tensor_slice_out_of_bounds() {
867        let meta = TensorMetadata::from_raw(vec![100, 100], "f32".to_string());
868        let slice = TensorSlice::parse("0:200,0:10").unwrap();
869
870        let result = slice.calculate_size(&meta);
871        assert!(result.is_err());
872    }
873
874    #[test]
875    fn test_tensor_layout_serialization() {
876        let layout = TensorLayout::RowMajor;
877        let json = serde_json::to_string(&layout).unwrap();
878        assert_eq!(json, r#""rowmajor""#);
879
880        let layout = TensorLayout::ColumnMajor;
881        let json = serde_json::to_string(&layout).unwrap();
882        assert_eq!(json, r#""columnmajor""#);
883    }
884
885    #[test]
886    fn test_tensor_slice_extract_1d() {
887        // 1D tensor: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] (f32)
888        let data: Vec<u8> = (0..10).flat_map(|i| (i as f32).to_le_bytes()).collect();
889
890        let meta = TensorMetadata::from_raw(vec![10], "f32".to_string());
891        let slice = TensorSlice::parse("2:5").unwrap();
892
893        let result = slice.extract_data(&data, &meta).unwrap();
894
895        // Should extract elements 2, 3, 4 (3 elements * 4 bytes = 12 bytes)
896        assert_eq!(result.len(), 12);
897
898        // Verify the extracted values
899        let values: Vec<f32> = result
900            .chunks_exact(4)
901            .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
902            .collect();
903
904        assert_eq!(values, vec![2.0, 3.0, 4.0]);
905    }
906
907    #[test]
908    fn test_tensor_slice_extract_2d() {
909        // 2D tensor: 4x3 matrix (f32)
910        // [[0, 1, 2],
911        //  [3, 4, 5],
912        //  [6, 7, 8],
913        //  [9, 10, 11]]
914        let data: Vec<u8> = (0..12).flat_map(|i| (i as f32).to_le_bytes()).collect();
915
916        let meta = TensorMetadata::from_raw(vec![4, 3], "f32".to_string());
917        let slice = TensorSlice::parse("1:3,0:2").unwrap(); // Rows 1-2, Cols 0-1
918
919        let result = slice.extract_data(&data, &meta).unwrap();
920
921        // Should extract:
922        // [[3, 4],
923        //  [6, 7]]
924        // 2 rows * 2 cols * 4 bytes = 16 bytes
925        assert_eq!(result.len(), 16);
926
927        let values: Vec<f32> = result
928            .chunks_exact(4)
929            .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
930            .collect();
931
932        assert_eq!(values, vec![3.0, 4.0, 6.0, 7.0]);
933    }
934
935    #[test]
936    fn test_tensor_slice_extract_2d_single_row() {
937        let data: Vec<u8> = (0..12).flat_map(|i| (i as f32).to_le_bytes()).collect();
938
939        let meta = TensorMetadata::from_raw(vec![4, 3], "f32".to_string());
940        let slice = TensorSlice::parse("2:3,0:3").unwrap(); // Row 2, all columns
941
942        let result = slice.extract_data(&data, &meta).unwrap();
943
944        // Should extract: [6, 7, 8]
945        assert_eq!(result.len(), 12);
946
947        let values: Vec<f32> = result
948            .chunks_exact(4)
949            .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
950            .collect();
951
952        assert_eq!(values, vec![6.0, 7.0, 8.0]);
953    }
954
955    #[test]
956    fn test_tensor_slice_extract_invalid_dimension() {
957        let data = vec![0u8; 40]; // 10 f32 elements
958        let meta = TensorMetadata::from_raw(vec![10], "f32".to_string());
959        let slice = TensorSlice::parse("2:5,0:2").unwrap(); // 2D slice for 1D tensor
960
961        let result = slice.extract_data(&data, &meta);
962        assert!(result.is_err());
963    }
964
965    #[test]
966    fn test_tensor_slice_extract_out_of_bounds() {
967        let data: Vec<u8> = (0..10).flat_map(|i| (i as f32).to_le_bytes()).collect();
968
969        let meta = TensorMetadata::from_raw(vec![10], "f32".to_string());
970        let slice = TensorSlice::parse("8:12").unwrap(); // Out of bounds
971
972        let result = slice.extract_data(&data, &meta);
973        assert!(result.is_err());
974    }
975}