kir/
ir.rs

1use std::{
2  fmt::Debug,
3  hash::Hash,
4  ops::{Deref, DerefMut},
5  str::FromStr,
6};
7
8use crate::{new_key_type, IdFor};
9use slotmap::SlotMap;
10
11new_key_type! {
12    pub struct ValueId; => Value
13}
14
15// TODO: its strange to have Type inside kir, but it is coupled with Value
16#[derive(Debug, Clone, PartialEq, Eq, Hash)]
17pub enum Type {
18  // software integer type, (rust i32)
19  Integer,
20  // unsigned int
21  UInt(u32),
22  // signed int
23  SInt(u32),
24  // probe, rwprobe, NOT HARDWARE
25  Ref(u32),
26  // vector type
27  Vector(Box<Type>, u32),
28  // bundle type
29  Bundle(Vec<(String, Type, bool)>),
30  // enum type
31  Enum(Vec<(String, Type)>),
32}
33
34impl Type {
35  pub fn unit() -> Type {
36    Type::UInt(0)
37  }
38  pub fn uint(width: u32) -> Type {
39    Type::UInt(width)
40  }
41  pub fn sint(width: u32) -> Type {
42    Type::SInt(width)
43  }
44  pub fn new_ref(width: u32) -> Type {
45    Type::Ref(width)
46  }
47  pub fn vector(base: Type, depth: u32) -> Type {
48    Type::Vector(Box::new(base), depth)
49  }
50  pub fn bundle(fields: Vec<(String, Type, bool)>) -> Type {
51    // must sort fields by name
52    let mut fields = fields;
53    fields.sort_by(|a, b| a.0.cmp(&b.0));
54    Type::Bundle(fields)
55  }
56  pub fn union(variants: Vec<(String, Option<Type>)>) -> Type {
57    Type::Enum(
58      variants
59        .into_iter()
60        .map(|(name, ty)| (name, ty.unwrap_or(Type::unit())))
61        .collect(),
62    )
63  }
64  pub fn int_width(&self) -> u32 {
65    match self {
66      Type::UInt(width) => *width,
67      Type::SInt(width) => *width,
68      _ => panic!("Type {self:?} is not an integer"),
69    }
70  }
71  pub fn int_to_ref(&self) -> Self {
72    match self {
73      Type::UInt(width) => Type::Ref(*width),
74      _ => panic!("Type {self:?} is not an integer"),
75    }
76  }
77  pub fn ref_to_int(&self) -> Self {
78    match self {
79      Type::Ref(width) => Type::UInt(*width),
80      _ => panic!("Type {self:?} is not a reference"),
81    }
82  }
83  pub fn ref_width(&self) -> u32 {
84    match self {
85      Type::Ref(width) => *width,
86      _ => panic!("Type is not a reference"),
87    }
88  }
89  pub fn vector_elem_width(&self) -> u32 {
90    match self {
91      Type::Vector(base, _) => base.int_width(),
92      _ => panic!("Type is not an array"),
93    }
94  }
95  pub fn vector_depth(&self) -> u32 {
96    match self {
97      Type::Vector(_, depth) => *depth,
98      _ => panic!("Type is not an array"),
99    }
100  }
101}
102
103impl ToString for Type {
104  fn to_string(&self) -> String {
105    match self {
106      Type::Integer => "integer".to_string(),
107      Type::UInt(width) => format!("i{}", width),
108      Type::SInt(width) => format!("s{}", width),
109      Type::Ref(width) => format!("r{}", width),
110      Type::Vector(base, depth) => format!("{}x{}", base.to_string(), depth),
111      Type::Bundle(fields) => format!(
112        "{{{}}}",
113        fields
114          .iter()
115          .map(|(name, ty, flip)| {
116            format!(
117              "{}{}: {}",
118              if *flip { "flip " } else { "" },
119              name,
120              ty.to_string()
121            )
122          })
123          .collect::<Vec<String>>()
124          .join(", ")
125      ),
126      Type::Enum(variants) => format!(
127        "{{|{}|}}",
128        variants
129          .iter()
130          .map(|(name, ty)| format!("{name}: {}", ty.to_string()))
131          .collect::<Vec<String>>()
132          .join(", ")
133      ),
134    }
135  }
136}
137impl FromStr for Type {
138  type Err = String;
139  fn from_str(s: &str) -> Result<Self, Self::Err> {
140    // int type: i32, i64, ...
141    // arr type: i32x2, i64x8 ...
142
143    // if starts with {, and ends with }, then it is a struct
144    if s.starts_with('{') && s.ends_with('}') {
145      let fields = s[1..s.len() - 1]
146        .split(',')
147        .map(|f| f.split_once(':').unwrap())
148        .map(|(name, ty)| {
149          let (name, flip) = if let Some((_, name)) = name.split_once(" ") {
150            if name == "flip" {
151              (name.to_string(), true)
152            } else {
153              return Err(format!("Invalid field name: {}", name));
154            }
155          } else {
156            (name.to_string(), false)
157          };
158          Ok((name, Type::from_str(ty)?, flip))
159        })
160        .collect::<Result<Vec<(String, Type, bool)>, String>>()?;
161      Ok(Type::bundle(fields))
162    } else if s.starts_with("{|") && s.ends_with("|}") {
163      let variants = s[2..s.len() - 2]
164        .split(',')
165        .map(|v| {
166          if let Some((name, ty)) = v.split_once(':') {
167            if let Ok(ty) = Type::from_str(ty) {
168              Ok((name.to_string(), Some(ty)))
169            } else {
170              return Err(format!("Invalid type: {}", ty));
171            }
172          } else {
173            Ok((v.to_string(), None))
174          }
175        })
176        .collect::<Result<Vec<(String, Option<Type>)>, String>>()?;
177      Ok(Type::union(variants))
178    } else {
179      let mut parts = s.split('x');
180      let base_part = parts
181        .next()
182        .ok_or_else(|| "Empty type string".to_string())?;
183      if let Some(depth) = parts.next() {
184        let base = Type::from_str(base_part)?;
185        let depth =
186          depth.parse().map_err(|e| format!("Invalid depth: {}", e))?;
187        Ok(Type::Vector(Box::new(base), depth))
188      } else {
189        if base_part == "integer" {
190          Ok(Type::Integer)
191        } else if base_part.starts_with('i') {
192          Ok(Type::UInt(base_part[1..].parse().unwrap()))
193        } else if base_part.starts_with('r') {
194          Ok(Type::Ref(base_part[1..].parse().unwrap()))
195        } else if base_part.starts_with('s') {
196          Ok(Type::SInt(base_part[1..].parse().unwrap()))
197        } else {
198          Err(format!("Invalid type string: {}", s))
199        }
200      }
201    }
202  }
203}
204
205#[derive(Debug, Clone)]
206pub struct Value {
207  pub ty: Option<Type>,
208  pub name: Option<String>,
209}
210
211impl Value {
212  pub fn new(ty: Type, name: Option<String>) -> Self {
213    Value { ty: Some(ty), name }
214  }
215
216  pub fn new_wo_ty(name: Option<String>) -> Self {
217    Value { ty: None, name }
218  }
219}
220
221impl ValueId {
222  pub fn ty(&self, t: &SlotMap<ValueId, Value>) -> Option<Type> {
223    t[*self].ty.clone()
224  }
225  pub fn name<'r>(&self, t: &'r SlotMap<ValueId, Value>) -> &'r Option<String> {
226    &t[*self].name
227  }
228  pub fn name_mut<'r>(
229    &self,
230    t: &'r mut SlotMap<ValueId, Value>,
231  ) -> &'r mut Option<String> {
232    &mut t[*self].name
233  }
234}
235
236pub type ValueMap = SlotMap<ValueId, Value>;
237
238pub trait OpIO {
239  fn num_inputs(&self) -> usize;
240  fn input(&self, i: usize) -> ValueId;
241  fn input_mut(&mut self, i: usize) -> &mut ValueId;
242  fn inputs(&self) -> impl Iterator<Item = ValueId> + '_ {
243    (0..self.num_inputs()).map(move |i| self.input(i))
244  }
245  fn map_inputs(&mut self, mut f: impl FnMut(ValueId) -> ValueId) {
246    for i in 0..self.num_inputs() {
247      *self.input_mut(i) = f(self.input(i));
248    }
249  }
250  fn num_outputs(&self) -> usize;
251  fn output(&self, i: usize) -> ValueId;
252  fn output_mut(&mut self, i: usize) -> &mut ValueId;
253  fn outputs(&self) -> impl Iterator<Item = ValueId> + '_ {
254    (0..self.num_outputs()).map(move |i| self.output(i))
255  }
256  fn map_outputs(&mut self, mut f: impl FnMut(ValueId) -> ValueId) {
257    for i in 0..self.num_outputs() {
258      *self.output_mut(i) = f(self.output(i));
259    }
260  }
261  fn values(&self) -> impl Iterator<Item = ValueId> + '_ {
262    self.inputs().chain(self.outputs())
263  }
264  fn map_values(&mut self, mut f: impl FnMut(ValueId) -> ValueId) {
265    self.map_inputs(&mut f);
266    self.map_outputs(&mut f);
267  }
268  fn attr_eq(&self, _rhs: &Self) -> bool {
269    true
270  }
271  fn attr_hash<H: std::hash::Hasher>(&self, _state: &mut H) {}
272}
273
274impl OpIO for ValueId {
275  fn num_inputs(&self) -> usize {
276    0
277  }
278  fn input(&self, _i: usize) -> ValueId {
279    panic!("ValueId has no inputs");
280  }
281  fn input_mut(&mut self, _i: usize) -> &mut ValueId {
282    panic!("ValueId has no inputs");
283  }
284  fn num_outputs(&self) -> usize {
285    1
286  }
287  fn output(&self, i: usize) -> ValueId {
288    assert_eq!(i, 0);
289    *self
290  }
291  fn output_mut(&mut self, i: usize) -> &mut ValueId {
292    assert_eq!(i, 0);
293    self
294  }
295  fn attr_eq(&self, rhs: &Self) -> bool {
296    *self == *rhs
297  }
298  fn attr_hash<H: std::hash::Hasher>(&self, state: &mut H) {
299    std::hash::Hash::hash(&self, state);
300  }
301}
302
303impl<T: OpIO> OpIO for Box<T> {
304  fn num_inputs(&self) -> usize {
305    self.deref().num_inputs()
306  }
307  fn input(&self, i: usize) -> ValueId {
308    self.deref().input(i)
309  }
310  fn input_mut(&mut self, i: usize) -> &mut ValueId {
311    self.deref_mut().input_mut(i)
312  }
313  fn num_outputs(&self) -> usize {
314    self.deref().num_outputs()
315  }
316  fn output(&self, i: usize) -> ValueId {
317    self.deref().output(i)
318  }
319  fn output_mut(&mut self, i: usize) -> &mut ValueId {
320    self.deref_mut().output_mut(i)
321  }
322  fn attr_eq(&self, rhs: &Self) -> bool {
323    self.deref().attr_eq(rhs.deref())
324  }
325  fn attr_hash<H: std::hash::Hasher>(&self, state: &mut H) {
326    self.deref().attr_hash(state)
327  }
328}
329
330impl<T: OpIO> OpIO for Vec<T> {
331  fn num_inputs(&self) -> usize {
332    self.iter().map(|t| t.num_inputs()).sum()
333  }
334  fn input(&self, i: usize) -> ValueId {
335    let mut offset = 0;
336    for t in self {
337      if i < offset + t.num_inputs() {
338        return t.input(i - offset);
339      }
340      offset += t.num_inputs();
341    }
342    panic!("Index out of bounds");
343  }
344  fn input_mut(&mut self, i: usize) -> &mut ValueId {
345    let mut offset = 0;
346    for t in self {
347      if i < offset + t.num_inputs() {
348        return t.input_mut(i - offset);
349      }
350      offset += t.num_inputs();
351    }
352    panic!("Index out of bounds");
353  }
354  fn num_outputs(&self) -> usize {
355    self.iter().map(|t| t.num_outputs()).sum()
356  }
357  fn output(&self, i: usize) -> ValueId {
358    let mut offset = 0;
359    for t in self {
360      if i < offset + t.num_outputs() {
361        return t.output(i - offset);
362      }
363      offset += t.num_outputs();
364    }
365    panic!("Index out of bounds");
366  }
367  fn output_mut(&mut self, i: usize) -> &mut ValueId {
368    let mut offset = 0;
369    for t in self {
370      if i < offset + t.num_outputs() {
371        return t.output_mut(i - offset);
372      }
373      offset += t.num_outputs();
374    }
375    panic!("Index out of bounds");
376  }
377  fn attr_eq(&self, rhs: &Self) -> bool {
378    self.iter().zip(rhs.iter()).all(|(a, b)| a.attr_eq(b))
379  }
380  fn attr_hash<H: std::hash::Hasher>(&self, state: &mut H) {
381    for t in self {
382      t.attr_hash(state);
383    }
384  }
385}
386
387impl<T: OpIO> OpIO for Option<T> {
388  fn num_inputs(&self) -> usize {
389    self.as_ref().map(|t| t.num_inputs()).unwrap_or(0)
390  }
391
392  fn input(&self, i: usize) -> ValueId {
393    self.as_ref().unwrap().input(i)
394  }
395
396  fn input_mut(&mut self, i: usize) -> &mut ValueId {
397    self.as_mut().unwrap().input_mut(i)
398  }
399
400  fn num_outputs(&self) -> usize {
401    self.as_ref().map(|t| t.num_outputs()).unwrap_or(0)
402  }
403
404  fn output(&self, i: usize) -> ValueId {
405    self.as_ref().unwrap().output(i)
406  }
407
408  fn output_mut(&mut self, i: usize) -> &mut ValueId {
409    self.as_mut().unwrap().output_mut(i)
410  }
411
412  fn attr_eq(&self, rhs: &Self) -> bool {
413    match (self, rhs) {
414      (None, None) => true,
415      (Some(a), Some(b)) => a.attr_eq(b),
416      _ => false,
417    }
418  }
419
420  fn attr_hash<H: std::hash::Hasher>(&self, state: &mut H) {
421    self.as_ref().map(|t| t.attr_hash(state));
422  }
423}
424
425#[derive(Debug, Clone, Copy)]
426pub struct AttrView<'op, T: OpIO>(pub &'op T);
427impl<'op, T: OpIO> AttrView<'op, T> {
428  pub fn new(op: &'op T) -> Self {
429    Self(op)
430  }
431}
432impl<'op, T: OpIO> std::ops::Deref for AttrView<'op, T> {
433  type Target = T;
434  fn deref(&self) -> &Self::Target {
435    self.0
436  }
437}
438impl<'op, T: OpIO> std::cmp::PartialEq for AttrView<'op, T> {
439  fn eq(&self, rhs: &Self) -> bool {
440    self.0.attr_eq(rhs.0)
441  }
442}
443impl<'op, T: OpIO> std::cmp::Eq for AttrView<'op, T> {}
444impl<'op, T: OpIO> std::hash::Hash for AttrView<'op, T> {
445  fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
446    self.0.attr_hash(state)
447  }
448}