use crate::mpc::Relay;
use crate::{AgentKeyShare, Error, Result, SessionConfig, THRESHOLD};
use k256::{
ProjectivePoint, Scalar,
elliptic_curve::{Field, bigint::U256, ops::Reduce, sec1::ToEncodedPoint},
};
use rand::rngs::OsRng;
use serde::{Deserialize, Serialize};
use tracing::{debug, info, instrument};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RefreshRound1Message {
pub party_id: usize,
pub commitments: Vec<Vec<u8>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RefreshRound2Message {
pub from: usize,
pub to: usize,
pub share: Vec<u8>,
}
#[instrument(skip(share, relay))]
pub async fn refresh_shares<R: Relay>(
share: &AgentKeyShare,
config: &SessionConfig,
relay: &R,
) -> Result<AgentKeyShare> {
info!(
party_id = config.party_id,
role = %config.role,
"Starting key share refresh"
);
debug!("Refresh Round 1: Generating refresh polynomial");
let (refresh_poly, commitments) = generate_refresh_polynomial(THRESHOLD)?;
let round1_msg = RefreshRound1Message {
party_id: config.party_id,
commitments,
};
relay.broadcast(&config.session_id, 1, &round1_msg).await?;
let all_commitments = relay
.collect_broadcasts::<RefreshRound1Message>(&config.session_id, 1, config.n_parties)
.await?;
debug!("Refresh Round 2: Sending refresh shares");
for party_id in &config.parties {
if *party_id == config.party_id {
continue;
}
let refresh_share = evaluate_refresh_polynomial(&refresh_poly, *party_id as u64 + 1);
let round2_msg = RefreshRound2Message {
from: config.party_id,
to: *party_id,
share: refresh_share.to_bytes().to_vec(),
};
relay
.send_direct(&config.session_id, 2, *party_id, &round2_msg)
.await?;
}
let received_shares = relay
.collect_direct::<RefreshRound2Message>(
&config.session_id,
2,
config.party_id,
config.n_parties - 1,
)
.await?;
debug!("Refresh Round 3: Verifying and applying refresh");
for refresh_msg in &received_shares {
let sender_commitments = all_commitments
.iter()
.find(|c| c.party_id == refresh_msg.from)
.ok_or_else(|| Error::VerificationFailed("Missing commitment".into()))?;
verify_refresh_share(
refresh_msg,
&sender_commitments.commitments,
config.party_id,
)?;
}
let my_refresh = evaluate_refresh_polynomial(&refresh_poly, config.party_id as u64 + 1);
let mut new_secret = share.secret_share + my_refresh;
for refresh_msg in &received_shares {
let refresh_bytes: [u8; 32] = refresh_msg
.share
.clone()
.try_into()
.map_err(|_| Error::Deserialization("Invalid share length".into()))?;
let refresh = <Scalar as Reduce<U256>>::reduce_bytes(&refresh_bytes.into());
new_secret = new_secret + refresh;
}
let mut new_share = share.clone();
new_share.secret_share = new_secret;
new_share.metadata.last_refreshed_at = Some(chrono::Utc::now().timestamp());
info!(
party_id = config.party_id,
"Key share refresh completed successfully"
);
Ok(new_share)
}
fn generate_refresh_polynomial(threshold: usize) -> Result<(Vec<Scalar>, Vec<Vec<u8>>)> {
let mut rng = OsRng;
let mut coefficients = Vec::with_capacity(threshold);
let mut commitments = Vec::with_capacity(threshold);
coefficients.push(Scalar::ZERO);
commitments.push(
ProjectivePoint::IDENTITY
.to_affine()
.to_encoded_point(true)
.as_bytes()
.to_vec(),
);
for _ in 1..threshold {
let coef = Scalar::random(&mut rng);
let commitment = (ProjectivePoint::GENERATOR * coef).to_affine();
coefficients.push(coef);
commitments.push(commitment.to_encoded_point(true).as_bytes().to_vec());
}
Ok((coefficients, commitments))
}
fn evaluate_refresh_polynomial(coefficients: &[Scalar], x: u64) -> Scalar {
let x_scalar = Scalar::from(x);
let mut result = Scalar::ZERO;
let mut x_power = Scalar::ONE;
for coef in coefficients {
result = result + (*coef * x_power);
x_power = x_power * x_scalar;
}
result
}
fn verify_refresh_share(
refresh_msg: &RefreshRound2Message,
commitments: &[Vec<u8>],
my_id: usize,
) -> Result<()> {
use k256::{AffinePoint, elliptic_curve::sec1::FromEncodedPoint};
let share_bytes: [u8; 32] = refresh_msg
.share
.clone()
.try_into()
.map_err(|_| Error::Deserialization("Invalid share length".into()))?;
let share = <Scalar as Reduce<U256>>::reduce_bytes(&share_bytes.into());
let expected = ProjectivePoint::GENERATOR * share;
let x = (my_id + 1) as u64;
let mut actual = ProjectivePoint::IDENTITY;
let mut x_power = Scalar::ONE;
let x_scalar = Scalar::from(x);
for commitment_bytes in commitments {
if commitment_bytes.len() == 1 && commitment_bytes[0] == 0 {
x_power = x_power * x_scalar;
continue;
}
let point = k256::EncodedPoint::from_bytes(commitment_bytes)
.map_err(|e| Error::VerificationFailed(e.to_string()))?;
let affine_opt = AffinePoint::from_encoded_point(&point);
let affine: AffinePoint = Option::<AffinePoint>::from(affine_opt)
.ok_or_else(|| Error::VerificationFailed("Invalid commitment point".into()))?;
let commitment = ProjectivePoint::from(affine);
actual = actual + (commitment * x_power);
x_power = x_power * x_scalar;
}
if expected != actual {
return Err(Error::VerificationFailed(format!(
"Refresh share from party {} does not match commitment",
refresh_msg.from
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_refresh_polynomial_zero_free_term() {
let (poly, _) = generate_refresh_polynomial(2).unwrap();
assert_eq!(poly[0], Scalar::ZERO);
let _at_zero = evaluate_refresh_polynomial(&poly, 0);
}
#[test]
fn test_refresh_shares_sum_to_zero() {
let (poly, _) = generate_refresh_polynomial(2).unwrap();
let s1 = evaluate_refresh_polynomial(&poly, 1);
let s2 = evaluate_refresh_polynomial(&poly, 2);
let _s3 = evaluate_refresh_polynomial(&poly, 3);
let lambda_1 = Scalar::from(2u64);
let lambda_2 = -Scalar::ONE;
let reconstructed = s1 * lambda_1 + s2 * lambda_2;
assert_eq!(reconstructed, Scalar::ZERO);
}
}