1use crate::lex::error::LexError;
50
51const MAX_RECURSION_DEPTH: usize = 100;
53
54const MAX_TENSOR_ELEMENTS: usize = 10_000_000;
56
57#[derive(Debug, Clone, PartialEq)]
83#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
84pub enum Tensor {
85 Scalar(f64),
87 Array(Vec<Tensor>),
89}
90
91impl std::fmt::Display for Tensor {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 match self {
97 Tensor::Scalar(n) => {
98 if n.fract() == 0.0 && n.is_finite() {
99 if *n >= i64::MIN as f64 && *n <= i64::MAX as f64 {
100 write!(f, "{}", *n as i64)
101 } else {
102 write!(f, "{}", n)
103 }
104 } else {
105 write!(f, "{}", n)
106 }
107 }
108 Tensor::Array(items) => {
109 write!(f, "[")?;
110 for (i, item) in items.iter().enumerate() {
111 if i > 0 {
112 write!(f, ", ")?;
113 }
114 write!(f, "{}", item)?;
115 }
116 write!(f, "]")
117 }
118 }
119 }
120}
121
122impl Tensor {
123 pub fn is_integer(&self) -> bool {
139 match self {
140 Tensor::Scalar(n) => n.fract() == 0.0,
141 Tensor::Array(items) => items.iter().all(|t| t.is_integer()),
142 }
143 }
144
145 pub fn shape(&self) -> Vec<usize> {
162 match self {
163 Tensor::Scalar(_) => vec![],
164 Tensor::Array(items) => {
165 if items.is_empty() {
166 vec![0]
167 } else {
168 let mut shape = vec![items.len()];
169 shape.extend(items[0].shape());
170 shape
171 }
172 }
173 }
174 }
175
176 pub fn flatten(&self) -> Vec<f64> {
187 let capacity = self.count_elements();
188 let mut result = Vec::with_capacity(capacity);
189 self.flatten_into(&mut result);
190 result
191 }
192
193 fn count_elements(&self) -> usize {
195 match self {
196 Tensor::Scalar(_) => 1,
197 Tensor::Array(items) => items.iter().map(|t| t.count_elements()).sum(),
198 }
199 }
200
201 fn flatten_into(&self, result: &mut Vec<f64>) {
203 match self {
204 Tensor::Scalar(n) => result.push(*n),
205 Tensor::Array(items) => {
206 for item in items {
207 item.flatten_into(result);
208 }
209 }
210 }
211 }
212
213 #[inline]
215 pub fn is_scalar(&self) -> bool {
216 matches!(self, Tensor::Scalar(_))
217 }
218
219 #[inline]
221 pub fn is_array(&self) -> bool {
222 matches!(self, Tensor::Array(_))
223 }
224
225 #[inline]
227 pub fn ndim(&self) -> usize {
228 self.shape().len()
229 }
230
231 #[inline]
233 pub fn len(&self) -> usize {
234 self.count_elements()
235 }
236
237 #[inline]
239 pub fn is_empty(&self) -> bool {
240 match self {
241 Tensor::Scalar(_) => false,
242 Tensor::Array(items) => items.is_empty(),
243 }
244 }
245}
246
247#[inline]
262pub fn is_tensor_literal(s: &str) -> bool {
263 let s = s.trim();
264 let bytes = s.as_bytes();
265 if bytes.first() != Some(&b'[') || bytes.last() != Some(&b']') {
266 return false;
267 }
268
269 let mut depth: i32 = 0;
270 for &b in bytes {
271 match b {
272 b'[' => depth += 1,
273 b']' => {
274 depth -= 1;
275 if depth < 0 {
276 return false;
277 }
278 }
279 b'0'..=b'9' | b'.' | b'-' | b',' | b' ' | b'\t' => {}
280 _ => return false,
281 }
282 }
283
284 depth == 0
285}
286
287pub fn parse_tensor(s: &str) -> Result<Tensor, LexError> {
316 let s = s.trim();
317 if !s.starts_with('[') {
318 return Err(LexError::UnexpectedChar(s.chars().next().unwrap_or(' ')));
319 }
320
321 let (tensor, remaining) = parse_tensor_inner(s, 0)?;
322
323 if !remaining.trim().is_empty() {
324 return Err(LexError::UnexpectedChar(
325 remaining.trim().chars().next().unwrap_or('?'),
326 ));
327 }
328
329 let element_count = tensor.count_elements();
331 if element_count > MAX_TENSOR_ELEMENTS {
332 return Err(LexError::InvalidStructure(format!(
333 "Tensor element count exceeds maximum: {} (max: {})",
334 element_count, MAX_TENSOR_ELEMENTS
335 )));
336 }
337
338 Ok(tensor)
339}
340
341fn estimate_array_size(s: &str) -> usize {
343 let mut depth = 0;
344 let mut comma_count = 0;
345
346 for ch in s.chars() {
347 match ch {
348 '[' => depth += 1,
349 ']' => {
350 if depth == 0 {
351 break;
352 }
353 depth -= 1;
354 }
355 ',' if depth == 0 => comma_count += 1,
356 _ => {}
357 }
358 }
359
360 if comma_count > 0 {
361 comma_count + 1
362 } else {
363 2
364 }
365}
366
367fn parse_tensor_inner(s: &str, depth: usize) -> Result<(Tensor, &str), LexError> {
368 if depth > MAX_RECURSION_DEPTH {
369 return Err(LexError::InvalidStructure(format!(
370 "Recursion depth exceeded (max: {})",
371 MAX_RECURSION_DEPTH
372 )));
373 }
374 let s = s.trim();
375
376 if let Some(remaining_str) = s.strip_prefix('[') {
377 let mut remaining = remaining_str;
378 let estimated_capacity = estimate_array_size(remaining_str);
379 let mut items = Vec::with_capacity(estimated_capacity);
380
381 loop {
382 remaining = remaining.trim_start();
383
384 if remaining.is_empty() {
385 return Err(LexError::UnbalancedBrackets);
386 }
387
388 if remaining.starts_with(']') {
389 remaining = &remaining[1..];
390 break;
391 }
392
393 if !items.is_empty() {
394 if !remaining.starts_with(',') {
395 return Err(LexError::UnexpectedChar(
396 remaining.chars().next().unwrap_or('?'),
397 ));
398 }
399 remaining = remaining[1..].trim_start();
400 }
401
402 if remaining.starts_with(']') {
403 remaining = &remaining[1..];
404 break;
405 }
406
407 if remaining.starts_with('[') {
408 let (tensor, rest) = parse_tensor_inner(remaining, depth + 1)?;
409 items.push(tensor);
410 remaining = rest;
411 } else {
412 let (num, rest) = parse_number(remaining)?;
413 items.push(Tensor::Scalar(num));
414 remaining = rest;
415 }
416 }
417
418 if items.is_empty() {
419 return Err(LexError::EmptyTensor);
420 }
421
422 if items.len() > 1 {
424 let first_shape = items[0].shape();
425 for item in &items[1..] {
426 if item.shape() != first_shape {
427 return Err(LexError::InconsistentDimensions);
428 }
429 }
430 }
431
432 Ok((Tensor::Array(items), remaining))
433 } else {
434 let (num, rest) = parse_number(s)?;
435 Ok((Tensor::Scalar(num), rest))
436 }
437}
438
439fn parse_number(s: &str) -> Result<(f64, &str), LexError> {
440 let s = s.trim_start();
441 let bytes = s.as_bytes();
442
443 let mut end = 0;
444 let mut has_dot = false;
445
446 if bytes.first() == Some(&b'-') {
447 end = 1;
448 }
449
450 while end < bytes.len() {
451 match bytes[end] {
452 b'0'..=b'9' => end += 1,
453 b'.' if !has_dot => {
454 has_dot = true;
455 end += 1;
456 }
457 _ => break,
458 }
459 }
460
461 if end == 0 || (end == 1 && bytes[0] == b'-') {
462 return Err(LexError::InvalidNumber(
463 s.chars().take(10).collect::<String>(),
464 ));
465 }
466
467 let num_str = &s[..end];
468 let num: f64 = num_str.parse().map_err(|_| {
469 let context = if num_str.len() > 80 {
470 format!("{}...", &num_str[..80])
471 } else {
472 num_str.to_string()
473 };
474 LexError::InvalidNumber(context)
475 })?;
476
477 if !num.is_finite() {
478 return Err(LexError::InvalidNumber(format!(
479 "{} (non-finite values not allowed)",
480 num_str
481 )));
482 }
483
484 Ok((num, &s[end..]))
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490
491 #[test]
494 fn test_is_tensor_literal_valid() {
495 assert!(is_tensor_literal("[1, 2, 3]"));
496 assert!(is_tensor_literal("[[1, 2], [3, 4]]"));
497 assert!(is_tensor_literal("[1.5, 2.5]"));
498 assert!(is_tensor_literal("[-1, -2]"));
499 assert!(is_tensor_literal(" [1, 2, 3] "));
500 }
501
502 #[test]
503 fn test_is_tensor_literal_invalid() {
504 assert!(!is_tensor_literal("hello"));
505 assert!(!is_tensor_literal("@reference"));
506 assert!(!is_tensor_literal("123"));
507 assert!(!is_tensor_literal(""));
508 assert!(!is_tensor_literal("[1, 2"));
509 assert!(!is_tensor_literal("[a, b]"));
510 }
511
512 #[test]
515 fn test_parse_1d() {
516 let t = parse_tensor("[1, 2, 3]").unwrap();
517 assert_eq!(t.shape(), vec![3]);
518 assert_eq!(t.flatten(), vec![1.0, 2.0, 3.0]);
519 }
520
521 #[test]
522 fn test_parse_2d() {
523 let t = parse_tensor("[[1, 2], [3, 4]]").unwrap();
524 assert_eq!(t.shape(), vec![2, 2]);
525 assert_eq!(t.flatten(), vec![1.0, 2.0, 3.0, 4.0]);
526 }
527
528 #[test]
529 fn test_parse_floats() {
530 let t = parse_tensor("[1.5, 2.5, 3.5]").unwrap();
531 assert_eq!(t.flatten(), vec![1.5, 2.5, 3.5]);
532 assert!(!t.is_integer());
533 }
534
535 #[test]
536 fn test_parse_negatives() {
537 let t = parse_tensor("[-1, -2, -3]").unwrap();
538 assert_eq!(t.flatten(), vec![-1.0, -2.0, -3.0]);
539 }
540
541 #[test]
542 fn test_parse_trailing_comma() {
543 let t = parse_tensor("[1, 2, 3,]").unwrap();
544 assert_eq!(t.flatten(), vec![1.0, 2.0, 3.0]);
545 }
546
547 #[test]
550 fn test_empty_tensor_error() {
551 assert!(matches!(parse_tensor("[]"), Err(LexError::EmptyTensor)));
552 }
553
554 #[test]
555 fn test_unbalanced_brackets_error() {
556 assert!(matches!(
557 parse_tensor("[1, 2"),
558 Err(LexError::UnbalancedBrackets)
559 ));
560 }
561
562 #[test]
563 fn test_inconsistent_dimensions_error() {
564 assert!(matches!(
565 parse_tensor("[[1, 2], [3]]"),
566 Err(LexError::InconsistentDimensions)
567 ));
568 }
569
570 #[test]
571 fn test_invalid_number_error() {
572 assert!(matches!(
573 parse_tensor("[abc]"),
574 Err(LexError::InvalidNumber(_))
575 ));
576 }
577
578 #[test]
581 fn test_tensor_display() {
582 let t = parse_tensor("[1, 2, 3]").unwrap();
583 assert_eq!(format!("{}", t), "[1, 2, 3]");
584
585 let t = parse_tensor("[[1, 2], [3, 4]]").unwrap();
586 assert_eq!(format!("{}", t), "[[1, 2], [3, 4]]");
587 }
588
589 #[test]
590 fn test_tensor_methods() {
591 let scalar = Tensor::Scalar(42.0);
592 assert!(scalar.is_scalar());
593 assert!(!scalar.is_array());
594 assert_eq!(scalar.ndim(), 0);
595 assert_eq!(scalar.len(), 1);
596 assert!(!scalar.is_empty());
597
598 let array = parse_tensor("[1, 2, 3]").unwrap();
599 assert!(!array.is_scalar());
600 assert!(array.is_array());
601 assert_eq!(array.ndim(), 1);
602 assert_eq!(array.len(), 3);
603 assert!(!array.is_empty());
604 }
605
606 #[test]
607 fn test_tensor_equality() {
608 let t1 = parse_tensor("[1, 2, 3]").unwrap();
609 let t2 = parse_tensor("[1, 2, 3]").unwrap();
610 assert_eq!(t1, t2);
611
612 let t3 = parse_tensor("[1, 2, 4]").unwrap();
613 assert_ne!(t1, t3);
614 }
615
616 #[test]
617 fn test_tensor_clone() {
618 let t1 = parse_tensor("[[1, 2], [3, 4]]").unwrap();
619 let t2 = t1.clone();
620 assert_eq!(t1, t2);
621 }
622
623 #[test]
626 fn test_display_roundtrip() {
627 let original = parse_tensor("[1, 2, 3]").unwrap();
628 let serialized = original.to_string();
629 let parsed = parse_tensor(&serialized).unwrap();
630 assert_eq!(original, parsed);
631
632 let original = parse_tensor("[[1.5, 2.5], [3.5, 4.5]]").unwrap();
633 let serialized = original.to_string();
634 let parsed = parse_tensor(&serialized).unwrap();
635 assert_eq!(original, parsed);
636 }
637}