Skip to main content

tensor_data/
lib.rs

1#![doc = include_str!("../README.md")]
2
3pub mod surface;
4use std::collections::BTreeMap;
5
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use video_analysis_core::{DetectError, Result};
9
10fn invalid_argument(message: impl Into<String>) -> DetectError {
11    DetectError::InvalidArgument(message.into())
12}
13
14#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
15/// Checked tensor dimensions with a non-zero rank and non-zero extents.
16pub struct TensorShape {
17    dims: Vec<usize>,
18}
19
20impl TensorShape {
21    /// Creates a tensor shape after validating rank, extents, and element count.
22    pub fn new(dims: impl Into<Vec<usize>>) -> Result<Self> {
23        let shape = Self { dims: dims.into() };
24        shape.validate()?;
25        Ok(shape)
26    }
27
28    /// Borrows the dimensions in storage order.
29    pub fn dimensions(&self) -> &[usize] {
30        &self.dims
31    }
32
33    /// Returns the number of dimensions.
34    pub fn rank(&self) -> usize {
35        self.dims.len()
36    }
37
38    /// Multiplies all dimensions and fails if the count overflows `usize`.
39    pub fn element_count(&self) -> Result<usize> {
40        self.dims.iter().try_fold(1_usize, |count, dimension| {
41            count
42                .checked_mul(*dimension)
43                .ok_or_else(|| invalid_argument("tensor shape element count overflowed usize"))
44        })
45    }
46
47    /// Builds a new shape with the same element count as this shape.
48    pub fn reshape(&self, dims: impl Into<Vec<usize>>) -> Result<Self> {
49        let reshaped = Self::new(dims)?;
50        if reshaped.element_count()? != self.element_count()? {
51            return Err(invalid_argument(format!(
52                "cannot reshape tensor with {} elements into {} elements",
53                self.element_count()?,
54                reshaped.element_count()?
55            )));
56        }
57        Ok(reshaped)
58    }
59
60    fn validate(&self) -> Result<()> {
61        if self.dims.is_empty() {
62            return Err(invalid_argument(
63                "tensor shape must have at least one dimension",
64            ));
65        }
66        if self.dims.contains(&0) {
67            return Err(invalid_argument(
68                "tensor shape dimensions must be greater than zero",
69            ));
70        }
71        let _ = self.element_count()?;
72        Ok(())
73    }
74}
75
76#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
77/// Owned finite `f32` tensor values plus optional JSON metadata.
78pub struct F32Tensor {
79    shape: TensorShape,
80    values: Vec<f32>,
81    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
82    metadata: BTreeMap<String, Value>,
83}
84
85impl F32Tensor {
86    /// Creates an owned tensor when shape and finite-value validation pass.
87    pub fn new(shape: TensorShape, values: Vec<f32>) -> Result<Self> {
88        let tensor = Self {
89            shape,
90            values,
91            metadata: BTreeMap::new(),
92        };
93        tensor.validate()?;
94        Ok(tensor)
95    }
96
97    /// Creates an owned tensor from raw dimensions and values.
98    pub fn from_dims(dims: impl Into<Vec<usize>>, values: Vec<f32>) -> Result<Self> {
99        Self::new(TensorShape::new(dims)?, values)
100    }
101
102    /// Borrows the checked tensor shape.
103    pub fn shape(&self) -> &TensorShape {
104        &self.shape
105    }
106
107    /// Borrows the contiguous tensor values.
108    pub fn values(&self) -> &[f32] {
109        &self.values
110    }
111
112    /// Consumes the tensor and returns its contiguous values.
113    pub fn into_values(self) -> Vec<f32> {
114        self.values
115    }
116
117    /// Borrows optional transport metadata attached to the tensor.
118    pub fn metadata(&self) -> &BTreeMap<String, Value> {
119        &self.metadata
120    }
121
122    /// Attaches one metadata entry and returns the updated tensor.
123    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
124        self.metadata.insert(key.into(), value.into());
125        self
126    }
127
128    /// Inserts or replaces one metadata entry in place.
129    pub fn set_metadata(&mut self, key: impl Into<String>, value: impl Into<Value>) -> &mut Self {
130        self.metadata.insert(key.into(), value.into());
131        self
132    }
133
134    /// Changes only the shape metadata after verifying the element count is unchanged.
135    pub fn reshape(mut self, dims: impl Into<Vec<usize>>) -> Result<Self> {
136        self.shape = self.shape.reshape(dims)?;
137        Ok(self)
138    }
139
140    /// Borrows this tensor as a validated view that shares the value slice.
141    pub fn as_view(&self) -> F32TensorView<'_> {
142        F32TensorView {
143            shape: self.shape.clone(),
144            values: &self.values,
145            metadata: self.metadata.clone(),
146        }
147    }
148
149    /// Verifies shape/value count agreement and rejects non-finite values.
150    pub fn validate(&self) -> Result<()> {
151        let expected = self.shape.element_count()?;
152        if expected != self.values.len() {
153            return Err(invalid_argument(format!(
154                "tensor shape expects {expected} elements but tensor has {}",
155                self.values.len()
156            )));
157        }
158        if self.values.iter().any(|value| !value.is_finite()) {
159            return Err(invalid_argument("tensor values must be finite"));
160        }
161        Ok(())
162    }
163}
164
165#[derive(Debug, Clone, PartialEq)]
166/// Borrowed finite `f32` tensor values with owned shape and metadata.
167pub struct F32TensorView<'a> {
168    shape: TensorShape,
169    values: &'a [f32],
170    metadata: BTreeMap<String, Value>,
171}
172
173impl<'a> F32TensorView<'a> {
174    /// Creates a borrowed tensor view when shape and values are compatible.
175    pub fn new(shape: TensorShape, values: &'a [f32]) -> Result<Self> {
176        let view = Self {
177            shape,
178            values,
179            metadata: BTreeMap::new(),
180        };
181        view.validate()?;
182        Ok(view)
183    }
184
185    /// Creates a borrowed tensor view from raw dimensions and values.
186    pub fn from_dims(dims: impl Into<Vec<usize>>, values: &'a [f32]) -> Result<Self> {
187        Self::new(TensorShape::new(dims)?, values)
188    }
189
190    /// Borrows the checked tensor shape.
191    pub fn shape(&self) -> &TensorShape {
192        &self.shape
193    }
194
195    /// Borrows the underlying contiguous value slice.
196    pub fn values(&self) -> &'a [f32] {
197        self.values
198    }
199
200    /// Borrows optional transport metadata attached to the view.
201    pub fn metadata(&self) -> &BTreeMap<String, Value> {
202        &self.metadata
203    }
204
205    /// Attaches one metadata entry and returns the updated view.
206    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
207        self.metadata.insert(key.into(), value.into());
208        self
209    }
210
211    /// Changes only the shape metadata after verifying the element count is unchanged.
212    pub fn reshape(mut self, dims: impl Into<Vec<usize>>) -> Result<Self> {
213        self.shape = self.shape.reshape(dims)?;
214        Ok(self)
215    }
216
217    /// Copies the borrowed values into an owned tensor while preserving metadata.
218    pub fn into_owned(self) -> Result<F32Tensor> {
219        let mut tensor = F32Tensor::new(self.shape, self.values.to_vec())?;
220        tensor.metadata = self.metadata;
221        Ok(tensor)
222    }
223
224    /// Verifies shape/value count agreement and rejects non-finite values.
225    pub fn validate(&self) -> Result<()> {
226        let expected = self.shape.element_count()?;
227        if expected != self.values.len() {
228            return Err(invalid_argument(format!(
229                "tensor shape expects {expected} elements but tensor view has {}",
230                self.values.len()
231            )));
232        }
233        if self.values.iter().any(|value| !value.is_finite()) {
234            return Err(invalid_argument("tensor values must be finite"));
235        }
236        Ok(())
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn rejects_empty_or_zero_dimension_shapes() {
246        assert!(TensorShape::new(Vec::<usize>::new()).is_err());
247        assert!(TensorShape::new([1, 0, 2]).is_err());
248    }
249
250    #[test]
251    fn rejects_wrong_element_count() {
252        let error = F32Tensor::from_dims([2, 2], vec![0.0; 3]).unwrap_err();
253        assert!(matches!(error, DetectError::InvalidArgument(_)));
254    }
255
256    #[test]
257    fn rejects_non_finite_values() {
258        let error = F32Tensor::from_dims([1, 2], vec![0.0, f32::NAN]).unwrap_err();
259        assert!(matches!(error, DetectError::InvalidArgument(_)));
260    }
261
262    #[test]
263    fn reshapes_when_element_counts_match() {
264        let tensor = F32Tensor::from_dims([1, 4], vec![0.0; 4]).unwrap();
265        let reshaped = tensor.reshape([2, 2]).unwrap();
266        assert_eq!(reshaped.shape().dimensions(), &[2, 2]);
267    }
268
269    #[test]
270    fn view_round_trips_into_owned_tensor() {
271        let view = F32TensorView::from_dims([1, 1, 2], &[0.25, 0.75]).unwrap();
272        let owned = view.into_owned().unwrap();
273        assert_eq!(owned.values(), &[0.25, 0.75]);
274    }
275}