calyx_frontend/
common.rs

1use super::Attributes;
2use crate::Attribute;
3use calyx_utils::{CalyxResult, Error, GetName, Id};
4use linked_hash_map::LinkedHashMap;
5use smallvec::SmallVec;
6
7/// Representation of a external primitive definition.
8///
9/// # Example
10/// ```
11/// primitive std_reg<"static"=1>[WIDTH](
12///   in: WIDTH, write_en: 1, clk: 1
13/// ) -> (
14///   out: WIDTH, done: 1
15/// );
16/// ```
17///
18/// The signature of a port is represented using [PortDef] which also specify
19/// the direction of the port.
20#[derive(Clone, Debug)]
21pub struct Primitive {
22    /// Name of this primitive.
23    pub name: Id,
24    /// Paramters for this primitive.
25    pub params: Vec<Id>,
26    /// The input/output signature for this primitive.
27    pub signature: Vec<PortDef<Width>>,
28    /// Key-value attributes for this primitive.
29    pub attributes: Attributes,
30    /// True iff this is a combinational primitive
31    pub is_comb: bool,
32    /// (Optional) latency; for static primitives
33    pub latency: Option<std::num::NonZeroU64>,
34    /// body of the string, if it is an inlined primitive
35    pub body: Option<String>,
36}
37
38impl Primitive {
39    /// Retuns the bindings for all the paramters, the input ports and the
40    /// output ports.
41    #[allow(clippy::type_complexity)]
42    pub fn resolve(
43        &self,
44        parameters: &[u64],
45    ) -> CalyxResult<(SmallVec<[(Id, u64); 5]>, Vec<PortDef<u64>>)> {
46        if self.params.len() != parameters.len() {
47            let msg = format!(
48               "primitive `{}` requires {} parameters but instantiation provides {} parameters",
49               self.name.clone(),
50               self.params.len(),
51               parameters.len(),
52            );
53            return Err(Error::malformed_structure(msg));
54        }
55        let bindings = self
56            .params
57            .iter()
58            .cloned()
59            .zip(parameters.iter().cloned())
60            .collect::<LinkedHashMap<Id, u64>>();
61
62        let ports = self
63            .signature
64            .iter()
65            .cloned()
66            .map(|pd| pd.resolve(&bindings))
67            .collect::<Result<_, _>>()?;
68
69        Ok((bindings.into_iter().collect(), ports))
70    }
71
72    /// Return all ports that have the attribute `attr`.
73    pub fn find_all_with_attr<A>(
74        &self,
75        attr: A,
76    ) -> impl Iterator<Item = &PortDef<Width>>
77    where
78        A: Into<Attribute> + Copy,
79    {
80        self.signature
81            .iter()
82            .filter(move |&g| g.attributes.has(attr))
83    }
84}
85
86impl GetName for Primitive {
87    fn name(&self) -> Id {
88        self.name
89    }
90}
91
92/// Definition of a port parameterized by a width type.
93/// Ports on Primitives can be parameteris and use [Width].
94/// Ports on Components cannot be parameterized and therefore use `u64`.
95#[derive(Clone, Debug)]
96pub struct PortDef<W> {
97    /// The name of the port.
98    name: Id,
99    /// The width of the port. .
100    pub width: W,
101    /// The direction of the port. Only allowed to be [Direction::Input]
102    /// or [Direction::Output].
103    pub direction: Direction,
104    /// Attributes attached to this port definition
105    pub attributes: Attributes,
106}
107
108impl<W> PortDef<W> {
109    pub fn new(
110        name: impl Into<Id>,
111        width: W,
112        direction: Direction,
113        attributes: Attributes,
114    ) -> Self {
115        assert!(
116            matches!(direction, Direction::Input | Direction::Output),
117            "Direction must be either Input or Output"
118        );
119
120        Self {
121            name: name.into(),
122            width,
123            direction,
124            attributes,
125        }
126    }
127
128    /// Return the name of the port definition
129    pub fn name(&self) -> Id {
130        self.name
131    }
132}
133
134/// Represents an abstract width of a primitive signature.
135#[derive(Clone, Debug, PartialEq)]
136pub enum Width {
137    /// The width is a constant.
138    Const { value: u64 },
139    /// The width is a parameter.
140    Param { value: Id },
141}
142
143impl std::fmt::Display for Width {
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        match self {
146            Width::Const { value } => write!(f, "{}", value),
147            Width::Param { value } => write!(f, "{}", value),
148        }
149    }
150}
151
152impl From<u64> for Width {
153    fn from(value: u64) -> Self {
154        Width::Const { value }
155    }
156}
157
158impl From<Id> for Width {
159    fn from(value: Id) -> Self {
160        Width::Param { value }
161    }
162}
163
164impl PortDef<Width> {
165    /// Given a map from names of parameters to their values, attempt to
166    /// resolve this definition.
167    /// Errors if there is no binding for a required parameter binding.
168    pub fn resolve(
169        self,
170        binding: &LinkedHashMap<Id, u64>,
171    ) -> CalyxResult<PortDef<u64>> {
172        match &self.width {
173            Width::Const { value } => Ok(PortDef {
174                name: self.name,
175                width: *value,
176                attributes: self.attributes,
177                direction: self.direction,
178            }),
179            Width::Param { value } => match binding.get(value) {
180                Some(width) => Ok(PortDef {
181                    name: self.name,
182                    width: *width,
183                    attributes: self.attributes,
184                    direction: self.direction,
185                }),
186                None => {
187                    let param_name = &self.name;
188                    let msg = format!("Failed to resolve: {param_name}");
189                    Err(Error::malformed_structure(msg))
190                }
191            },
192        }
193    }
194}
195
196/// Direction of a port on a cell.
197#[derive(Debug, Clone, PartialEq, Eq, Hash)]
198#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
199pub enum Direction {
200    /// Input port.
201    Input,
202    /// Output port.
203    Output,
204    /// Input-Output "port". Should only be used by holes.
205    Inout,
206}
207
208impl Direction {
209    /// Return the direction opposite to the current direction
210    pub fn reverse(&self) -> Self {
211        match self {
212            Direction::Input => Direction::Output,
213            Direction::Output => Direction::Input,
214            Direction::Inout => Direction::Inout,
215        }
216    }
217}