headless_talk/init/
mod.rs

1pub mod config;
2
3use std::{io, pin::pin, sync::Arc};
4
5use diesel::{QueryDsl, RunQueryDsl};
6use futures::{AsyncRead, AsyncWrite, Future, TryStreamExt};
7use futures_loco_protocol::{
8    loco_protocol::command::BoxedCommand,
9    session::{LocoSession, LocoSessionStream},
10    LocoClient,
11};
12use talk_loco_client::talk::session::{
13    load_channel_list::{self},
14    login, TalkSession,
15};
16use thiserror::Error;
17use tokio::time;
18
19use crate::{
20    conn::Conn,
21    constants::PING_INTERVAL,
22    database::{schema::channel_list, DatabasePool, MigrationError, PoolTaskError},
23    event::ClientEvent,
24    handler::{error::HandlerError, SessionHandler},
25    task::BackgroundTask,
26    updater::list::ChannelListUpdater,
27    ClientError, ClientStatus, HeadlessTalk,
28};
29
30use self::config::ClientEnv;
31
32pub struct TalkInitializer<'a, S> {
33    session: LocoSession,
34    stream: LocoSessionStream<S>,
35
36    pool: DatabasePool,
37
38    env: ClientEnv<'a>,
39}
40
41impl<'a, S: AsyncRead + AsyncWrite + Unpin> TalkInitializer<'a, S> {
42    pub async fn new(
43        client: LocoClient<S>,
44        env: ClientEnv<'a>,
45        database_url: impl Into<String>,
46    ) -> Result<TalkInitializer<'a, S>, InitError> {
47        let (session, stream) = LocoSession::new(client);
48
49        let pool = DatabasePool::initialize(database_url).await?;
50        pool.migrate_to_latest().await?;
51
52        Ok(Self {
53            session,
54            stream,
55
56            pool,
57
58            env,
59        })
60    }
61
62    pub async fn login<F, Fut>(
63        mut self,
64        credential: Credential<'_>,
65        status: ClientStatus,
66        command_handler: F,
67    ) -> Result<HeadlessTalk, LoginError>
68    where
69        S: Send + 'static,
70        F: Fn(Result<ClientEvent, HandlerError>) -> Fut + Send + Sync + 'static,
71        Fut: Future + Send + Sync + 'static,
72    {
73        let mut channel_list = Vec::new();
74
75        let (chat_ids, max_ids) = self
76            .pool
77            .spawn(|conn| {
78                let iter = channel_list::table
79                    .select((channel_list::id, channel_list::last_seen_log_id))
80                    .load_iter::<(i64, Option<i64>), _>(conn)?;
81
82                let mut chat_ids = Vec::with_capacity(iter.size_hint().0);
83                let mut max_ids = Vec::with_capacity(iter.size_hint().0);
84
85                for res in iter {
86                    let (channel_id, max_id) = res?;
87
88                    chat_ids.push(channel_id);
89                    max_ids.push(max_id.unwrap_or(0));
90                }
91
92                Ok((chat_ids, max_ids))
93            })
94            .await
95            .map_err(ClientError::from)?;
96
97        let mut stream_buffer = Vec::new();
98
99        let (user_id, deleted_channels) =
100            run_session(&mut self.stream, &mut stream_buffer, async {
101                let (res, stream) = TalkSession(&self.session)
102                    .login(login::Request {
103                        os: self.env.os,
104                        net_type: self.env.net_type as _,
105                        app_version: self.env.app_version,
106                        mccmnc: self.env.mccmnc,
107                        protocol_version: "1.0",
108                        device_uuid: credential.device_uuid,
109                        oauth_token: credential.access_token,
110                        language: self.env.language,
111                        device_type: Some(2),
112                        pc_status: Some(status as _),
113                        revision: None,
114                        rp: [0x00, 0x00, 0xff, 0xff, 0x00, 0x00],
115                        chat_list: load_channel_list::Request {
116                            chat_ids: &chat_ids,
117                            max_ids: &max_ids,
118                            last_token_id: 0,
119                            last_chat_id: None,
120                        },
121                        last_block_token: 0,
122                        background: None,
123                    })
124                    .await?;
125
126                channel_list.push(res.chat_list.chat_datas);
127
128                if let Some(stream) = stream {
129                    let mut stream = pin!(stream);
130
131                    while let Some(res) = stream.try_next().await? {
132                        channel_list.push(res.chat_datas);
133                    }
134                }
135
136                Ok::<_, ClientError>((res.user_id, res.chat_list.deleted_chat_ids))
137            })
138            .await??;
139
140        let conn = Conn {
141            user_id,
142            session: self.session.clone(),
143            pool: self.pool.clone(),
144        };
145
146        let stream_task = BackgroundTask::new(
147            tokio::spawn({
148                let command_handler = Arc::new(command_handler);
149                let handler = Arc::new(SessionHandler::new(conn.clone()));
150
151                async move {
152                    for read in stream_buffer {
153                        tokio::spawn({
154                            let command_handler = command_handler.clone();
155                            let handler = handler.clone();
156
157                            async move {
158                                match handler.handle(read).await {
159                                    Ok(Some(event)) => {
160                                        command_handler(Ok(event)).await;
161                                    }
162
163                                    Err(err) => {
164                                        command_handler(Err(err)).await;
165                                    }
166
167                                    _ => {}
168                                }
169                            }
170                        });
171                    }
172
173                    let res: Result<ClientEvent, HandlerError> = async {
174                        let mut stream = pin!(self.stream);
175                        while let Some(read) = stream.try_next().await? {
176                            tokio::spawn({
177                                let command_handler = command_handler.clone();
178                                let handler = handler.clone();
179
180                                async move {
181                                    match handler.handle(read).await {
182                                        Ok(Some(event)) => {
183                                            command_handler(Ok(event)).await;
184                                        }
185
186                                        Err(err) => {
187                                            command_handler(Err(err)).await;
188                                        }
189
190                                        _ => {}
191                                    }
192                                }
193                            });
194                        }
195
196                        unreachable!();
197                    }
198                    .await;
199
200                    command_handler(res);
201                }
202            })
203            .abort_handle(),
204        );
205
206        let ping_task = BackgroundTask::new(
207            tokio::spawn({
208                let session = self.session.clone();
209
210                async move {
211                    let mut interval = time::interval(PING_INTERVAL);
212
213                    while TalkSession(&session).ping().await.is_ok() {
214                        interval.tick().await;
215                    }
216                }
217            })
218            .abort_handle(),
219        );
220
221        ChannelListUpdater::new(&self.session, &self.pool)
222            .update(channel_list.into_iter().flatten(), deleted_channels)
223            .await?;
224
225        Ok(HeadlessTalk {
226            conn,
227            _ping_task: ping_task,
228            _stream_task: stream_task,
229        })
230    }
231}
232
233#[derive(Debug, Error)]
234#[error(transparent)]
235pub enum LoginError {
236    Client(#[from] ClientError),
237    Io(#[from] io::Error),
238}
239
240#[derive(Debug, Error)]
241#[error(transparent)]
242pub enum InitError {
243    DatabaseInit(#[from] PoolTaskError),
244    Migration(#[from] MigrationError),
245}
246
247#[derive(Debug, Clone, Copy)]
248pub struct Credential<'a> {
249    pub access_token: &'a str,
250    pub device_uuid: &'a str,
251}
252
253async fn run_session<F: Future>(
254    stream: &mut LocoSessionStream<impl AsyncRead + AsyncWrite + Unpin>,
255    buffer: &mut Vec<BoxedCommand>,
256    task: F,
257) -> Result<F::Output, io::Error> {
258    let stream_task = async {
259        while let Some(read) = stream.try_next().await? {
260            buffer.push(read);
261        }
262
263        Ok::<_, io::Error>(())
264    };
265
266    Ok(tokio::select! {
267        res = task => res,
268        res = stream_task => {
269            res?;
270            unreachable!();
271        },
272    })
273}