1use super::Attributes;
2use crate::Attribute;
3use calyx_utils::{CalyxResult, Error, GetName, Id};
4use linked_hash_map::LinkedHashMap;
5use smallvec::SmallVec;
6
7#[derive(Clone, Debug)]
21pub struct Primitive {
22 pub name: Id,
24 pub params: Vec<Id>,
26 pub signature: Vec<PortDef<Width>>,
28 pub attributes: Attributes,
30 pub is_comb: bool,
32 pub latency: Option<std::num::NonZeroU64>,
34 pub body: Option<String>,
36}
37
38impl Primitive {
39 #[allow(clippy::type_complexity)]
42 pub fn resolve(
43 &self,
44 parameters: &[u64],
45 ) -> CalyxResult<(SmallVec<[(Id, u64); 5]>, Vec<PortDef<u64>>)> {
46 if self.params.len() != parameters.len() {
47 let msg = format!(
48 "primitive `{}` requires {} parameters but instantiation provides {} parameters",
49 self.name.clone(),
50 self.params.len(),
51 parameters.len(),
52 );
53 return Err(Error::malformed_structure(msg));
54 }
55 let bindings = self
56 .params
57 .iter()
58 .cloned()
59 .zip(parameters.iter().cloned())
60 .collect::<LinkedHashMap<Id, u64>>();
61
62 let ports = self
63 .signature
64 .iter()
65 .cloned()
66 .map(|pd| pd.resolve(&bindings))
67 .collect::<Result<_, _>>()?;
68
69 Ok((bindings.into_iter().collect(), ports))
70 }
71
72 pub fn find_all_with_attr<A>(
74 &self,
75 attr: A,
76 ) -> impl Iterator<Item = &PortDef<Width>>
77 where
78 A: Into<Attribute> + Copy,
79 {
80 self.signature
81 .iter()
82 .filter(move |&g| g.attributes.has(attr))
83 }
84}
85
86impl GetName for Primitive {
87 fn name(&self) -> Id {
88 self.name
89 }
90}
91
92#[derive(Clone, Debug)]
96pub struct PortDef<W> {
97 name: Id,
99 pub width: W,
101 pub direction: Direction,
104 pub attributes: Attributes,
106}
107
108impl<W> PortDef<W> {
109 pub fn new(
110 name: impl Into<Id>,
111 width: W,
112 direction: Direction,
113 attributes: Attributes,
114 ) -> Self {
115 assert!(
116 matches!(direction, Direction::Input | Direction::Output),
117 "Direction must be either Input or Output"
118 );
119
120 Self {
121 name: name.into(),
122 width,
123 direction,
124 attributes,
125 }
126 }
127
128 pub fn name(&self) -> Id {
130 self.name
131 }
132}
133
134#[derive(Clone, Debug, PartialEq)]
136pub enum Width {
137 Const { value: u64 },
139 Param { value: Id },
141}
142
143impl std::fmt::Display for Width {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 match self {
146 Width::Const { value } => write!(f, "{}", value),
147 Width::Param { value } => write!(f, "{}", value),
148 }
149 }
150}
151
152impl From<u64> for Width {
153 fn from(value: u64) -> Self {
154 Width::Const { value }
155 }
156}
157
158impl From<Id> for Width {
159 fn from(value: Id) -> Self {
160 Width::Param { value }
161 }
162}
163
164impl PortDef<Width> {
165 pub fn resolve(
169 self,
170 binding: &LinkedHashMap<Id, u64>,
171 ) -> CalyxResult<PortDef<u64>> {
172 match &self.width {
173 Width::Const { value } => Ok(PortDef {
174 name: self.name,
175 width: *value,
176 attributes: self.attributes,
177 direction: self.direction,
178 }),
179 Width::Param { value } => match binding.get(value) {
180 Some(width) => Ok(PortDef {
181 name: self.name,
182 width: *width,
183 attributes: self.attributes,
184 direction: self.direction,
185 }),
186 None => {
187 let param_name = &self.name;
188 let msg = format!("Failed to resolve: {param_name}");
189 Err(Error::malformed_structure(msg))
190 }
191 },
192 }
193 }
194}
195
196#[derive(Debug, Clone, PartialEq, Eq, Hash)]
198#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
199pub enum Direction {
200 Input,
202 Output,
204 Inout,
206}
207
208impl Direction {
209 pub fn reverse(&self) -> Self {
211 match self {
212 Direction::Input => Direction::Output,
213 Direction::Output => Direction::Input,
214 Direction::Inout => Direction::Inout,
215 }
216 }
217}