use alloc::vec::Vec;
use mls_rs_core::{
    crypto::{CipherSuite, SignatureSecretKey},
    extension::ExtensionList,
    identity::SigningIdentity,
    protocol_version::ProtocolVersion,
};
use crate::{client::MlsError, Client, Group, MlsMessage};
use super::{
    proposal::ReInitProposal, ClientConfig, ExportedTree, JustPreSharedKeyID, MessageProcessor,
    NewMemberInfo, PreSharedKeyID, PskGroupId, PskSecretInput, ResumptionPSKUsage, ResumptionPsk,
};
struct ResumptionGroupParameters<'a> {
    group_id: &'a [u8],
    cipher_suite: CipherSuite,
    version: ProtocolVersion,
    extensions: &'a ExtensionList,
}
pub struct ReinitClient<C: ClientConfig + Clone> {
    client: Client<C>,
    reinit: ReInitProposal,
    psk_input: PskSecretInput,
}
impl<C> Group<C>
where
    C: ClientConfig + Clone,
{
    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    pub async fn branch(
        &self,
        sub_group_id: Vec<u8>,
        new_key_packages: Vec<MlsMessage>,
    ) -> Result<(Group<C>, Vec<MlsMessage>), MlsError> {
        let new_group_params = ResumptionGroupParameters {
            group_id: &sub_group_id,
            cipher_suite: self.cipher_suite(),
            version: self.protocol_version(),
            extensions: &self.group_state().context.extensions,
        };
        let current_leaf_node_extensions = &self.current_user_leaf_node()?.ungreased_extensions();
        resumption_create_group(
            self.config.clone(),
            new_key_packages,
            &new_group_params,
            self.current_member_signing_identity()?.clone(),
            self.signer.clone(),
            current_leaf_node_extensions,
            #[cfg(any(feature = "private_message", feature = "psk"))]
            self.resumption_psk_input(ResumptionPSKUsage::Branch)?,
        )
        .await
    }
    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    pub async fn join_subgroup(
        &self,
        welcome: &MlsMessage,
        tree_data: Option<ExportedTree<'_>>,
    ) -> Result<(Group<C>, NewMemberInfo), MlsError> {
        let expected_new_group_prams = ResumptionGroupParameters {
            group_id: &[],
            cipher_suite: self.cipher_suite(),
            version: self.protocol_version(),
            extensions: &self.group_state().context.extensions,
        };
        resumption_join_group(
            self.config.clone(),
            self.signer.clone(),
            welcome,
            tree_data,
            expected_new_group_prams,
            false,
            self.resumption_psk_input(ResumptionPSKUsage::Branch)?,
        )
        .await
    }
    pub fn get_reinit_client(
        self,
        new_signer: Option<SignatureSecretKey>,
        new_signing_identity: Option<SigningIdentity>,
    ) -> Result<ReinitClient<C>, MlsError> {
        let psk_input = self.resumption_psk_input(ResumptionPSKUsage::Reinit)?;
        let new_signing_identity = new_signing_identity
            .map(Ok)
            .unwrap_or_else(|| self.current_member_signing_identity().cloned())?;
        let reinit = self
            .state
            .pending_reinit
            .ok_or(MlsError::PendingReInitNotFound)?;
        let new_signer = match new_signer {
            Some(signer) => signer,
            None => self.signer,
        };
        let client = Client::new(
            self.config,
            Some(new_signer),
            Some((new_signing_identity, reinit.new_cipher_suite())),
            reinit.new_version(),
        );
        Ok(ReinitClient {
            client,
            reinit,
            psk_input,
        })
    }
    fn resumption_psk_input(&self, usage: ResumptionPSKUsage) -> Result<PskSecretInput, MlsError> {
        let psk = self.epoch_secrets.resumption_secret.clone();
        let id = JustPreSharedKeyID::Resumption(ResumptionPsk {
            usage,
            psk_group_id: PskGroupId(self.group_id().to_vec()),
            psk_epoch: self.current_epoch(),
        });
        let id = PreSharedKeyID::new(id, self.cipher_suite_provider())?;
        Ok(PskSecretInput { id, psk })
    }
}
impl<C: ClientConfig + Clone> ReinitClient<C> {
    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    pub async fn generate_key_package(&self) -> Result<MlsMessage, MlsError> {
        self.client
            .generate_key_package_message(Default::default(), Default::default())
            .await
    }
    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    pub async fn commit(
        self,
        new_key_packages: Vec<MlsMessage>,
        new_leaf_node_extensions: ExtensionList,
    ) -> Result<(Group<C>, Vec<MlsMessage>), MlsError> {
        let new_group_params = ResumptionGroupParameters {
            group_id: self.reinit.group_id(),
            cipher_suite: self.reinit.new_cipher_suite(),
            version: self.reinit.new_version(),
            extensions: self.reinit.new_group_context_extensions(),
        };
        resumption_create_group(
            self.client.config.clone(),
            new_key_packages,
            &new_group_params,
            self.client.signing_identity.unwrap().0,
            self.client.signer.unwrap(),
            &new_leaf_node_extensions,
            #[cfg(any(feature = "private_message", feature = "psk"))]
            self.psk_input,
        )
        .await
    }
    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    pub async fn join(
        self,
        welcome: &MlsMessage,
        tree_data: Option<ExportedTree<'_>>,
    ) -> Result<(Group<C>, NewMemberInfo), MlsError> {
        let reinit = self.reinit;
        let expected_group_params = ResumptionGroupParameters {
            group_id: reinit.group_id(),
            cipher_suite: reinit.new_cipher_suite(),
            version: reinit.new_version(),
            extensions: reinit.new_group_context_extensions(),
        };
        resumption_join_group(
            self.client.config,
            self.client.signer.unwrap(),
            welcome,
            tree_data,
            expected_group_params,
            true,
            self.psk_input,
        )
        .await
    }
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
async fn resumption_create_group<C: ClientConfig + Clone>(
    config: C,
    new_key_packages: Vec<MlsMessage>,
    new_group_params: &ResumptionGroupParameters<'_>,
    signing_identity: SigningIdentity,
    signer: SignatureSecretKey,
    leaf_node_extensions: &ExtensionList,
    psk_input: PskSecretInput,
) -> Result<(Group<C>, Vec<MlsMessage>), MlsError> {
    let mut group = Group::new(
        config,
        Some(new_group_params.group_id.to_vec()),
        new_group_params.cipher_suite,
        new_group_params.version,
        signing_identity,
        new_group_params.extensions.clone(),
        leaf_node_extensions.clone(),
        signer,
    )
    .await?;
    group.previous_psk = Some(psk_input);
    let mut commit = group.commit_builder();
    for kp in new_key_packages.into_iter() {
        commit = commit.add_member(kp)?;
    }
    let commit = commit.build().await?;
    group.apply_pending_commit().await?;
    group.previous_psk = None;
    Ok((group, commit.welcome_messages))
}
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
async fn resumption_join_group<C: ClientConfig + Clone>(
    config: C,
    signer: SignatureSecretKey,
    welcome: &MlsMessage,
    tree_data: Option<ExportedTree<'_>>,
    expected_new_group_params: ResumptionGroupParameters<'_>,
    verify_group_id: bool,
    psk_input: PskSecretInput,
) -> Result<(Group<C>, NewMemberInfo), MlsError> {
    let psk_input = Some(psk_input);
    let (group, new_member_info) =
        Group::<C>::from_welcome_message(welcome, tree_data, config, signer, psk_input).await?;
    if group.protocol_version() != expected_new_group_params.version {
        Err(MlsError::ProtocolVersionMismatch)
    } else if group.cipher_suite() != expected_new_group_params.cipher_suite {
        Err(MlsError::CipherSuiteMismatch)
    } else if verify_group_id && group.group_id() != expected_new_group_params.group_id {
        Err(MlsError::GroupIdMismatch)
    } else if &group.group_state().context.extensions != expected_new_group_params.extensions {
        Err(MlsError::ReInitExtensionsMismatch)
    } else {
        Ok((group, new_member_info))
    }
}