ckks_engine/utils.rs
1use crate::polynomial::Polynomial;
2use rand::Rng;
3
4
5// Rounds a given value to a specified number of decimal places
6fn round_to(value: f64, decimal_places: usize) -> f64 {
7 let factor = 10f64.powi(decimal_places as i32); // Calculate the rounding factor
8 (value * factor).round() / factor // Round the value and return
9}
10
11// Encode real numbers into polynomial form with scaling
12pub fn encode(plaintext: &[f64], scaling_factor: f64) -> Polynomial {
13 if scaling_factor <= 0.0 {
14 panic!("Scaling factor must be positive"); // Ensure the scaling factor is positive
15 }
16 // Print the input plaintext and scaling factor
17
18 // Scale the real numbers and convert them to integer coefficients
19 let coeffs: Vec<i64> = plaintext.iter()
20 .map(|&x| (x * scaling_factor).round() as i64) // Scale the real numbers
21 .collect();
22
23 // Print the resulting polynomial coefficients
24
25 Polynomial::new(coeffs) // Return a new polynomial with the coefficients
26}
27
28// Decode polynomial back to real numbers
29pub fn decode(ciphertext: &Polynomial, scaling_factor: f64) -> Vec<f64> {
30 if scaling_factor <= 0.0 {
31 panic!("Scaling factor must be positive"); // Ensure the scaling factor is positive
32 }
33 let threshold = 1e-10; // Define a small threshold for considering values as zero
34 let decimal_places = 2; // Number of decimal places for rounding
35
36 // Print the input ciphertext and scaling factor
37
38 // Perform decoding (reverse scaling) and apply thresholding and rounding
39 let decoded_values: Vec<f64> = ciphertext.coeffs.iter()
40 .map(|&c| {
41 let value = (c as f64) / scaling_factor; // Reverse scaling
42 let rounded_value = round_to(value, decimal_places); // Round the value to 2 decimal places
43 // Apply thresholding to treat small values as zero
44 if rounded_value.abs() < threshold {
45 0.0 // Treat small values as zero
46 } else {
47 rounded_value // Return the rounded value
48 }
49 })
50 .collect();
51
52 // Print the decoded real numbers
53
54 decoded_values // Return the decoded values
55}
56
57// Add noise to a polynomial
58pub fn add_noise(poly: &Polynomial, _pub_key: &impl std::fmt::Debug) -> Polynomial {
59 let mut rng = rand::thread_rng(); // Create a random number generator
60 // Generate noise for each coefficient of the polynomial
61 let noise: Vec<i64> = poly.coeffs.iter().map(|&coeff| coeff + rng.gen_range(-10..10)).collect();
62 Polynomial::new(noise) // Return a new polynomial with added noise
63}
64
65// Modular reduction using the prime modulus q
66pub fn mod_reduce(poly: &Polynomial, modulus: i64) -> Polynomial {
67 // Reduce each coefficient of the polynomial modulo the given modulus
68 let reduced: Vec<i64> = poly.coeffs.iter().map(|&coeff| coeff % modulus).collect();
69 Polynomial::new(reduced) // Return a new polynomial with reduced coefficients
70}
71
72// Modular reduction using the prime modulus q
73pub fn mod_reduce_string(poly: &Polynomial, modulus: i64) -> Polynomial {
74 // Reduce each coefficient of the polynomial modulo the given modulus
75 let reduced: Vec<i64> = poly.coeffs.iter().map(|&coeff| coeff % modulus).collect();
76
77
78 // Filter out zero coefficients if necessary (optional)
79 let filtered: Vec<i64> = reduced.into_iter().filter(|&coeff| coeff != 0).collect();
80
81 Polynomial::new(filtered) // Return a new polynomial with reduced coefficients
82}
83
84pub fn encode_string(plaintext: &str, scaling_factor: f64) -> Polynomial {
85 if scaling_factor <= 0.0 {
86 panic!("Scaling factor must be positive");
87 }
88
89 // Convert each character to its Unicode code point and scale it
90 let coeffs: Vec<i64> = plaintext.chars()
91 .map(|c| {
92 let unicode_val = c as u32; // Get Unicode code point
93 let scaled_val = (unicode_val as f64 * scaling_factor).round(); // Scale and round
94 scaled_val as i64 // Convert to i64 for polynomial storage
95 })
96 .collect();
97
98
99 Polynomial::new(coeffs) // Return the polynomial with encoded coefficients
100}
101
102// Decode polynomial back to a string
103pub fn decode_string(ciphertext: &Polynomial, scaling_factor: f64) -> String {
104 if scaling_factor <= 0.0 {
105 panic!("Scaling factor must be positive");
106 }
107
108 // Reverse scaling and convert coefficients back to Unicode characters
109 let decoded_chars: String = ciphertext.coeffs.iter()
110 .map(|&c| {
111 let scaled_val = c as f64 / scaling_factor; // Reverse scaling
112 let unicode_val = scaled_val.round() as u32; // Convert back to Unicode value
113 std::char::from_u32(unicode_val).unwrap_or('?') // Map to character or '?' if invalid
114 })
115 .collect();
116
117
118 decoded_chars // Return the decoded string
119}