1pub mod config;
2
3use std::{io, pin::pin, sync::Arc};
4
5use diesel::{QueryDsl, RunQueryDsl};
6use futures::{AsyncRead, AsyncWrite, Future, TryStreamExt};
7use futures_loco_protocol::{
8 loco_protocol::command::BoxedCommand,
9 session::{LocoSession, LocoSessionStream},
10 LocoClient,
11};
12use talk_loco_client::talk::session::{
13 load_channel_list::{self},
14 login, TalkSession,
15};
16use thiserror::Error;
17use tokio::time;
18
19use crate::{
20 conn::Conn,
21 constants::PING_INTERVAL,
22 database::{schema::channel_list, DatabasePool, MigrationError, PoolTaskError},
23 event::ClientEvent,
24 handler::{error::HandlerError, SessionHandler},
25 task::BackgroundTask,
26 updater::list::ChannelListUpdater,
27 ClientError, ClientStatus, HeadlessTalk,
28};
29
30use self::config::ClientEnv;
31
32pub struct TalkInitializer<'a, S> {
33 session: LocoSession,
34 stream: LocoSessionStream<S>,
35
36 pool: DatabasePool,
37
38 env: ClientEnv<'a>,
39}
40
41impl<'a, S: AsyncRead + AsyncWrite + Unpin> TalkInitializer<'a, S> {
42 pub async fn new(
43 client: LocoClient<S>,
44 env: ClientEnv<'a>,
45 database_url: impl Into<String>,
46 ) -> Result<TalkInitializer<'a, S>, InitError> {
47 let (session, stream) = LocoSession::new(client);
48
49 let pool = DatabasePool::initialize(database_url).await?;
50 pool.migrate_to_latest().await?;
51
52 Ok(Self {
53 session,
54 stream,
55
56 pool,
57
58 env,
59 })
60 }
61
62 pub async fn login<F, Fut>(
63 mut self,
64 credential: Credential<'_>,
65 status: ClientStatus,
66 command_handler: F,
67 ) -> Result<HeadlessTalk, LoginError>
68 where
69 S: Send + 'static,
70 F: Fn(Result<ClientEvent, HandlerError>) -> Fut + Send + Sync + 'static,
71 Fut: Future + Send + Sync + 'static,
72 {
73 let mut channel_list = Vec::new();
74
75 let (chat_ids, max_ids) = self
76 .pool
77 .spawn(|conn| {
78 let iter = channel_list::table
79 .select((channel_list::id, channel_list::last_seen_log_id))
80 .load_iter::<(i64, Option<i64>), _>(conn)?;
81
82 let mut chat_ids = Vec::with_capacity(iter.size_hint().0);
83 let mut max_ids = Vec::with_capacity(iter.size_hint().0);
84
85 for res in iter {
86 let (channel_id, max_id) = res?;
87
88 chat_ids.push(channel_id);
89 max_ids.push(max_id.unwrap_or(0));
90 }
91
92 Ok((chat_ids, max_ids))
93 })
94 .await
95 .map_err(ClientError::from)?;
96
97 let mut stream_buffer = Vec::new();
98
99 let (user_id, deleted_channels) =
100 run_session(&mut self.stream, &mut stream_buffer, async {
101 let (res, stream) = TalkSession(&self.session)
102 .login(login::Request {
103 os: self.env.os,
104 net_type: self.env.net_type as _,
105 app_version: self.env.app_version,
106 mccmnc: self.env.mccmnc,
107 protocol_version: "1.0",
108 device_uuid: credential.device_uuid,
109 oauth_token: credential.access_token,
110 language: self.env.language,
111 device_type: Some(2),
112 pc_status: Some(status as _),
113 revision: None,
114 rp: [0x00, 0x00, 0xff, 0xff, 0x00, 0x00],
115 chat_list: load_channel_list::Request {
116 chat_ids: &chat_ids,
117 max_ids: &max_ids,
118 last_token_id: 0,
119 last_chat_id: None,
120 },
121 last_block_token: 0,
122 background: None,
123 })
124 .await?;
125
126 channel_list.push(res.chat_list.chat_datas);
127
128 if let Some(stream) = stream {
129 let mut stream = pin!(stream);
130
131 while let Some(res) = stream.try_next().await? {
132 channel_list.push(res.chat_datas);
133 }
134 }
135
136 Ok::<_, ClientError>((res.user_id, res.chat_list.deleted_chat_ids))
137 })
138 .await??;
139
140 let conn = Conn {
141 user_id,
142 session: self.session.clone(),
143 pool: self.pool.clone(),
144 };
145
146 let stream_task = BackgroundTask::new(
147 tokio::spawn({
148 let command_handler = Arc::new(command_handler);
149 let handler = Arc::new(SessionHandler::new(conn.clone()));
150
151 async move {
152 for read in stream_buffer {
153 tokio::spawn({
154 let command_handler = command_handler.clone();
155 let handler = handler.clone();
156
157 async move {
158 match handler.handle(read).await {
159 Ok(Some(event)) => {
160 command_handler(Ok(event)).await;
161 }
162
163 Err(err) => {
164 command_handler(Err(err)).await;
165 }
166
167 _ => {}
168 }
169 }
170 });
171 }
172
173 let res: Result<ClientEvent, HandlerError> = async {
174 let mut stream = pin!(self.stream);
175 while let Some(read) = stream.try_next().await? {
176 tokio::spawn({
177 let command_handler = command_handler.clone();
178 let handler = handler.clone();
179
180 async move {
181 match handler.handle(read).await {
182 Ok(Some(event)) => {
183 command_handler(Ok(event)).await;
184 }
185
186 Err(err) => {
187 command_handler(Err(err)).await;
188 }
189
190 _ => {}
191 }
192 }
193 });
194 }
195
196 unreachable!();
197 }
198 .await;
199
200 command_handler(res);
201 }
202 })
203 .abort_handle(),
204 );
205
206 let ping_task = BackgroundTask::new(
207 tokio::spawn({
208 let session = self.session.clone();
209
210 async move {
211 let mut interval = time::interval(PING_INTERVAL);
212
213 while TalkSession(&session).ping().await.is_ok() {
214 interval.tick().await;
215 }
216 }
217 })
218 .abort_handle(),
219 );
220
221 ChannelListUpdater::new(&self.session, &self.pool)
222 .update(channel_list.into_iter().flatten(), deleted_channels)
223 .await?;
224
225 Ok(HeadlessTalk {
226 conn,
227 _ping_task: ping_task,
228 _stream_task: stream_task,
229 })
230 }
231}
232
233#[derive(Debug, Error)]
234#[error(transparent)]
235pub enum LoginError {
236 Client(#[from] ClientError),
237 Io(#[from] io::Error),
238}
239
240#[derive(Debug, Error)]
241#[error(transparent)]
242pub enum InitError {
243 DatabaseInit(#[from] PoolTaskError),
244 Migration(#[from] MigrationError),
245}
246
247#[derive(Debug, Clone, Copy)]
248pub struct Credential<'a> {
249 pub access_token: &'a str,
250 pub device_uuid: &'a str,
251}
252
253async fn run_session<F: Future>(
254 stream: &mut LocoSessionStream<impl AsyncRead + AsyncWrite + Unpin>,
255 buffer: &mut Vec<BoxedCommand>,
256 task: F,
257) -> Result<F::Output, io::Error> {
258 let stream_task = async {
259 while let Some(read) = stream.try_next().await? {
260 buffer.push(read);
261 }
262
263 Ok::<_, io::Error>(())
264 };
265
266 Ok(tokio::select! {
267 res = task => res,
268 res = stream_task => {
269 res?;
270 unreachable!();
271 },
272 })
273}