1use crate::defs::{
10 ChannelKind, ChannelState, DcOption, PeerAuth, PeerId, PeerInfo, PeerKind, UpdateState,
11 UpdatesState,
12};
13use crate::{DEFAULT_DC, KNOWN_DC_OPTIONS, Session};
14use std::path::Path;
15use std::sync::Mutex;
16
17const VERSION: i64 = 1;
18
19struct Database(sqlite::Connection);
20
21struct TransactionGuard<'c>(&'c sqlite::Connection);
22
23pub struct SqliteSession {
25 database: Mutex<Database>,
26}
27
28#[repr(u8)]
29enum PeerSubtype {
30 UserSelf = 1,
31 UserBot = 2,
32 UserSelfBot = 3,
33 Megagroup = 4,
34 Broadcast = 8,
35 Gigagroup = 12,
36}
37
38impl Database {
39 fn init(&self) -> sqlite::Result<()> {
40 let mut user_version = self
41 .fetch_one("PRAGMA user_version", &[], |stmt| stmt.read::<i64, _>(0))?
42 .unwrap_or(0);
43 if user_version == VERSION {
44 return Ok(());
45 }
46
47 if user_version == 0 {
48 self.migrate_v0_to_v1()?;
49 user_version += 1;
50 }
51 if user_version == VERSION {
52 self.0.execute(format!("PRAGMA user_version = {VERSION}"))?;
54 }
55 Ok(())
56 }
57
58 fn migrate_v0_to_v1(&self) -> sqlite::Result<()> {
59 let _transaction = self.begin_transaction()?;
60 self.0.execute(
61 "CREATE TABLE dc_home (
62 dc_id INTEGER NOT NULL,
63 PRIMARY KEY(dc_id))",
64 )?;
65 self.0.execute(
66 "CREATE TABLE dc_option (
67 dc_id INTEGER NOT NULL,
68 ipv4 TEXT NOT NULL,
69 ipv6 TEXT NOT NULL,
70 auth_key BLOB,
71 PRIMARY KEY (dc_id))",
72 )?;
73 self.0.execute(
74 "CREATE TABLE peer_info (
75 peer_id INTEGER NOT NULL,
76 hash INTEGER,
77 subtype INTEGER,
78 PRIMARY KEY (peer_id))",
79 )?;
80 self.0.execute(
81 "CREATE TABLE update_state (
82 pts INTEGER NOT NULL,
83 qts INTEGER NOT NULL,
84 date INTEGER NOT NULL,
85 seq INTEGER NOT NULL)",
86 )?;
87 self.0.execute(
88 "CREATE TABLE channel_state (
89 peer_id INTEGER NOT NULL,
90 pts INTEGER NOT NULL,
91 PRIMARY KEY (peer_id))",
92 )?;
93
94 Ok(())
95 }
96
97 fn begin_transaction(&self) -> sqlite::Result<TransactionGuard<'_>> {
98 self.0.execute("BEGIN TRANSACTION")?;
99 Ok(TransactionGuard(&self.0))
100 }
101
102 fn fetch_one<T, F: FnOnce(sqlite::Statement) -> sqlite::Result<T>>(
103 &self,
104 statement: &str,
105 bindings: &[(&str, sqlite::Value)],
106 select: F,
107 ) -> sqlite::Result<Option<T>> {
108 let mut statement = self.0.prepare(statement)?;
109 statement.bind(bindings)?;
110 let result = match statement.next()? {
111 sqlite::State::Row => Some(select(statement)?),
112 sqlite::State::Done => None,
113 };
114 Ok(result)
115 }
116
117 fn fetch_all<T, F: FnMut(&sqlite::Statement) -> sqlite::Result<T>>(
118 &self,
119 statement: &str,
120 bindings: &[(&str, sqlite::Value)],
121 mut select: F,
122 ) -> sqlite::Result<Vec<T>> {
123 let mut result = Vec::new();
124 let mut statement = self.0.prepare(statement)?;
125 statement.bind(bindings)?;
126 while statement.next()? == sqlite::State::Row {
127 result.push(select(&statement)?);
128 }
129 Ok(result)
130 }
131}
132
133impl Drop for TransactionGuard<'_> {
134 fn drop(&mut self) {
135 self.0.execute("COMMIT").unwrap();
136 }
137}
138
139impl SqliteSession {
140 pub fn open<P: AsRef<Path>>(path: P) -> sqlite::Result<Self> {
143 let database = Database(sqlite::Connection::open(path)?);
144 database.init()?;
145 Ok(SqliteSession {
146 database: Mutex::new(database),
147 })
148 }
149}
150
151impl Session for SqliteSession {
152 fn home_dc_id(&self) -> i32 {
153 let db = self.database.lock().unwrap();
154 db.fetch_one("SELECT * FROM dc_home LIMIT 1", &[], |stmt| {
155 Ok(stmt.read::<i64, _>("dc_id")? as i32)
156 })
157 .unwrap()
158 .unwrap_or(DEFAULT_DC)
159 }
160
161 fn set_home_dc_id(&self, dc_id: i32) {
162 let db = self.database.lock().unwrap();
163 let _transaction = db.begin_transaction().unwrap();
164 db.0.execute("DELETE FROM dc_home").unwrap();
165 let mut stmt = db.0.prepare("INSERT INTO dc_home VALUES (:dc_id)").unwrap();
166 stmt.bind((":dc_id", dc_id as i64)).unwrap();
167 stmt.next().unwrap();
168 }
169
170 fn dc_option(&self, dc_id: i32) -> Option<DcOption> {
171 let db = self.database.lock().unwrap();
172 db.fetch_one(
173 "SELECT * FROM dc_option WHERE dc_id = :dc_id LIMIT 1",
174 &[(":dc_id", sqlite::Value::Integer(dc_id as _))],
175 |stmt| {
176 Ok(DcOption {
177 id: stmt.read::<i64, _>("dc_id")? as _,
178 ipv4: stmt.read::<String, _>("ipv4")?.parse().unwrap(),
179 ipv6: stmt.read::<String, _>("ipv6")?.parse().unwrap(),
180 auth_key: stmt
181 .read::<Option<Vec<u8>>, _>("auth_key")?
182 .map(|auth_key| auth_key.try_into().unwrap()),
183 })
184 },
185 )
186 .unwrap()
187 .or_else(|| {
188 KNOWN_DC_OPTIONS
189 .iter()
190 .find(|dc_option| dc_option.id == dc_id)
191 .cloned()
192 })
193 }
194
195 fn set_dc_option(&self, dc_option: &DcOption) {
196 let db = self.database.lock().unwrap();
197 let mut stmt = db
198 .0
199 .prepare("INSERT OR REPLACE INTO dc_option VALUES (:dc_id, :ipv4, :ipv6, :auth_key)")
200 .unwrap();
201 stmt.bind((":dc_id", dc_option.id as i64)).unwrap();
202 stmt.bind((":ipv4", dc_option.ipv4.to_string().as_str()))
203 .unwrap();
204 stmt.bind((":ipv6", dc_option.ipv6.to_string().as_str()))
205 .unwrap();
206 if let Some(auth_key) = dc_option.auth_key {
207 stmt.bind((":auth_key", auth_key.as_slice())).unwrap();
208 }
209 stmt.next().unwrap();
210 }
211
212 fn peer(&self, peer: PeerId) -> Option<PeerInfo> {
213 let db = self.database.lock().unwrap();
214 let map_stmt = |stmt: sqlite::Statement| {
215 let subtype = stmt.read::<Option<i64>, _>("subtype")?.map(|s| s as u8);
216 Ok(match peer.kind() {
217 PeerKind::User | PeerKind::UserSelf => PeerInfo::User {
218 id: PeerId::user(stmt.read::<i64, _>("peer_id")?).bare_id(),
219 auth: stmt
220 .read::<Option<i64>, _>("hash")?
221 .map(PeerAuth::from_hash),
222 bot: subtype.map(|s| s & PeerSubtype::UserBot as u8 != 0),
223 is_self: subtype.map(|s| s & PeerSubtype::UserSelf as u8 != 0),
224 },
225 PeerKind::Chat => PeerInfo::Chat { id: peer.bare_id() },
226 PeerKind::Channel => PeerInfo::Channel {
227 id: peer.bare_id(),
228 auth: stmt
229 .read::<Option<i64>, _>("hash")?
230 .map(PeerAuth::from_hash),
231 kind: subtype.and_then(|s| {
232 if (s & PeerSubtype::Gigagroup as u8) == PeerSubtype::Gigagroup as _ {
233 Some(ChannelKind::Gigagroup)
234 } else if s & PeerSubtype::Broadcast as u8 != 0 {
235 Some(ChannelKind::Broadcast)
236 } else if s & PeerSubtype::Megagroup as u8 != 0 {
237 Some(ChannelKind::Megagroup)
238 } else {
239 None
240 }
241 }),
242 },
243 })
244 };
245
246 if peer.kind() == PeerKind::UserSelf {
247 db.fetch_one(
248 "SELECT * FROM peer_info WHERE subtype & :type LIMIT 1",
249 &[(":type", sqlite::Value::Integer(PeerSubtype::UserSelf as _))],
250 map_stmt,
251 )
252 .unwrap()
253 } else {
254 db.fetch_one(
255 "SELECT * FROM peer_info WHERE peer_id = :peer_id LIMIT 1",
256 &[(":peer_id", sqlite::Value::Integer(peer.bot_api_dialog_id()))],
257 map_stmt,
258 )
259 .unwrap()
260 }
261 }
262
263 fn cache_peer(&self, peer: &PeerInfo) {
264 let db = self.database.lock().unwrap();
265 let mut stmt =
266 db.0.prepare("INSERT OR REPLACE INTO peer_info VALUES (:peer_id, :hash, :subtype)")
267 .unwrap();
268 stmt.bind((":peer_id", peer.id().bot_api_dialog_id()))
269 .unwrap();
270 if peer.auth() != PeerAuth::default() {
271 stmt.bind((":hash", peer.auth().hash())).unwrap();
272 }
273 let subtype = match peer {
274 PeerInfo::User { bot, is_self, .. } => {
275 match (bot.unwrap_or_default(), is_self.unwrap_or_default()) {
276 (true, true) => Some(PeerSubtype::UserSelfBot),
277 (true, false) => Some(PeerSubtype::UserBot),
278 (false, true) => Some(PeerSubtype::UserSelf),
279 (false, false) => None,
280 }
281 }
282 PeerInfo::Chat { .. } => None,
283 PeerInfo::Channel { kind, .. } => kind.map(|kind| match kind {
284 ChannelKind::Megagroup => PeerSubtype::Megagroup,
285 ChannelKind::Broadcast => PeerSubtype::Broadcast,
286 ChannelKind::Gigagroup => PeerSubtype::Gigagroup,
287 }),
288 };
289 if let Some(subtype) = subtype {
290 stmt.bind((":subtype", subtype as i64)).unwrap();
291 }
292 stmt.next().unwrap();
293 }
294
295 fn updates_state(&self) -> UpdatesState {
296 let db = self.database.lock().unwrap();
297 let mut state = db
298 .fetch_one("SELECT * FROM update_state LIMIT 1", &[], |stmt| {
299 Ok(UpdatesState {
300 pts: stmt.read::<i64, _>("pts")? as _,
301 qts: stmt.read::<i64, _>("qts")? as _,
302 date: stmt.read::<i64, _>("date")? as _,
303 seq: stmt.read::<i64, _>("seq")? as _,
304 channels: Vec::new(),
305 })
306 })
307 .unwrap()
308 .unwrap_or_default();
309 state.channels = db
310 .fetch_all("SELECT * FROM channel_state", &[], |stmt| {
311 Ok(ChannelState {
312 id: stmt.read::<i64, _>("peer_id")?,
313 pts: stmt.read::<i64, _>("pts")? as _,
314 })
315 })
316 .unwrap();
317 state
318 }
319
320 fn set_update_state(&self, update: UpdateState) {
321 let db = self.database.lock().unwrap();
322 let _transaction = db.begin_transaction().unwrap();
323
324 match update {
325 UpdateState::All(updates_state) => {
326 db.0.execute("DELETE FROM update_state").unwrap();
327 let mut stmt =
328 db.0.prepare("INSERT INTO update_state VALUES (:pts, :qts, :date, :seq)")
329 .unwrap();
330 stmt.bind((":pts", updates_state.pts as i64)).unwrap();
331 stmt.bind((":qts", updates_state.qts as i64)).unwrap();
332 stmt.bind((":date", updates_state.date as i64)).unwrap();
333 stmt.bind((":seq", updates_state.seq as i64)).unwrap();
334 stmt.next().unwrap();
335
336 db.0.execute("DELETE FROM channel_state").unwrap();
337 for channel in updates_state.channels {
338 let mut stmt =
339 db.0.prepare("INSERT INTO channel_state VALUES (:peer_id, :pts)")
340 .unwrap();
341 stmt.bind((":peer_id", channel.id as i64)).unwrap();
342 stmt.bind((":pts", channel.pts as i64)).unwrap();
343 stmt.next().unwrap();
344 }
345 }
346 UpdateState::Primary { pts, date, seq } => {
347 let previous = db
348 .fetch_one("SELECT * FROM update_state LIMIT 1", &[], |_| Ok(()))
349 .unwrap();
350
351 let mut stmt = if previous.is_some() {
352 db.0.prepare("UPDATE update_state SET pts = :pts, date = :date, seq = :seq")
353 .unwrap()
354 } else {
355 db.0.prepare("INSERT INTO update_state VALUES (:pts, 0, :date, :seq)")
356 .unwrap()
357 };
358 stmt.bind((":pts", pts as i64)).unwrap();
359 stmt.bind((":date", date as i64)).unwrap();
360 stmt.bind((":seq", seq as i64)).unwrap();
361 stmt.next().unwrap();
362 }
363 UpdateState::Secondary { qts } => {
364 let previous = db
365 .fetch_one("SELECT * FROM update_state LIMIT 1", &[], |_| Ok(()))
366 .unwrap();
367
368 let mut stmt = if previous.is_some() {
369 db.0.prepare("UPDATE update_state SET qts = :qts").unwrap()
370 } else {
371 db.0.prepare("INSERT INTO update_state VALUES (0, :qts, 0, 0)")
372 .unwrap()
373 };
374 stmt.bind((":qts", qts as i64)).unwrap();
375 stmt.next().unwrap();
376 }
377 UpdateState::Channel { id, pts } => {
378 let mut stmt =
379 db.0.prepare("INSERT OR REPLACE INTO channel_state VALUES (:peer_id, :pts)")
380 .unwrap();
381 stmt.bind((":peer_id", id)).unwrap();
382 stmt.bind((":pts", pts as i64)).unwrap();
383 stmt.next().unwrap();
384 }
385 }
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
392
393 use {DcOption, KNOWN_DC_OPTIONS, PeerInfo, Session, UpdateState};
394
395 use super::*;
396
397 #[test]
398 fn exercise_sqlite_session() {
399 let session = SqliteSession::open(":memory:").unwrap();
400
401 assert_eq!(session.home_dc_id(), DEFAULT_DC);
402 session.set_home_dc_id(DEFAULT_DC + 1);
403 assert_eq!(session.home_dc_id(), DEFAULT_DC + 1);
404
405 assert_eq!(
406 session.dc_option(KNOWN_DC_OPTIONS[0].id),
407 Some(KNOWN_DC_OPTIONS[0].clone())
408 );
409 let new_dc_option = DcOption {
410 id: KNOWN_DC_OPTIONS
411 .iter()
412 .map(|dc_option| dc_option.id)
413 .max()
414 .unwrap()
415 + 1,
416 ipv4: SocketAddrV4::new(Ipv4Addr::from_bits(0), 1),
417 ipv6: SocketAddrV6::new(Ipv6Addr::from_bits(0), 1, 0, 0),
418 auth_key: Some([1; 256]),
419 };
420 assert_eq!(session.dc_option(new_dc_option.id), None);
421 session.set_dc_option(&new_dc_option);
422 assert_eq!(session.dc_option(new_dc_option.id), Some(new_dc_option));
423
424 assert_eq!(session.peer(PeerId::self_user()), None);
425 assert_eq!(session.peer(PeerId::user(1)), None);
426 let peer = PeerInfo::User {
427 id: 1,
428 auth: None,
429 bot: Some(true),
430 is_self: Some(true),
431 };
432 session.cache_peer(&peer);
433 assert_eq!(session.peer(PeerId::self_user()), Some(peer.clone()));
434 assert_eq!(session.peer(PeerId::user(1)), Some(peer));
435
436 assert_eq!(session.peer(PeerId::channel(1)), None);
437 let peer = PeerInfo::Channel {
438 id: 1,
439 auth: Some(PeerAuth::from_hash(-1)),
440 kind: Some(ChannelKind::Broadcast),
441 };
442 session.cache_peer(&peer);
443 assert_eq!(session.peer(PeerId::channel(1)), Some(peer));
444
445 assert_eq!(session.updates_state(), UpdatesState::default());
446 session.set_update_state(UpdateState::All(UpdatesState {
447 pts: 1,
448 qts: 2,
449 date: 3,
450 seq: 4,
451 channels: vec![
452 ChannelState { id: 5, pts: 6 },
453 ChannelState { id: 7, pts: 8 },
454 ],
455 }));
456 session.set_update_state(UpdateState::Primary {
457 pts: 2,
458 date: 4,
459 seq: 5,
460 });
461 session.set_update_state(UpdateState::Secondary { qts: 3 });
462 session.set_update_state(UpdateState::Channel { id: 7, pts: 9 });
463 assert_eq!(
464 session.updates_state(),
465 UpdatesState {
466 pts: 2,
467 qts: 3,
468 date: 4,
469 seq: 5,
470 channels: vec![
471 ChannelState { id: 5, pts: 6 },
472 ChannelState { id: 7, pts: 9 },
473 ],
474 }
475 );
476 }
477}