1use axum::{
10 body::Body,
11 extract::{Path, Query, State},
12 http::{header, HeaderMap, StatusCode},
13 response::{IntoResponse, Response},
14 Json,
15};
16use ipfrs_core::Cid;
17use ipfrs_storage::BlockStoreTrait;
18use serde::{Deserialize, Serialize};
19use std::path::PathBuf;
20use std::sync::Arc;
21
22use crate::gateway::GatewayState;
23use crate::middleware::{
24 add_caching_headers, check_etag_match, not_modified_response, CacheConfig,
25};
26use crate::mmap::{MmapCache, MmapError};
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct TensorMetadata {
35 pub shape: Vec<usize>,
37 pub dtype: String,
39 pub num_elements: usize,
41 pub size_bytes: usize,
43 pub layout: TensorLayout,
45}
46
47#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
49#[serde(rename_all = "lowercase")]
50pub enum TensorLayout {
51 RowMajor,
53 ColumnMajor,
55}
56
57impl TensorMetadata {
58 pub fn from_safetensors_data(data: &[u8]) -> Result<Self, String> {
60 if data.len() < 8 {
62 return Err("Data too short for safetensors format".to_string());
63 }
64
65 let header_len = u64::from_le_bytes(data[0..8].try_into().unwrap()) as usize;
66 if data.len() < 8 + header_len {
67 return Err("Incomplete safetensors header".to_string());
68 }
69
70 let header_bytes = &data[8..8 + header_len];
72 let header: serde_json::Value = serde_json::from_slice(header_bytes)
73 .map_err(|e| format!("Failed to parse safetensors header: {}", e))?;
74
75 if let Some(tensors) = header.as_object() {
77 if let Some((_name, tensor_info)) =
78 tensors.iter().find(|(k, _)| k.as_str() != "__metadata__")
79 {
80 if let Some(shape) = tensor_info.get("shape").and_then(|s| s.as_array()) {
81 let shape: Vec<usize> = shape
82 .iter()
83 .filter_map(|v| v.as_u64().map(|n| n as usize))
84 .collect();
85
86 let dtype = tensor_info
87 .get("dtype")
88 .and_then(|d| d.as_str())
89 .unwrap_or("f32")
90 .to_string();
91
92 let num_elements = shape.iter().product();
93 let element_size = Self::dtype_size(&dtype);
94 let size_bytes = num_elements * element_size;
95
96 return Ok(TensorMetadata {
97 shape,
98 dtype,
99 num_elements,
100 size_bytes,
101 layout: TensorLayout::RowMajor, });
103 }
104 }
105 }
106
107 Err("No tensor found in safetensors data".to_string())
108 }
109
110 fn dtype_size(dtype: &str) -> usize {
112 match dtype {
113 "f16" | "bf16" => 2,
114 "f32" | "i32" | "u32" => 4,
115 "f64" | "i64" | "u64" => 8,
116 "i8" | "u8" => 1,
117 "i16" | "u16" => 2,
118 _ => 4, }
120 }
121
122 pub fn from_raw(shape: Vec<usize>, dtype: String) -> Self {
124 let num_elements = shape.iter().product();
125 let element_size = Self::dtype_size(&dtype);
126 let size_bytes = num_elements * element_size;
127
128 TensorMetadata {
129 shape,
130 dtype,
131 num_elements,
132 size_bytes,
133 layout: TensorLayout::RowMajor,
134 }
135 }
136}
137
138#[derive(Debug, Deserialize)]
144pub struct TensorQuery {
145 pub metadata_only: Option<bool>,
147 pub slice: Option<String>,
149 pub format: Option<String>,
151}
152
153#[derive(Debug)]
155pub struct TensorSlice {
156 pub ranges: Vec<(usize, Option<usize>)>,
158}
159
160impl TensorSlice {
161 pub fn extract_data(&self, data: &[u8], metadata: &TensorMetadata) -> Result<Vec<u8>, String> {
166 if self.ranges.len() != metadata.shape.len() {
167 return Err(format!(
168 "Slice dimensions ({}) don't match tensor dimensions ({})",
169 self.ranges.len(),
170 metadata.shape.len()
171 ));
172 }
173
174 let element_size = TensorMetadata::dtype_size(&metadata.dtype);
175
176 match metadata.shape.len() {
178 1 => self.extract_1d(data, &metadata.shape, element_size),
179 2 => self.extract_2d(data, &metadata.shape, element_size),
180 _ => Err("Tensor slicing for dimensions > 2 not yet implemented".to_string()),
181 }
182 }
183
184 fn extract_1d(
186 &self,
187 data: &[u8],
188 shape: &[usize],
189 element_size: usize,
190 ) -> Result<Vec<u8>, String> {
191 let (start, end) = (self.ranges[0].0, self.ranges[0].1.unwrap_or(shape[0]));
192
193 if start >= shape[0] || end > shape[0] || start >= end {
194 return Err(format!(
195 "Invalid 1D slice range [{}:{}] for shape [{}]",
196 start, end, shape[0]
197 ));
198 }
199
200 let byte_start = start * element_size;
201 let byte_end = end * element_size;
202
203 if byte_end > data.len() {
204 return Err(format!(
205 "Slice range {}..{} exceeds data length {}",
206 byte_start,
207 byte_end,
208 data.len()
209 ));
210 }
211
212 Ok(data[byte_start..byte_end].to_vec())
213 }
214
215 fn extract_2d(
217 &self,
218 data: &[u8],
219 shape: &[usize],
220 element_size: usize,
221 ) -> Result<Vec<u8>, String> {
222 let rows = shape[0];
223 let cols = shape[1];
224
225 let (row_start, row_end) = (self.ranges[0].0, self.ranges[0].1.unwrap_or(rows));
226 let (col_start, col_end) = (self.ranges[1].0, self.ranges[1].1.unwrap_or(cols));
227
228 if row_start >= rows || row_end > rows || row_start >= row_end {
229 return Err(format!(
230 "Invalid row range [{}:{}] for shape [{}, {}]",
231 row_start, row_end, rows, cols
232 ));
233 }
234
235 if col_start >= cols || col_end > cols || col_start >= col_end {
236 return Err(format!(
237 "Invalid column range [{}:{}] for shape [{}, {}]",
238 col_start, col_end, rows, cols
239 ));
240 }
241
242 let mut result = Vec::new();
243 let row_size = cols * element_size;
244
245 for row in row_start..row_end {
246 let row_offset = row * row_size;
247 let slice_start = row_offset + col_start * element_size;
248 let slice_end = row_offset + col_end * element_size;
249
250 if slice_end > data.len() {
251 return Err(format!(
252 "Row {} slice range {}..{} exceeds data length {}",
253 row,
254 slice_start,
255 slice_end,
256 data.len()
257 ));
258 }
259
260 result.extend_from_slice(&data[slice_start..slice_end]);
261 }
262
263 Ok(result)
264 }
265
266 pub fn parse(slice_str: &str) -> Result<Self, String> {
268 let ranges: Result<Vec<_>, String> = slice_str
269 .split(',')
270 .map(|part| {
271 let parts: Vec<&str> = part.split(':').collect();
272 match parts.len() {
273 1 => {
274 let idx = parts[0]
275 .parse::<usize>()
276 .map_err(|e| format!("Invalid slice index: {}", e))?;
277 Ok((idx, Some(idx + 1)))
278 }
279 2 => {
280 let start = parts[0]
281 .parse::<usize>()
282 .map_err(|e| format!("Invalid slice start: {}", e))?;
283 let end = if parts[1].is_empty() {
284 None
285 } else {
286 Some(
287 parts[1]
288 .parse::<usize>()
289 .map_err(|e| format!("Invalid slice end: {}", e))?,
290 )
291 };
292 Ok((start, end))
293 }
294 _ => Err(format!("Invalid slice format: {}", part)),
295 }
296 })
297 .collect();
298
299 Ok(TensorSlice { ranges: ranges? })
300 }
301
302 pub fn calculate_size(&self, metadata: &TensorMetadata) -> Result<usize, String> {
304 if self.ranges.len() != metadata.shape.len() {
305 return Err(format!(
306 "Slice dimensions ({}) don't match tensor dimensions ({})",
307 self.ranges.len(),
308 metadata.shape.len()
309 ));
310 }
311
312 let mut slice_elements = 1;
313 for (i, (start, end)) in self.ranges.iter().enumerate() {
314 let dim_size = metadata.shape[i];
315 let actual_end = end.unwrap_or(dim_size);
316
317 if *start >= dim_size || actual_end > dim_size || *start >= actual_end {
318 return Err(format!(
319 "Invalid slice range [{}:{}] for dimension {} of size {}",
320 start, actual_end, i, dim_size
321 ));
322 }
323
324 slice_elements *= actual_end - start;
325 }
326
327 let element_size = TensorMetadata::dtype_size(&metadata.dtype);
328 Ok(slice_elements * element_size)
329 }
330}
331
332#[derive(Debug, Serialize)]
338pub struct TensorInfoResponse {
339 pub cid: String,
340 pub metadata: TensorMetadata,
341}
342
343pub async fn get_tensor(
354 State(state): State<GatewayState>,
355 Path(cid_str): Path<String>,
356 Query(query): Query<TensorQuery>,
357 headers: HeaderMap,
358) -> Result<Response, TensorError> {
359 let cid: Cid = cid_str
360 .parse()
361 .map_err(|_| TensorError::InvalidCid(cid_str.clone()))?;
362
363 let cache_config = CacheConfig::default();
365 if check_etag_match(&headers, &cid_str) {
366 return Ok(not_modified_response(&cid_str, &cache_config));
367 }
368
369 let block = state
371 .store
372 .get(&cid)
373 .await
374 .map_err(|e| TensorError::Storage(e.to_string()))?
375 .ok_or_else(|| TensorError::NotFound(cid_str.clone()))?;
376
377 let data = block.data();
378
379 let metadata = TensorMetadata::from_safetensors_data(data).ok();
381
382 if query.metadata_only.unwrap_or(false) {
384 if let Some(metadata) = metadata {
385 return Ok(Json(TensorInfoResponse {
386 cid: cid_str,
387 metadata,
388 })
389 .into_response());
390 } else {
391 return Err(TensorError::InvalidFormat(
392 "Cannot extract metadata from tensor".to_string(),
393 ));
394 }
395 }
396
397 let (response_data, is_partial, metadata_for_response) = if let Some(slice_str) = query.slice {
399 let meta = metadata.ok_or_else(|| {
400 TensorError::InvalidFormat("Metadata required for slicing".to_string())
401 })?;
402
403 let slice = TensorSlice::parse(&slice_str)?;
404
405 let sliced_data = slice.extract_data(data, &meta)?;
407
408 (sliced_data, true, Some(meta))
409 } else {
410 (data.to_vec(), false, metadata)
412 };
413
414 let mut response_builder = Response::builder();
416
417 if is_partial {
418 response_builder = response_builder.status(StatusCode::PARTIAL_CONTENT);
419 } else {
420 response_builder = response_builder.status(StatusCode::OK);
421 }
422
423 let content_type = match query.format.as_deref() {
425 Some("safetensors") | None if metadata_for_response.is_some() => {
426 "application/vnd.safetensors"
427 }
428 _ => "application/octet-stream",
429 };
430
431 let mut response = response_builder
432 .header(header::CONTENT_TYPE, content_type)
433 .header(header::CONTENT_LENGTH, response_data.len())
434 .header(
435 "X-Tensor-Format",
436 if metadata_for_response.is_some() {
437 "safetensors"
438 } else {
439 "raw"
440 },
441 )
442 .body(Body::from(response_data))
443 .unwrap();
444
445 add_caching_headers(response.headers_mut(), &cid_str, &cache_config);
447
448 if let Some(ref meta) = metadata_for_response {
450 if let Ok(shape_json) = serde_json::to_string(&meta.shape) {
451 if let Ok(header_value) = header::HeaderValue::from_str(&shape_json) {
452 response
453 .headers_mut()
454 .insert("X-Tensor-Shape", header_value);
455 }
456 }
457 if let Ok(header_value) = header::HeaderValue::from_str(&meta.dtype) {
458 response
459 .headers_mut()
460 .insert("X-Tensor-Dtype", header_value);
461 }
462 }
463
464 Ok(response)
465}
466
467pub async fn get_tensor_info(
473 State(state): State<GatewayState>,
474 Path(cid_str): Path<String>,
475) -> Result<Json<TensorInfoResponse>, TensorError> {
476 let cid: Cid = cid_str
477 .parse()
478 .map_err(|_| TensorError::InvalidCid(cid_str.clone()))?;
479
480 let block = state
482 .store
483 .get(&cid)
484 .await
485 .map_err(|e| TensorError::Storage(e.to_string()))?
486 .ok_or_else(|| TensorError::NotFound(cid_str.clone()))?;
487
488 let data = block.data();
489
490 let metadata = TensorMetadata::from_safetensors_data(data).map_err(|e| {
492 TensorError::InvalidFormat(format!("Failed to parse tensor metadata: {}", e))
493 })?;
494
495 Ok(Json(TensorInfoResponse {
496 cid: cid_str,
497 metadata,
498 }))
499}
500
501pub async fn get_tensor_arrow(
508 State(state): State<GatewayState>,
509 Path(cid_str): Path<String>,
510 Query(query): Query<TensorQuery>,
511) -> Result<Response, TensorError> {
512 let cid: Cid = cid_str
513 .parse()
514 .map_err(|_| TensorError::InvalidCid(cid_str.clone()))?;
515
516 let block = state
518 .store
519 .get(&cid)
520 .await
521 .map_err(|e| TensorError::Storage(e.to_string()))?
522 .ok_or_else(|| TensorError::NotFound(cid_str.clone()))?;
523
524 let data = block.data();
525
526 let metadata = TensorMetadata::from_safetensors_data(data)
528 .map_err(|e| TensorError::InvalidFormat(format!("Cannot parse tensor metadata: {}", e)))?;
529
530 let response_data = if let Some(slice_str) = query.slice {
532 let slice = TensorSlice::parse(&slice_str)?;
533 slice.extract_data(data, &metadata)?
534 } else {
535 let header_len = u64::from_le_bytes(data[0..8].try_into().unwrap()) as usize;
537 data[8 + header_len..].to_vec()
538 };
539
540 let batch = crate::arrow::tensor_to_record_batch(&metadata, &response_data)
542 .map_err(|e| TensorError::Storage(format!("Failed to create Arrow batch: {}", e)))?;
543
544 let ipc_bytes = crate::arrow::record_batch_to_ipc_bytes(&batch)
545 .map_err(|e| TensorError::Storage(format!("Failed to serialize Arrow IPC: {}", e)))?;
546
547 Response::builder()
549 .status(StatusCode::OK)
550 .header(header::CONTENT_TYPE, "application/vnd.apache.arrow.stream")
551 .header("X-Tensor-Shape", format!("{:?}", metadata.shape))
552 .header("X-Tensor-Dtype", &metadata.dtype)
553 .header("X-Tensor-Elements", metadata.num_elements.to_string())
554 .body(Body::from(ipc_bytes))
555 .map_err(|e| TensorError::Storage(format!("Failed to build response: {}", e)))
556}
557
558pub async fn get_tensor_mmap(
582 Path(cid_str): Path<String>,
583 Query(query): Query<TensorQuery>,
584 headers: HeaderMap,
585 mmap_cache: Arc<MmapCache>,
586 tensor_storage_path: PathBuf,
587) -> Result<Response, TensorError> {
588 let _cid: Cid = cid_str
589 .parse()
590 .map_err(|_| TensorError::InvalidCid(cid_str.clone()))?;
591
592 let cache_config = CacheConfig::default();
594 if check_etag_match(&headers, &cid_str) {
595 return Ok(not_modified_response(&cid_str, &cache_config));
596 }
597
598 let file_path = tensor_storage_path.join(format!("{}.tensor", cid_str));
601
602 let mmap_file = mmap_cache.get_or_create(&file_path).map_err(|e| match e {
604 MmapError::FileNotFound(_) => TensorError::NotFound(cid_str.clone()),
605 _ => TensorError::Storage(e.to_string()),
606 })?;
607
608 let data = mmap_file.bytes();
610
611 let metadata = TensorMetadata::from_safetensors_data(&data).ok();
613
614 if query.metadata_only.unwrap_or(false) {
616 if let Some(metadata) = metadata {
617 return Ok(Json(TensorInfoResponse {
618 cid: cid_str,
619 metadata,
620 })
621 .into_response());
622 } else {
623 return Err(TensorError::InvalidFormat(
624 "Cannot extract metadata from tensor".to_string(),
625 ));
626 }
627 }
628
629 let (response_data, is_partial, metadata_for_response) = if let Some(slice_str) = query.slice {
631 let meta = metadata.ok_or_else(|| {
632 TensorError::InvalidFormat("Metadata required for slicing".to_string())
633 })?;
634
635 let slice = TensorSlice::parse(&slice_str)?;
636
637 let sliced_data = slice.extract_data(&data, &meta)?;
640
641 (sliced_data, true, Some(meta))
642 } else {
643 (data.to_vec(), false, metadata)
645 };
646
647 let mut response_builder = Response::builder();
649
650 if is_partial {
651 response_builder = response_builder.status(StatusCode::PARTIAL_CONTENT);
652 } else {
653 response_builder = response_builder.status(StatusCode::OK);
654 }
655
656 let content_type = match query.format.as_deref() {
658 Some("safetensors") | None if metadata_for_response.is_some() => {
659 "application/vnd.safetensors"
660 }
661 _ => "application/octet-stream",
662 };
663
664 let mut response = response_builder
665 .header(header::CONTENT_TYPE, content_type)
666 .header(header::CONTENT_LENGTH, response_data.len())
667 .header("X-Served-By", "mmap")
668 .header(
669 "X-Tensor-Format",
670 if metadata_for_response.is_some() {
671 "safetensors"
672 } else {
673 "raw"
674 },
675 )
676 .body(Body::from(response_data))
677 .unwrap();
678
679 add_caching_headers(response.headers_mut(), &cid_str, &cache_config);
681
682 if let Some(ref meta) = metadata_for_response {
684 if let Ok(shape_json) = serde_json::to_string(&meta.shape) {
685 if let Ok(header_value) = header::HeaderValue::from_str(&shape_json) {
686 response
687 .headers_mut()
688 .insert("X-Tensor-Shape", header_value);
689 }
690 }
691 if let Ok(header_value) = header::HeaderValue::from_str(&meta.dtype) {
692 response
693 .headers_mut()
694 .insert("X-Tensor-Dtype", header_value);
695 }
696 }
697
698 Ok(response)
699}
700
701#[allow(dead_code)]
706pub async fn get_tensor_mmap_range(
707 cid_str: String,
708 range: std::ops::Range<usize>,
709 mmap_cache: Arc<MmapCache>,
710 tensor_storage_path: PathBuf,
711) -> Result<Response, TensorError> {
712 let _cid: Cid = cid_str
713 .parse()
714 .map_err(|_| TensorError::InvalidCid(cid_str.clone()))?;
715
716 let file_path = tensor_storage_path.join(format!("{}.tensor", cid_str));
718
719 let mmap_file = mmap_cache.get_or_create(&file_path).map_err(|e| match e {
721 MmapError::FileNotFound(_) => TensorError::NotFound(cid_str.clone()),
722 _ => TensorError::Storage(e.to_string()),
723 })?;
724
725 let range_data = mmap_file
727 .range(range.clone())
728 .map_err(|e| TensorError::Storage(e.to_string()))?;
729
730 let response = Response::builder()
732 .status(StatusCode::PARTIAL_CONTENT)
733 .header(header::CONTENT_TYPE, "application/octet-stream")
734 .header(header::CONTENT_LENGTH, range_data.len())
735 .header(
736 header::CONTENT_RANGE,
737 format!(
738 "bytes {}-{}/{}",
739 range.start,
740 range.end - 1,
741 mmap_file.size()
742 ),
743 )
744 .header("X-Served-By", "mmap")
745 .body(Body::from(range_data))
746 .unwrap();
747
748 Ok(response)
749}
750
751#[derive(Debug)]
757pub enum TensorError {
758 InvalidCid(String),
759 NotFound(String),
760 InvalidFormat(String),
761 Storage(String),
762 NotImplemented(String),
763}
764
765impl IntoResponse for TensorError {
766 fn into_response(self) -> Response {
767 let (status, message) = match self {
768 TensorError::InvalidCid(cid) => {
769 (StatusCode::BAD_REQUEST, format!("Invalid CID: {}", cid))
770 }
771 TensorError::NotFound(cid) => {
772 (StatusCode::NOT_FOUND, format!("Tensor not found: {}", cid))
773 }
774 TensorError::InvalidFormat(msg) => (
775 StatusCode::BAD_REQUEST,
776 format!("Invalid tensor format: {}", msg),
777 ),
778 TensorError::Storage(msg) => (
779 StatusCode::INTERNAL_SERVER_ERROR,
780 format!("Storage error: {}", msg),
781 ),
782 TensorError::NotImplemented(msg) => (
783 StatusCode::NOT_IMPLEMENTED,
784 format!("Not implemented: {}", msg),
785 ),
786 };
787
788 (status, message).into_response()
789 }
790}
791
792impl From<String> for TensorError {
793 fn from(s: String) -> Self {
794 TensorError::InvalidFormat(s)
795 }
796}
797
798#[cfg(test)]
799mod tests {
800 use super::*;
801
802 #[test]
803 fn test_tensor_metadata_dtype_size() {
804 assert_eq!(TensorMetadata::dtype_size("f32"), 4);
805 assert_eq!(TensorMetadata::dtype_size("f64"), 8);
806 assert_eq!(TensorMetadata::dtype_size("i32"), 4);
807 assert_eq!(TensorMetadata::dtype_size("u8"), 1);
808 assert_eq!(TensorMetadata::dtype_size("f16"), 2);
809 }
810
811 #[test]
812 fn test_tensor_metadata_from_raw() {
813 let meta = TensorMetadata::from_raw(vec![10, 20, 30], "f32".to_string());
814 assert_eq!(meta.shape, vec![10, 20, 30]);
815 assert_eq!(meta.dtype, "f32");
816 assert_eq!(meta.num_elements, 6000);
817 assert_eq!(meta.size_bytes, 24000);
818 }
819
820 #[test]
821 fn test_tensor_slice_parse_single() {
822 let slice = TensorSlice::parse("5").unwrap();
823 assert_eq!(slice.ranges, vec![(5, Some(6))]);
824 }
825
826 #[test]
827 fn test_tensor_slice_parse_range() {
828 let slice = TensorSlice::parse("10:20").unwrap();
829 assert_eq!(slice.ranges, vec![(10, Some(20))]);
830 }
831
832 #[test]
833 fn test_tensor_slice_parse_open_end() {
834 let slice = TensorSlice::parse("10:").unwrap();
835 assert_eq!(slice.ranges, vec![(10, None)]);
836 }
837
838 #[test]
839 fn test_tensor_slice_parse_multi_dim() {
840 let slice = TensorSlice::parse("0:10,5:15,2:8").unwrap();
841 assert_eq!(
842 slice.ranges,
843 vec![(0, Some(10)), (5, Some(15)), (2, Some(8))]
844 );
845 }
846
847 #[test]
848 fn test_tensor_slice_calculate_size() {
849 let meta = TensorMetadata::from_raw(vec![100, 100], "f32".to_string());
850 let slice = TensorSlice::parse("0:10,0:10").unwrap();
851
852 let size = slice.calculate_size(&meta).unwrap();
853 assert_eq!(size, 10 * 10 * 4); }
855
856 #[test]
857 fn test_tensor_slice_invalid_dimensions() {
858 let meta = TensorMetadata::from_raw(vec![100, 100], "f32".to_string());
859 let slice = TensorSlice::parse("0:10").unwrap(); let result = slice.calculate_size(&meta);
862 assert!(result.is_err());
863 }
864
865 #[test]
866 fn test_tensor_slice_out_of_bounds() {
867 let meta = TensorMetadata::from_raw(vec![100, 100], "f32".to_string());
868 let slice = TensorSlice::parse("0:200,0:10").unwrap();
869
870 let result = slice.calculate_size(&meta);
871 assert!(result.is_err());
872 }
873
874 #[test]
875 fn test_tensor_layout_serialization() {
876 let layout = TensorLayout::RowMajor;
877 let json = serde_json::to_string(&layout).unwrap();
878 assert_eq!(json, r#""rowmajor""#);
879
880 let layout = TensorLayout::ColumnMajor;
881 let json = serde_json::to_string(&layout).unwrap();
882 assert_eq!(json, r#""columnmajor""#);
883 }
884
885 #[test]
886 fn test_tensor_slice_extract_1d() {
887 let data: Vec<u8> = (0..10).flat_map(|i| (i as f32).to_le_bytes()).collect();
889
890 let meta = TensorMetadata::from_raw(vec![10], "f32".to_string());
891 let slice = TensorSlice::parse("2:5").unwrap();
892
893 let result = slice.extract_data(&data, &meta).unwrap();
894
895 assert_eq!(result.len(), 12);
897
898 let values: Vec<f32> = result
900 .chunks_exact(4)
901 .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
902 .collect();
903
904 assert_eq!(values, vec![2.0, 3.0, 4.0]);
905 }
906
907 #[test]
908 fn test_tensor_slice_extract_2d() {
909 let data: Vec<u8> = (0..12).flat_map(|i| (i as f32).to_le_bytes()).collect();
915
916 let meta = TensorMetadata::from_raw(vec![4, 3], "f32".to_string());
917 let slice = TensorSlice::parse("1:3,0:2").unwrap(); let result = slice.extract_data(&data, &meta).unwrap();
920
921 assert_eq!(result.len(), 16);
926
927 let values: Vec<f32> = result
928 .chunks_exact(4)
929 .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
930 .collect();
931
932 assert_eq!(values, vec![3.0, 4.0, 6.0, 7.0]);
933 }
934
935 #[test]
936 fn test_tensor_slice_extract_2d_single_row() {
937 let data: Vec<u8> = (0..12).flat_map(|i| (i as f32).to_le_bytes()).collect();
938
939 let meta = TensorMetadata::from_raw(vec![4, 3], "f32".to_string());
940 let slice = TensorSlice::parse("2:3,0:3").unwrap(); let result = slice.extract_data(&data, &meta).unwrap();
943
944 assert_eq!(result.len(), 12);
946
947 let values: Vec<f32> = result
948 .chunks_exact(4)
949 .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
950 .collect();
951
952 assert_eq!(values, vec![6.0, 7.0, 8.0]);
953 }
954
955 #[test]
956 fn test_tensor_slice_extract_invalid_dimension() {
957 let data = vec![0u8; 40]; let meta = TensorMetadata::from_raw(vec![10], "f32".to_string());
959 let slice = TensorSlice::parse("2:5,0:2").unwrap(); let result = slice.extract_data(&data, &meta);
962 assert!(result.is_err());
963 }
964
965 #[test]
966 fn test_tensor_slice_extract_out_of_bounds() {
967 let data: Vec<u8> = (0..10).flat_map(|i| (i as f32).to_le_bytes()).collect();
968
969 let meta = TensorMetadata::from_raw(vec![10], "f32".to_string());
970 let slice = TensorSlice::parse("8:12").unwrap(); let result = slice.extract_data(&data, &meta);
973 assert!(result.is_err());
974 }
975}