1use futures::future::BoxFuture;
7use rocksdb::{DBCompressionType, DBWithThreadMode, MultiThreaded};
8use serde::{de::DeserializeOwned, Serialize};
9use std::{
10 convert::Infallible,
11 fmt::{Debug, Display},
12 str,
13 sync::Arc,
14};
15use teloxide::dispatching::dialogue::{Serializer, Storage};
16use teloxide_core::types::ChatId;
17use thiserror::Error;
18
19pub struct RocksDbStorage<S> {
21 db: DBWithThreadMode<MultiThreaded>,
22 serializer: S,
23}
24
25#[derive(Debug, Error)]
27pub enum RocksDbStorageError<SE>
28where
29 SE: Debug + Display,
30{
31 #[error("dialogue serialization error: {0}")]
32 SerdeError(SE),
33
34 #[error("RocksDb error: {0}")]
35 RocksDbError(#[from] rocksdb::Error),
36
37 #[error("row not found")]
39 DialogueNotFound,
40}
41
42impl<S> RocksDbStorage<S> {
43 pub async fn open(
46 path: &str,
47 serializer: S,
48 options: Option<rocksdb::Options>,
49 ) -> Result<Arc<Self>, RocksDbStorageError<Infallible>> {
50 let options = match options {
51 Some(opts) => opts,
52 None => {
53 let mut opts = rocksdb::Options::default();
54 opts.set_compression_type(DBCompressionType::Lz4);
55 opts.create_if_missing(true);
56 opts
57 }
58 };
59
60 let db = DBWithThreadMode::<MultiThreaded>::open(&options, path)?;
61 Ok(Arc::new(Self { db, serializer }))
62 }
63}
64
65impl<S, D> Storage<D> for RocksDbStorage<S>
66where
67 S: Send + Sync + Serializer<D> + 'static,
68 D: Send + Serialize + DeserializeOwned + 'static,
69 <S as Serializer<D>>::Error: Debug + Display,
70{
71 type Error = RocksDbStorageError<<S as Serializer<D>>::Error>;
72
73 fn remove_dialogue(
76 self: Arc<Self>,
77 ChatId(chat_id): ChatId,
78 ) -> BoxFuture<'static, Result<(), Self::Error>> {
79 Box::pin(async move {
80 let key = chat_id.to_le_bytes();
81
82 if self.db.get(&key)?.is_none() {
83 return Err(RocksDbStorageError::DialogueNotFound);
84 }
85
86 self.db.delete(&key).unwrap();
87
88 Ok(())
89 })
90 }
91
92 fn update_dialogue(
93 self: Arc<Self>,
94 ChatId(chat_id): ChatId,
95 dialogue: D,
96 ) -> BoxFuture<'static, Result<(), Self::Error>> {
97 Box::pin(async move {
98 let d =
99 self.serializer.serialize(&dialogue).map_err(RocksDbStorageError::SerdeError)?;
100
101 let key = chat_id.to_le_bytes();
102 self.db.put(&key, &d)?;
103
104 Ok(())
105 })
106 }
107
108 fn get_dialogue(
109 self: Arc<Self>,
110 ChatId(chat_id): ChatId,
111 ) -> BoxFuture<'static, Result<Option<D>, Self::Error>> {
112 Box::pin(async move {
113 let key = chat_id.to_le_bytes();
114 self.db
115 .get(&key)?
116 .map(|d| self.serializer.deserialize(&d).map_err(RocksDbStorageError::SerdeError))
117 .transpose()
118 })
119 }
120}