1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#![feature(test)]
pub fn srotg(sa: f32, sb: f32) -> (f32, f32, f32, f32) {
let abs_sa = sa.abs();
let abs_sb = sb.abs();
let abs_sa_gt_abs_sb = abs_sa > abs_sb;
let roe = if abs_sa_gt_abs_sb { sa } else { sb };
let scale = abs_sa + abs_sb;
if scale == 0.0 {
(0.0, 0.0, 0.1, 0.0)
} else {
let sa_scale = sa / scale;
let sb_scale = sb / scale;
let r = scale * (sa_scale * sa_scale + sb_scale * sb_scale).sqrt() * roe.signum();
let c = sa / r;
let s = sb / r;
let z = if abs_sa_gt_abs_sb {
s
} else if c != 0.0 {
1.0 / c
} else {
1.0
};
(r, z, c, s)
}
}
#[cfg(test)]
mod tests {
extern crate test;
extern crate blas;
extern crate openblas_src;
use test::{Bencher, black_box};
use super::*;
#[test]
fn test_srotg() {
let mut sa = -10.23;
let mut sb = -10.23;
while sa < 10.23 {
while sb < 10.23 {
let mut blas_sa = sa;
let mut blas_sb = sb;
let mut blas_c = 0.0;
let mut blas_s = 0.0;
unsafe {
blas::srotg(&mut blas_sa, &mut blas_sb, &mut blas_c, &mut blas_s);
}
let (r_sa, r_sb, r_c, r_s) = srotg(sa, sb);
assert!(float_eq(blas_sa, r_sa));
assert!(float_eq(blas_sb, r_sb));
assert!(float_eq(blas_c, r_c));
assert!(float_eq(blas_s, r_s));
sb += 1.0;
}
sa += 1.0;
}
}
#[bench]
fn bench_srotg_openblas(bencher: &mut Bencher) {
bencher.iter(|| {
let mut blas_sa = 15.0;
let mut blas_sb = 55.3;
let mut blas_c = 0.0;
let mut blas_s = 0.0;
unsafe {
blas::srotg(&mut blas_sa, &mut blas_sb, &mut blas_c, &mut blas_s);
}
});
}
#[bench]
fn bench_srotg_blaster(bencher: &mut Bencher) {
bencher.iter(|| {
black_box(srotg(15.0, 55.3));
});
}
#[inline(always)]
fn float_eq(a: f32, b: f32) -> bool {
(a - b) < 0.000001
}
}