Skip to main content

hedl_core/lex/
tensor.rs

1// Dweve HEDL - Hierarchical Entity Data Language
2//
3// Copyright (c) 2025 Dweve IP B.V. and individual contributors.
4//
5// SPDX-License-Identifier: Apache-2.0
6//
7// Licensed under the Apache License, Version 2.0 (the "License");
8// you may not use this file except in compliance with the License.
9// You may obtain a copy of the License in the LICENSE file at the
10// root of this repository or at: http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! Tensor literal parsing for HEDL format.
19//!
20//! Tensors are multi-dimensional numerical arrays like `[1, 2, 3]` or `[[1, 2], [3, 4]]`.
21//!
22//! # Examples
23//!
24//! ```
25//! use hedl_core::lex::{parse_tensor, is_tensor_literal, Tensor};
26//!
27//! // Parse a 1D tensor
28//! let tensor = parse_tensor("[1, 2, 3]").unwrap();
29//! assert_eq!(tensor.shape(), vec![3]);
30//! assert_eq!(tensor.flatten(), vec![1.0, 2.0, 3.0]);
31//!
32//! // Parse a 2D tensor (matrix)
33//! let matrix = parse_tensor("[[1, 2], [3, 4]]").unwrap();
34//! assert_eq!(matrix.shape(), vec![2, 2]);
35//!
36//! // Quick validation
37//! assert!(is_tensor_literal("[1, 2, 3]"));
38//! assert!(!is_tensor_literal("not a tensor"));
39//! ```
40//!
41//! # Security
42//!
43//! This module includes multiple security protections:
44//! - Maximum recursion depth of 100 to prevent stack overflow
45//! - Maximum element count of 10 million to prevent memory exhaustion
46//! - Rejection of NaN and Infinity values for predictable behavior
47//! - Error message truncation to prevent DoS attacks
48
49use crate::lex::error::LexError;
50
51/// Maximum recursion depth for tensor parsing (prevents stack overflow).
52const MAX_RECURSION_DEPTH: usize = 100;
53
54/// Maximum number of elements in a tensor (prevents memory exhaustion).
55const MAX_TENSOR_ELEMENTS: usize = 10_000_000;
56
57/// A multi-dimensional numerical array.
58///
59/// Tensors can be scalars (single values) or arrays of nested tensors.
60/// All leaf values are stored as f64.
61///
62/// # Examples
63///
64/// ```
65/// use hedl_core::lex::{parse_tensor, Tensor};
66///
67/// // Scalar
68/// let scalar = Tensor::Scalar(42.0);
69/// assert_eq!(scalar.shape(), vec![]);
70/// assert_eq!(scalar.flatten(), vec![42.0]);
71///
72/// // 1D array
73/// let vec = parse_tensor("[1, 2, 3]").unwrap();
74/// assert_eq!(vec.shape(), vec![3]);
75/// assert!(vec.is_integer());
76///
77/// // 2D matrix
78/// let matrix = parse_tensor("[[1.5, 2.5], [3.5, 4.5]]").unwrap();
79/// assert_eq!(matrix.shape(), vec![2, 2]);
80/// assert!(!matrix.is_integer());
81/// ```
82#[derive(Debug, Clone, PartialEq)]
83#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
84pub enum Tensor {
85    /// A scalar number (integer or float).
86    Scalar(f64),
87    /// A nested array of tensors.
88    Array(Vec<Tensor>),
89}
90
91impl std::fmt::Display for Tensor {
92    /// Formats the tensor as a parseable HEDL tensor literal.
93    ///
94    /// This produces output that can be parsed back by `parse_tensor()`.
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        match self {
97            Tensor::Scalar(n) => {
98                if n.fract() == 0.0 && n.is_finite() {
99                    if *n >= i64::MIN as f64 && *n <= i64::MAX as f64 {
100                        write!(f, "{}", *n as i64)
101                    } else {
102                        write!(f, "{}", n)
103                    }
104                } else {
105                    write!(f, "{}", n)
106                }
107            }
108            Tensor::Array(items) => {
109                write!(f, "[")?;
110                for (i, item) in items.iter().enumerate() {
111                    if i > 0 {
112                        write!(f, ", ")?;
113                    }
114                    write!(f, "{}", item)?;
115                }
116                write!(f, "]")
117            }
118        }
119    }
120}
121
122impl Tensor {
123    /// Returns `true` if this tensor contains only integers (no decimal points).
124    ///
125    /// A number is considered an integer if its fractional part is zero.
126    ///
127    /// # Examples
128    ///
129    /// ```
130    /// use hedl_core::lex::{parse_tensor, Tensor};
131    ///
132    /// let integers = parse_tensor("[1, 2, 3]").unwrap();
133    /// assert!(integers.is_integer());
134    ///
135    /// let floats = parse_tensor("[1.5, 2.5]").unwrap();
136    /// assert!(!floats.is_integer());
137    /// ```
138    pub fn is_integer(&self) -> bool {
139        match self {
140            Tensor::Scalar(n) => n.fract() == 0.0,
141            Tensor::Array(items) => items.iter().all(|t| t.is_integer()),
142        }
143    }
144
145    /// Returns the shape of the tensor as a vector of dimensions.
146    ///
147    /// # Examples
148    ///
149    /// ```
150    /// use hedl_core::lex::{parse_tensor, Tensor};
151    ///
152    /// let scalar = Tensor::Scalar(42.0);
153    /// assert_eq!(scalar.shape(), vec![]);
154    ///
155    /// let vec = parse_tensor("[1, 2, 3]").unwrap();
156    /// assert_eq!(vec.shape(), vec![3]);
157    ///
158    /// let matrix = parse_tensor("[[1, 2], [3, 4]]").unwrap();
159    /// assert_eq!(matrix.shape(), vec![2, 2]);
160    /// ```
161    pub fn shape(&self) -> Vec<usize> {
162        match self {
163            Tensor::Scalar(_) => vec![],
164            Tensor::Array(items) => {
165                if items.is_empty() {
166                    vec![0]
167                } else {
168                    let mut shape = vec![items.len()];
169                    shape.extend(items[0].shape());
170                    shape
171                }
172            }
173        }
174    }
175
176    /// Flattens the tensor into a 1D vector of f64 values in row-major order.
177    ///
178    /// # Examples
179    ///
180    /// ```
181    /// use hedl_core::lex::parse_tensor;
182    ///
183    /// let matrix = parse_tensor("[[1, 2], [3, 4]]").unwrap();
184    /// assert_eq!(matrix.flatten(), vec![1.0, 2.0, 3.0, 4.0]);
185    /// ```
186    pub fn flatten(&self) -> Vec<f64> {
187        let capacity = self.count_elements();
188        let mut result = Vec::with_capacity(capacity);
189        self.flatten_into(&mut result);
190        result
191    }
192
193    /// Counts the total number of scalar elements.
194    fn count_elements(&self) -> usize {
195        match self {
196            Tensor::Scalar(_) => 1,
197            Tensor::Array(items) => items.iter().map(|t| t.count_elements()).sum(),
198        }
199    }
200
201    /// Flattens into a pre-allocated vector.
202    fn flatten_into(&self, result: &mut Vec<f64>) {
203        match self {
204            Tensor::Scalar(n) => result.push(*n),
205            Tensor::Array(items) => {
206                for item in items {
207                    item.flatten_into(result);
208                }
209            }
210        }
211    }
212
213    /// Returns `true` if this is a scalar value.
214    #[inline]
215    pub fn is_scalar(&self) -> bool {
216        matches!(self, Tensor::Scalar(_))
217    }
218
219    /// Returns `true` if this is an array.
220    #[inline]
221    pub fn is_array(&self) -> bool {
222        matches!(self, Tensor::Array(_))
223    }
224
225    /// Returns the number of dimensions (0 for scalar).
226    #[inline]
227    pub fn ndim(&self) -> usize {
228        self.shape().len()
229    }
230
231    /// Returns the total number of elements.
232    #[inline]
233    pub fn len(&self) -> usize {
234        self.count_elements()
235    }
236
237    /// Returns `true` if the tensor has no elements.
238    #[inline]
239    pub fn is_empty(&self) -> bool {
240        match self {
241            Tensor::Scalar(_) => false,
242            Tensor::Array(items) => items.is_empty(),
243        }
244    }
245}
246
247/// Checks if a string looks like it could be a tensor literal.
248///
249/// This is a quick check that doesn't fully validate - use `parse_tensor` for that.
250///
251/// # Examples
252///
253/// ```
254/// use hedl_core::lex::is_tensor_literal;
255///
256/// assert!(is_tensor_literal("[1, 2, 3]"));
257/// assert!(is_tensor_literal("[[1, 2], [3, 4]]"));
258/// assert!(!is_tensor_literal("hello"));
259/// assert!(!is_tensor_literal("@reference"));
260/// ```
261#[inline]
262pub fn is_tensor_literal(s: &str) -> bool {
263    let s = s.trim();
264    let bytes = s.as_bytes();
265    if bytes.first() != Some(&b'[') || bytes.last() != Some(&b']') {
266        return false;
267    }
268
269    let mut depth: i32 = 0;
270    for &b in bytes {
271        match b {
272            b'[' => depth += 1,
273            b']' => {
274                depth -= 1;
275                if depth < 0 {
276                    return false;
277                }
278            }
279            b'0'..=b'9' | b'.' | b'-' | b',' | b' ' | b'\t' => {}
280            _ => return false,
281        }
282    }
283
284    depth == 0
285}
286
287/// Parses a tensor literal string into a `Tensor` structure.
288///
289/// # Examples
290///
291/// ```
292/// use hedl_core::lex::parse_tensor;
293///
294/// // Parse 1D tensor
295/// let t = parse_tensor("[1, 2, 3]").unwrap();
296/// assert_eq!(t.shape(), vec![3]);
297///
298/// // Parse 2D tensor
299/// let t = parse_tensor("[[1, 2], [3, 4]]").unwrap();
300/// assert_eq!(t.shape(), vec![2, 2]);
301///
302/// // Parse with floats
303/// let t = parse_tensor("[1.5, 2.5]").unwrap();
304/// assert!(!t.is_integer());
305/// ```
306///
307/// # Errors
308///
309/// Returns error for:
310/// - Unbalanced brackets
311/// - Empty tensor
312/// - Invalid numbers
313/// - Inconsistent dimensions
314/// - Exceeding security limits
315pub fn parse_tensor(s: &str) -> Result<Tensor, LexError> {
316    let s = s.trim();
317    if !s.starts_with('[') {
318        return Err(LexError::UnexpectedChar(s.chars().next().unwrap_or(' ')));
319    }
320
321    let (tensor, remaining) = parse_tensor_inner(s, 0)?;
322
323    if !remaining.trim().is_empty() {
324        return Err(LexError::UnexpectedChar(
325            remaining.trim().chars().next().unwrap_or('?'),
326        ));
327    }
328
329    // Validate total element count
330    let element_count = tensor.count_elements();
331    if element_count > MAX_TENSOR_ELEMENTS {
332        return Err(LexError::InvalidStructure(format!(
333            "Tensor element count exceeds maximum: {} (max: {})",
334            element_count, MAX_TENSOR_ELEMENTS
335        )));
336    }
337
338    Ok(tensor)
339}
340
341/// Estimates array size by counting commas.
342fn estimate_array_size(s: &str) -> usize {
343    let mut depth = 0;
344    let mut comma_count = 0;
345
346    for ch in s.chars() {
347        match ch {
348            '[' => depth += 1,
349            ']' => {
350                if depth == 0 {
351                    break;
352                }
353                depth -= 1;
354            }
355            ',' if depth == 0 => comma_count += 1,
356            _ => {}
357        }
358    }
359
360    if comma_count > 0 {
361        comma_count + 1
362    } else {
363        2
364    }
365}
366
367fn parse_tensor_inner(s: &str, depth: usize) -> Result<(Tensor, &str), LexError> {
368    if depth > MAX_RECURSION_DEPTH {
369        return Err(LexError::InvalidStructure(format!(
370            "Recursion depth exceeded (max: {})",
371            MAX_RECURSION_DEPTH
372        )));
373    }
374    let s = s.trim();
375
376    if let Some(remaining_str) = s.strip_prefix('[') {
377        let mut remaining = remaining_str;
378        let estimated_capacity = estimate_array_size(remaining_str);
379        let mut items = Vec::with_capacity(estimated_capacity);
380
381        loop {
382            remaining = remaining.trim_start();
383
384            if remaining.is_empty() {
385                return Err(LexError::UnbalancedBrackets);
386            }
387
388            if remaining.starts_with(']') {
389                remaining = &remaining[1..];
390                break;
391            }
392
393            if !items.is_empty() {
394                if !remaining.starts_with(',') {
395                    return Err(LexError::UnexpectedChar(
396                        remaining.chars().next().unwrap_or('?'),
397                    ));
398                }
399                remaining = remaining[1..].trim_start();
400            }
401
402            if remaining.starts_with(']') {
403                remaining = &remaining[1..];
404                break;
405            }
406
407            if remaining.starts_with('[') {
408                let (tensor, rest) = parse_tensor_inner(remaining, depth + 1)?;
409                items.push(tensor);
410                remaining = rest;
411            } else {
412                let (num, rest) = parse_number(remaining)?;
413                items.push(Tensor::Scalar(num));
414                remaining = rest;
415            }
416        }
417
418        if items.is_empty() {
419            return Err(LexError::EmptyTensor);
420        }
421
422        // Validate consistent dimensions
423        if items.len() > 1 {
424            let first_shape = items[0].shape();
425            for item in &items[1..] {
426                if item.shape() != first_shape {
427                    return Err(LexError::InconsistentDimensions);
428                }
429            }
430        }
431
432        Ok((Tensor::Array(items), remaining))
433    } else {
434        let (num, rest) = parse_number(s)?;
435        Ok((Tensor::Scalar(num), rest))
436    }
437}
438
439fn parse_number(s: &str) -> Result<(f64, &str), LexError> {
440    let s = s.trim_start();
441    let bytes = s.as_bytes();
442
443    let mut end = 0;
444    let mut has_dot = false;
445
446    if bytes.first() == Some(&b'-') {
447        end = 1;
448    }
449
450    while end < bytes.len() {
451        match bytes[end] {
452            b'0'..=b'9' => end += 1,
453            b'.' if !has_dot => {
454                has_dot = true;
455                end += 1;
456            }
457            _ => break,
458        }
459    }
460
461    if end == 0 || (end == 1 && bytes[0] == b'-') {
462        return Err(LexError::InvalidNumber(
463            s.chars().take(10).collect::<String>(),
464        ));
465    }
466
467    let num_str = &s[..end];
468    let num: f64 = num_str.parse().map_err(|_| {
469        let context = if num_str.len() > 80 {
470            format!("{}...", &num_str[..80])
471        } else {
472            num_str.to_string()
473        };
474        LexError::InvalidNumber(context)
475    })?;
476
477    if !num.is_finite() {
478        return Err(LexError::InvalidNumber(format!(
479            "{} (non-finite values not allowed)",
480            num_str
481        )));
482    }
483
484    Ok((num, &s[end..]))
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490
491    // ==================== is_tensor_literal tests ====================
492
493    #[test]
494    fn test_is_tensor_literal_valid() {
495        assert!(is_tensor_literal("[1, 2, 3]"));
496        assert!(is_tensor_literal("[[1, 2], [3, 4]]"));
497        assert!(is_tensor_literal("[1.5, 2.5]"));
498        assert!(is_tensor_literal("[-1, -2]"));
499        assert!(is_tensor_literal("  [1, 2, 3]  "));
500    }
501
502    #[test]
503    fn test_is_tensor_literal_invalid() {
504        assert!(!is_tensor_literal("hello"));
505        assert!(!is_tensor_literal("@reference"));
506        assert!(!is_tensor_literal("123"));
507        assert!(!is_tensor_literal(""));
508        assert!(!is_tensor_literal("[1, 2"));
509        assert!(!is_tensor_literal("[a, b]"));
510    }
511
512    // ==================== parse_tensor tests ====================
513
514    #[test]
515    fn test_parse_1d() {
516        let t = parse_tensor("[1, 2, 3]").unwrap();
517        assert_eq!(t.shape(), vec![3]);
518        assert_eq!(t.flatten(), vec![1.0, 2.0, 3.0]);
519    }
520
521    #[test]
522    fn test_parse_2d() {
523        let t = parse_tensor("[[1, 2], [3, 4]]").unwrap();
524        assert_eq!(t.shape(), vec![2, 2]);
525        assert_eq!(t.flatten(), vec![1.0, 2.0, 3.0, 4.0]);
526    }
527
528    #[test]
529    fn test_parse_floats() {
530        let t = parse_tensor("[1.5, 2.5, 3.5]").unwrap();
531        assert_eq!(t.flatten(), vec![1.5, 2.5, 3.5]);
532        assert!(!t.is_integer());
533    }
534
535    #[test]
536    fn test_parse_negatives() {
537        let t = parse_tensor("[-1, -2, -3]").unwrap();
538        assert_eq!(t.flatten(), vec![-1.0, -2.0, -3.0]);
539    }
540
541    #[test]
542    fn test_parse_trailing_comma() {
543        let t = parse_tensor("[1, 2, 3,]").unwrap();
544        assert_eq!(t.flatten(), vec![1.0, 2.0, 3.0]);
545    }
546
547    // ==================== Error tests ====================
548
549    #[test]
550    fn test_empty_tensor_error() {
551        assert!(matches!(parse_tensor("[]"), Err(LexError::EmptyTensor)));
552    }
553
554    #[test]
555    fn test_unbalanced_brackets_error() {
556        assert!(matches!(
557            parse_tensor("[1, 2"),
558            Err(LexError::UnbalancedBrackets)
559        ));
560    }
561
562    #[test]
563    fn test_inconsistent_dimensions_error() {
564        assert!(matches!(
565            parse_tensor("[[1, 2], [3]]"),
566            Err(LexError::InconsistentDimensions)
567        ));
568    }
569
570    #[test]
571    fn test_invalid_number_error() {
572        assert!(matches!(
573            parse_tensor("[abc]"),
574            Err(LexError::InvalidNumber(_))
575        ));
576    }
577
578    // ==================== Tensor struct tests ====================
579
580    #[test]
581    fn test_tensor_display() {
582        let t = parse_tensor("[1, 2, 3]").unwrap();
583        assert_eq!(format!("{}", t), "[1, 2, 3]");
584
585        let t = parse_tensor("[[1, 2], [3, 4]]").unwrap();
586        assert_eq!(format!("{}", t), "[[1, 2], [3, 4]]");
587    }
588
589    #[test]
590    fn test_tensor_methods() {
591        let scalar = Tensor::Scalar(42.0);
592        assert!(scalar.is_scalar());
593        assert!(!scalar.is_array());
594        assert_eq!(scalar.ndim(), 0);
595        assert_eq!(scalar.len(), 1);
596        assert!(!scalar.is_empty());
597
598        let array = parse_tensor("[1, 2, 3]").unwrap();
599        assert!(!array.is_scalar());
600        assert!(array.is_array());
601        assert_eq!(array.ndim(), 1);
602        assert_eq!(array.len(), 3);
603        assert!(!array.is_empty());
604    }
605
606    #[test]
607    fn test_tensor_equality() {
608        let t1 = parse_tensor("[1, 2, 3]").unwrap();
609        let t2 = parse_tensor("[1, 2, 3]").unwrap();
610        assert_eq!(t1, t2);
611
612        let t3 = parse_tensor("[1, 2, 4]").unwrap();
613        assert_ne!(t1, t3);
614    }
615
616    #[test]
617    fn test_tensor_clone() {
618        let t1 = parse_tensor("[[1, 2], [3, 4]]").unwrap();
619        let t2 = t1.clone();
620        assert_eq!(t1, t2);
621    }
622
623    // ==================== Round-trip tests ====================
624
625    #[test]
626    fn test_display_roundtrip() {
627        let original = parse_tensor("[1, 2, 3]").unwrap();
628        let serialized = original.to_string();
629        let parsed = parse_tensor(&serialized).unwrap();
630        assert_eq!(original, parsed);
631
632        let original = parse_tensor("[[1.5, 2.5], [3.5, 4.5]]").unwrap();
633        let serialized = original.to_string();
634        let parsed = parse_tensor(&serialized).unwrap();
635        assert_eq!(original, parsed);
636    }
637}