imessage-database 4.0.0

Parsers and tools to interact with iMessage SQLite data
Documentation
/*!
 These are the poll messages generated by the iOS 26 (or newer) Polls app.
*/

use std::collections::HashMap;

use base64::{Engine as _, engine::general_purpose};
use jzon::{self, JsonValue};
use plist::Value;

use crate::{
    error::plist::PlistParseError,
    util::plist::{get_string_from_nested_dict, parse_ns_keyed_archiver},
};

/// Type alias for Poll Option ID
pub type PollOptionID = String;

/// Represents a Poll option
#[derive(Debug, PartialEq, Eq)]
pub struct PollOption {
    /// The text of the option
    pub text: String,
    /// The creator of the option
    pub creator: String,
    /// The votes for this option
    pub votes: Vec<PollVote>,
}

impl PollOption {
    fn from_json(json_data: &JsonValue) -> (PollOptionID, Self) {
        let poll_id = json_data["optionIdentifier"]
            .as_str()
            .unwrap_or_default()
            .to_string();
        let creator = json_data["creatorHandle"]
            .as_str()
            .unwrap_or_default()
            .to_string();
        let text = json_data["text"].as_str().unwrap_or_default().to_string();

        (
            poll_id,
            PollOption {
                text,
                creator,
                votes: Vec::new(),
            },
        )
    }
}

/// Represents a vote in a Poll
#[derive(Debug, PartialEq, Eq)]
pub struct PollVote {
    /// The handle of the voter
    pub voter: String,
    /// The ID of the option being voted on
    pub option_id: PollOptionID,
}

impl PollVote {
    fn from_json(payload: &JsonValue) -> Result<Self, PlistParseError> {
        let voter = payload["participantHandle"]
            .as_str()
            .ok_or(PlistParseError::MissingKey("participantHandle".to_string()))?
            .to_string();

        let option_id = payload["voteOptionIdentifier"]
            .as_str()
            .ok_or(PlistParseError::MissingKey(
                "voteOptionIdentifier".to_string(),
            ))?
            .to_string();

        Ok(PollVote { voter, option_id })
    }
}

/// Represents a Poll message
#[derive(Debug, PartialEq, Eq)]
pub struct Poll {
    /// Map of option ID to [`PollOption`]
    pub options: HashMap<PollOptionID, PollOption>,
    /// The order of the options as they were created
    pub order: Vec<PollOptionID>,
}

impl Poll {
    /// Parse a Poll from the given payload
    pub fn from_payload(payload: &Value) -> Result<Self, PlistParseError> {
        let parsed = parse_ns_keyed_archiver(payload)?;

        let url = get_string_from_nested_dict(&parsed, "URL")
            .ok_or(PlistParseError::MissingKey("URL".to_string()))?;

        // Parse the JSON
        let parsed_json = base64_url_to_json(url)?;

        let mut options = HashMap::new();
        let mut order = Vec::new();
        let ordered_options = parsed_json["item"]["orderedPollOptions"].as_array().ok_or(
            PlistParseError::MissingKey("orderedPollOptions".to_string()),
        )?;

        for option in ordered_options {
            let (id, poll_option) = PollOption::from_json(option);
            order.push(id.clone());
            options.insert(id, poll_option);
        }

        Ok(Poll { options, order })
    }

    /// Count votes from a vote payload and update the poll options
    pub fn count_votes(&mut self, payload: &Value) -> Result<(), PlistParseError> {
        let parsed = parse_ns_keyed_archiver(payload)?;

        let url = get_string_from_nested_dict(&parsed, "URL")
            .ok_or(PlistParseError::MissingKey("URL".to_string()))?;

        // Parse the JSON
        let parsed_json = base64_url_to_json(url)?;

        let votes = parsed_json["item"]["votes"]
            .as_array()
            .ok_or(PlistParseError::MissingKey("votes".to_string()))?;

        for vote in votes {
            let poll_vote = PollVote::from_json(vote)?;
            if let Some(option) = self.options.get_mut(&poll_vote.option_id) {
                option.votes.push(poll_vote);
            }
        }

        Ok(())
    }
}

fn base64_url_to_json(data: &str) -> Result<JsonValue, PlistParseError> {
    // Strip the fixed prefix
    let after_prefix = data
        .strip_prefix("data:,")
        .ok_or(PlistParseError::WrongMessageType)?;

    // Extract the base64 part before the first "?" (if present)
    let base64_part = after_prefix
        .split_once('?')
        .map_or(after_prefix, |(before, _)| before);

    // Decode the base64 part
    let bytes = String::from_utf8(
        general_purpose::URL_SAFE
            .decode(base64_part)
            .map_err(|_| PlistParseError::WrongMessageType)?,
    )
    .map_err(|_| PlistParseError::WrongMessageType)?;

    // Parse the JSON
    jzon::parse(&bytes).map_err(|_| PlistParseError::WrongMessageType)
}

#[cfg(test)]
mod tests {
    use std::{env::current_dir, fs::File};

    use plist::Value;

    use crate::message_types::polls::Poll;

    #[test]
    fn test_parse_poll_creation() {
        let plist_path = current_dir()
            .unwrap()
            .as_path()
            .join("test_data/app_message/PollCreate.plist");
        let plist_data = File::open(plist_path).unwrap();
        let plist = Value::from_reader(plist_data).unwrap();
        println!("{:#?}", plist);

        let poll = Poll::from_payload(&plist).unwrap();

        println!("{:#?}", poll);
        assert_eq!(poll.options.len(), 4);
    }

    #[test]
    fn test_parse_poll_votes() {
        // Parse the poll first
        let plist_path = current_dir()
            .unwrap()
            .as_path()
            .join("test_data/app_message/PollCreate.plist");
        let plist_data = File::open(plist_path).unwrap();
        let plist = Value::from_reader(plist_data).unwrap();

        let mut poll = Poll::from_payload(&plist).unwrap();

        let plist_path = current_dir()
            .unwrap()
            .as_path()
            .join("test_data/app_message/PollVote.plist");
        let plist_data = File::open(plist_path).unwrap();
        let plist = Value::from_reader(plist_data).unwrap();

        poll.count_votes(&plist).unwrap();

        let plist_path = current_dir()
            .unwrap()
            .as_path()
            .join("test_data/app_message/PollVote2.plist");
        let plist_data = File::open(plist_path).unwrap();
        let plist = Value::from_reader(plist_data).unwrap();

        poll.count_votes(&plist).unwrap();

        println!("{:#?}", poll);
    }

    #[test]
    fn test_parse_poll_vote_removed() {
        // Parse the poll first
        let plist_path = current_dir()
            .unwrap()
            .as_path()
            .join("test_data/app_message/PollCreate.plist");
        let plist_data = File::open(plist_path).unwrap();
        let plist = Value::from_reader(plist_data).unwrap();

        let mut poll = Poll::from_payload(&plist).unwrap();

        let plist_path = current_dir()
            .unwrap()
            .as_path()
            .join("test_data/app_message/PollRemovedVote.plist");
        let plist_data = File::open(plist_path).unwrap();
        let plist = Value::from_reader(plist_data).unwrap();

        poll.count_votes(&plist).unwrap();
        println!("{:#?}", poll);
    }
}