1use std::collections::{BTreeSet, HashMap, HashSet};
24use std::fmt::Write;
25
26use crate::lcsre::longest_repeated_substring;
27
28#[derive(Debug, Clone, PartialEq)]
30pub enum CsdMultiplierError {
31 InvalidCharacter,
33 LengthMismatch,
35 EmptyCoefficients,
37 WidthMismatch,
39}
40
41pub struct CsdMultiplier {
56 csd: String,
57 n: usize,
58 m: usize,
59}
60
61#[derive(Debug, Clone)]
66pub struct MultiplierSpec {
67 pub name: String,
69 pub csd: String,
71 pub input_width: usize,
73 pub max_power: usize,
75}
76
77#[derive(Debug, Clone, Copy, PartialEq)]
82enum TermOp {
83 Add,
84 Sub,
85}
86
87fn parse_terms(
89 csd_str: &str,
90 max_power: usize,
91) -> Result<Vec<(usize, TermOp)>, CsdMultiplierError> {
92 let mut terms = Vec::new();
93 for (i, c) in csd_str.chars().enumerate() {
94 let power = max_power - i;
95 match c {
96 '+' => terms.push((power, TermOp::Add)),
97 '-' => terms.push((power, TermOp::Sub)),
98 '0' => {}
99 _ => return Err(CsdMultiplierError::InvalidCharacter),
100 }
101 }
102 Ok(terms)
103}
104
105fn build_range_expr(csd_str: &str, start: usize, length: usize, max_power: usize) -> String {
107 let mut expr = String::new();
108 let mut first = true;
109 let end = start.saturating_add(length).min(csd_str.len());
110 for (i, c) in csd_str[start..end].char_indices() {
111 let power = max_power - (start + i);
112 match c {
113 '+' => {
114 if first {
115 write!(expr, "x_shift{}", power).unwrap();
116 first = false;
117 } else {
118 write!(expr, " + x_shift{}", power).unwrap();
119 }
120 }
121 '-' => {
122 if first {
123 write!(expr, "-x_shift{}", power).unwrap();
124 first = false;
125 } else {
126 write!(expr, " - x_shift{}", power).unwrap();
127 }
128 }
129 _ => {}
130 }
131 }
132 expr
133}
134
135fn output_width(input_width: usize, max_power: usize) -> usize {
137 input_width + max_power
138}
139
140impl CsdMultiplier {
145 pub fn new(csd: &str, n: usize, m: usize) -> Result<Self, CsdMultiplierError> {
161 if !csd.chars().all(|c| matches!(c, '+' | '-' | '0')) {
162 return Err(CsdMultiplierError::InvalidCharacter);
163 }
164 if csd.len() != m + 1 {
165 return Err(CsdMultiplierError::LengthMismatch);
166 }
167 Ok(Self {
168 csd: csd.to_string(),
169 n,
170 m,
171 })
172 }
173
174 fn decimal_value(&self) -> i32 {
176 self.csd.chars().fold(0, |acc, c| {
177 let acc = acc << 1;
178 match c {
179 '+' => acc + 1,
180 '-' => acc - 1,
181 '0' => acc,
182 _ => unreachable!(),
183 }
184 })
185 }
186
187 pub fn generate_verilog(&self) -> String {
189 let mut output = String::new();
190 self.generate_header(&mut output);
191 self.generate_wires(&mut output);
192 self.generate_result_lcsre(&mut output);
193 writeln!(output, "endmodule").unwrap();
194 output
195 }
196
197 fn generate_header(&self, output: &mut String) {
198 writeln!(
199 output,
200 "// CSD Multiplier for pattern: {} (value: {})",
201 self.csd,
202 self.decimal_value()
203 )
204 .unwrap();
205 writeln!(
206 output,
207 "module csd_multiplier (
208 input signed [{}:0] x, // Input value (signed)
209 output signed [{}:0] result // Result (signed)
210);",
211 self.n - 1,
212 self.n + self.m - 1
213 )
214 .unwrap();
215 }
216
217 fn get_unique_powers(&self) -> Vec<usize> {
219 let mut powers: Vec<usize> = self
220 .csd
221 .char_indices()
222 .filter(|(_, c)| *c != '0')
223 .map(|(i, _)| self.m - i)
224 .collect();
225 powers.sort_unstable_by(|a, b| b.cmp(a));
226 powers.dedup();
227 powers
228 }
229
230 fn generate_wires(&self, output: &mut String) {
231 let shift_powers = self.get_unique_powers();
232 if shift_powers.is_empty() {
233 return;
234 }
235 writeln!(
236 output,
237 "\n // Signed shifted versions (Verilog handles sign extension)"
238 )
239 .unwrap();
240 for &power in &shift_powers {
241 let padding = self.m - power;
242 writeln!(
243 output,
244 " wire signed [{}:0] x_shift{} = $signed({{ {{{}{{x[{}]}}}}, x}}) << {};",
245 self.n + self.m - 1,
246 power,
247 padding,
248 self.n - 1,
249 power
250 )
251 .unwrap();
252 }
253 }
254
255 fn generate_result_lcsre(&self, output: &mut String) {
257 let terms = parse_terms(&self.csd, self.m).unwrap_or_default();
258 if terms.is_empty() {
259 writeln!(output, "\n // CSD implementation").unwrap();
260 writeln!(output, " assign result = 0;").unwrap();
261 return;
262 }
263
264 let repeated = longest_repeated_substring(&self.csd);
266 let pat_positions = if repeated.len() > 1 {
267 let pat_nnz = repeated.chars().filter(|c| *c == '+' || *c == '-').count();
268 if pat_nnz >= 2 {
269 let pos = find_pattern_occurrences(&self.csd, &repeated);
270 if pos.len() >= 2 {
271 Some((repeated, pos))
272 } else {
273 None
274 }
275 } else {
276 None
277 }
278 } else {
279 None
280 };
281
282 if let Some((ref pat, ref positions)) = pat_positions {
283 let base_pos = positions[0];
285 let ow = output_width(self.n, self.m);
286
287 let pat_expr = build_range_expr(&self.csd, base_pos, pat.len(), self.m);
288 writeln!(output, "\n // LCSRe: repeated pattern \"{}\"", pat).unwrap();
289 writeln!(
290 output,
291 " wire signed [{}:0] _pat = {};",
292 ow - 1,
293 pat_expr
294 )
295 .unwrap();
296
297 let mut expr = String::new();
298 let mut cur = 0;
299 for &pos in positions {
300 if pos > cur {
302 let gap = build_range_expr(&self.csd, cur, pos - cur, self.m);
303 if !gap.is_empty() {
304 if expr.is_empty() {
305 expr = gap;
306 } else {
307 write!(expr, " + {}", gap).unwrap();
308 }
309 }
310 }
311 let shift = pos as isize - base_pos as isize;
313 let pat_ref = if shift == 0 {
314 "_pat".to_string()
315 } else {
316 format!("(_pat >>> {})", shift)
317 };
318 if expr.is_empty() {
319 expr = pat_ref;
320 } else {
321 write!(expr, " + {}", pat_ref).unwrap();
322 }
323 cur = pos + pat.len();
324 }
325 if cur < self.csd.len() {
327 let suffix = build_range_expr(&self.csd, cur, self.csd.len() - cur, self.m);
328 if !suffix.is_empty() {
329 write!(expr, " + {}", suffix).unwrap();
330 }
331 }
332
333 writeln!(output, "\n // CSD implementation (LCSRe optimized)").unwrap();
334 writeln!(output, " assign result = {};", expr).unwrap();
335 } else {
336 writeln!(output, "\n // CSD implementation with signed arithmetic").unwrap();
338 let (first_power, first_op) = terms[0];
339 let mut expr = format!(
340 "{}x_shift{}",
341 if first_op == TermOp::Sub { "-" } else { "" },
342 first_power
343 );
344 for (power, op) in &terms[1..] {
345 match op {
346 TermOp::Add => write!(expr, " + x_shift{}", power).unwrap(),
347 TermOp::Sub => write!(expr, " - x_shift{}", power).unwrap(),
348 }
349 }
350 writeln!(output, " assign result = {};", expr).unwrap();
351 }
352 }
353}
354
355fn find_pattern_occurrences(csd_str: &str, pattern: &str) -> Vec<usize> {
361 let mut positions = Vec::new();
362 let mut pos = 0;
363 while let Some(found) = csd_str[pos..].find(pattern) {
364 let absolute = pos + found;
365 positions.push(absolute);
366 pos = absolute + pattern.len();
367 }
368 positions
369}
370
371fn count_nnz(s: &str) -> usize {
373 s.chars().filter(|c| *c == '+' || *c == '-').count()
374}
375
376fn build_coeff_expr(
378 csd: &str,
379 max_power: usize,
380 pattern: &str,
381 cse_base_pos: usize,
382 cse_name: &str,
383) -> String {
384 if pattern.is_empty() {
385 return build_range_expr(csd, 0, csd.len(), max_power);
386 }
387
388 let positions = find_pattern_occurrences(csd, pattern);
389 let mut parts: Vec<String> = Vec::new();
390 let mut cur = 0;
391
392 for pos in positions {
393 if pos > cur {
395 let gap = build_range_expr(csd, cur, pos - cur, max_power);
396 if !gap.is_empty() {
397 parts.push(gap);
398 }
399 }
400 let shift = pos as isize - cse_base_pos as isize;
402 if shift == 0 {
403 parts.push(cse_name.to_string());
404 } else {
405 parts.push(format!("({} >>> {})", cse_name, shift));
406 }
407 cur = pos + pattern.len();
408 }
409 if cur < csd.len() {
411 let gap = build_range_expr(csd, cur, csd.len() - cur, max_power);
412 if !gap.is_empty() {
413 parts.push(gap);
414 }
415 }
416
417 if parts.is_empty() {
418 return String::new();
419 }
420 let mut result = parts[0].clone();
421 for p in &parts[1..] {
422 write!(result, " + {}", p).unwrap();
423 }
424 result
425}
426
427fn find_cross_patterns(csd_list: &[String]) -> HashMap<String, Vec<(usize, usize)>> {
430 let mut patterns: HashMap<String, Vec<(usize, usize)>> = HashMap::new();
431 for (ci, csd) in csd_list.iter().enumerate() {
432 let n = csd.len();
433 for i in 0..n {
434 for j in (i + 2)..=n {
435 let sub: String = csd[i..j].to_string();
436 if count_nnz(&sub) >= 2 {
437 patterns.entry(sub).or_default().push((ci, i));
438 }
439 }
440 }
441 }
442 patterns.retain(|_, occ: &mut Vec<(usize, usize)>| {
444 let unique: HashSet<usize> = occ.iter().map(|(ci, _)| *ci).collect();
445 unique.len() >= 2
446 });
447 patterns
448}
449
450pub fn generate_csd_multiplier(
479 csd_str: &str,
480 input_width: usize,
481 max_power: usize,
482) -> Result<String, CsdMultiplierError> {
483 let len = csd_str.len();
485 if len != max_power + 1 {
486 return Err(CsdMultiplierError::LengthMismatch);
487 }
488 for c in csd_str.chars() {
489 if c != '+' && c != '-' && c != '0' {
490 return Err(CsdMultiplierError::InvalidCharacter);
491 }
492 }
493
494 let terms = parse_terms(csd_str, max_power)?;
495 let ow = output_width(input_width, max_power);
496
497 let mut verilog = String::new();
498
499 writeln!(verilog).unwrap();
501 writeln!(verilog, "module csd_multiplier (").unwrap();
502 writeln!(
503 verilog,
504 " input signed [{}:0] x, // Input value",
505 input_width - 1
506 )
507 .unwrap();
508 writeln!(
509 verilog,
510 " output signed [{}:0] result // Result of multiplication",
511 ow - 1
512 )
513 .unwrap();
514 writeln!(verilog, ");").unwrap();
515
516 if !terms.is_empty() {
518 writeln!(verilog).unwrap();
519 writeln!(verilog, " // Create shifted versions of input").unwrap();
520 let mut powers_needed: BTreeSet<usize> = BTreeSet::new();
521 for (p, _) in &terms {
523 powers_needed.insert(*p);
524 }
525 for p in powers_needed.into_iter().rev() {
526 writeln!(
527 verilog,
528 " wire signed [{}:0] x_shift{} = x <<< {};",
529 ow - 1,
530 p,
531 p
532 )
533 .unwrap();
534 }
535 }
536
537 let repeated = longest_repeated_substring(csd_str);
539
540 let pat_positions: Vec<usize> = if repeated.len() > 1 {
541 let pat_nnz = count_nnz(&repeated);
542 if pat_nnz >= 2 {
543 let pos = find_pattern_occurrences(csd_str, &repeated);
544 if pos.len() >= 2 {
545 pos
546 } else {
547 Vec::new()
548 }
549 } else {
550 Vec::new()
551 }
552 } else {
553 Vec::new()
554 };
555
556 let use_opt = !pat_positions.is_empty();
557
558 if terms.is_empty() {
560 writeln!(verilog).unwrap();
561 writeln!(verilog, " // CSD implementation").unwrap();
562 writeln!(verilog, " assign result = 0;").unwrap();
563 } else if use_opt {
564 let base_pos = pat_positions[0];
566 let pat_expr = build_range_expr(csd_str, base_pos, repeated.len(), max_power);
567 writeln!(verilog).unwrap();
568 writeln!(verilog, " // LCSRe: repeated pattern \"{}\"", repeated).unwrap();
569 writeln!(
570 verilog,
571 " wire signed [{}:0] _pat = {};",
572 ow - 1,
573 pat_expr
574 )
575 .unwrap();
576
577 let mut expr = String::new();
578 let mut cur = 0;
579 for &pos in &pat_positions {
580 if pos > cur {
582 let gap = build_range_expr(csd_str, cur, pos - cur, max_power);
583 if !gap.is_empty() {
584 if expr.is_empty() {
585 expr = gap;
586 } else {
587 write!(expr, " + {}", gap).unwrap();
588 }
589 }
590 }
591 let shift = pos as isize - base_pos as isize;
593 let pat_ref = if shift == 0 {
594 "_pat".to_string()
595 } else {
596 format!("(_pat >>> {})", shift)
597 };
598 if expr.is_empty() {
599 expr = pat_ref;
600 } else {
601 write!(expr, " + {}", pat_ref).unwrap();
602 }
603 cur = pos + repeated.len();
604 }
605 if cur < csd_str.len() {
607 let suffix = build_range_expr(csd_str, cur, csd_str.len() - cur, max_power);
608 if !suffix.is_empty() {
609 write!(expr, " + {}", suffix).unwrap();
610 }
611 }
612
613 writeln!(verilog).unwrap();
614 writeln!(verilog, " // CSD implementation (LCSRe optimized)").unwrap();
615 writeln!(verilog, " assign result = {};", expr).unwrap();
616 } else {
617 writeln!(verilog).unwrap();
619 writeln!(verilog, " // CSD implementation").unwrap();
620 let mut expr = String::new();
621 for (i, (power, op)) in terms.iter().enumerate() {
622 if i == 0 {
623 if *op == TermOp::Sub {
624 write!(expr, "-").unwrap();
625 }
626 write!(expr, "x_shift{}", power).unwrap();
627 } else {
628 match op {
629 TermOp::Add => write!(expr, " + x_shift{}", power).unwrap(),
630 TermOp::Sub => write!(expr, " - x_shift{}", power).unwrap(),
631 }
632 }
633 }
634 writeln!(verilog, " assign result = {};", expr).unwrap();
635 }
636
637 writeln!(verilog, "endmodule").unwrap();
638 Ok(verilog)
639}
640
641pub fn generate_csd_multipliers(
683 coeffs: &[MultiplierSpec],
684 module_name: &str,
685) -> Result<String, CsdMultiplierError> {
686 if coeffs.is_empty() {
687 return Err(CsdMultiplierError::EmptyCoefficients);
688 }
689
690 let input_width = coeffs[0].input_width;
692 let max_power = coeffs[0].max_power;
693
694 for spec in coeffs {
695 if spec.input_width != input_width || spec.max_power != max_power {
696 return Err(CsdMultiplierError::WidthMismatch);
697 }
698 let len = spec.csd.len();
699 if len != max_power + 1 {
700 return Err(CsdMultiplierError::LengthMismatch);
701 }
702 for c in spec.csd.chars() {
703 if c != '+' && c != '-' && c != '0' {
704 return Err(CsdMultiplierError::InvalidCharacter);
705 }
706 }
707 }
708
709 let ow = output_width(input_width, max_power);
710
711 let mut all_powers: BTreeSet<usize> = BTreeSet::new();
713 for spec in coeffs {
714 for (i, c) in spec.csd.char_indices() {
715 if c != '0' {
716 all_powers.insert(max_power - i);
717 }
718 }
719 }
720
721 let csd_strings: Vec<String> = coeffs.iter().map(|s| s.csd.clone()).collect();
723 let cross = find_cross_patterns(&csd_strings);
724
725 let mut best_pattern = String::new();
726 let mut best_occurrences: Vec<(usize, usize)> = Vec::new();
727 let mut best_score = 0;
728
729 for (pat, occ) in &cross {
730 let nnz = count_nnz(pat);
731 let score = (nnz.saturating_sub(1)) * (occ.len().saturating_sub(1));
732 if score > best_score {
733 best_score = score;
734 best_pattern.clone_from(pat);
735 best_occurrences.clone_from(occ);
736 }
737 }
738
739 let cse_base_pos = if best_pattern.is_empty() {
741 0
742 } else {
743 best_occurrences
744 .iter()
745 .map(|(_, pos)| *pos)
746 .min()
747 .unwrap_or(0)
748 };
749
750 let mut verilog = String::new();
752 writeln!(verilog).unwrap();
753 writeln!(verilog, "module {} (", module_name).unwrap();
754 writeln!(
755 verilog,
756 " input signed [{}:0] x, // Input value",
757 input_width - 1
758 )
759 .unwrap();
760 for spec in coeffs {
761 let ow_spec = output_width(spec.input_width, spec.max_power);
762 writeln!(
763 verilog,
764 " output signed [{}:0] {}",
765 ow_spec - 1,
766 spec.name
767 )
768 .unwrap();
769 }
770 writeln!(verilog, ");").unwrap();
771
772 if !all_powers.is_empty() {
774 writeln!(verilog).unwrap();
775 writeln!(verilog, " // Create shifted versions of input").unwrap();
776 for p in all_powers.iter().rev() {
777 writeln!(
778 verilog,
779 " wire signed [{}:0] x_shift{} = x <<< {};",
780 ow - 1,
781 p,
782 p
783 )
784 .unwrap();
785 }
786 }
787
788 let cse_name = "_cse_0";
790 if !best_pattern.is_empty() {
791 let cse_expr = build_range_expr(
792 &best_pattern,
793 0,
794 best_pattern.len(),
795 max_power.saturating_sub(cse_base_pos),
796 );
797 writeln!(verilog).unwrap();
798 writeln!(
799 verilog,
800 " // Cross-CSE: shared pattern \"{}\"",
801 best_pattern
802 )
803 .unwrap();
804 writeln!(
805 verilog,
806 " wire signed [{}:0] {} = {};",
807 ow - 1,
808 cse_name,
809 cse_expr
810 )
811 .unwrap();
812 }
813
814 let cse_coeffs: HashSet<usize> = best_occurrences.iter().map(|(ci, _)| *ci).collect();
816
817 for (idx, spec) in coeffs.iter().enumerate() {
819 writeln!(verilog).unwrap();
820 writeln!(verilog, " // {}: {}", spec.name, spec.csd).unwrap();
821
822 let has_cse = !best_pattern.is_empty() && cse_coeffs.contains(&idx);
823 let expr = if has_cse {
824 build_coeff_expr(&spec.csd, max_power, &best_pattern, cse_base_pos, cse_name)
825 } else {
826 build_coeff_expr(&spec.csd, max_power, "", 0, "")
827 };
828
829 if expr.is_empty() {
830 writeln!(verilog, " assign {} = 0;", spec.name).unwrap();
831 } else {
832 writeln!(verilog, " assign {} = {};", spec.name, expr).unwrap();
833 }
834 }
835
836 writeln!(verilog, "endmodule").unwrap();
837 Ok(verilog)
838}
839
840#[cfg(test)]
845mod tests {
846 use super::*;
847
848 #[test]
851 fn test_valid_csd() {
852 let csd = "+00-00+0+";
853 let multiplier = CsdMultiplier::new(csd, 8, 8).unwrap();
854 assert_eq!(multiplier.decimal_value(), 229);
855 }
856
857 #[test]
858 fn test_decimal_value() {
859 let multiplier = CsdMultiplier::new("+", 8, 0).unwrap();
860 assert_eq!(multiplier.decimal_value(), 1);
861
862 let multiplier = CsdMultiplier::new("-", 8, 0).unwrap();
863 assert_eq!(multiplier.decimal_value(), -1);
864
865 let multiplier = CsdMultiplier::new("+0-", 8, 2).unwrap();
866 assert_eq!(multiplier.decimal_value(), 3);
867
868 let multiplier = CsdMultiplier::new("-0+", 8, 2).unwrap();
869 assert_eq!(multiplier.decimal_value(), -3);
870 }
871
872 #[test]
873 fn test_all_zeros_csd() {
874 let csd = "0000";
875 let multiplier = CsdMultiplier::new(csd, 8, 3).unwrap();
876 let verilog = multiplier.generate_verilog();
877 assert!(verilog.contains("assign result = 0;"));
878 }
879
880 #[test]
881 fn test_invalid_csd_chars() {
882 let csd = "+01-00+0+";
883 let result = CsdMultiplier::new(csd, 8, 6);
884 assert!(matches!(result, Err(CsdMultiplierError::InvalidCharacter)));
885 }
886
887 #[test]
888 fn test_length_mismatch() {
889 let csd = "+00-00+0+";
890 let result = CsdMultiplier::new(csd, 8, 5);
891 assert!(matches!(result, Err(CsdMultiplierError::LengthMismatch)));
892 }
893
894 #[test]
895 fn test_verilog_generation() {
896 let csd = "+0-";
897 let n = 8;
898 let m = 2;
899 let multiplier = CsdMultiplier::new(csd, n, m).unwrap();
900 let expected_verilog = r###"// CSD Multiplier for pattern: +0- (value: 3)
901module csd_multiplier (
902 input signed [7:0] x, // Input value (signed)
903 output signed [9:0] result // Result (signed)
904);
905
906 // Signed shifted versions (Verilog handles sign extension)
907 wire signed [9:0] x_shift2 = $signed({ {0{x[7]}}, x}) << 2;
908 wire signed [9:0] x_shift0 = $signed({ {2{x[7]}}, x}) << 0;
909
910 // CSD implementation with signed arithmetic
911 assign result = x_shift2 - x_shift0;
912endmodule
913"###;
914 assert_eq!(multiplier.generate_verilog(), expected_verilog);
915 }
916
917 #[test]
921 fn test_fn_basic_valid() {
922 let v = generate_csd_multiplier("+0-", 8, 2).unwrap();
923 assert!(v.contains("module csd_multiplier"));
924 assert!(v.contains("endmodule"));
925 assert!(v.contains("input signed [7:0] x"));
926 assert!(v.contains("output signed [9:0] result"));
927 assert!(v.contains("assign result = x_shift2 - x_shift0"));
928 }
929
930 #[test]
931 fn test_fn_positive_only() {
932 let v = generate_csd_multiplier("+0+", 4, 2).unwrap();
933 assert!(v.contains("assign result = x_shift2 + x_shift0"));
934 }
935
936 #[test]
937 fn test_fn_negative_only() {
938 let v = generate_csd_multiplier("-0-", 8, 2).unwrap();
939 assert!(v.contains("assign result = -x_shift2 - x_shift0"));
940 }
941
942 #[test]
943 fn test_fn_all_zeros() {
944 let v = generate_csd_multiplier("000", 8, 2).unwrap();
945 assert!(v.contains("assign result = 0;"));
946 assert!(!v.contains("x_shift"));
947 }
948
949 #[test]
950 fn test_fn_single_nonzero() {
951 let v = generate_csd_multiplier("+00", 8, 2).unwrap();
952 assert!(v.contains("assign result"));
953 assert!(v.contains("x_shift2"));
954 }
955
956 #[test]
957 fn test_fn_invalid_chars() {
958 let r = generate_csd_multiplier("123", 8, 2);
959 assert_eq!(r, Err(CsdMultiplierError::InvalidCharacter));
960 }
961
962 #[test]
963 fn test_fn_invalid_length() {
964 let r = generate_csd_multiplier("+0-", 8, 3);
965 assert_eq!(r, Err(CsdMultiplierError::LengthMismatch));
966 }
967
968 #[test]
970 fn test_fn_flat_when_pattern_nnz_is_1() {
971 let v = generate_csd_multiplier("+00-00+0", 8, 7).unwrap();
973 assert!(!v.contains("_pat"));
974 assert!(v.contains("x_shift7 - x_shift4 + x_shift1"));
975 }
976
977 #[test]
978 fn test_fn_double_repeat_optimization() {
979 let v = generate_csd_multiplier("+0-0+0-0", 8, 7).unwrap();
981 assert!(v.contains("_pat"));
982 assert!(v.contains("_pat = x_shift7 - x_shift5"));
983 assert!(v.contains("(_pat >>> 4)"));
984 assert!(v.contains("LCSRe"));
985 }
986
987 #[test]
988 fn test_fn_triple_repeat_optimization() {
989 let v = generate_csd_multiplier("+0-0+0-0+0-0", 8, 11).unwrap();
991 assert!(v.contains("_pat"));
992 assert!(v.contains("(_pat >>> 4)"));
993 assert!(v.contains("(_pat >>> 8)"));
994 }
995
996 #[test]
997 fn test_fn_longer_pattern_repeat() {
998 let v = generate_csd_multiplier("+00-00+00-00", 8, 11).unwrap();
1000 assert!(v.contains("_pat"));
1001 assert!(v.contains("_pat = x_shift11 - x_shift8"));
1002 assert!(v.contains("(_pat >>> 6)"));
1003 }
1004
1005 #[test]
1006 fn test_fn_leading_minus_no_optimization() {
1007 let v = generate_csd_multiplier("-0-", 8, 2).unwrap();
1009 assert!(!v.contains("_pat"));
1010 assert!(v.contains("-x_shift2 - x_shift0"));
1011 }
1012
1013 #[test]
1014 fn test_fn_pattern_with_leading_minus() {
1015 let v = generate_csd_multiplier("-0+0-0+0", 8, 7).unwrap();
1017 assert!(v.contains("_pat"));
1018 assert!(v.contains("_pat = -x_shift7 + x_shift5"));
1019 assert!(v.contains("(_pat >>> 4)"));
1020 }
1021
1022 #[test]
1023 fn test_fn_no_optimization_for_single_occurrence() {
1024 let v = generate_csd_multiplier("+0-+00-0", 8, 7).unwrap();
1026 assert!(!v.contains("_pat"));
1027 }
1028
1029 #[test]
1030 fn test_fn_pat_wire_width_matches_output() {
1031 let v = generate_csd_multiplier("+0-0+0-0", 8, 7).unwrap();
1033 assert!(v.contains("[14:0] _pat"));
1034 }
1035
1036 #[test]
1037 fn test_fn_repeat_with_trailing_gap() {
1038 let v = generate_csd_multiplier("+0-0+0-0+0", 8, 9).unwrap();
1040 assert!(v.contains("_pat"));
1041 assert!(v.contains("(_pat >>> 4)"));
1042 }
1043
1044 #[test]
1046 fn test_fn_very_short_csd() {
1047 let v = generate_csd_multiplier("+", 8, 0).unwrap();
1049 assert!(v.contains("assign result = x_shift0"));
1050 }
1051
1052 #[test]
1053 fn test_fn_all_minus_signs() {
1054 let v = generate_csd_multiplier("---", 8, 2).unwrap();
1055 assert!(!v.contains("_pat"));
1056 }
1057
1058 #[test]
1059 fn test_fn_always_has_proper_module_boundaries() {
1060 let v = generate_csd_multiplier("+0-0+0-0", 8, 7).unwrap();
1061 assert!(v.contains("\nmodule csd_multiplier"));
1062 assert!(v.contains("endmodule\n"));
1063 }
1064
1065 #[test]
1066 fn test_fn_lcsre_comment_present_when_optimized() {
1067 let v = generate_csd_multiplier("+0-0+0-0", 8, 7).unwrap();
1068 assert!(v.contains("LCSRe"));
1069 }
1070
1071 #[test]
1072 fn test_fn_no_lcsre_comment_when_flat() {
1073 let v = generate_csd_multiplier("+00-00+0", 8, 7).unwrap();
1074 assert!(!v.contains("LCSRe"));
1075 }
1076
1077 #[test]
1080 fn test_multi_empty_coeffs() {
1081 let r = generate_csd_multipliers(&[], "test");
1082 assert_eq!(r, Err(CsdMultiplierError::EmptyCoefficients));
1083 }
1084
1085 #[test]
1086 fn test_multi_single_coeff() {
1087 let coeffs = vec![MultiplierSpec {
1088 name: "y0".to_string(),
1089 csd: "+0-".to_string(),
1090 input_width: 8,
1091 max_power: 2,
1092 }];
1093 let v = generate_csd_multipliers(&coeffs, "test_mod").unwrap();
1094 assert!(v.contains("module test_mod"));
1095 assert!(v.contains("output signed [9:0] y0"));
1096 }
1097
1098 #[test]
1099 fn test_multi_duplicate_coeffs() {
1100 let coeffs = vec![
1101 MultiplierSpec {
1102 name: "y0".to_string(),
1103 csd: "+00-00+0+".to_string(),
1104 input_width: 8,
1105 max_power: 8,
1106 },
1107 MultiplierSpec {
1108 name: "y1".to_string(),
1109 csd: "+00-00+0+".to_string(),
1110 input_width: 8,
1111 max_power: 8,
1112 },
1113 ];
1114 let v = generate_csd_multipliers(&coeffs, "csd_filter").unwrap();
1115 assert!(v.contains("Cross-CSE"));
1116 assert!(v.contains("_cse_0"));
1117 }
1118
1119 #[test]
1120 fn test_multi_width_mismatch() {
1121 let coeffs = vec![
1122 MultiplierSpec {
1123 name: "y0".to_string(),
1124 csd: "+0-".to_string(),
1125 input_width: 8,
1126 max_power: 2,
1127 },
1128 MultiplierSpec {
1129 name: "y1".to_string(),
1130 csd: "+0-".to_string(),
1131 input_width: 16,
1132 max_power: 2,
1133 },
1134 ];
1135 let r = generate_csd_multipliers(&coeffs, "test");
1136 assert_eq!(r, Err(CsdMultiplierError::WidthMismatch));
1137 }
1138
1139 #[test]
1140 fn test_multi_invalid_chars() {
1141 let coeffs = vec![MultiplierSpec {
1142 name: "y0".to_string(),
1143 csd: "123".to_string(),
1144 input_width: 8,
1145 max_power: 2,
1146 }];
1147 let r = generate_csd_multipliers(&coeffs, "test");
1148 assert_eq!(r, Err(CsdMultiplierError::InvalidCharacter));
1149 }
1150}