use crate::client::Client;
use crate::request::{InfoQuery, InfoQueryType, IqError};
use crate::types::events::Event;
use log::{error, info, warn};
use std::sync::Arc;
use wacore::libsignal::protocol::KeyPair;
use wacore::pair_code::{PairCodeState, PairCodeUtils, resolve_companion_platform};
use wacore_binary::Jid;
use wacore_binary::{NodeContent, NodeContentRef, NodeRef};
pub use wacore::companion_reg::CompanionWebClientType;
pub use wacore::pair_code::{PairCodeError, PairCodeOptions};
#[derive(Debug, thiserror::Error)]
pub enum PairError {
#[error(transparent)]
PairCode(#[from] PairCodeError),
#[error("pair-code IQ request failed")]
RequestFailed(#[from] IqError),
}
impl Client {
pub async fn pair_with_code(
self: &Arc<Self>,
options: PairCodeOptions,
) -> Result<String, PairError> {
let phone_number: String = options
.phone_number
.chars()
.filter(|c| c.is_ascii_digit())
.collect();
if phone_number.is_empty() {
return Err(PairCodeError::PhoneNumberRequired.into());
}
if phone_number.len() < 7 {
return Err(PairCodeError::PhoneNumberTooShort.into());
}
if phone_number.starts_with('0') {
return Err(PairCodeError::PhoneNumberNotInternational.into());
}
let code = match &options.custom_code {
Some(custom) => {
if !PairCodeUtils::validate_code(custom) {
return Err(PairCodeError::InvalidCustomCode.into());
}
custom.to_uppercase()
}
None => PairCodeUtils::generate_code(),
};
info!(
target: "Client/PairCode",
"Starting pair code authentication for phone: {}",
phone_number
);
let ephemeral_keypair = KeyPair::generate(&mut rand::make_rng::<rand::rngs::StdRng>());
let device_snapshot = self.persistence_manager.get_device_snapshot().await;
let noise_static_pub: [u8; 32] = device_snapshot
.noise_key
.public_key
.public_key_bytes()
.try_into()
.expect("noise key is 32 bytes");
let code_clone = code.clone();
let ephemeral_pub: [u8; 32] = ephemeral_keypair
.public_key
.public_key_bytes()
.try_into()
.expect("ephemeral key is 32 bytes");
let wrapped_ephemeral = wacore::runtime::blocking(&*self.runtime, move || {
PairCodeUtils::encrypt_ephemeral_pub(&ephemeral_pub, &code_clone)
})
.await;
let (platform_id, platform_display) =
resolve_companion_platform(&options, &device_snapshot.device_props);
let platform_id_str = platform_id.to_string();
let req_id = self.generate_request_id();
let iq_content = PairCodeUtils::build_companion_hello_iq(
&phone_number,
&noise_static_pub,
&wrapped_ephemeral,
&platform_id_str,
&platform_display,
options.show_push_notification,
req_id.clone(),
);
let query = InfoQuery {
query_type: InfoQueryType::Set,
namespace: "md",
to: Jid::new("", wacore_binary::Server::Pn),
target: None,
content: Some(NodeContent::Nodes(
iq_content
.children()
.map(|c| c.to_vec())
.unwrap_or_default(),
)),
id: Some(req_id),
timeout: Some(std::time::Duration::from_secs(30)),
};
let response = self.send_iq(query).await?;
let pairing_ref = PairCodeUtils::parse_companion_hello_response(response.get())
.ok_or(PairCodeError::MissingPairingRef)?;
info!(
target: "Client/PairCode",
"Stage 1 complete, waiting for phone confirmation. Code: {}",
code
);
*self.pair_code_state.lock().await = PairCodeState::WaitingForPhoneConfirmation {
pairing_ref,
phone_jid: phone_number,
pair_code: code.clone(),
ephemeral_keypair: Box::new(ephemeral_keypair),
};
self.core.event_bus.dispatch(Event::PairingCode {
code: code.clone(),
timeout: PairCodeUtils::code_validity(),
});
Ok(code)
}
}
pub(crate) async fn handle_pair_code_notification(
client: &Arc<Client>,
node: &NodeRef<'_>,
) -> bool {
let Some(reg_node) = node.get_optional_child_by_tag(&["link_code_companion_reg"]) else {
return false;
};
let primary_wrapped_ephemeral = match reg_node
.get_optional_child_by_tag(&["link_code_pairing_wrapped_primary_ephemeral_pub"])
.and_then(|n| match n.content.as_deref() {
Some(NodeContentRef::Bytes(b)) if b.len() == 80 => Some(b.to_vec()),
_ => None,
}) {
Some(b) => b,
None => {
warn!(
target: "Client/PairCode",
"Missing or invalid primary wrapped ephemeral pub in notification"
);
return false;
}
};
let primary_identity_pub: [u8; 32] = match reg_node
.get_optional_child_by_tag(&["primary_identity_pub"])
.and_then(|n| match n.content.as_deref() {
Some(NodeContentRef::Bytes(b)) if b.len() == 32 => b.as_ref().try_into().ok(),
_ => None,
}) {
Some(arr) => arr,
None => {
warn!(
target: "Client/PairCode",
"Missing or invalid primary identity pub in notification"
);
return false;
}
};
let mut state_guard = client.pair_code_state.lock().await;
let state = std::mem::take(&mut *state_guard);
drop(state_guard);
let (pairing_ref, phone_jid, pair_code, ephemeral_keypair) = match state {
PairCodeState::WaitingForPhoneConfirmation {
pairing_ref,
phone_jid,
pair_code,
ephemeral_keypair,
} => (pairing_ref, phone_jid, pair_code, ephemeral_keypair),
_ => {
warn!(
target: "Client/PairCode",
"Received pair code notification but not in waiting state"
);
return false;
}
};
info!(
target: "Client/PairCode",
"Phone confirmed code entry, processing stage 2"
);
let pair_code_clone = pair_code.clone();
let primary_ephemeral_pub = match wacore::runtime::blocking(&*client.runtime, move || {
PairCodeUtils::decrypt_primary_ephemeral_pub(&primary_wrapped_ephemeral, &pair_code_clone)
})
.await
{
Ok(pub_key) => pub_key,
Err(e) => {
error!(
target: "Client/PairCode",
"Failed to decrypt primary ephemeral pub: {e}"
);
return false;
}
};
let device_snapshot = client.persistence_manager.get_device_snapshot().await;
let (wrapped_bundle, new_adv_secret) = match PairCodeUtils::prepare_key_bundle(
&ephemeral_keypair,
&primary_ephemeral_pub,
&primary_identity_pub,
&device_snapshot.identity_key,
) {
Ok(result) => result,
Err(e) => {
error!(target: "Client/PairCode", "Failed to prepare key bundle: {e}");
return false;
}
};
client
.persistence_manager
.process_command(crate::store::commands::DeviceCommand::SetAdvSecretKey(
new_adv_secret,
))
.await;
let req_id = client.generate_request_id();
let identity_pub: [u8; 32] = device_snapshot
.identity_key
.public_key
.public_key_bytes()
.try_into()
.expect("identity key is 32 bytes");
let iq = PairCodeUtils::build_companion_finish_iq(
&phone_jid,
wrapped_bundle,
&identity_pub,
&pairing_ref,
req_id,
);
if let Err(e) = client.send_node(iq).await {
error!(target: "Client/PairCode", "Failed to send companion_finish: {e}");
return false;
}
info!(
target: "Client/PairCode",
"Sent companion_finish, waiting for pair-success"
);
*client.pair_code_state.lock().await = PairCodeState::Completed;
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pair_error_request_failed_preserves_iq_source() {
let iq = IqError::ServerError {
code: 400,
text: "bad-request".into(),
};
let pe: PairError = iq.into();
let src = std::error::Error::source(&pe).expect("source preserved");
let downcast = src.downcast_ref::<IqError>().expect("downcasts to IqError");
assert!(matches!(downcast, IqError::ServerError { code: 400, .. }));
}
#[test]
fn pair_error_paircode_transparent_walks_to_curve_error() {
use wacore::libsignal::protocol::CurveError;
let pe: PairError =
PairCodeError::EphemeralKeyAgreement(CurveError::NoKeyTypeIdentifier).into();
assert_eq!(pe.to_string(), "ephemeral key agreement failed");
let src = std::error::Error::source(&pe).expect("source preserved");
let curve = src
.downcast_ref::<CurveError>()
.expect("downcasts to CurveError through transparent wrapper");
assert!(matches!(curve, CurveError::NoKeyTypeIdentifier));
}
}