headless-talk 0.6.1

Headless talk implementation
Documentation
pub mod config;

use std::{io, pin::pin, sync::Arc};

use diesel::{QueryDsl, RunQueryDsl};
use futures::{AsyncRead, AsyncWrite, Future, TryStreamExt};
use futures_loco_protocol::{
    loco_protocol::command::BoxedCommand,
    session::{LocoSession, LocoSessionStream},
    LocoClient,
};
use talk_loco_client::talk::session::{
    load_channel_list::{self},
    login, TalkSession,
};
use thiserror::Error;
use tokio::time;

use crate::{
    conn::Conn,
    constants::PING_INTERVAL,
    database::{schema::channel_list, DatabasePool, MigrationError, PoolTaskError},
    event::ClientEvent,
    handler::{error::HandlerError, SessionHandler},
    task::BackgroundTask,
    updater::list::ChannelListUpdater,
    ClientError, ClientStatus, HeadlessTalk,
};

use self::config::ClientEnv;

pub struct TalkInitializer<'a, S> {
    session: LocoSession,
    stream: LocoSessionStream<S>,

    pool: DatabasePool,

    env: ClientEnv<'a>,
}

impl<'a, S: AsyncRead + AsyncWrite + Unpin> TalkInitializer<'a, S> {
    pub async fn new(
        client: LocoClient<S>,
        env: ClientEnv<'a>,
        database_url: impl Into<String>,
    ) -> Result<TalkInitializer<'a, S>, InitError> {
        let (session, stream) = LocoSession::new(client);

        let pool = DatabasePool::initialize(database_url).await?;
        pool.migrate_to_latest().await?;

        Ok(Self {
            session,
            stream,

            pool,

            env,
        })
    }

    pub async fn login<F, Fut>(
        mut self,
        credential: Credential<'_>,
        status: ClientStatus,
        command_handler: F,
    ) -> Result<HeadlessTalk, LoginError>
    where
        S: Send + 'static,
        F: Fn(Result<ClientEvent, HandlerError>) -> Fut + Send + Sync + 'static,
        Fut: Future + Send + Sync + 'static,
    {
        let mut channel_list = Vec::new();

        let (chat_ids, max_ids) = self
            .pool
            .spawn(|conn| {
                let iter = channel_list::table
                    .select((channel_list::id, channel_list::last_seen_log_id))
                    .load_iter::<(i64, Option<i64>), _>(conn)?;

                let mut chat_ids = Vec::with_capacity(iter.size_hint().0);
                let mut max_ids = Vec::with_capacity(iter.size_hint().0);

                for res in iter {
                    let (channel_id, max_id) = res?;

                    chat_ids.push(channel_id);
                    max_ids.push(max_id.unwrap_or(0));
                }

                Ok((chat_ids, max_ids))
            })
            .await
            .map_err(ClientError::from)?;

        let mut stream_buffer = Vec::new();

        let (user_id, deleted_channels) =
            run_session(&mut self.stream, &mut stream_buffer, async {
                let (res, stream) = TalkSession(&self.session)
                    .login(login::Request {
                        os: self.env.os,
                        net_type: self.env.net_type as _,
                        app_version: self.env.app_version,
                        mccmnc: self.env.mccmnc,
                        protocol_version: "1.0",
                        device_uuid: credential.device_uuid,
                        oauth_token: credential.access_token,
                        language: self.env.language,
                        device_type: Some(2),
                        pc_status: Some(status as _),
                        revision: None,
                        rp: [0x00, 0x00, 0xff, 0xff, 0x00, 0x00],
                        chat_list: load_channel_list::Request {
                            chat_ids: &chat_ids,
                            max_ids: &max_ids,
                            last_token_id: 0,
                            last_chat_id: None,
                        },
                        last_block_token: 0,
                        background: None,
                    })
                    .await?;

                channel_list.push(res.chat_list.chat_datas);

                if let Some(stream) = stream {
                    let mut stream = pin!(stream);

                    while let Some(res) = stream.try_next().await? {
                        channel_list.push(res.chat_datas);
                    }
                }

                Ok::<_, ClientError>((res.user_id, res.chat_list.deleted_chat_ids))
            })
            .await??;

        let conn = Conn {
            user_id,
            session: self.session.clone(),
            pool: self.pool.clone(),
        };

        let stream_task = BackgroundTask::new(
            tokio::spawn({
                let command_handler = Arc::new(command_handler);
                let handler = Arc::new(SessionHandler::new(conn.clone()));

                async move {
                    for read in stream_buffer {
                        tokio::spawn({
                            let command_handler = command_handler.clone();
                            let handler = handler.clone();

                            async move {
                                match handler.handle(read).await {
                                    Ok(Some(event)) => {
                                        command_handler(Ok(event)).await;
                                    }

                                    Err(err) => {
                                        command_handler(Err(err)).await;
                                    }

                                    _ => {}
                                }
                            }
                        });
                    }

                    let res: Result<ClientEvent, HandlerError> = async {
                        let mut stream = pin!(self.stream);
                        while let Some(read) = stream.try_next().await? {
                            tokio::spawn({
                                let command_handler = command_handler.clone();
                                let handler = handler.clone();

                                async move {
                                    match handler.handle(read).await {
                                        Ok(Some(event)) => {
                                            command_handler(Ok(event)).await;
                                        }

                                        Err(err) => {
                                            command_handler(Err(err)).await;
                                        }

                                        _ => {}
                                    }
                                }
                            });
                        }

                        unreachable!();
                    }
                    .await;

                    command_handler(res);
                }
            })
            .abort_handle(),
        );

        let ping_task = BackgroundTask::new(
            tokio::spawn({
                let session = self.session.clone();

                async move {
                    let mut interval = time::interval(PING_INTERVAL);

                    while TalkSession(&session).ping().await.is_ok() {
                        interval.tick().await;
                    }
                }
            })
            .abort_handle(),
        );

        ChannelListUpdater::new(&self.session, &self.pool)
            .update(channel_list.into_iter().flatten(), deleted_channels)
            .await?;

        Ok(HeadlessTalk {
            conn,
            _ping_task: ping_task,
            _stream_task: stream_task,
        })
    }
}

#[derive(Debug, Error)]
#[error(transparent)]
pub enum LoginError {
    Client(#[from] ClientError),
    Io(#[from] io::Error),
}

#[derive(Debug, Error)]
#[error(transparent)]
pub enum InitError {
    DatabaseInit(#[from] PoolTaskError),
    Migration(#[from] MigrationError),
}

#[derive(Debug, Clone, Copy)]
pub struct Credential<'a> {
    pub access_token: &'a str,
    pub device_uuid: &'a str,
}

async fn run_session<F: Future>(
    stream: &mut LocoSessionStream<impl AsyncRead + AsyncWrite + Unpin>,
    buffer: &mut Vec<BoxedCommand>,
    task: F,
) -> Result<F::Output, io::Error> {
    let stream_task = async {
        while let Some(read) = stream.try_next().await? {
            buffer.push(read);
        }

        Ok::<_, io::Error>(())
    };

    Ok(tokio::select! {
        res = task => res,
        res = stream_task => {
            res?;
            unreachable!();
        },
    })
}