hugr_llvm/emit/func/
mailbox.rs

1use std::{borrow::Cow, rc::Rc};
2
3use anyhow::{Result, bail};
4use delegate::delegate;
5use inkwell::{
6    builder::Builder,
7    types::{BasicType, BasicTypeEnum},
8    values::{BasicValue, BasicValueEnum, PointerValue},
9};
10use itertools::{Itertools as _, zip_eq};
11
12#[derive(Eq, PartialEq, Clone)]
13pub struct ValueMailBox<'c> {
14    typ: BasicTypeEnum<'c>,
15    ptr: PointerValue<'c>,
16    name: Cow<'static, str>,
17}
18
19fn join_names<'a>(names: impl IntoIterator<Item = &'a str>) -> String {
20    names
21        .into_iter()
22        .filter(|x| !x.is_empty())
23        .join("_")
24        .to_string()
25}
26
27impl<'c> ValueMailBox<'c> {
28    pub(super) fn new(
29        typ: impl BasicType<'c>,
30        ptr: PointerValue<'c>,
31        name: Option<String>,
32    ) -> Self {
33        Self {
34            typ: typ.as_basic_type_enum(),
35            ptr,
36            name: name.map_or(Cow::Borrowed(""), Cow::Owned),
37        }
38    }
39    pub fn get_type(&self) -> BasicTypeEnum<'c> {
40        self.typ
41    }
42
43    pub fn name(&self) -> &str {
44        self.name.as_ref()
45    }
46
47    pub fn promise(&self) -> ValuePromise<'c> {
48        ValuePromise(self.clone())
49    }
50
51    pub fn read<'a>(
52        &'a self,
53        builder: &Builder<'c>,
54        labels: impl IntoIterator<Item = &'a str>,
55    ) -> Result<BasicValueEnum<'c>> {
56        let r = builder.build_load(
57            self.ptr,
58            &join_names(
59                labels
60                    .into_iter()
61                    .chain(std::iter::once(self.name.as_ref())),
62            ),
63        )?;
64        debug_assert_eq!(r.get_type(), self.get_type());
65        Ok(r)
66    }
67
68    fn write(&self, builder: &Builder<'c>, v: impl BasicValue<'c>) -> Result<()> {
69        builder.build_store(self.ptr, v)?;
70        Ok(())
71    }
72}
73
74#[must_use]
75pub struct ValuePromise<'c>(ValueMailBox<'c>);
76
77impl<'c> ValuePromise<'c> {
78    pub fn finish(self, builder: &Builder<'c>, v: impl BasicValue<'c>) -> Result<()> {
79        self.0.write(builder, v)
80    }
81
82    delegate! {
83        to self.0 {
84            pub fn get_type(&self) -> BasicTypeEnum<'c>;
85        }
86    }
87}
88
89/// Holds a vector of [`PointerValue`]s pointing to `alloca`s in the first block
90/// of a function.
91#[derive(Eq, PartialEq, Clone)]
92#[allow(clippy::len_without_is_empty)]
93pub struct RowMailBox<'c>(Rc<Vec<ValueMailBox<'c>>>, Cow<'static, str>);
94
95impl<'c> RowMailBox<'c> {
96    #[must_use]
97    pub fn new_empty() -> Self {
98        Self::new(std::iter::empty(), None)
99    }
100
101    pub(super) fn new(
102        mbs: impl IntoIterator<Item = ValueMailBox<'c>>,
103        name: Option<String>,
104    ) -> Self {
105        Self(
106            Rc::new(mbs.into_iter().collect_vec()),
107            name.map_or(Cow::Borrowed(""), Cow::Owned),
108        )
109    }
110
111    /// Returns a [`RowPromise`] that when [`RowPromise::finish`]ed will write to this `RowMailBox`.
112    pub fn promise(&self) -> RowPromise<'c> {
113        RowPromise(self.clone())
114    }
115
116    /// Get the LLVM types of this `RowMailBox`.
117    pub fn get_types(&'_ self) -> impl Iterator<Item = BasicTypeEnum<'c>> + '_ {
118        self.0.iter().map(ValueMailBox::get_type)
119    }
120
121    /// Returns the number of values in this `RowMailBox`.
122    #[must_use]
123    pub fn len(&self) -> usize {
124        self.0.len()
125    }
126
127    /// Read from the inner pointers.
128    pub fn read_vec<'a>(
129        &'a self,
130        builder: &Builder<'c>,
131        labels: impl IntoIterator<Item = &'a str>,
132    ) -> Result<Vec<BasicValueEnum<'c>>> {
133        self.read(builder, labels)
134    }
135
136    /// Read from the inner pointers.
137    pub fn read<'a, R: FromIterator<BasicValueEnum<'c>>>(
138        &'a self,
139        builder: &Builder<'c>,
140        labels: impl IntoIterator<Item = &'a str>,
141    ) -> Result<R> {
142        let labels = labels.into_iter().collect_vec();
143        self.mailboxes()
144            .map(|mb| mb.read(builder, labels.clone()))
145            .collect::<Result<_>>()
146    }
147
148    pub(crate) fn write(
149        &self,
150        builder: &Builder<'c>,
151        vs: impl IntoIterator<Item = BasicValueEnum<'c>>,
152    ) -> Result<()> {
153        let vs = vs.into_iter().collect_vec();
154        #[cfg(debug_assertions)]
155        {
156            let actual_types = vs.clone().into_iter().map(|x| x.get_type()).collect_vec();
157            let expected_types = self.get_types().collect_vec();
158            if actual_types != expected_types {
159                bail!(
160                    "RowMailbox::write: Expected types {:?}, got {:?}",
161                    expected_types,
162                    actual_types
163                );
164            }
165        }
166        zip_eq(self.0.iter(), vs).try_for_each(|(mb, v)| mb.write(builder, v))
167    }
168
169    fn mailboxes(&'_ self) -> impl Iterator<Item = ValueMailBox<'c>> + '_ {
170        self.0.iter().cloned()
171    }
172}
173
174impl<'c> FromIterator<ValueMailBox<'c>> for RowMailBox<'c> {
175    fn from_iter<T: IntoIterator<Item = ValueMailBox<'c>>>(iter: T) -> Self {
176        Self::new(iter, None)
177    }
178}
179
180/// A promise to write values into a `RowMailBox`
181#[must_use]
182#[allow(clippy::len_without_is_empty)]
183pub struct RowPromise<'c>(RowMailBox<'c>);
184
185impl<'c> RowPromise<'c> {
186    /// Consumes the `RowPromise`, writing the values to the promised [`RowMailBox`].
187    pub fn finish(
188        self,
189        builder: &Builder<'c>,
190        vs: impl IntoIterator<Item = BasicValueEnum<'c>>,
191    ) -> Result<()> {
192        self.0.write(builder, vs)
193    }
194
195    delegate! {
196        to self.0 {
197            /// Get the LLVM types of this `RowMailBox`.
198            pub fn get_types(&'_ self) -> impl Iterator<Item=BasicTypeEnum<'c>> + '_;
199            /// Returns the number of values promised to be written.
200            #[must_use] pub fn len(&self) -> usize;
201        }
202    }
203}