Skip to main content

csd/
csd_multiplier.rs

1//! CSD Multiplier Module
2//!
3//! This module provides functionality to generate Verilog code for efficient constant multiplication
4//! using Canonical Signed Digit (CSD) representation. CSD representation minimizes the number of
5//! non-zero digits, which reduces the number of adders/subtractors needed in hardware implementation.
6//!
7//! # Overview
8//!
9//! In digital signal processing and hardware design, multiplying a variable by a constant is a common
10//! operation. Using CSD representation, we can implement these multiplications efficiently using only
11//! shifts, additions, and subtractions instead of full multipliers.
12//!
13//! # Single Multiplier (with LCSRe optimization)
14//!
15//! When the CSD string contains a repeated non-overlapping pattern with ≥2 non-zero digits,
16//! the generated Verilog shares hardware via a sub-expression wire `_pat`, reducing adder count.
17//!
18//! # Multi-Coefficient Cross-CSE
19//!
20//! `generate_csd_multipliers()` finds repeated substrings across **different** coefficients and
21//! creates a shared common sub-expression (CSE) wire, reducing total hardware across the filter.
22
23use std::collections::{BTreeSet, HashMap, HashSet};
24use std::fmt::Write;
25
26use crate::lcsre::longest_repeated_substring;
27
28/// Error type for CSD multiplier operations.
29#[derive(Debug, Clone, PartialEq)]
30pub enum CsdMultiplierError {
31    /// Invalid character found in CSD string (only '+', '-', '0' allowed)
32    InvalidCharacter,
33    /// Length of CSD string doesn't match expected length (max_power + 1)
34    LengthMismatch,
35    /// At least one coefficient is required
36    EmptyCoefficients,
37    /// All coefficients must share the same input_width and max_power
38    WidthMismatch,
39}
40
41/// A CSD-based constant multiplier that generates Verilog code
42///
43/// # Example
44///
45/// ```rust
46/// use csd::csd_multiplier::{CsdMultiplier, CsdMultiplierError};
47///
48/// // Create a multiplier for the CSD pattern "+00-00+" (value: 57)
49/// let multiplier = CsdMultiplier::new("+00-00+", 8, 6).unwrap();
50///
51/// // Generate Verilog code
52/// let verilog = multiplier.generate_verilog();
53/// assert!(verilog.contains("module csd_multiplier"));
54/// ```
55pub struct CsdMultiplier {
56    csd: String,
57    n: usize,
58    m: usize,
59}
60
61/// Specification for a single CSD multiplier coefficient
62///
63/// Used with [`generate_csd_multipliers()`] for multi-coefficient
64/// cross-common-subexpression elimination.
65#[derive(Debug, Clone)]
66pub struct MultiplierSpec {
67    /// Output port name (e.g. "y0", "y1")
68    pub name: String,
69    /// CSD string ('+', '-', '0')
70    pub csd: String,
71    /// Bit width of input x
72    pub input_width: usize,
73    /// Highest power (len(csd) - 1)
74    pub max_power: usize,
75}
76
77// ---------------------------------------------------------------------------
78// Internal helpers
79// ---------------------------------------------------------------------------
80
81#[derive(Debug, Clone, Copy, PartialEq)]
82enum TermOp {
83    Add,
84    Sub,
85}
86
87/// Parse a CSD string into (power, operation) pairs.
88fn parse_terms(
89    csd_str: &str,
90    max_power: usize,
91) -> Result<Vec<(usize, TermOp)>, CsdMultiplierError> {
92    let mut terms = Vec::new();
93    for (i, c) in csd_str.chars().enumerate() {
94        let power = max_power - i;
95        match c {
96            '+' => terms.push((power, TermOp::Add)),
97            '-' => terms.push((power, TermOp::Sub)),
98            '0' => {}
99            _ => return Err(CsdMultiplierError::InvalidCharacter),
100        }
101    }
102    Ok(terms)
103}
104
105/// Build a flat Verilog expression for a range [start, start+length) of the CSD string.
106fn build_range_expr(csd_str: &str, start: usize, length: usize, max_power: usize) -> String {
107    let mut expr = String::new();
108    let mut first = true;
109    let end = start.saturating_add(length).min(csd_str.len());
110    for (i, c) in csd_str[start..end].char_indices() {
111        let power = max_power - (start + i);
112        match c {
113            '+' => {
114                if first {
115                    write!(expr, "x_shift{}", power).unwrap();
116                    first = false;
117                } else {
118                    write!(expr, " + x_shift{}", power).unwrap();
119                }
120            }
121            '-' => {
122                if first {
123                    write!(expr, "-x_shift{}", power).unwrap();
124                    first = false;
125                } else {
126                    write!(expr, " - x_shift{}", power).unwrap();
127                }
128            }
129            _ => {}
130        }
131    }
132    expr
133}
134
135/// Compute output width from input_width and max_power.
136fn output_width(input_width: usize, max_power: usize) -> usize {
137    input_width + max_power
138}
139
140// ---------------------------------------------------------------------------
141// CsdMultiplier (struct-based, backward compatible)
142// ---------------------------------------------------------------------------
143
144impl CsdMultiplier {
145    /// Create a new CSD multiplier.
146    ///
147    /// # Arguments
148    ///
149    /// * `csd` - The CSD pattern string (e.g., "+0-")
150    /// * `n` - Input bit width
151    /// * `m` - Highest power index (length of CSD minus 1)
152    ///
153    /// # Errors
154    ///
155    /// Returns `CsdMultiplierError::InvalidCharacter` if the CSD string contains
156    /// characters other than '+', '-', or '0'.
157    ///
158    /// Returns `CsdMultiplierError::LengthMismatch` if the CSD string length
159    /// doesn't equal `m + 1`.
160    pub fn new(csd: &str, n: usize, m: usize) -> Result<Self, CsdMultiplierError> {
161        if !csd.chars().all(|c| matches!(c, '+' | '-' | '0')) {
162            return Err(CsdMultiplierError::InvalidCharacter);
163        }
164        if csd.len() != m + 1 {
165            return Err(CsdMultiplierError::LengthMismatch);
166        }
167        Ok(Self {
168            csd: csd.to_string(),
169            n,
170            m,
171        })
172    }
173
174    /// Calculate the decimal value represented by the CSD string.
175    fn decimal_value(&self) -> i32 {
176        self.csd.chars().fold(0, |acc, c| {
177            let acc = acc << 1;
178            match c {
179                '+' => acc + 1,
180                '-' => acc - 1,
181                '0' => acc,
182                _ => unreachable!(),
183            }
184        })
185    }
186
187    /// Generate the Verilog module code (with LCSRe optimization).
188    pub fn generate_verilog(&self) -> String {
189        let mut output = String::new();
190        self.generate_header(&mut output);
191        self.generate_wires(&mut output);
192        self.generate_result_lcsre(&mut output);
193        writeln!(output, "endmodule").unwrap();
194        output
195    }
196
197    fn generate_header(&self, output: &mut String) {
198        writeln!(
199            output,
200            "// CSD Multiplier for pattern: {} (value: {})",
201            self.csd,
202            self.decimal_value()
203        )
204        .unwrap();
205        writeln!(
206            output,
207            "module csd_multiplier (
208    input signed [{}:0] x,      // Input value (signed)
209    output signed [{}:0] result // Result (signed)
210);",
211            self.n - 1,
212            self.n + self.m - 1
213        )
214        .unwrap();
215    }
216
217    /// Return sorted unique powers of non-zero digits, descending.
218    fn get_unique_powers(&self) -> Vec<usize> {
219        let mut powers: Vec<usize> = self
220            .csd
221            .char_indices()
222            .filter(|(_, c)| *c != '0')
223            .map(|(i, _)| self.m - i)
224            .collect();
225        powers.sort_unstable_by(|a, b| b.cmp(a));
226        powers.dedup();
227        powers
228    }
229
230    fn generate_wires(&self, output: &mut String) {
231        let shift_powers = self.get_unique_powers();
232        if shift_powers.is_empty() {
233            return;
234        }
235        writeln!(
236            output,
237            "\n    // Signed shifted versions (Verilog handles sign extension)"
238        )
239        .unwrap();
240        for &power in &shift_powers {
241            let padding = self.m - power;
242            writeln!(
243                output,
244                "    wire signed [{}:0] x_shift{} = $signed({{ {{{}{{x[{}]}}}}, x}}) << {};",
245                self.n + self.m - 1,
246                power,
247                padding,
248                self.n - 1,
249                power
250            )
251            .unwrap();
252        }
253    }
254
255    /// Generate assign statement with LCSRe optimization.
256    fn generate_result_lcsre(&self, output: &mut String) {
257        let terms = parse_terms(&self.csd, self.m).unwrap_or_default();
258        if terms.is_empty() {
259            writeln!(output, "\n    // CSD implementation").unwrap();
260            writeln!(output, "    assign result = 0;").unwrap();
261            return;
262        }
263
264        // Detect LCSRe optimization opportunity
265        let repeated = longest_repeated_substring(&self.csd);
266        let pat_positions = if repeated.len() > 1 {
267            let pat_nnz = repeated.chars().filter(|c| *c == '+' || *c == '-').count();
268            if pat_nnz >= 2 {
269                let pos = find_pattern_occurrences(&self.csd, &repeated);
270                if pos.len() >= 2 {
271                    Some((repeated, pos))
272                } else {
273                    None
274                }
275            } else {
276                None
277            }
278        } else {
279            None
280        };
281
282        if let Some((ref pat, ref positions)) = pat_positions {
283            // LCSRe-optimized path
284            let base_pos = positions[0];
285            let ow = output_width(self.n, self.m);
286
287            let pat_expr = build_range_expr(&self.csd, base_pos, pat.len(), self.m);
288            writeln!(output, "\n    // LCSRe: repeated pattern \"{}\"", pat).unwrap();
289            writeln!(
290                output,
291                "    wire signed [{}:0] _pat = {};",
292                ow - 1,
293                pat_expr
294            )
295            .unwrap();
296
297            let mut expr = String::new();
298            let mut cur = 0;
299            for &pos in positions {
300                // gap before this occurrence
301                if pos > cur {
302                    let gap = build_range_expr(&self.csd, cur, pos - cur, self.m);
303                    if !gap.is_empty() {
304                        if expr.is_empty() {
305                            expr = gap;
306                        } else {
307                            write!(expr, " + {}", gap).unwrap();
308                        }
309                    }
310                }
311                // pattern occurrence
312                let shift = pos as isize - base_pos as isize;
313                let pat_ref = if shift == 0 {
314                    "_pat".to_string()
315                } else {
316                    format!("(_pat >>> {})", shift)
317                };
318                if expr.is_empty() {
319                    expr = pat_ref;
320                } else {
321                    write!(expr, " + {}", pat_ref).unwrap();
322                }
323                cur = pos + pat.len();
324            }
325            // suffix
326            if cur < self.csd.len() {
327                let suffix = build_range_expr(&self.csd, cur, self.csd.len() - cur, self.m);
328                if !suffix.is_empty() {
329                    write!(expr, " + {}", suffix).unwrap();
330                }
331            }
332
333            writeln!(output, "\n    // CSD implementation (LCSRe optimized)").unwrap();
334            writeln!(output, "    assign result = {};", expr).unwrap();
335        } else {
336            // flat path (no repeated pattern)
337            writeln!(output, "\n    // CSD implementation with signed arithmetic").unwrap();
338            let (first_power, first_op) = terms[0];
339            let mut expr = format!(
340                "{}x_shift{}",
341                if first_op == TermOp::Sub { "-" } else { "" },
342                first_power
343            );
344            for (power, op) in &terms[1..] {
345                match op {
346                    TermOp::Add => write!(expr, " + x_shift{}", power).unwrap(),
347                    TermOp::Sub => write!(expr, " - x_shift{}", power).unwrap(),
348                }
349            }
350            writeln!(output, "    assign result = {};", expr).unwrap();
351        }
352    }
353}
354
355// ---------------------------------------------------------------------------
356// Free-function API (matching C++ style)
357// ---------------------------------------------------------------------------
358
359/// Find all non-overlapping occurrences of `pattern` in `csd_str`.
360fn find_pattern_occurrences(csd_str: &str, pattern: &str) -> Vec<usize> {
361    let mut positions = Vec::new();
362    let mut pos = 0;
363    while let Some(found) = csd_str[pos..].find(pattern) {
364        let absolute = pos + found;
365        positions.push(absolute);
366        pos = absolute + pattern.len();
367    }
368    positions
369}
370
371/// Count non-zero digits ('+' or '-') in a CSD substring.
372fn count_nnz(s: &str) -> usize {
373    s.chars().filter(|c| *c == '+' || *c == '-').count()
374}
375
376/// Build a coefficient expression using CSE wire + flat gap terms.
377fn build_coeff_expr(
378    csd: &str,
379    max_power: usize,
380    pattern: &str,
381    cse_base_pos: usize,
382    cse_name: &str,
383) -> String {
384    if pattern.is_empty() {
385        return build_range_expr(csd, 0, csd.len(), max_power);
386    }
387
388    let positions = find_pattern_occurrences(csd, pattern);
389    let mut parts: Vec<String> = Vec::new();
390    let mut cur = 0;
391
392    for pos in positions {
393        // gap before this occurrence
394        if pos > cur {
395            let gap = build_range_expr(csd, cur, pos - cur, max_power);
396            if !gap.is_empty() {
397                parts.push(gap);
398            }
399        }
400        // CSE reference
401        let shift = pos as isize - cse_base_pos as isize;
402        if shift == 0 {
403            parts.push(cse_name.to_string());
404        } else {
405            parts.push(format!("({} >>> {})", cse_name, shift));
406        }
407        cur = pos + pattern.len();
408    }
409    // suffix
410    if cur < csd.len() {
411        let gap = build_range_expr(csd, cur, csd.len() - cur, max_power);
412        if !gap.is_empty() {
413            parts.push(gap);
414        }
415    }
416
417    if parts.is_empty() {
418        return String::new();
419    }
420    let mut result = parts[0].clone();
421    for p in &parts[1..] {
422        write!(result, " + {}", p).unwrap();
423    }
424    result
425}
426
427/// Find substrings (NNZ >= 2) that appear in >= 2 different CSD strings.
428/// Returns a map: pattern -> [(coeff_index, position), ...].
429fn find_cross_patterns(csd_list: &[String]) -> HashMap<String, Vec<(usize, usize)>> {
430    let mut patterns: HashMap<String, Vec<(usize, usize)>> = HashMap::new();
431    for (ci, csd) in csd_list.iter().enumerate() {
432        let n = csd.len();
433        for i in 0..n {
434            for j in (i + 2)..=n {
435                let sub: String = csd[i..j].to_string();
436                if count_nnz(&sub) >= 2 {
437                    patterns.entry(sub).or_default().push((ci, i));
438                }
439            }
440        }
441    }
442    // Keep only patterns crossing >= 2 different CSD strings
443    patterns.retain(|_, occ: &mut Vec<(usize, usize)>| {
444        let unique: HashSet<usize> = occ.iter().map(|(ci, _)| *ci).collect();
445        unique.len() >= 2
446    });
447    patterns
448}
449
450/// Generate Verilog code for a single CSD multiplier module (no cross-CSE).
451///
452/// Converts a Canonical Signed Digit (CSD) string into a synthesizable
453/// Verilog module that performs constant multiplication using shifts and
454/// additions/subtractions. When the CSD string contains a repeated
455/// non-overlapping pattern, LCSRe optimization shares hardware via a
456/// `_pat` wire.
457///
458/// # Arguments
459///
460/// * `csd_str` - CSD string using '+', '-', '0' (e.g. "+00-00+0+")
461/// * `input_width` - Bit width of the input signal x
462/// * `max_power` - Highest power of two in the CSD (must be csd_str.len() - 1)
463///
464/// # Errors
465///
466/// Returns `CsdMultiplierError` if csd_str length doesn't match max_power+1
467/// or if the string contains characters other than '+', '-', '0'.
468///
469/// # Examples
470///
471/// ```
472/// use csd::csd_multiplier::generate_csd_multiplier;
473///
474/// let v = generate_csd_multiplier("+0-", 8, 2).unwrap();
475/// assert!(v.contains("module csd_multiplier"));
476/// assert!(v.contains("assign result = x_shift2 - x_shift0"));
477/// ```
478pub fn generate_csd_multiplier(
479    csd_str: &str,
480    input_width: usize,
481    max_power: usize,
482) -> Result<String, CsdMultiplierError> {
483    // --- validation ---
484    let len = csd_str.len();
485    if len != max_power + 1 {
486        return Err(CsdMultiplierError::LengthMismatch);
487    }
488    for c in csd_str.chars() {
489        if c != '+' && c != '-' && c != '0' {
490            return Err(CsdMultiplierError::InvalidCharacter);
491        }
492    }
493
494    let terms = parse_terms(csd_str, max_power)?;
495    let ow = output_width(input_width, max_power);
496
497    let mut verilog = String::new();
498
499    // --- module header ---
500    writeln!(verilog).unwrap();
501    writeln!(verilog, "module csd_multiplier (").unwrap();
502    writeln!(
503        verilog,
504        "    input signed [{}:0] x,      // Input value",
505        input_width - 1
506    )
507    .unwrap();
508    writeln!(
509        verilog,
510        "    output signed [{}:0] result // Result of multiplication",
511        ow - 1
512    )
513    .unwrap();
514    writeln!(verilog, ");").unwrap();
515
516    // --- wire declarations (deduplicated powers) ---
517    if !terms.is_empty() {
518        writeln!(verilog).unwrap();
519        writeln!(verilog, "    // Create shifted versions of input").unwrap();
520        let mut powers_needed: BTreeSet<usize> = BTreeSet::new();
521        // Reverse order: highest power first
522        for (p, _) in &terms {
523            powers_needed.insert(*p);
524        }
525        for p in powers_needed.into_iter().rev() {
526            writeln!(
527                verilog,
528                "    wire signed [{}:0] x_shift{} = x <<< {};",
529                ow - 1,
530                p,
531                p
532            )
533            .unwrap();
534        }
535    }
536
537    // --- detect LCSRe optimization opportunity ---
538    let repeated = longest_repeated_substring(csd_str);
539
540    let pat_positions: Vec<usize> = if repeated.len() > 1 {
541        let pat_nnz = count_nnz(&repeated);
542        if pat_nnz >= 2 {
543            let pos = find_pattern_occurrences(csd_str, &repeated);
544            if pos.len() >= 2 {
545                pos
546            } else {
547                Vec::new()
548            }
549        } else {
550            Vec::new()
551        }
552    } else {
553        Vec::new()
554    };
555
556    let use_opt = !pat_positions.is_empty();
557
558    // --- combinational logic ---
559    if terms.is_empty() {
560        writeln!(verilog).unwrap();
561        writeln!(verilog, "    // CSD implementation").unwrap();
562        writeln!(verilog, "    assign result = 0;").unwrap();
563    } else if use_opt {
564        // LCSRe-optimized path
565        let base_pos = pat_positions[0];
566        let pat_expr = build_range_expr(csd_str, base_pos, repeated.len(), max_power);
567        writeln!(verilog).unwrap();
568        writeln!(verilog, "    // LCSRe: repeated pattern \"{}\"", repeated).unwrap();
569        writeln!(
570            verilog,
571            "    wire signed [{}:0] _pat = {};",
572            ow - 1,
573            pat_expr
574        )
575        .unwrap();
576
577        let mut expr = String::new();
578        let mut cur = 0;
579        for &pos in &pat_positions {
580            // prefix/gap before this occurrence
581            if pos > cur {
582                let gap = build_range_expr(csd_str, cur, pos - cur, max_power);
583                if !gap.is_empty() {
584                    if expr.is_empty() {
585                        expr = gap;
586                    } else {
587                        write!(expr, " + {}", gap).unwrap();
588                    }
589                }
590            }
591            // pattern occurrence
592            let shift = pos as isize - base_pos as isize;
593            let pat_ref = if shift == 0 {
594                "_pat".to_string()
595            } else {
596                format!("(_pat >>> {})", shift)
597            };
598            if expr.is_empty() {
599                expr = pat_ref;
600            } else {
601                write!(expr, " + {}", pat_ref).unwrap();
602            }
603            cur = pos + repeated.len();
604        }
605        // suffix
606        if cur < csd_str.len() {
607            let suffix = build_range_expr(csd_str, cur, csd_str.len() - cur, max_power);
608            if !suffix.is_empty() {
609                write!(expr, " + {}", suffix).unwrap();
610            }
611        }
612
613        writeln!(verilog).unwrap();
614        writeln!(verilog, "    // CSD implementation (LCSRe optimized)").unwrap();
615        writeln!(verilog, "    assign result = {};", expr).unwrap();
616    } else {
617        // flat path (no repeated pattern)
618        writeln!(verilog).unwrap();
619        writeln!(verilog, "    // CSD implementation").unwrap();
620        let mut expr = String::new();
621        for (i, (power, op)) in terms.iter().enumerate() {
622            if i == 0 {
623                if *op == TermOp::Sub {
624                    write!(expr, "-").unwrap();
625                }
626                write!(expr, "x_shift{}", power).unwrap();
627            } else {
628                match op {
629                    TermOp::Add => write!(expr, " + x_shift{}", power).unwrap(),
630                    TermOp::Sub => write!(expr, " - x_shift{}", power).unwrap(),
631                }
632            }
633        }
634        writeln!(verilog, "    assign result = {};", expr).unwrap();
635    }
636
637    writeln!(verilog, "endmodule").unwrap();
638    Ok(verilog)
639}
640
641/// Generate Verilog for multiple CSD multipliers with cross-CSE.
642///
643/// When the same CSD substring appears in multiple coefficients, a shared
644/// sub-expression wire is created — reducing total adder count across the
645/// entire filter.
646///
647/// All coefficients **must** share the same `input_width` and `max_power`
648/// so that the same bit position encodes the same power of two.
649///
650/// # Arguments
651///
652/// * `coeffs` - List of coefficient specifications
653/// * `module_name` - Name for the generated Verilog module
654///
655/// # Errors
656///
657/// Returns `CsdMultiplierError::EmptyCoefficients` if the list is empty.
658/// Returns `CsdMultiplierError::WidthMismatch` if coefficient widths differ.
659///
660/// # Examples
661///
662/// ```
663/// use csd::csd_multiplier::{generate_csd_multipliers, MultiplierSpec};
664///
665/// let coeffs = vec![
666///     MultiplierSpec {
667///         name: "y0".to_string(),
668///         csd: "+00-00+0+".to_string(),
669///         input_width: 8,
670///         max_power: 8,
671///     },
672///     MultiplierSpec {
673///         name: "y1".to_string(),
674///         csd: "+00-00+0+".to_string(),
675///         input_width: 8,
676///         max_power: 8,
677///     },
678/// ];
679/// let v = generate_csd_multipliers(&coeffs, "csd_filter").unwrap();
680/// assert!(v.contains("module csd_filter"));
681/// ```
682pub fn generate_csd_multipliers(
683    coeffs: &[MultiplierSpec],
684    module_name: &str,
685) -> Result<String, CsdMultiplierError> {
686    if coeffs.is_empty() {
687        return Err(CsdMultiplierError::EmptyCoefficients);
688    }
689
690    // Validation and uniform-width enforcement
691    let input_width = coeffs[0].input_width;
692    let max_power = coeffs[0].max_power;
693
694    for spec in coeffs {
695        if spec.input_width != input_width || spec.max_power != max_power {
696            return Err(CsdMultiplierError::WidthMismatch);
697        }
698        let len = spec.csd.len();
699        if len != max_power + 1 {
700            return Err(CsdMultiplierError::LengthMismatch);
701        }
702        for c in spec.csd.chars() {
703            if c != '+' && c != '-' && c != '0' {
704                return Err(CsdMultiplierError::InvalidCharacter);
705            }
706        }
707    }
708
709    let ow = output_width(input_width, max_power);
710
711    // Collect all x_shift powers
712    let mut all_powers: BTreeSet<usize> = BTreeSet::new();
713    for spec in coeffs {
714        for (i, c) in spec.csd.char_indices() {
715            if c != '0' {
716                all_powers.insert(max_power - i);
717            }
718        }
719    }
720
721    // Find best cross-CSD pattern
722    let csd_strings: Vec<String> = coeffs.iter().map(|s| s.csd.clone()).collect();
723    let cross = find_cross_patterns(&csd_strings);
724
725    let mut best_pattern = String::new();
726    let mut best_occurrences: Vec<(usize, usize)> = Vec::new();
727    let mut best_score = 0;
728
729    for (pat, occ) in &cross {
730        let nnz = count_nnz(pat);
731        let score = (nnz.saturating_sub(1)) * (occ.len().saturating_sub(1));
732        if score > best_score {
733            best_score = score;
734            best_pattern.clone_from(pat);
735            best_occurrences.clone_from(occ);
736        }
737    }
738
739    // Base position for the CSE wire
740    let cse_base_pos = if best_pattern.is_empty() {
741        0
742    } else {
743        best_occurrences
744            .iter()
745            .map(|(_, pos)| *pos)
746            .min()
747            .unwrap_or(0)
748    };
749
750    // Build the Verilog module
751    let mut verilog = String::new();
752    writeln!(verilog).unwrap();
753    writeln!(verilog, "module {} (", module_name).unwrap();
754    writeln!(
755        verilog,
756        "    input signed [{}:0] x,      // Input value",
757        input_width - 1
758    )
759    .unwrap();
760    for spec in coeffs {
761        let ow_spec = output_width(spec.input_width, spec.max_power);
762        writeln!(
763            verilog,
764            "    output signed [{}:0] {}",
765            ow_spec - 1,
766            spec.name
767        )
768        .unwrap();
769    }
770    writeln!(verilog, ");").unwrap();
771
772    // x_shift wires
773    if !all_powers.is_empty() {
774        writeln!(verilog).unwrap();
775        writeln!(verilog, "    // Create shifted versions of input").unwrap();
776        for p in all_powers.iter().rev() {
777            writeln!(
778                verilog,
779                "    wire signed [{}:0] x_shift{} = x <<< {};",
780                ow - 1,
781                p,
782                p
783            )
784            .unwrap();
785        }
786    }
787
788    // Shared CSE wire
789    let cse_name = "_cse_0";
790    if !best_pattern.is_empty() {
791        let cse_expr = build_range_expr(
792            &best_pattern,
793            0,
794            best_pattern.len(),
795            max_power.saturating_sub(cse_base_pos),
796        );
797        writeln!(verilog).unwrap();
798        writeln!(
799            verilog,
800            "    // Cross-CSE: shared pattern \"{}\"",
801            best_pattern
802        )
803        .unwrap();
804        writeln!(
805            verilog,
806            "    wire signed [{}:0] {} = {};",
807            ow - 1,
808            cse_name,
809            cse_expr
810        )
811        .unwrap();
812    }
813
814    // Set of coeff indices that have the pattern
815    let cse_coeffs: HashSet<usize> = best_occurrences.iter().map(|(ci, _)| *ci).collect();
816
817    // Per-coefficient assignments
818    for (idx, spec) in coeffs.iter().enumerate() {
819        writeln!(verilog).unwrap();
820        writeln!(verilog, "    // {}: {}", spec.name, spec.csd).unwrap();
821
822        let has_cse = !best_pattern.is_empty() && cse_coeffs.contains(&idx);
823        let expr = if has_cse {
824            build_coeff_expr(&spec.csd, max_power, &best_pattern, cse_base_pos, cse_name)
825        } else {
826            build_coeff_expr(&spec.csd, max_power, "", 0, "")
827        };
828
829        if expr.is_empty() {
830            writeln!(verilog, "    assign {} = 0;", spec.name).unwrap();
831        } else {
832            writeln!(verilog, "    assign {} = {};", spec.name, expr).unwrap();
833        }
834    }
835
836    writeln!(verilog, "endmodule").unwrap();
837    Ok(verilog)
838}
839
840// ---------------------------------------------------------------------------
841// Tests
842// ---------------------------------------------------------------------------
843
844#[cfg(test)]
845mod tests {
846    use super::*;
847
848    // ---- Existing struct-based tests ----
849
850    #[test]
851    fn test_valid_csd() {
852        let csd = "+00-00+0+";
853        let multiplier = CsdMultiplier::new(csd, 8, 8).unwrap();
854        assert_eq!(multiplier.decimal_value(), 229);
855    }
856
857    #[test]
858    fn test_decimal_value() {
859        let multiplier = CsdMultiplier::new("+", 8, 0).unwrap();
860        assert_eq!(multiplier.decimal_value(), 1);
861
862        let multiplier = CsdMultiplier::new("-", 8, 0).unwrap();
863        assert_eq!(multiplier.decimal_value(), -1);
864
865        let multiplier = CsdMultiplier::new("+0-", 8, 2).unwrap();
866        assert_eq!(multiplier.decimal_value(), 3);
867
868        let multiplier = CsdMultiplier::new("-0+", 8, 2).unwrap();
869        assert_eq!(multiplier.decimal_value(), -3);
870    }
871
872    #[test]
873    fn test_all_zeros_csd() {
874        let csd = "0000";
875        let multiplier = CsdMultiplier::new(csd, 8, 3).unwrap();
876        let verilog = multiplier.generate_verilog();
877        assert!(verilog.contains("assign result = 0;"));
878    }
879
880    #[test]
881    fn test_invalid_csd_chars() {
882        let csd = "+01-00+0+";
883        let result = CsdMultiplier::new(csd, 8, 6);
884        assert!(matches!(result, Err(CsdMultiplierError::InvalidCharacter)));
885    }
886
887    #[test]
888    fn test_length_mismatch() {
889        let csd = "+00-00+0+";
890        let result = CsdMultiplier::new(csd, 8, 5);
891        assert!(matches!(result, Err(CsdMultiplierError::LengthMismatch)));
892    }
893
894    #[test]
895    fn test_verilog_generation() {
896        let csd = "+0-";
897        let n = 8;
898        let m = 2;
899        let multiplier = CsdMultiplier::new(csd, n, m).unwrap();
900        let expected_verilog = r###"// CSD Multiplier for pattern: +0- (value: 3)
901module csd_multiplier (
902    input signed [7:0] x,      // Input value (signed)
903    output signed [9:0] result // Result (signed)
904);
905
906    // Signed shifted versions (Verilog handles sign extension)
907    wire signed [9:0] x_shift2 = $signed({ {0{x[7]}}, x}) << 2;
908    wire signed [9:0] x_shift0 = $signed({ {2{x[7]}}, x}) << 0;
909
910    // CSD implementation with signed arithmetic
911    assign result = x_shift2 - x_shift0;
912endmodule
913"###;
914        assert_eq!(multiplier.generate_verilog(), expected_verilog);
915    }
916
917    // ---- Free-function tests (matching C++ test_csd_multiplier.cpp) ----
918
919    // Basic structural tests
920    #[test]
921    fn test_fn_basic_valid() {
922        let v = generate_csd_multiplier("+0-", 8, 2).unwrap();
923        assert!(v.contains("module csd_multiplier"));
924        assert!(v.contains("endmodule"));
925        assert!(v.contains("input signed [7:0] x"));
926        assert!(v.contains("output signed [9:0] result"));
927        assert!(v.contains("assign result = x_shift2 - x_shift0"));
928    }
929
930    #[test]
931    fn test_fn_positive_only() {
932        let v = generate_csd_multiplier("+0+", 4, 2).unwrap();
933        assert!(v.contains("assign result = x_shift2 + x_shift0"));
934    }
935
936    #[test]
937    fn test_fn_negative_only() {
938        let v = generate_csd_multiplier("-0-", 8, 2).unwrap();
939        assert!(v.contains("assign result = -x_shift2 - x_shift0"));
940    }
941
942    #[test]
943    fn test_fn_all_zeros() {
944        let v = generate_csd_multiplier("000", 8, 2).unwrap();
945        assert!(v.contains("assign result = 0;"));
946        assert!(!v.contains("x_shift"));
947    }
948
949    #[test]
950    fn test_fn_single_nonzero() {
951        let v = generate_csd_multiplier("+00", 8, 2).unwrap();
952        assert!(v.contains("assign result"));
953        assert!(v.contains("x_shift2"));
954    }
955
956    #[test]
957    fn test_fn_invalid_chars() {
958        let r = generate_csd_multiplier("123", 8, 2);
959        assert_eq!(r, Err(CsdMultiplierError::InvalidCharacter));
960    }
961
962    #[test]
963    fn test_fn_invalid_length() {
964        let r = generate_csd_multiplier("+0-", 8, 3);
965        assert_eq!(r, Err(CsdMultiplierError::LengthMismatch));
966    }
967
968    // LCSRe optimization tests
969    #[test]
970    fn test_fn_flat_when_pattern_nnz_is_1() {
971        // "+00-00+0" has no repeated pattern with ≥2 nnz
972        let v = generate_csd_multiplier("+00-00+0", 8, 7).unwrap();
973        assert!(!v.contains("_pat"));
974        assert!(v.contains("x_shift7 - x_shift4 + x_shift1"));
975    }
976
977    #[test]
978    fn test_fn_double_repeat_optimization() {
979        // +0-0+0-0: repeated "+0-0" (2 nnz) at positions 0 and 4
980        let v = generate_csd_multiplier("+0-0+0-0", 8, 7).unwrap();
981        assert!(v.contains("_pat"));
982        assert!(v.contains("_pat = x_shift7 - x_shift5"));
983        assert!(v.contains("(_pat >>> 4)"));
984        assert!(v.contains("LCSRe"));
985    }
986
987    #[test]
988    fn test_fn_triple_repeat_optimization() {
989        // +0-0+0-0+0-0: repeated "+0-0" at positions 0, 4, 8
990        let v = generate_csd_multiplier("+0-0+0-0+0-0", 8, 11).unwrap();
991        assert!(v.contains("_pat"));
992        assert!(v.contains("(_pat >>> 4)"));
993        assert!(v.contains("(_pat >>> 8)"));
994    }
995
996    #[test]
997    fn test_fn_longer_pattern_repeat() {
998        // +00-00+00-00: repeated "+00-00" (2 nnz, 5 chars) at positions 0 and 6
999        let v = generate_csd_multiplier("+00-00+00-00", 8, 11).unwrap();
1000        assert!(v.contains("_pat"));
1001        assert!(v.contains("_pat = x_shift11 - x_shift8"));
1002        assert!(v.contains("(_pat >>> 6)"));
1003    }
1004
1005    #[test]
1006    fn test_fn_leading_minus_no_optimization() {
1007        // CSD starting with '-' and no repeated pattern
1008        let v = generate_csd_multiplier("-0-", 8, 2).unwrap();
1009        assert!(!v.contains("_pat"));
1010        assert!(v.contains("-x_shift2 - x_shift0"));
1011    }
1012
1013    #[test]
1014    fn test_fn_pattern_with_leading_minus() {
1015        // Repeated pattern starting with '-': -0+0-0+0
1016        let v = generate_csd_multiplier("-0+0-0+0", 8, 7).unwrap();
1017        assert!(v.contains("_pat"));
1018        assert!(v.contains("_pat = -x_shift7 + x_shift5"));
1019        assert!(v.contains("(_pat >>> 4)"));
1020    }
1021
1022    #[test]
1023    fn test_fn_no_optimization_for_single_occurrence() {
1024        // CSD with unique pattern throughout — no repeat = flat
1025        let v = generate_csd_multiplier("+0-+00-0", 8, 7).unwrap();
1026        assert!(!v.contains("_pat"));
1027    }
1028
1029    #[test]
1030    fn test_fn_pat_wire_width_matches_output() {
1031        // output_width = 8 + 7 = 15, so wire signed [14:0]
1032        let v = generate_csd_multiplier("+0-0+0-0", 8, 7).unwrap();
1033        assert!(v.contains("[14:0] _pat"));
1034    }
1035
1036    #[test]
1037    fn test_fn_repeat_with_trailing_gap() {
1038        // Repeated pattern followed by non-repeating suffix
1039        let v = generate_csd_multiplier("+0-0+0-0+0", 8, 9).unwrap();
1040        assert!(v.contains("_pat"));
1041        assert!(v.contains("(_pat >>> 4)"));
1042    }
1043
1044    // Edge cases
1045    #[test]
1046    fn test_fn_very_short_csd() {
1047        // Length-1 CSD
1048        let v = generate_csd_multiplier("+", 8, 0).unwrap();
1049        assert!(v.contains("assign result = x_shift0"));
1050    }
1051
1052    #[test]
1053    fn test_fn_all_minus_signs() {
1054        let v = generate_csd_multiplier("---", 8, 2).unwrap();
1055        assert!(!v.contains("_pat"));
1056    }
1057
1058    #[test]
1059    fn test_fn_always_has_proper_module_boundaries() {
1060        let v = generate_csd_multiplier("+0-0+0-0", 8, 7).unwrap();
1061        assert!(v.contains("\nmodule csd_multiplier"));
1062        assert!(v.contains("endmodule\n"));
1063    }
1064
1065    #[test]
1066    fn test_fn_lcsre_comment_present_when_optimized() {
1067        let v = generate_csd_multiplier("+0-0+0-0", 8, 7).unwrap();
1068        assert!(v.contains("LCSRe"));
1069    }
1070
1071    #[test]
1072    fn test_fn_no_lcsre_comment_when_flat() {
1073        let v = generate_csd_multiplier("+00-00+0", 8, 7).unwrap();
1074        assert!(!v.contains("LCSRe"));
1075    }
1076
1077    // ---- Multi-coefficient tests ----
1078
1079    #[test]
1080    fn test_multi_empty_coeffs() {
1081        let r = generate_csd_multipliers(&[], "test");
1082        assert_eq!(r, Err(CsdMultiplierError::EmptyCoefficients));
1083    }
1084
1085    #[test]
1086    fn test_multi_single_coeff() {
1087        let coeffs = vec![MultiplierSpec {
1088            name: "y0".to_string(),
1089            csd: "+0-".to_string(),
1090            input_width: 8,
1091            max_power: 2,
1092        }];
1093        let v = generate_csd_multipliers(&coeffs, "test_mod").unwrap();
1094        assert!(v.contains("module test_mod"));
1095        assert!(v.contains("output signed [9:0] y0"));
1096    }
1097
1098    #[test]
1099    fn test_multi_duplicate_coeffs() {
1100        let coeffs = vec![
1101            MultiplierSpec {
1102                name: "y0".to_string(),
1103                csd: "+00-00+0+".to_string(),
1104                input_width: 8,
1105                max_power: 8,
1106            },
1107            MultiplierSpec {
1108                name: "y1".to_string(),
1109                csd: "+00-00+0+".to_string(),
1110                input_width: 8,
1111                max_power: 8,
1112            },
1113        ];
1114        let v = generate_csd_multipliers(&coeffs, "csd_filter").unwrap();
1115        assert!(v.contains("Cross-CSE"));
1116        assert!(v.contains("_cse_0"));
1117    }
1118
1119    #[test]
1120    fn test_multi_width_mismatch() {
1121        let coeffs = vec![
1122            MultiplierSpec {
1123                name: "y0".to_string(),
1124                csd: "+0-".to_string(),
1125                input_width: 8,
1126                max_power: 2,
1127            },
1128            MultiplierSpec {
1129                name: "y1".to_string(),
1130                csd: "+0-".to_string(),
1131                input_width: 16,
1132                max_power: 2,
1133            },
1134        ];
1135        let r = generate_csd_multipliers(&coeffs, "test");
1136        assert_eq!(r, Err(CsdMultiplierError::WidthMismatch));
1137    }
1138
1139    #[test]
1140    fn test_multi_invalid_chars() {
1141        let coeffs = vec![MultiplierSpec {
1142            name: "y0".to_string(),
1143            csd: "123".to_string(),
1144            input_width: 8,
1145            max_power: 2,
1146        }];
1147        let r = generate_csd_multipliers(&coeffs, "test");
1148        assert_eq!(r, Err(CsdMultiplierError::InvalidCharacter));
1149    }
1150}