1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
//! Modulation and demodulation.
//!
//! This module implements routines for modulation of bits to symbols and
//! demodulation of symbols to LLRs.

use super::channel::ChannelType;
use crate::gf2::GF2;
use ndarray::{ArrayBase, Data, Ix1};
use num_complex::Complex;
use num_traits::{One, Zero};

/// Modulation.
///
/// This trait is used to define the modulations that can be handled by the
/// simulation. It ties together a modulator and demodulator that work over the
/// same channel type (either real or complex), and declares the number of bits
/// per symbol of the modulation.
pub trait Modulation: 'static {
    /// Channel type.
    ///
    /// This is the scalar type for the symbols of the channel.
    type T: ChannelType;
    /// Modulator type.
    type Modulator: Modulator<T = Self::T>;
    /// Demodulator type.
    type Demodulator: Demodulator<T = Self::T>;
    /// Number of bits per symbol.
    const BITS_PER_SYMBOL: f64;
}

/// Modulator.
///
/// This trait defines modulators, which can convert a sequence of bits into
/// symbols.
pub trait Modulator: Default + Clone + Send {
    /// Scalar type for the symbols.
    type T;

    /// Modulates a sequence of bits into symbols.
    fn modulate<S>(&self, codeword: &ArrayBase<S, Ix1>) -> Vec<Self::T>
    where
        S: Data<Elem = GF2>;
}

/// Demodulator.
///
/// This trait defines demodulators, which can compute the bit LLRs for a
/// sequence of symbols.
pub trait Demodulator: Send {
    /// Scalar type for the symbols.
    type T;

    /// Creates a new demodulator.
    ///
    /// The parameter `noise_sigma` indicates the channel noise standard
    /// deviation in its real and imaginary part (or the channel noise standard
    /// deviation if the channel is real).
    fn from_noise_sigma(noise_sigma: f64) -> Self;

    /// Returns the LLRs corresponding to a sequence of symbols.
    fn demodulate(&self, symbols: &[Self::T]) -> Vec<f64>;
}

/// BPSK modulation
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, Default)]
pub struct Bpsk {}

impl Modulation for Bpsk {
    type T = f64;
    type Modulator = BpskModulator;
    type Demodulator = BpskDemodulator;
    const BITS_PER_SYMBOL: f64 = 1.0;
}

/// BPSK modulator.
///
/// Maps the bit 0 to the symbol -1.0 and the bit 1 to the symbol +1.0.
#[derive(Debug, Clone, Default)]
pub struct BpskModulator {}

impl BpskModulator {
    /// Creates a new BPSK modulator.
    pub fn new() -> BpskModulator {
        BpskModulator::default()
    }

    fn modulate_bit(bit: GF2) -> f64 {
        if bit.is_zero() {
            -1.0
        } else if bit.is_one() {
            1.0
        } else {
            panic!("invalid GF2 value")
        }
    }
}

impl Modulator for BpskModulator {
    type T = f64;

    fn modulate<S>(&self, codeword: &ArrayBase<S, Ix1>) -> Vec<f64>
    where
        S: Data<Elem = GF2>,
    {
        codeword.iter().cloned().map(Self::modulate_bit).collect()
    }
}

/// BPSK demodulator.
///
/// Assumes the same mapping as the [BpskModulator].
#[derive(Debug, Clone, Default)]
pub struct BpskDemodulator {
    scale: f64,
}

impl BpskDemodulator {
    /// Creates a new BPSK demodulator.
    ///
    /// The `noise_sigma` indicates the channel noise standard deviation. The
    /// channel noise is assumed to be a real Gaussian with mean zero and
    /// standard deviation `noise_sigma`.
    pub fn new(noise_sigma: f64) -> BpskDemodulator {
        BpskDemodulator {
            // Negative scale because we use the convention that +1 means a 1
            // bit.
            scale: -2.0 / (noise_sigma * noise_sigma),
        }
    }
}

impl Demodulator for BpskDemodulator {
    type T = f64;

    fn from_noise_sigma(noise_sigma: f64) -> BpskDemodulator {
        BpskDemodulator::new(noise_sigma)
    }

    fn demodulate(&self, symbols: &[f64]) -> Vec<f64> {
        symbols.iter().map(|&x| self.scale * x).collect()
    }
}

/// BPSK modulation
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, Default)]
pub struct Psk8 {}

impl Modulation for Psk8 {
    type T = Complex<f64>;
    type Modulator = Psk8Modulator;
    type Demodulator = Psk8Demodulator;
    const BITS_PER_SYMBOL: f64 = 3.0;
}

/// 8PSK modulator.
///
/// 8PSK modulator using the DVB-S2 Gray-coded constellation. The modulator can
/// only work with codewords whose length is a multiple of 3 bits.
#[derive(Debug, Clone, Default)]
pub struct Psk8Modulator {}

impl Psk8Modulator {
    /// Creates a new 8PSK modulator.
    pub fn new() -> Psk8Modulator {
        Psk8Modulator::default()
    }

    fn modulate_bits(b0: GF2, b1: GF2, b2: GF2) -> Complex<f64> {
        let a = (0.5f64).sqrt();
        match (b0.is_one(), b1.is_one(), b2.is_one()) {
            (false, false, false) => Complex::new(a, a),
            (true, false, false) => Complex::new(0.0, 1.0),
            (true, true, false) => Complex::new(-a, a),
            (false, true, false) => Complex::new(-1.0, 0.0),
            (false, true, true) => Complex::new(-a, -a),
            (true, true, true) => Complex::new(0.0, -1.0),
            (true, false, true) => Complex::new(a, -a),
            (false, false, true) => Complex::new(1.0, 0.0),
        }
    }
}

impl Modulator for Psk8Modulator {
    type T = Complex<f64>;

    /// Modulates a sequence of bits into symbols.
    ///
    /// # Panics
    ///
    /// Panics if the length of the codeword is not a multiple of 3 bits.
    fn modulate<S>(&self, codeword: &ArrayBase<S, Ix1>) -> Vec<Complex<f64>>
    where
        S: Data<Elem = GF2>,
    {
        assert_eq!(codeword.len() % 3, 0);
        codeword
            .iter()
            .step_by(3)
            .zip(codeword.iter().skip(1).step_by(3))
            .zip(codeword.iter().skip(2).step_by(3))
            .map(|((&b0, &b1), &b2)| Self::modulate_bits(b0, b1, b2))
            .collect()
    }
}

/// 8PSK demodulator.
///
/// Assumes the same mapping as the [Psk8Modulator]. Demodulates symbols into
/// LLRs using the exact formula implemented with the max-* function.
#[derive(Debug, Clone, Default)]
pub struct Psk8Demodulator {
    scale: f64,
}

impl Psk8Demodulator {
    /// Creates a new 8PSK demodulator.
    ///
    /// The `noise_sigma` indicates the channel noise standard deviation. The
    /// channel noise is assumed to be a circularly symmetric Gaussian with mean
    /// zero and standard deviation `noise_sigma` in its real part and imaginary
    /// part (the total variance is `2 * noise_sigma * noise_sigma`.
    pub fn new(noise_sigma: f64) -> Psk8Demodulator {
        Psk8Demodulator {
            scale: 1.0 / (noise_sigma * noise_sigma),
        }
    }

    fn demodulate_symbol(&self, symbol: Complex<f64>) -> [f64; 3] {
        let a = (0.5f64).sqrt();
        let symbol = symbol * self.scale;
        let d000 = dot(symbol, Complex::new(a, a));
        let d100 = dot(symbol, Complex::new(0.0, 1.0));
        let d110 = dot(symbol, Complex::new(-a, a));
        let d010 = dot(symbol, Complex::new(-1.0, 0.0));
        let d011 = dot(symbol, Complex::new(-a, -a));
        let d111 = dot(symbol, Complex::new(0.0, -1.0));
        let d101 = dot(symbol, Complex::new(a, -a));
        let d001 = dot(symbol, Complex::new(1.0, 0.0));
        let b0 = [d000, d001, d010, d011]
            .into_iter()
            .reduce(maxstar)
            .unwrap()
            - [d100, d101, d110, d111]
                .into_iter()
                .reduce(maxstar)
                .unwrap();
        let b1 = [d000, d001, d100, d101]
            .into_iter()
            .reduce(maxstar)
            .unwrap()
            - [d010, d011, d110, d111]
                .into_iter()
                .reduce(maxstar)
                .unwrap();
        let b2 = [d000, d010, d100, d110]
            .into_iter()
            .reduce(maxstar)
            .unwrap()
            - [d001, d011, d101, d111]
                .into_iter()
                .reduce(maxstar)
                .unwrap();
        [b0, b1, b2]
    }
}

impl Demodulator for Psk8Demodulator {
    type T = Complex<f64>;

    fn from_noise_sigma(noise_sigma: f64) -> Psk8Demodulator {
        Psk8Demodulator::new(noise_sigma)
    }

    fn demodulate(&self, symbols: &[Complex<f64>]) -> Vec<f64> {
        symbols
            .iter()
            .flat_map(|&x| self.demodulate_symbol(x))
            .collect()
    }
}

fn dot(a: Complex<f64>, b: Complex<f64>) -> f64 {
    a.re * b.re + a.im * b.im
}

fn maxstar(a: f64, b: f64) -> f64 {
    a.max(b) + (-((a - b).abs())).exp().ln_1p()
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn bpsk_modulator() {
        let modulator = BpskModulator::new();
        let x = modulator.modulate(&ndarray::arr1(&[GF2::one(), GF2::zero()]));
        assert_eq!(&x, &[1.0, -1.0]);
    }

    #[test]
    fn bpsk_demodulator() {
        let demodulator = BpskDemodulator::new(2.0_f64.sqrt());
        let x = demodulator.demodulate(&[1.0, -1.0]);
        assert_eq!(x.len(), 2);
        let tol = 1e-4;
        assert!((x[0] + 1.0).abs() < tol);
        assert!((x[1] - 1.0).abs() < tol);
    }

    #[test]
    fn psk8_modulator() {
        let o = GF2::one();
        let z = GF2::zero();
        let modulator = Psk8Modulator::new();
        let x = modulator.modulate(&ndarray::arr1(&[o, o, z, z, z, z, o, z, o]));
        let a = (0.5f64).sqrt();
        assert_eq!(
            &x,
            &[Complex::new(-a, a), Complex::new(a, a), Complex::new(a, -a)]
        );
    }

    #[test]
    fn psk8_demodulator_signs() {
        let noise_sigma = 1.0;
        let demodulator = Psk8Demodulator::new(noise_sigma);
        let a = (0.5f64).sqrt();
        let llr = demodulator.demodulate(&[
            Complex::new(1.0, 0.0),
            Complex::new(a, a),
            Complex::new(0.0, 1.0),
        ]);
        // 001
        assert!(llr[0] > 0.0);
        assert!(llr[1] > 0.0);
        assert!(llr[2] < 0.0);
        // 000
        assert!(llr[3] > 0.0);
        assert!(llr[4] > 0.0);
        assert!(llr[5] > 0.0);
        // 100
        assert!(llr[6] < 0.0);
        assert!(llr[7] > 0.0);
        assert!(llr[8] > 0.0);
    }
}