use once_cell::sync::OnceCell;
use std::sync::{Arc, Weak};
use fhe_math::{
rns::ScalingFactor,
rq::{scaler::Scaler, Context},
};
use crate::bfv::{context::CipherPlainContext, parameters::MultiplicationParameters};
#[derive(Debug, Clone)]
pub struct ContextLevel {
pub poly_context: Arc<Context>,
pub(crate) cipher_plain_context: Arc<CipherPlainContext>,
pub(crate) level: usize,
pub(crate) num_moduli: usize,
pub next: OnceCell<Arc<ContextLevel>>,
pub(crate) prev: OnceCell<Weak<ContextLevel>>,
pub(crate) down_scaler: OnceCell<Arc<Scaler>>,
pub(crate) up_scaler: OnceCell<Arc<Scaler>>,
pub(crate) mul_params: OnceCell<MultiplicationParameters>,
}
impl PartialEq for ContextLevel {
fn eq(&self, other: &Self) -> bool {
self.level == other.level
&& self.num_moduli == other.num_moduli
&& self.cipher_plain_context == other.cipher_plain_context
}
}
impl Eq for ContextLevel {}
impl ContextLevel {
pub fn new(
poly_context: Arc<Context>,
cipher_plain_context: Arc<CipherPlainContext>,
level: usize,
) -> Self {
Self {
num_moduli: poly_context.moduli().len(),
poly_context,
cipher_plain_context,
level,
next: OnceCell::new(),
prev: OnceCell::new(),
down_scaler: OnceCell::new(),
up_scaler: OnceCell::new(),
mul_params: OnceCell::new(),
}
}
pub fn chain(prev: &Arc<Self>, next: &Arc<Self>) {
if let Ok(ds) = Scaler::new(&prev.poly_context, &next.poly_context, ScalingFactor::one()) {
let _ = prev.down_scaler.set(Arc::new(ds));
}
if let Ok(us) = Scaler::new(&next.poly_context, &prev.poly_context, ScalingFactor::one()) {
let _ = next.up_scaler.set(Arc::new(us));
}
let _ = prev.next.set(next.clone());
let _ = next.prev.set(Arc::downgrade(prev));
}
pub fn can_switch_down(&self) -> bool {
self.next.get().is_some()
}
pub fn max_level(&self) -> usize {
let mut max = self.level;
let mut current = self.next.get();
while let Some(ctx) = current {
max = ctx.level;
current = ctx.next.get();
}
max
}
pub fn iter_chain(&self) -> impl Iterator<Item = Arc<ContextLevel>> {
let head = if let Some(prev) = self.prev.get().and_then(|w| w.upgrade()) {
let mut head = prev;
while let Some(p) = head.prev.get().and_then(|w| w.upgrade()) {
head = p;
}
head
} else {
Arc::new(self.clone())
};
std::iter::successors(Some(head), |ctx| ctx.next.get().cloned())
}
pub(crate) fn mul_params(&self) -> &MultiplicationParameters {
self.mul_params
.get()
.expect("multiplication parameters not set")
}
}
#[cfg(test)]
mod tests {
use crate::bfv::BfvParametersBuilder;
#[test]
fn chain_basics() {
let params = BfvParametersBuilder::new()
.set_degree(16)
.set_plaintext_modulus(1153)
.set_moduli_sizes(&[50, 50])
.build()
.unwrap();
let head = params.context_chain();
assert!(head.can_switch_down());
let next = head.next.get().unwrap();
assert!(!next.can_switch_down());
assert_eq!(head.max_level(), 1);
let chain: Vec<_> = head.iter_chain().collect();
assert_eq!(chain.len(), 2);
}
}