mad/
mad.rs

1//
2// Copyright (c) 2025 Ɓukasz Szpakowski
3//
4// This Source Code Form is subject to the terms of the Mozilla Public
5// License, v. 2.0. If a copy of the MPL was not distributed with this
6// file, You can obtain one at https://mozilla.org/MPL/2.0/.
7//
8use std::env;
9use std::process::exit;
10use std::time::Instant;
11use unmtx_gpu::*;
12
13fn create_matrix(n: usize, m: usize, is_scalar: bool) -> Matrix
14{
15    let mut elems: Vec<f32> = vec![0.0f32; n * m];
16    let scalar = if is_scalar {
17        (n * m) as f32
18    } else {
19        1.0f32
20    };
21    for i in 0..n {
22        for j in 0..m {
23            elems[m * i + j] = ((m * i + j) as f32) * scalar;
24        }
25    }
26    Matrix::new_with_elems(n, m, elems.as_slice())
27}
28
29fn main()
30{
31    let args: Vec<String> = env::args().collect();
32    let n: usize = match args.get(1) {
33        Some(s) => {
34            match s.parse::<usize>() {
35                Ok(tmp_n) => tmp_n,
36                Err(err) => {
37                    eprintln!("{}", err);
38                    exit(1);
39                },
40            }
41        },
42        None => 100,
43    };
44    let m: usize = match args.get(2) {
45        Some(s) => {
46            match s.parse::<usize>() {
47                Ok(tmp_m) => tmp_m,
48                Err(err) => {
49                    eprintln!("{}", err);
50                    exit(1);
51                },
52            }
53        },
54        None => 100,
55    };
56    let l: usize = match args.get(3) {
57        Some(s) => {
58            match s.parse::<usize>() {
59                Ok(tmp_l) => tmp_l,
60                Err(err) => {
61                    eprintln!("{}", err);
62                    exit(1);
63                },
64            }
65        },
66        None => 100,
67    };
68    let frontend = match Frontend::new() {
69        Ok(tmp_frontend) => tmp_frontend,
70        Err(err) => {
71            eprintln!("{}", err);
72            exit(1);
73        },
74    };
75    println!("backend: {}", frontend.backend().name());
76    let a = create_matrix(n, l, false);
77    let b = create_matrix(l, m, false);
78    let c = create_matrix(n, m, true);
79    let now = Instant::now();
80    let d = a * b + c;
81    let duration = now.elapsed();
82    let elems = d.elems();
83    let sum = elems.iter().fold(0.0f32, |x, y| x + y);
84    println!("sum: {}", sum);
85    println!("time: {}.{:06}", duration.as_secs(), duration.as_micros() % 1000000);
86    match finalize_default_backend() {
87        Ok(()) => (),
88        Err(err) => {
89            eprintln!("{}", err);
90            exit(1);
91        },
92    }
93}