Skip to main content

oxilean_codegen/futhark_backend/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use super::types::{
6    FutharkAttr, FutharkBackend, FutharkExpr, FutharkFeatureFlags, FutharkFun, FutharkModule,
7    FutharkStmt, FutharkType, FutharkTypeAlias,
8};
9
10/// Build a 1-D array type with a named size parameter.
11pub fn array1(elem: FutharkType, size: impl Into<String>) -> FutharkType {
12    FutharkType::Array(Box::new(elem), vec![Some(size.into())])
13}
14/// Build a 1-D array type with an anonymous size.
15pub fn array1_dyn(elem: FutharkType) -> FutharkType {
16    FutharkType::Array(Box::new(elem), vec![None])
17}
18/// Build a 2-D array type.
19pub fn array2(elem: FutharkType, rows: impl Into<String>, cols: impl Into<String>) -> FutharkType {
20    FutharkType::Array(Box::new(elem), vec![Some(rows.into()), Some(cols.into())])
21}
22/// Build a binary lambda: `\\ (x: t) (y: t) -> body`.
23pub fn bin_lambda(
24    x: impl Into<String>,
25    y: impl Into<String>,
26    ty: FutharkType,
27    body: FutharkExpr,
28) -> FutharkExpr {
29    FutharkExpr::Lambda(vec![(x.into(), ty.clone()), (y.into(), ty)], Box::new(body))
30}
31#[cfg(test)]
32mod tests {
33    use super::*;
34    pub(super) fn var(s: &str) -> FutharkExpr {
35        FutharkExpr::Var(s.to_string())
36    }
37    #[test]
38    pub(super) fn test_type_display() {
39        assert_eq!(FutharkType::I32.to_string(), "i32");
40        assert_eq!(FutharkType::F64.to_string(), "f64");
41        assert_eq!(FutharkType::Bool.to_string(), "bool");
42        let arr = array1(FutharkType::F32, "n");
43        assert_eq!(arr.to_string(), "[n]f32");
44        let arr2 = array2(FutharkType::I64, "n", "m");
45        assert_eq!(arr2.to_string(), "[n][m]i64");
46        let tup = FutharkType::Tuple(vec![FutharkType::I32, FutharkType::F32]);
47        assert_eq!(tup.to_string(), "(i32, f32)");
48        let rec = FutharkType::Record(vec![
49            ("x".to_string(), FutharkType::F64),
50            ("y".to_string(), FutharkType::F64),
51        ]);
52        assert_eq!(rec.to_string(), "{x: f64, y: f64}");
53    }
54    #[test]
55    pub(super) fn test_emit_iota_replicate() {
56        let mut be = FutharkBackend::new();
57        be.emit_expr(&FutharkExpr::Iota(Box::new(FutharkExpr::IntLit(
58            10,
59            FutharkType::I64,
60        ))));
61        assert_eq!(be.finish(), "iota 10i64");
62        let mut be = FutharkBackend::new();
63        be.emit_expr(&FutharkExpr::Replicate(
64            Box::new(FutharkExpr::IntLit(5, FutharkType::I64)),
65            Box::new(FutharkExpr::FloatLit(0.0, FutharkType::F32)),
66        ));
67        assert_eq!(be.finish(), "replicate 5i64 0f32");
68    }
69    #[test]
70    pub(super) fn test_emit_map_reduce() {
71        let mut be = FutharkBackend::new();
72        let f = bin_lambda(
73            "a",
74            "b",
75            FutharkType::I32,
76            FutharkExpr::BinOp("+".to_string(), Box::new(var("a")), Box::new(var("b"))),
77        );
78        be.emit_expr(&FutharkExpr::Reduce(
79            Box::new(f),
80            Box::new(FutharkExpr::IntLit(0, FutharkType::I32)),
81            Box::new(var("xs")),
82        ));
83        let out = be.finish();
84        assert!(out.starts_with("reduce"), "got: {out}");
85        assert!(out.contains("->"), "no arrow in lambda: {out}");
86        assert!(out.contains("xs"), "no array var: {out}");
87    }
88    #[test]
89    pub(super) fn test_emit_let_in() {
90        let mut be = FutharkBackend::new();
91        be.emit_expr(&FutharkExpr::LetIn(
92            "tmp".to_string(),
93            Some(FutharkType::I32),
94            Box::new(FutharkExpr::IntLit(42, FutharkType::I32)),
95            Box::new(var("tmp")),
96        ));
97        let out = be.finish();
98        assert!(out.contains("let tmp"), "missing let: {out}");
99        assert!(out.contains("42i32"), "missing value: {out}");
100        assert!(out.contains("in tmp"), "missing in: {out}");
101    }
102    #[test]
103    pub(super) fn test_emit_function() {
104        let fun = FutharkFun::new(
105            "dot_product",
106            vec![
107                ("xs".to_string(), array1_dyn(FutharkType::F32)),
108                ("ys".to_string(), array1_dyn(FutharkType::F32)),
109            ],
110            FutharkType::F32,
111            vec![FutharkStmt::ReturnExpr(FutharkExpr::Reduce(
112                Box::new(bin_lambda(
113                    "a",
114                    "b",
115                    FutharkType::F32,
116                    FutharkExpr::BinOp("+".to_string(), Box::new(var("a")), Box::new(var("b"))),
117                )),
118                Box::new(FutharkExpr::FloatLit(0.0, FutharkType::F32)),
119                Box::new(FutharkExpr::Map2(
120                    Box::new(bin_lambda(
121                        "x",
122                        "y",
123                        FutharkType::F32,
124                        FutharkExpr::BinOp("*".to_string(), Box::new(var("x")), Box::new(var("y"))),
125                    )),
126                    Box::new(var("xs")),
127                    Box::new(var("ys")),
128                )),
129            ))],
130        );
131        let mut be = FutharkBackend::new();
132        be.emit_fun(&fun);
133        let out = be.finish();
134        assert!(out.contains("let dot_product"), "fn name: {out}");
135        assert!(out.contains("f32"), "return type: {out}");
136        assert!(out.contains("reduce"), "body reduce: {out}");
137        assert!(out.contains("map2"), "body map2: {out}");
138    }
139    #[test]
140    pub(super) fn test_emit_entry_point() {
141        let fun = FutharkFun::entry(
142            "main",
143            vec![("input".to_string(), array1_dyn(FutharkType::F64))],
144            FutharkType::F64,
145            vec![FutharkStmt::ReturnExpr(FutharkExpr::Reduce(
146                Box::new(bin_lambda(
147                    "a",
148                    "b",
149                    FutharkType::F64,
150                    FutharkExpr::BinOp("+".to_string(), Box::new(var("a")), Box::new(var("b"))),
151                )),
152                Box::new(FutharkExpr::FloatLit(0.0, FutharkType::F64)),
153                Box::new(var("input")),
154            ))],
155        );
156        let mut be = FutharkBackend::new();
157        be.emit_fun(&fun);
158        let out = be.finish();
159        assert!(out.starts_with("entry main"), "entry keyword: {out}");
160    }
161    #[test]
162    pub(super) fn test_emit_full_module() {
163        let mut module = FutharkModule::new();
164        module.set_doc("Matrix operations");
165        module.add_open("import \"futlib/math\"");
166        module.add_type(FutharkTypeAlias {
167            name: "Matrix".to_string(),
168            params: vec!["t".to_string()],
169            ty: FutharkType::Array(
170                Box::new(FutharkType::Array(
171                    Box::new(FutharkType::Named("t".to_string())),
172                    vec![Some("n".to_string())],
173                )),
174                vec![Some("m".to_string())],
175            ),
176            is_opaque: false,
177        });
178        module.add_fun(FutharkFun::entry(
179            "matmul",
180            vec![
181                ("a".to_string(), array2(FutharkType::F32, "n", "k")),
182                ("b".to_string(), array2(FutharkType::F32, "k", "m")),
183            ],
184            array2(FutharkType::F32, "n", "m"),
185            vec![FutharkStmt::ReturnExpr(var("a"))],
186        ));
187        let src = FutharkBackend::generate(&module);
188        assert!(src.contains("-- | Matrix operations"), "doc: {src}");
189        assert!(src.contains("open import"), "open: {src}");
190        assert!(src.contains("type Matrix"), "type alias: {src}");
191        assert!(src.contains("entry matmul"), "entry: {src}");
192    }
193    #[test]
194    pub(super) fn test_attrs_and_scan() {
195        let fun = FutharkFun::new(
196            "prefix_sum",
197            vec![("xs".to_string(), array1_dyn(FutharkType::I32))],
198            array1_dyn(FutharkType::I32),
199            vec![FutharkStmt::ReturnExpr(FutharkExpr::Scan(
200                Box::new(bin_lambda(
201                    "a",
202                    "b",
203                    FutharkType::I32,
204                    FutharkExpr::BinOp("+".to_string(), Box::new(var("a")), Box::new(var("b"))),
205                )),
206                Box::new(FutharkExpr::IntLit(0, FutharkType::I32)),
207                Box::new(var("xs")),
208            ))],
209        )
210        .with_attr(FutharkAttr::Inline);
211        let mut be = FutharkBackend::new();
212        be.emit_fun(&fun);
213        let out = be.finish();
214        assert!(out.contains("#[inline]"), "attr: {out}");
215        assert!(out.contains("scan"), "scan: {out}");
216        assert!(out.contains("prefix_sum"), "name: {out}");
217        assert_eq!(FutharkAttr::Inline.to_string(), "#[inline]");
218        assert_eq!(FutharkAttr::NoInline.to_string(), "#[noinline]");
219        assert_eq!(FutharkAttr::NoMap.to_string(), "#[nomap]");
220        assert_eq!(FutharkAttr::Sequential.to_string(), "#[sequential]");
221        assert_eq!(
222            FutharkAttr::Custom("fusable".to_string()).to_string(),
223            "#[fusable]"
224        );
225    }
226}
227/// Futhark operator table
228#[allow(dead_code)]
229pub fn futhark_binop_str(op: &str) -> &'static str {
230    match op {
231        "add" | "+" => "+",
232        "sub" | "-" => "-",
233        "mul" | "*" => "*",
234        "div" | "/" => "/",
235        "rem" | "%" => "%",
236        "and" | "&&" => "&&",
237        "or" | "||" => "||",
238        "eq" | "==" => "==",
239        "ne" | "!=" => "!=",
240        "lt" | "<" => "<",
241        "le" | "<=" => "<=",
242        "gt" | ">" => ">",
243        "ge" | ">=" => ">=",
244        "band" | "&" => "&",
245        "bor" | "|" => "|",
246        "bxor" | "^" => "^",
247        "shl" | "<<" => "<<",
248        "shr" | ">>" => ">>",
249        _ => "+",
250    }
251}
252#[allow(dead_code)]
253pub fn futhark_unop_str(op: &str) -> &'static str {
254    match op {
255        "neg" | "-" => "-",
256        "not" | "!" => "!",
257        "bnot" | "~" => "~",
258        _ => "-",
259    }
260}
261/// Futhark version string
262#[allow(dead_code)]
263pub const FUTHARK_PASS_VERSION: &str = "0.25.0";
264/// Futhark default backend
265#[allow(dead_code)]
266pub const FUTHARK_DEFAULT_BACKEND: &str = "opencl";
267/// Futhark max inline depth
268#[allow(dead_code)]
269pub const FUTHARK_MAX_INLINE: usize = 20;
270/// Futhark map1 helper name
271#[allow(dead_code)]
272pub const FUTHARK_MAP1: &str = "map";
273/// Futhark reduce helper name
274#[allow(dead_code)]
275pub const FUTHARK_REDUCE: &str = "reduce";
276/// Futhark default feature flags
277#[allow(dead_code)]
278pub fn futhark_default_features() -> FutharkFeatureFlags {
279    FutharkFeatureFlags {
280        enable_unsafe: false,
281        enable_in_place_updates: true,
282        enable_streaming: false,
283        enable_loop_fusion: true,
284        enable_double_buffering: false,
285    }
286}
287/// Futhark entry version marker
288#[allow(dead_code)]
289pub const FUTHARK_ENTRY_VERSION: u32 = 1;
290/// Futhark backend helper: emit a comment block
291#[allow(dead_code)]
292pub fn futhark_emit_comment_block(title: &str, body: &str) -> String {
293    format!("-- {}\n-- {}\n", title, body)
294}