inspire 0.2.0

InsPIRe: Communication-Efficient PIR with Server-side Preprocessing
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
//! Database encoding for PIR
//!
//! Encodes database entries as polynomial coefficients for PIR queries.
//!
//! # Direct Coefficient Encoding
//!
//! Values are stored directly as polynomial coefficients:
//! - Given values [y_0, y_1, ..., y_{t-1}], create polynomial h(X) = Σ y_k · X^k
//! - The value y_k is stored at coefficient k
//!
//! To retrieve y_k, the query encrypts X^(-k) (the inverse monomial).
//! Multiplying h(X) · X^(-k) rotates coefficients so y_k appears at position 0.
//!
//! In R_q = Z_q[X]/(X^d + 1), we have X^d = -1, so:
//! - X^(-k) = -X^(d-k) for k > 0
//! - X^(-0) = X^0 = 1

use crate::math::Poly;
use crate::params::{InspireParams, ShardConfig};

use super::setup::ShardData;

/// Encode a database column as polynomial coefficients
///
/// Given t values [y_0, y_1, ..., y_{t-1}], creates polynomial h(X) where
/// the coefficient of X^k is y_k. This allows retrieval via monomial multiplication.
///
/// # Arguments
/// * `column` - Column values to encode (t values)
/// * `params` - System parameters
///
/// # Returns
/// Polynomial h(X) = Σ y_k · X^k
pub fn encode_column(column: &[u64], params: &InspireParams) -> Poly {
    let d = params.ring_dim;
    let q = params.q;

    if column.is_empty() {
        return Poly::zero_moduli(d, params.moduli());
    }

    encode_direct(column, d, q, params.moduli())
}

/// Direct coefficient encoding: store values as polynomial coefficients
///
/// Creates h(X) = y_0 + y_1·X + y_2·X² + ... + y_{t-1}·X^{t-1}
///
/// # Arguments
/// * `values` - Values to encode at positions 0, 1, ..., t-1
/// * `d` - Ring dimension
/// * `q` - Modulus
///
/// # Returns
/// Polynomial with values stored as coefficients
pub fn encode_direct(values: &[u64], d: usize, q: u64, moduli: &[u64]) -> Poly {
    let mut coeffs = vec![0u64; d];
    for (i, &val) in values.iter().enumerate() {
        if i < d {
            coeffs[i] = val % q;
        }
    }
    Poly::from_coeffs_moduli(coeffs, moduli)
}

/// Create inverse monomial X^(-k) mod (X^d + 1)
///
/// In R_q = Z_q\[X\]/(X^d + 1), we have X^d = -1, so:
/// - X^(-k) = X^(2d-k) mod (X^d + 1)
/// - For k > 0: X^(2d-k) = X^(d + (d-k)) = -X^(d-k)
/// - For k = 0: X^0 = 1
///
/// # Arguments
/// * `k` - The exponent (index to retrieve)
/// * `d` - Ring dimension
/// * `q` - Modulus
///
/// # Returns
/// Polynomial representing X^(-k) = -X^(d-k) for k > 0, or 1 for k = 0
pub fn inverse_monomial(k: usize, d: usize, q: u64, moduli: &[u64]) -> Poly {
    let mut coeffs = vec![0u64; d];

    if k == 0 {
        coeffs[0] = 1;
    } else {
        let pos = d - k;
        coeffs[pos] = q - 1; // -1 mod q
    }

    Poly::from_coeffs_moduli(coeffs, moduli)
}

/// Encode full database into polynomial representation
///
/// Splits the database into shards, each containing at most d entries.
/// Each shard is encoded as polynomials ready for PIR queries.
///
/// # Arguments
/// * `database` - Raw database bytes (entries concatenated)
/// * `entry_size` - Size of each entry in bytes
/// * `params` - System parameters
/// * `shard_config` - Configuration for database sharding
///
/// # Returns
/// Vector of ShardData, each containing encoded polynomials
pub fn encode_database(
    database: &[u8],
    entry_size: usize,
    params: &InspireParams,
    shard_config: &ShardConfig,
) -> Vec<ShardData> {
    if database.is_empty() || entry_size == 0 {
        return vec![];
    }

    let total_entries = database.len() / entry_size;
    let entries_per_shard = shard_config.entries_per_shard() as usize;

    debug_assert!(
        entries_per_shard <= params.ring_dim,
        "entries_per_shard ({}) must be <= ring_dim ({})",
        entries_per_shard,
        params.ring_dim
    );

    let mut shards = Vec::new();
    let mut entry_offset = 0;
    let mut shard_id = 0u32;

    while entry_offset < total_entries {
        let actual_entries = std::cmp::min(entries_per_shard, total_entries - entry_offset);

        let num_polys = (entry_size * 8).div_ceil(16);
        let mut polynomials = Vec::with_capacity(num_polys);

        for poly_idx in 0..num_polys {
            let mut column = vec![0u64; entries_per_shard];

            for (local_idx, col) in column.iter_mut().enumerate().take(actual_entries) {
                let global_entry_idx = entry_offset + local_idx;
                let entry_start = global_entry_idx * entry_size;
                let entry_end = entry_start + entry_size;

                if entry_end <= database.len() {
                    let entry_bytes = &database[entry_start..entry_end];
                    *col = extract_column_value(entry_bytes, poly_idx);
                }
            }

            let poly = encode_column(&column, params);
            polynomials.push(poly);
        }

        shards.push(ShardData {
            id: shard_id,
            polynomials,
        });

        entry_offset += actual_entries;
        shard_id += 1;
    }

    shards
}

/// Extract a 16-bit column value from an entry
///
/// Splits entry into 16-bit chunks for polynomial encoding.
fn extract_column_value(entry: &[u8], column_idx: usize) -> u64 {
    let byte_offset = column_idx * 2;

    if byte_offset + 1 < entry.len() {
        let low = entry[byte_offset] as u64;
        let high = entry[byte_offset + 1] as u64;
        low | (high << 8)
    } else if byte_offset < entry.len() {
        entry[byte_offset] as u64
    } else {
        0
    }
}

/// Reconstruct entry from column values
///
/// Inverse of extract_column_value: combines 16-bit values back into bytes.
pub fn reconstruct_entry(column_values: &[u64], entry_size: usize) -> Vec<u8> {
    let mut entry = vec![0u8; entry_size];

    for (col_idx, &val) in column_values.iter().enumerate() {
        let byte_offset = col_idx * 2;

        if byte_offset < entry_size {
            entry[byte_offset] = (val & 0xFF) as u8;
        }
        if byte_offset + 1 < entry_size {
            entry[byte_offset + 1] = ((val >> 8) & 0xFF) as u8;
        }
    }

    entry
}

/// Generate evaluation points (unit monomials ±X^k)
///
/// z_k = X^(2d*k/t) for k = 0..t-1
/// These are t-th roots of unity in the ring R_q = Z_q[X]/(X^d + 1).
///
/// # Arguments
/// * `t` - Number of evaluation points
/// * `d` - Ring dimension
/// * `q` - Modulus
///
/// # Returns
/// Vector of polynomials representing z_k = X^(2d*k/t)
#[allow(dead_code)]
pub fn generate_eval_points_poly(t: usize, d: usize, q: u64, moduli: &[u64]) -> Vec<Poly> {
    if t == 0 {
        return vec![];
    }

    let step = (2 * d) / t;
    let mut points = Vec::with_capacity(t);

    for k in 0..t {
        let power = (k * step) % (2 * d);
        let poly = monomial(power, d, q, moduli);
        points.push(poly);
    }

    points
}

/// Create monomial X^power mod (X^d + 1)
///
/// X^d = -1, so X^(d+k) = -X^k
#[allow(dead_code)]
fn monomial(power: usize, d: usize, q: u64, moduli: &[u64]) -> Poly {
    let mut coeffs = vec![0u64; d];

    if power < d {
        coeffs[power] = 1;
    } else {
        let reduced_power = power - d;
        coeffs[reduced_power] = q - 1;
    }

    Poly::from_coeffs_moduli(coeffs, moduli)
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::math::NttContext;

    fn test_params() -> InspireParams {
        InspireParams {
            ring_dim: 256,
            q: 1152921504606830593,
            crt_moduli: vec![1152921504606830593],
            p: 65536,
            sigma: 6.4,
            gadget_base: 1 << 20,
            gadget_len: 3,
            security_level: crate::params::SecurityLevel::Bits128,
        }
    }

    #[test]
    fn test_encode_column_simple() {
        let params = test_params();
        let column = vec![1, 2, 3, 4];

        let poly = encode_column(&column, &params);

        assert_eq!(poly.dimension(), params.ring_dim);
        assert_eq!(poly.coeff(0), 1);
        assert_eq!(poly.coeff(1), 2);
        assert_eq!(poly.coeff(2), 3);
        assert_eq!(poly.coeff(3), 4);
    }

    #[test]
    fn test_encode_direct_stores_coefficients() {
        let d = 256;
        let q = 1152921504606830593u64;
        let values: Vec<u64> = (0..16).map(|i| (i * 7 + 3) as u64).collect();

        let poly = encode_direct(&values, d, q, &[q]);

        for (i, &val) in values.iter().enumerate() {
            assert_eq!(poly.coeff(i), val, "Coefficient {} should be {}", i, val);
        }
        for i in values.len()..d {
            assert_eq!(poly.coeff(i), 0, "Coefficient {} should be 0", i);
        }
    }

    #[test]
    fn test_encode_direct_empty() {
        let d = 256;
        let q = 1152921504606830593u64;

        let poly = encode_direct(&[], d, q, &[q]);

        assert!(poly.is_zero());
    }

    #[test]
    fn test_encode_column_empty() {
        let params = test_params();

        let poly = encode_column(&[], &params);

        assert!(poly.is_zero());
    }

    #[test]
    fn test_inverse_monomial_zero() {
        let d = 256;
        let q = 1152921504606830593u64;
        let moduli = [q];

        let inv_m0 = inverse_monomial(0, d, q, &moduli);

        assert_eq!(inv_m0.coeff(0), 1);
        for i in 1..d {
            assert_eq!(inv_m0.coeff(i), 0);
        }
    }

    #[test]
    fn test_inverse_monomial_one() {
        let d = 256;
        let q = 1152921504606830593u64;
        let moduli = [q];

        let inv_m1 = inverse_monomial(1, d, q, &moduli);

        assert_eq!(inv_m1.coeff(d - 1), q - 1);
        for i in 0..(d - 1) {
            assert_eq!(inv_m1.coeff(i), 0);
        }
    }

    #[test]
    fn test_inverse_monomial_rotation() {
        let d = 256;
        let q = 1152921504606830593u64;
        let moduli = [q];
        let ctx = NttContext::with_moduli(d, &moduli);

        let values: Vec<u64> = (0..d).map(|i| (i + 1) as u64).collect();
        let h = encode_direct(&values, d, q, &moduli);

        for (k, &expected) in values.iter().enumerate().take(16) {
            let inv_mono = inverse_monomial(k, d, q, &moduli);
            let rotated = h.mul_ntt(&inv_mono, &ctx);

            assert_eq!(
                rotated.coeff(0),
                expected,
                "Rotation by {} should bring value {} to position 0",
                k,
                expected
            );
        }
    }

    #[test]
    fn test_extract_reconstruct_entry() {
        let entry: Vec<u8> = (0..32).collect();
        let entry_size: usize = 32;
        let num_cols = (entry_size * 8).div_ceil(16);

        let mut column_values = Vec::new();
        for col_idx in 0..num_cols {
            column_values.push(extract_column_value(&entry, col_idx));
        }

        let reconstructed = reconstruct_entry(&column_values, entry_size);

        assert_eq!(entry, reconstructed);
    }

    #[test]
    fn test_monomial_in_ring() {
        let d = 256;
        let q = 1152921504606830593u64;
        let moduli = [q];

        let m0 = monomial(0, d, q, &moduli);
        assert_eq!(m0.coeff(0), 1);
        for i in 1..d {
            assert_eq!(m0.coeff(i), 0);
        }

        let m1 = monomial(1, d, q, &moduli);
        assert_eq!(m1.coeff(0), 0);
        assert_eq!(m1.coeff(1), 1);

        let m_d = monomial(d, d, q, &moduli);
        assert_eq!(m_d.coeff(0), q - 1);
        for i in 1..d {
            assert_eq!(m_d.coeff(i), 0);
        }
    }

    #[test]
    fn test_generate_eval_points_count() {
        let d = 256;
        let q = 1152921504606830593u64;

        for t in [1, 2, 4, 8, 16, 32] {
            let points = generate_eval_points_poly(t, d, q, &[q]);
            assert_eq!(points.len(), t);
        }
    }

    #[test]
    fn test_generate_eval_points_first() {
        let d = 256;
        let q = 1152921504606830593u64;
        let t = 8;

        let points = generate_eval_points_poly(t, d, q, &[q]);

        assert_eq!(points[0].coeff(0), 1);
        for i in 1..d {
            assert_eq!(points[0].coeff(i), 0);
        }
    }
}