hedl_core/lex/
tensor.rs

1// Dweve HEDL - Hierarchical Entity Data Language
2//
3// Copyright (c) 2025 Dweve IP B.V. and individual contributors.
4//
5// SPDX-License-Identifier: Apache-2.0
6//
7// Licensed under the Apache License, Version 2.0 (the "License");
8// you may not use this file except in compliance with the License.
9// You may obtain a copy of the License in the LICENSE file at the
10// root of this repository or at: http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! Tensor literal parsing for HEDL format.
19//!
20//! Tensors are multi-dimensional numerical arrays like `[1, 2, 3]` or `[[1, 2], [3, 4]]`.
21//!
22//! # Examples
23//!
24//! ```
25//! use hedl_core::lex::{parse_tensor, is_tensor_literal, Tensor};
26//!
27//! // Parse a 1D tensor
28//! let tensor = parse_tensor("[1, 2, 3]").unwrap();
29//! assert_eq!(tensor.shape(), vec![3]);
30//! assert_eq!(tensor.flatten(), vec![1.0, 2.0, 3.0]);
31//!
32//! // Parse a 2D tensor (matrix)
33//! let matrix = parse_tensor("[[1, 2], [3, 4]]").unwrap();
34//! assert_eq!(matrix.shape(), vec![2, 2]);
35//!
36//! // Quick validation
37//! assert!(is_tensor_literal("[1, 2, 3]"));
38//! assert!(!is_tensor_literal("not a tensor"));
39//! ```
40//!
41//! # Security
42//!
43//! This module includes multiple security protections:
44//! - Maximum recursion depth of 100 to prevent stack overflow
45//! - Maximum element count of 10 million to prevent memory exhaustion
46//! - Rejection of NaN and Infinity values for predictable behavior
47//! - Error message truncation to prevent DoS attacks
48
49use crate::lex::error::LexError;
50
51/// Maximum recursion depth for tensor parsing (prevents stack overflow).
52const MAX_RECURSION_DEPTH: usize = 100;
53
54/// Maximum number of elements in a tensor (prevents memory exhaustion).
55const MAX_TENSOR_ELEMENTS: usize = 10_000_000;
56
57/// A multi-dimensional numerical array.
58///
59/// Tensors can be scalars (single values) or arrays of nested tensors.
60/// All leaf values are stored as f64.
61///
62/// # Examples
63///
64/// ```
65/// use hedl_core::lex::{parse_tensor, Tensor};
66///
67/// // Scalar
68/// let scalar = Tensor::Scalar(42.0);
69/// assert_eq!(scalar.shape(), vec![]);
70/// assert_eq!(scalar.flatten(), vec![42.0]);
71///
72/// // 1D array
73/// let vec = parse_tensor("[1, 2, 3]").unwrap();
74/// assert_eq!(vec.shape(), vec![3]);
75/// assert!(vec.is_integer());
76///
77/// // 2D matrix
78/// let matrix = parse_tensor("[[1.5, 2.5], [3.5, 4.5]]").unwrap();
79/// assert_eq!(matrix.shape(), vec![2, 2]);
80/// assert!(!matrix.is_integer());
81/// ```
82#[derive(Debug, Clone, PartialEq)]
83pub enum Tensor {
84    /// A scalar number (integer or float).
85    Scalar(f64),
86    /// A nested array of tensors.
87    Array(Vec<Tensor>),
88}
89
90impl std::fmt::Display for Tensor {
91    /// Formats the tensor as a parseable HEDL tensor literal.
92    ///
93    /// This produces output that can be parsed back by `parse_tensor()`.
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        match self {
96            Tensor::Scalar(n) => {
97                if n.fract() == 0.0 && n.is_finite() {
98                    if *n >= i64::MIN as f64 && *n <= i64::MAX as f64 {
99                        write!(f, "{}", *n as i64)
100                    } else {
101                        write!(f, "{}", n)
102                    }
103                } else {
104                    write!(f, "{}", n)
105                }
106            }
107            Tensor::Array(items) => {
108                write!(f, "[")?;
109                for (i, item) in items.iter().enumerate() {
110                    if i > 0 {
111                        write!(f, ", ")?;
112                    }
113                    write!(f, "{}", item)?;
114                }
115                write!(f, "]")
116            }
117        }
118    }
119}
120
121impl Tensor {
122    /// Returns `true` if this tensor contains only integers (no decimal points).
123    ///
124    /// A number is considered an integer if its fractional part is zero.
125    ///
126    /// # Examples
127    ///
128    /// ```
129    /// use hedl_core::lex::{parse_tensor, Tensor};
130    ///
131    /// let integers = parse_tensor("[1, 2, 3]").unwrap();
132    /// assert!(integers.is_integer());
133    ///
134    /// let floats = parse_tensor("[1.5, 2.5]").unwrap();
135    /// assert!(!floats.is_integer());
136    /// ```
137    pub fn is_integer(&self) -> bool {
138        match self {
139            Tensor::Scalar(n) => n.fract() == 0.0,
140            Tensor::Array(items) => items.iter().all(|t| t.is_integer()),
141        }
142    }
143
144    /// Returns the shape of the tensor as a vector of dimensions.
145    ///
146    /// # Examples
147    ///
148    /// ```
149    /// use hedl_core::lex::{parse_tensor, Tensor};
150    ///
151    /// let scalar = Tensor::Scalar(42.0);
152    /// assert_eq!(scalar.shape(), vec![]);
153    ///
154    /// let vec = parse_tensor("[1, 2, 3]").unwrap();
155    /// assert_eq!(vec.shape(), vec![3]);
156    ///
157    /// let matrix = parse_tensor("[[1, 2], [3, 4]]").unwrap();
158    /// assert_eq!(matrix.shape(), vec![2, 2]);
159    /// ```
160    pub fn shape(&self) -> Vec<usize> {
161        match self {
162            Tensor::Scalar(_) => vec![],
163            Tensor::Array(items) => {
164                if items.is_empty() {
165                    vec![0]
166                } else {
167                    let mut shape = vec![items.len()];
168                    shape.extend(items[0].shape());
169                    shape
170                }
171            }
172        }
173    }
174
175    /// Flattens the tensor into a 1D vector of f64 values in row-major order.
176    ///
177    /// # Examples
178    ///
179    /// ```
180    /// use hedl_core::lex::parse_tensor;
181    ///
182    /// let matrix = parse_tensor("[[1, 2], [3, 4]]").unwrap();
183    /// assert_eq!(matrix.flatten(), vec![1.0, 2.0, 3.0, 4.0]);
184    /// ```
185    pub fn flatten(&self) -> Vec<f64> {
186        let capacity = self.count_elements();
187        let mut result = Vec::with_capacity(capacity);
188        self.flatten_into(&mut result);
189        result
190    }
191
192    /// Counts the total number of scalar elements.
193    fn count_elements(&self) -> usize {
194        match self {
195            Tensor::Scalar(_) => 1,
196            Tensor::Array(items) => items.iter().map(|t| t.count_elements()).sum(),
197        }
198    }
199
200    /// Flattens into a pre-allocated vector.
201    fn flatten_into(&self, result: &mut Vec<f64>) {
202        match self {
203            Tensor::Scalar(n) => result.push(*n),
204            Tensor::Array(items) => {
205                for item in items {
206                    item.flatten_into(result);
207                }
208            }
209        }
210    }
211
212    /// Returns `true` if this is a scalar value.
213    #[inline]
214    pub fn is_scalar(&self) -> bool {
215        matches!(self, Tensor::Scalar(_))
216    }
217
218    /// Returns `true` if this is an array.
219    #[inline]
220    pub fn is_array(&self) -> bool {
221        matches!(self, Tensor::Array(_))
222    }
223
224    /// Returns the number of dimensions (0 for scalar).
225    #[inline]
226    pub fn ndim(&self) -> usize {
227        self.shape().len()
228    }
229
230    /// Returns the total number of elements.
231    #[inline]
232    pub fn len(&self) -> usize {
233        self.count_elements()
234    }
235
236    /// Returns `true` if the tensor has no elements.
237    #[inline]
238    pub fn is_empty(&self) -> bool {
239        match self {
240            Tensor::Scalar(_) => false,
241            Tensor::Array(items) => items.is_empty(),
242        }
243    }
244}
245
246/// Checks if a string looks like it could be a tensor literal.
247///
248/// This is a quick check that doesn't fully validate - use `parse_tensor` for that.
249///
250/// # Examples
251///
252/// ```
253/// use hedl_core::lex::is_tensor_literal;
254///
255/// assert!(is_tensor_literal("[1, 2, 3]"));
256/// assert!(is_tensor_literal("[[1, 2], [3, 4]]"));
257/// assert!(!is_tensor_literal("hello"));
258/// assert!(!is_tensor_literal("@reference"));
259/// ```
260#[inline]
261pub fn is_tensor_literal(s: &str) -> bool {
262    let s = s.trim();
263    let bytes = s.as_bytes();
264    if bytes.first() != Some(&b'[') || bytes.last() != Some(&b']') {
265        return false;
266    }
267
268    let mut depth: i32 = 0;
269    for &b in bytes {
270        match b {
271            b'[' => depth += 1,
272            b']' => {
273                depth -= 1;
274                if depth < 0 {
275                    return false;
276                }
277            }
278            b'0'..=b'9' | b'.' | b'-' | b',' | b' ' | b'\t' => {}
279            _ => return false,
280        }
281    }
282
283    depth == 0
284}
285
286/// Parses a tensor literal string into a `Tensor` structure.
287///
288/// # Examples
289///
290/// ```
291/// use hedl_core::lex::parse_tensor;
292///
293/// // Parse 1D tensor
294/// let t = parse_tensor("[1, 2, 3]").unwrap();
295/// assert_eq!(t.shape(), vec![3]);
296///
297/// // Parse 2D tensor
298/// let t = parse_tensor("[[1, 2], [3, 4]]").unwrap();
299/// assert_eq!(t.shape(), vec![2, 2]);
300///
301/// // Parse with floats
302/// let t = parse_tensor("[1.5, 2.5]").unwrap();
303/// assert!(!t.is_integer());
304/// ```
305///
306/// # Errors
307///
308/// Returns error for:
309/// - Unbalanced brackets
310/// - Empty tensor
311/// - Invalid numbers
312/// - Inconsistent dimensions
313/// - Exceeding security limits
314pub fn parse_tensor(s: &str) -> Result<Tensor, LexError> {
315    let s = s.trim();
316    if !s.starts_with('[') {
317        return Err(LexError::UnexpectedChar(s.chars().next().unwrap_or(' ')));
318    }
319
320    let (tensor, remaining) = parse_tensor_inner(s, 0)?;
321
322    if !remaining.trim().is_empty() {
323        return Err(LexError::UnexpectedChar(
324            remaining.trim().chars().next().unwrap_or('?'),
325        ));
326    }
327
328    // Validate total element count
329    let element_count = tensor.count_elements();
330    if element_count > MAX_TENSOR_ELEMENTS {
331        return Err(LexError::InvalidStructure(format!(
332            "Tensor element count exceeds maximum: {} (max: {})",
333            element_count, MAX_TENSOR_ELEMENTS
334        )));
335    }
336
337    Ok(tensor)
338}
339
340/// Estimates array size by counting commas.
341fn estimate_array_size(s: &str) -> usize {
342    let mut depth = 0;
343    let mut comma_count = 0;
344
345    for ch in s.chars() {
346        match ch {
347            '[' => depth += 1,
348            ']' => {
349                if depth == 0 {
350                    break;
351                }
352                depth -= 1;
353            }
354            ',' if depth == 0 => comma_count += 1,
355            _ => {}
356        }
357    }
358
359    if comma_count > 0 {
360        comma_count + 1
361    } else {
362        2
363    }
364}
365
366fn parse_tensor_inner(s: &str, depth: usize) -> Result<(Tensor, &str), LexError> {
367    if depth > MAX_RECURSION_DEPTH {
368        return Err(LexError::InvalidStructure(format!(
369            "Recursion depth exceeded (max: {})",
370            MAX_RECURSION_DEPTH
371        )));
372    }
373    let s = s.trim();
374
375    if let Some(remaining_str) = s.strip_prefix('[') {
376        let mut remaining = remaining_str;
377        let estimated_capacity = estimate_array_size(remaining_str);
378        let mut items = Vec::with_capacity(estimated_capacity);
379
380        loop {
381            remaining = remaining.trim_start();
382
383            if remaining.is_empty() {
384                return Err(LexError::UnbalancedBrackets);
385            }
386
387            if remaining.starts_with(']') {
388                remaining = &remaining[1..];
389                break;
390            }
391
392            if !items.is_empty() {
393                if !remaining.starts_with(',') {
394                    return Err(LexError::UnexpectedChar(
395                        remaining.chars().next().unwrap_or('?'),
396                    ));
397                }
398                remaining = remaining[1..].trim_start();
399            }
400
401            if remaining.starts_with(']') {
402                remaining = &remaining[1..];
403                break;
404            }
405
406            if remaining.starts_with('[') {
407                let (tensor, rest) = parse_tensor_inner(remaining, depth + 1)?;
408                items.push(tensor);
409                remaining = rest;
410            } else {
411                let (num, rest) = parse_number(remaining)?;
412                items.push(Tensor::Scalar(num));
413                remaining = rest;
414            }
415        }
416
417        if items.is_empty() {
418            return Err(LexError::EmptyTensor);
419        }
420
421        // Validate consistent dimensions
422        if items.len() > 1 {
423            let first_shape = items[0].shape();
424            for item in &items[1..] {
425                if item.shape() != first_shape {
426                    return Err(LexError::InconsistentDimensions);
427                }
428            }
429        }
430
431        Ok((Tensor::Array(items), remaining))
432    } else {
433        let (num, rest) = parse_number(s)?;
434        Ok((Tensor::Scalar(num), rest))
435    }
436}
437
438fn parse_number(s: &str) -> Result<(f64, &str), LexError> {
439    let s = s.trim_start();
440    let bytes = s.as_bytes();
441
442    let mut end = 0;
443    let mut has_dot = false;
444
445    if bytes.first() == Some(&b'-') {
446        end = 1;
447    }
448
449    while end < bytes.len() {
450        match bytes[end] {
451            b'0'..=b'9' => end += 1,
452            b'.' if !has_dot => {
453                has_dot = true;
454                end += 1;
455            }
456            _ => break,
457        }
458    }
459
460    if end == 0 || (end == 1 && bytes[0] == b'-') {
461        return Err(LexError::InvalidNumber(
462            s.chars().take(10).collect::<String>(),
463        ));
464    }
465
466    let num_str = &s[..end];
467    let num: f64 = num_str.parse().map_err(|_| {
468        let context = if num_str.len() > 80 {
469            format!("{}...", &num_str[..80])
470        } else {
471            num_str.to_string()
472        };
473        LexError::InvalidNumber(context)
474    })?;
475
476    if !num.is_finite() {
477        return Err(LexError::InvalidNumber(format!(
478            "{} (non-finite values not allowed)",
479            num_str
480        )));
481    }
482
483    Ok((num, &s[end..]))
484}
485
486#[cfg(test)]
487mod tests {
488    use super::*;
489
490    // ==================== is_tensor_literal tests ====================
491
492    #[test]
493    fn test_is_tensor_literal_valid() {
494        assert!(is_tensor_literal("[1, 2, 3]"));
495        assert!(is_tensor_literal("[[1, 2], [3, 4]]"));
496        assert!(is_tensor_literal("[1.5, 2.5]"));
497        assert!(is_tensor_literal("[-1, -2]"));
498        assert!(is_tensor_literal("  [1, 2, 3]  "));
499    }
500
501    #[test]
502    fn test_is_tensor_literal_invalid() {
503        assert!(!is_tensor_literal("hello"));
504        assert!(!is_tensor_literal("@reference"));
505        assert!(!is_tensor_literal("123"));
506        assert!(!is_tensor_literal(""));
507        assert!(!is_tensor_literal("[1, 2"));
508        assert!(!is_tensor_literal("[a, b]"));
509    }
510
511    // ==================== parse_tensor tests ====================
512
513    #[test]
514    fn test_parse_1d() {
515        let t = parse_tensor("[1, 2, 3]").unwrap();
516        assert_eq!(t.shape(), vec![3]);
517        assert_eq!(t.flatten(), vec![1.0, 2.0, 3.0]);
518    }
519
520    #[test]
521    fn test_parse_2d() {
522        let t = parse_tensor("[[1, 2], [3, 4]]").unwrap();
523        assert_eq!(t.shape(), vec![2, 2]);
524        assert_eq!(t.flatten(), vec![1.0, 2.0, 3.0, 4.0]);
525    }
526
527    #[test]
528    fn test_parse_floats() {
529        let t = parse_tensor("[1.5, 2.5, 3.5]").unwrap();
530        assert_eq!(t.flatten(), vec![1.5, 2.5, 3.5]);
531        assert!(!t.is_integer());
532    }
533
534    #[test]
535    fn test_parse_negatives() {
536        let t = parse_tensor("[-1, -2, -3]").unwrap();
537        assert_eq!(t.flatten(), vec![-1.0, -2.0, -3.0]);
538    }
539
540    #[test]
541    fn test_parse_trailing_comma() {
542        let t = parse_tensor("[1, 2, 3,]").unwrap();
543        assert_eq!(t.flatten(), vec![1.0, 2.0, 3.0]);
544    }
545
546    // ==================== Error tests ====================
547
548    #[test]
549    fn test_empty_tensor_error() {
550        assert!(matches!(parse_tensor("[]"), Err(LexError::EmptyTensor)));
551    }
552
553    #[test]
554    fn test_unbalanced_brackets_error() {
555        assert!(matches!(
556            parse_tensor("[1, 2"),
557            Err(LexError::UnbalancedBrackets)
558        ));
559    }
560
561    #[test]
562    fn test_inconsistent_dimensions_error() {
563        assert!(matches!(
564            parse_tensor("[[1, 2], [3]]"),
565            Err(LexError::InconsistentDimensions)
566        ));
567    }
568
569    #[test]
570    fn test_invalid_number_error() {
571        assert!(matches!(
572            parse_tensor("[abc]"),
573            Err(LexError::InvalidNumber(_))
574        ));
575    }
576
577    // ==================== Tensor struct tests ====================
578
579    #[test]
580    fn test_tensor_display() {
581        let t = parse_tensor("[1, 2, 3]").unwrap();
582        assert_eq!(format!("{}", t), "[1, 2, 3]");
583
584        let t = parse_tensor("[[1, 2], [3, 4]]").unwrap();
585        assert_eq!(format!("{}", t), "[[1, 2], [3, 4]]");
586    }
587
588    #[test]
589    fn test_tensor_methods() {
590        let scalar = Tensor::Scalar(42.0);
591        assert!(scalar.is_scalar());
592        assert!(!scalar.is_array());
593        assert_eq!(scalar.ndim(), 0);
594        assert_eq!(scalar.len(), 1);
595        assert!(!scalar.is_empty());
596
597        let array = parse_tensor("[1, 2, 3]").unwrap();
598        assert!(!array.is_scalar());
599        assert!(array.is_array());
600        assert_eq!(array.ndim(), 1);
601        assert_eq!(array.len(), 3);
602        assert!(!array.is_empty());
603    }
604
605    #[test]
606    fn test_tensor_equality() {
607        let t1 = parse_tensor("[1, 2, 3]").unwrap();
608        let t2 = parse_tensor("[1, 2, 3]").unwrap();
609        assert_eq!(t1, t2);
610
611        let t3 = parse_tensor("[1, 2, 4]").unwrap();
612        assert_ne!(t1, t3);
613    }
614
615    #[test]
616    fn test_tensor_clone() {
617        let t1 = parse_tensor("[[1, 2], [3, 4]]").unwrap();
618        let t2 = t1.clone();
619        assert_eq!(t1, t2);
620    }
621
622    // ==================== Round-trip tests ====================
623
624    #[test]
625    fn test_display_roundtrip() {
626        let original = parse_tensor("[1, 2, 3]").unwrap();
627        let serialized = original.to_string();
628        let parsed = parse_tensor(&serialized).unwrap();
629        assert_eq!(original, parsed);
630
631        let original = parse_tensor("[[1.5, 2.5], [3.5, 4.5]]").unwrap();
632        let serialized = original.to_string();
633        let parsed = parse_tensor(&serialized).unwrap();
634        assert_eq!(original, parsed);
635    }
636}