1use std::{
2 fmt::Debug,
3 hash::Hash,
4 ops::{Deref, DerefMut},
5 str::FromStr,
6};
7
8use crate::{new_key_type, IdFor};
9use slotmap::SlotMap;
10
11new_key_type! {
12 pub struct ValueId; => Value
13}
14
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
17pub enum Type {
18 Integer,
20 UInt(u32),
22 SInt(u32),
24 Ref(u32),
26 Vector(Box<Type>, u32),
28 Bundle(Vec<(String, Type, bool)>),
30 Enum(Vec<(String, Type)>),
32}
33
34impl Type {
35 pub fn unit() -> Type {
36 Type::UInt(0)
37 }
38 pub fn uint(width: u32) -> Type {
39 Type::UInt(width)
40 }
41 pub fn sint(width: u32) -> Type {
42 Type::SInt(width)
43 }
44 pub fn new_ref(width: u32) -> Type {
45 Type::Ref(width)
46 }
47 pub fn vector(base: Type, depth: u32) -> Type {
48 Type::Vector(Box::new(base), depth)
49 }
50 pub fn bundle(fields: Vec<(String, Type, bool)>) -> Type {
51 let mut fields = fields;
53 fields.sort_by(|a, b| a.0.cmp(&b.0));
54 Type::Bundle(fields)
55 }
56 pub fn union(variants: Vec<(String, Option<Type>)>) -> Type {
57 Type::Enum(
58 variants
59 .into_iter()
60 .map(|(name, ty)| (name, ty.unwrap_or(Type::unit())))
61 .collect(),
62 )
63 }
64 pub fn int_width(&self) -> u32 {
65 match self {
66 Type::UInt(width) => *width,
67 Type::SInt(width) => *width,
68 _ => panic!("Type {self:?} is not an integer"),
69 }
70 }
71 pub fn int_to_ref(&self) -> Self {
72 match self {
73 Type::UInt(width) => Type::Ref(*width),
74 _ => panic!("Type {self:?} is not an integer"),
75 }
76 }
77 pub fn ref_to_int(&self) -> Self {
78 match self {
79 Type::Ref(width) => Type::UInt(*width),
80 _ => panic!("Type {self:?} is not a reference"),
81 }
82 }
83 pub fn ref_width(&self) -> u32 {
84 match self {
85 Type::Ref(width) => *width,
86 _ => panic!("Type is not a reference"),
87 }
88 }
89 pub fn vector_elem_width(&self) -> u32 {
90 match self {
91 Type::Vector(base, _) => base.int_width(),
92 _ => panic!("Type is not an array"),
93 }
94 }
95 pub fn vector_depth(&self) -> u32 {
96 match self {
97 Type::Vector(_, depth) => *depth,
98 _ => panic!("Type is not an array"),
99 }
100 }
101}
102
103impl ToString for Type {
104 fn to_string(&self) -> String {
105 match self {
106 Type::Integer => "integer".to_string(),
107 Type::UInt(width) => format!("i{}", width),
108 Type::SInt(width) => format!("s{}", width),
109 Type::Ref(width) => format!("r{}", width),
110 Type::Vector(base, depth) => format!("{}x{}", base.to_string(), depth),
111 Type::Bundle(fields) => format!(
112 "{{{}}}",
113 fields
114 .iter()
115 .map(|(name, ty, flip)| {
116 format!(
117 "{}{}: {}",
118 if *flip { "flip " } else { "" },
119 name,
120 ty.to_string()
121 )
122 })
123 .collect::<Vec<String>>()
124 .join(", ")
125 ),
126 Type::Enum(variants) => format!(
127 "{{|{}|}}",
128 variants
129 .iter()
130 .map(|(name, ty)| format!("{name}: {}", ty.to_string()))
131 .collect::<Vec<String>>()
132 .join(", ")
133 ),
134 }
135 }
136}
137impl FromStr for Type {
138 type Err = String;
139 fn from_str(s: &str) -> Result<Self, Self::Err> {
140 if s.starts_with('{') && s.ends_with('}') {
145 let fields = s[1..s.len() - 1]
146 .split(',')
147 .map(|f| f.split_once(':').unwrap())
148 .map(|(name, ty)| {
149 let (name, flip) = if let Some((_, name)) = name.split_once(" ") {
150 if name == "flip" {
151 (name.to_string(), true)
152 } else {
153 return Err(format!("Invalid field name: {}", name));
154 }
155 } else {
156 (name.to_string(), false)
157 };
158 Ok((name, Type::from_str(ty)?, flip))
159 })
160 .collect::<Result<Vec<(String, Type, bool)>, String>>()?;
161 Ok(Type::bundle(fields))
162 } else if s.starts_with("{|") && s.ends_with("|}") {
163 let variants = s[2..s.len() - 2]
164 .split(',')
165 .map(|v| {
166 if let Some((name, ty)) = v.split_once(':') {
167 if let Ok(ty) = Type::from_str(ty) {
168 Ok((name.to_string(), Some(ty)))
169 } else {
170 return Err(format!("Invalid type: {}", ty));
171 }
172 } else {
173 Ok((v.to_string(), None))
174 }
175 })
176 .collect::<Result<Vec<(String, Option<Type>)>, String>>()?;
177 Ok(Type::union(variants))
178 } else {
179 let mut parts = s.split('x');
180 let base_part = parts
181 .next()
182 .ok_or_else(|| "Empty type string".to_string())?;
183 if let Some(depth) = parts.next() {
184 let base = Type::from_str(base_part)?;
185 let depth =
186 depth.parse().map_err(|e| format!("Invalid depth: {}", e))?;
187 Ok(Type::Vector(Box::new(base), depth))
188 } else {
189 if base_part == "integer" {
190 Ok(Type::Integer)
191 } else if base_part.starts_with('i') {
192 Ok(Type::UInt(base_part[1..].parse().unwrap()))
193 } else if base_part.starts_with('r') {
194 Ok(Type::Ref(base_part[1..].parse().unwrap()))
195 } else if base_part.starts_with('s') {
196 Ok(Type::SInt(base_part[1..].parse().unwrap()))
197 } else {
198 Err(format!("Invalid type string: {}", s))
199 }
200 }
201 }
202 }
203}
204
205#[derive(Debug, Clone)]
206pub struct Value {
207 pub ty: Option<Type>,
208 pub name: Option<String>,
209}
210
211impl Value {
212 pub fn new(ty: Type, name: Option<String>) -> Self {
213 Value { ty: Some(ty), name }
214 }
215
216 pub fn new_wo_ty(name: Option<String>) -> Self {
217 Value { ty: None, name }
218 }
219}
220
221impl ValueId {
222 pub fn ty(&self, t: &SlotMap<ValueId, Value>) -> Option<Type> {
223 t[*self].ty.clone()
224 }
225 pub fn name<'r>(&self, t: &'r SlotMap<ValueId, Value>) -> &'r Option<String> {
226 &t[*self].name
227 }
228 pub fn name_mut<'r>(
229 &self,
230 t: &'r mut SlotMap<ValueId, Value>,
231 ) -> &'r mut Option<String> {
232 &mut t[*self].name
233 }
234}
235
236pub type ValueMap = SlotMap<ValueId, Value>;
237
238pub trait OpIO {
239 fn num_inputs(&self) -> usize;
240 fn input(&self, i: usize) -> ValueId;
241 fn input_mut(&mut self, i: usize) -> &mut ValueId;
242 fn inputs(&self) -> impl Iterator<Item = ValueId> + '_ {
243 (0..self.num_inputs()).map(move |i| self.input(i))
244 }
245 fn map_inputs(&mut self, mut f: impl FnMut(ValueId) -> ValueId) {
246 for i in 0..self.num_inputs() {
247 *self.input_mut(i) = f(self.input(i));
248 }
249 }
250 fn num_outputs(&self) -> usize;
251 fn output(&self, i: usize) -> ValueId;
252 fn output_mut(&mut self, i: usize) -> &mut ValueId;
253 fn outputs(&self) -> impl Iterator<Item = ValueId> + '_ {
254 (0..self.num_outputs()).map(move |i| self.output(i))
255 }
256 fn map_outputs(&mut self, mut f: impl FnMut(ValueId) -> ValueId) {
257 for i in 0..self.num_outputs() {
258 *self.output_mut(i) = f(self.output(i));
259 }
260 }
261 fn values(&self) -> impl Iterator<Item = ValueId> + '_ {
262 self.inputs().chain(self.outputs())
263 }
264 fn map_values(&mut self, mut f: impl FnMut(ValueId) -> ValueId) {
265 self.map_inputs(&mut f);
266 self.map_outputs(&mut f);
267 }
268 fn attr_eq(&self, _rhs: &Self) -> bool {
269 true
270 }
271 fn attr_hash<H: std::hash::Hasher>(&self, _state: &mut H) {}
272}
273
274impl OpIO for ValueId {
275 fn num_inputs(&self) -> usize {
276 0
277 }
278 fn input(&self, _i: usize) -> ValueId {
279 panic!("ValueId has no inputs");
280 }
281 fn input_mut(&mut self, _i: usize) -> &mut ValueId {
282 panic!("ValueId has no inputs");
283 }
284 fn num_outputs(&self) -> usize {
285 1
286 }
287 fn output(&self, i: usize) -> ValueId {
288 assert_eq!(i, 0);
289 *self
290 }
291 fn output_mut(&mut self, i: usize) -> &mut ValueId {
292 assert_eq!(i, 0);
293 self
294 }
295 fn attr_eq(&self, rhs: &Self) -> bool {
296 *self == *rhs
297 }
298 fn attr_hash<H: std::hash::Hasher>(&self, state: &mut H) {
299 std::hash::Hash::hash(&self, state);
300 }
301}
302
303impl<T: OpIO> OpIO for Box<T> {
304 fn num_inputs(&self) -> usize {
305 self.deref().num_inputs()
306 }
307 fn input(&self, i: usize) -> ValueId {
308 self.deref().input(i)
309 }
310 fn input_mut(&mut self, i: usize) -> &mut ValueId {
311 self.deref_mut().input_mut(i)
312 }
313 fn num_outputs(&self) -> usize {
314 self.deref().num_outputs()
315 }
316 fn output(&self, i: usize) -> ValueId {
317 self.deref().output(i)
318 }
319 fn output_mut(&mut self, i: usize) -> &mut ValueId {
320 self.deref_mut().output_mut(i)
321 }
322 fn attr_eq(&self, rhs: &Self) -> bool {
323 self.deref().attr_eq(rhs.deref())
324 }
325 fn attr_hash<H: std::hash::Hasher>(&self, state: &mut H) {
326 self.deref().attr_hash(state)
327 }
328}
329
330impl<T: OpIO> OpIO for Vec<T> {
331 fn num_inputs(&self) -> usize {
332 self.iter().map(|t| t.num_inputs()).sum()
333 }
334 fn input(&self, i: usize) -> ValueId {
335 let mut offset = 0;
336 for t in self {
337 if i < offset + t.num_inputs() {
338 return t.input(i - offset);
339 }
340 offset += t.num_inputs();
341 }
342 panic!("Index out of bounds");
343 }
344 fn input_mut(&mut self, i: usize) -> &mut ValueId {
345 let mut offset = 0;
346 for t in self {
347 if i < offset + t.num_inputs() {
348 return t.input_mut(i - offset);
349 }
350 offset += t.num_inputs();
351 }
352 panic!("Index out of bounds");
353 }
354 fn num_outputs(&self) -> usize {
355 self.iter().map(|t| t.num_outputs()).sum()
356 }
357 fn output(&self, i: usize) -> ValueId {
358 let mut offset = 0;
359 for t in self {
360 if i < offset + t.num_outputs() {
361 return t.output(i - offset);
362 }
363 offset += t.num_outputs();
364 }
365 panic!("Index out of bounds");
366 }
367 fn output_mut(&mut self, i: usize) -> &mut ValueId {
368 let mut offset = 0;
369 for t in self {
370 if i < offset + t.num_outputs() {
371 return t.output_mut(i - offset);
372 }
373 offset += t.num_outputs();
374 }
375 panic!("Index out of bounds");
376 }
377 fn attr_eq(&self, rhs: &Self) -> bool {
378 self.iter().zip(rhs.iter()).all(|(a, b)| a.attr_eq(b))
379 }
380 fn attr_hash<H: std::hash::Hasher>(&self, state: &mut H) {
381 for t in self {
382 t.attr_hash(state);
383 }
384 }
385}
386
387impl<T: OpIO> OpIO for Option<T> {
388 fn num_inputs(&self) -> usize {
389 self.as_ref().map(|t| t.num_inputs()).unwrap_or(0)
390 }
391
392 fn input(&self, i: usize) -> ValueId {
393 self.as_ref().unwrap().input(i)
394 }
395
396 fn input_mut(&mut self, i: usize) -> &mut ValueId {
397 self.as_mut().unwrap().input_mut(i)
398 }
399
400 fn num_outputs(&self) -> usize {
401 self.as_ref().map(|t| t.num_outputs()).unwrap_or(0)
402 }
403
404 fn output(&self, i: usize) -> ValueId {
405 self.as_ref().unwrap().output(i)
406 }
407
408 fn output_mut(&mut self, i: usize) -> &mut ValueId {
409 self.as_mut().unwrap().output_mut(i)
410 }
411
412 fn attr_eq(&self, rhs: &Self) -> bool {
413 match (self, rhs) {
414 (None, None) => true,
415 (Some(a), Some(b)) => a.attr_eq(b),
416 _ => false,
417 }
418 }
419
420 fn attr_hash<H: std::hash::Hasher>(&self, state: &mut H) {
421 self.as_ref().map(|t| t.attr_hash(state));
422 }
423}
424
425#[derive(Debug, Clone, Copy)]
426pub struct AttrView<'op, T: OpIO>(pub &'op T);
427impl<'op, T: OpIO> AttrView<'op, T> {
428 pub fn new(op: &'op T) -> Self {
429 Self(op)
430 }
431}
432impl<'op, T: OpIO> std::ops::Deref for AttrView<'op, T> {
433 type Target = T;
434 fn deref(&self) -> &Self::Target {
435 self.0
436 }
437}
438impl<'op, T: OpIO> std::cmp::PartialEq for AttrView<'op, T> {
439 fn eq(&self, rhs: &Self) -> bool {
440 self.0.attr_eq(rhs.0)
441 }
442}
443impl<'op, T: OpIO> std::cmp::Eq for AttrView<'op, T> {}
444impl<'op, T: OpIO> std::hash::Hash for AttrView<'op, T> {
445 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
446 self.0.attr_hash(state)
447 }
448}