use std::collections::HashMap;
use anyhow::{Result, anyhow};
use wacore::poll;
use wacore_binary::jid::{Jid, JidExt};
use waproto::whatsapp as wa;
use crate::client::Client;
#[derive(Debug, Clone)]
pub struct PollOptionResult {
pub name: String,
pub voters: Vec<String>,
}
pub struct Polls<'a> {
client: &'a Client,
}
impl<'a> Polls<'a> {
pub(crate) fn new(client: &'a Client) -> Self {
Self { client }
}
pub async fn create(
&self,
to: &Jid,
name: &str,
options: &[String],
selectable_count: u32,
) -> Result<(String, Vec<u8>)> {
if options.len() < 2 {
return Err(anyhow!("Poll must have at least 2 options"));
}
if options.len() > 12 {
return Err(anyhow!("Polls can have a maximum of 12 options"));
}
if selectable_count < 1 || selectable_count > options.len() as u32 {
return Err(anyhow!(
"selectable_count must be between 1 and {} (got {selectable_count})",
options.len()
));
}
let mut seen = std::collections::HashSet::new();
for opt in options {
if !seen.insert(opt) {
return Err(anyhow!("Duplicate option name: {opt}"));
}
}
let poll_options: Vec<wa::message::poll_creation_message::Option> = options
.iter()
.map(|name| wa::message::poll_creation_message::Option {
option_name: Some(name.clone()),
option_hash: None,
})
.collect();
let poll_msg = wa::message::PollCreationMessage {
enc_key: None,
name: Some(name.to_string()),
options: poll_options,
selectable_options_count: Some(selectable_count),
context_info: None,
poll_content_type: None,
poll_type: None,
correct_answer: None,
};
let mut message = if selectable_count == 1 {
wa::Message {
poll_creation_message_v3: Some(Box::new(poll_msg)),
..Default::default()
}
} else {
wa::Message {
poll_creation_message: Some(Box::new(poll_msg)),
..Default::default()
}
};
let message_secret: Vec<u8> = {
use rand::Rng;
let mut secret = vec![0u8; 32];
rand::make_rng::<rand::rngs::StdRng>().fill_bytes(&mut secret);
secret
};
message.message_context_info = Some(wa::MessageContextInfo {
message_secret: Some(message_secret.clone()),
..Default::default()
});
let msg_id = self.client.send_message(to.clone(), message).await?;
Ok((msg_id, message_secret))
}
pub async fn vote(
&self,
chat_jid: &Jid,
poll_msg_id: &str,
poll_creator_jid: &Jid,
message_secret: &[u8],
option_names: &[String],
) -> Result<String> {
let my_jid = self
.client
.get_pn()
.await
.ok_or_else(|| anyhow!("Not logged in — cannot determine own JID"))?;
let voter_jid_str = my_jid.to_non_ad().to_string();
let creator_jid_str = poll_creator_jid.to_non_ad().to_string();
let selected_hashes: Vec<Vec<u8>> = option_names
.iter()
.map(|name| poll::compute_option_hash(name).to_vec())
.collect();
let key = poll::derive_vote_encryption_key(
message_secret,
poll_msg_id,
&creator_jid_str,
&voter_jid_str,
)?;
let (enc_payload, iv) =
poll::encrypt_poll_vote(&selected_hashes, &key, poll_msg_id, &voter_jid_str)?;
let from_me = my_jid.to_non_ad() == poll_creator_jid.to_non_ad();
let poll_update = wa::message::PollUpdateMessage {
poll_creation_message_key: Some(wa::MessageKey {
remote_jid: Some(chat_jid.to_string()),
from_me: Some(from_me),
id: Some(poll_msg_id.to_string()),
participant: if chat_jid.is_group() {
Some(poll_creator_jid.to_string())
} else {
None
},
}),
vote: Some(wa::message::PollEncValue {
enc_payload: Some(enc_payload),
enc_iv: Some(iv.to_vec()),
}),
metadata: Some(wa::message::PollUpdateMessageMetadata {}),
sender_timestamp_ms: Some(wacore::time::now_millis()),
};
let message = wa::Message {
poll_update_message: Some(poll_update),
..Default::default()
};
self.client.send_message(chat_jid.clone(), message).await
}
pub fn decrypt_vote(
enc_payload: &[u8],
enc_iv: &[u8],
message_secret: &[u8],
poll_msg_id: &str,
poll_creator_jid: &Jid,
voter_jid: &Jid,
) -> Result<Vec<Vec<u8>>> {
let creator = poll_creator_jid.to_non_ad().to_string();
let voter = voter_jid.to_non_ad().to_string();
let key = poll::derive_vote_encryption_key(message_secret, poll_msg_id, &creator, &voter)?;
poll::decrypt_poll_vote(enc_payload, enc_iv, &key, poll_msg_id, &voter)
}
pub fn aggregate_votes(
poll_options: &[String],
votes: &[(&Jid, &[u8], &[u8])], message_secret: &[u8],
poll_msg_id: &str,
poll_creator_jid: &Jid,
) -> Result<Vec<PollOptionResult>> {
let option_hashes: Vec<([u8; 32], &str)> = poll_options
.iter()
.map(|name| (poll::compute_option_hash(name), name.as_str()))
.collect();
let mut latest_votes: HashMap<String, Vec<Vec<u8>>> = HashMap::new();
for (voter_jid, enc_payload, enc_iv) in votes {
let voter_key = voter_jid.to_non_ad().to_string();
match Self::decrypt_vote(
enc_payload,
enc_iv,
message_secret,
poll_msg_id,
poll_creator_jid,
voter_jid,
) {
Ok(selected_hashes) => {
if selected_hashes.is_empty() {
latest_votes.remove(&voter_key);
} else {
latest_votes.insert(voter_key, selected_hashes);
}
}
Err(e) => {
log::warn!("Failed to decrypt vote from {voter_jid}: {e}");
}
}
}
let mut results: Vec<PollOptionResult> = poll_options
.iter()
.map(|name| PollOptionResult {
name: name.clone(),
voters: Vec::new(),
})
.collect();
for (voter_jid, selected_hashes) in &latest_votes {
for hash in selected_hashes {
if let Ok(hash_arr) = <[u8; 32]>::try_from(hash.as_slice())
&& let Some(idx) = option_hashes.iter().position(|(h, _)| *h == hash_arr)
{
results[idx].voters.push(voter_jid.clone());
}
}
}
Ok(results)
}
}
impl Client {
pub fn polls(&self) -> Polls<'_> {
Polls::new(self)
}
}