1pub mod accounts;
2pub mod screen_names;
3pub mod table;
4pub mod util;
5
6use accounts::AccountTable;
7use chrono::NaiveDate;
8use screen_names::ScreenNameTable;
9use std::collections::HashMap;
10use std::path::Path;
11use std::sync::Arc;
12pub use table::{Mode, ReadOnly, Table, Writeable};
13
14#[derive(thiserror::Error, Debug)]
15pub enum Error {
16 #[error("RocksDb error")]
17 Db(#[from] rocksdb::Error),
18 #[error("Invalid UTF-8 string")]
19 InvalidString(#[from] std::str::Utf8Error),
20 #[error("Invalid key")]
21 InvalidKey(Vec<u8>),
22 #[error("Invalid value")]
23 InvalidValue(Vec<u8>),
24 #[error("Invalid Twitter epoch day")]
25 InvalidDay(i64),
26 #[error("Invalid Twitter screen name")]
27 InvalidScreenName(String),
28 #[error("Channel send error")]
29 ChannelSend,
30 #[error("Channel receive error")]
31 ChannelRecv(#[from] std::sync::mpsc::RecvError),
32}
33
34pub struct Database<M> {
35 pub accounts: Arc<AccountTable<M>>,
36 pub screen_names: ScreenNameTable<M>,
37}
38
39impl<M: Sync + Send + 'static> Database<M> {
40 pub fn get_counts(
41 &self,
42 ) -> Result<
43 (
44 accounts::AccountTableCounts,
45 screen_names::ScreenNameTableCounts,
46 ),
47 Error,
48 > {
49 let (tx, rx) = std::sync::mpsc::channel();
50 let accounts = self.accounts.clone();
51
52 std::thread::spawn(move || {
53 tx.send(accounts.get_counts())
54 .map_err(|_| Error::ChannelSend)
55 });
56
57 let screen_name_counts = self.screen_names.get_counts()?;
58 let account_counts = rx.recv()??;
59
60 Ok((account_counts, screen_name_counts))
61 }
62
63 pub fn lookup_by_user_id(
64 &self,
65 user_id: u64,
66 ) -> Result<HashMap<String, Vec<NaiveDate>>, Error> {
67 self.accounts.lookup(user_id)
68 }
69
70 pub fn lookup_by_screen_name(&self, screen_name: &str) -> Result<Vec<u64>, Error> {
71 self.screen_names.lookup(screen_name)
72 }
73
74 pub fn lookup_by_screen_name_prefix(
75 &self,
76 screen_name_prefix: &str,
77 limit: usize,
78 ) -> Result<Vec<(String, Vec<u64>)>, Error> {
79 self.screen_names
80 .lookup_by_prefix(screen_name_prefix, limit)
81 }
82
83 pub fn limited_lookup_by_user_id(
84 &self,
85 user_id: u64,
86 earliest: Option<NaiveDate>,
87 ) -> Result<HashMap<String, Vec<NaiveDate>>, Error> {
88 match earliest {
89 Some(earliest) => self.accounts.limited_lookup(user_id, earliest),
90 None => self.accounts.lookup(user_id),
91 }
92 }
93}
94
95impl<M: Mode> Database<M> {
96 pub fn open<P: AsRef<Path>>(base: P) -> Result<Self, Error> {
97 Self::open_from_tables(
98 base.as_ref().join("accounts"),
99 base.as_ref().join("screen-names"),
100 )
101 }
102
103 fn open_from_tables<P: AsRef<Path>>(
104 accounts_path: P,
105 screen_names_path: P,
106 ) -> Result<Self, Error> {
107 Ok(Self {
108 accounts: Arc::new(AccountTable::open(accounts_path)?),
109 screen_names: ScreenNameTable::open(screen_names_path)?,
110 })
111 }
112}
113
114impl Database<Writeable> {
115 pub fn insert(&self, id: u64, screen_name: &str, dates: Vec<NaiveDate>) -> Result<(), Error> {
116 self.accounts.insert(id, screen_name, dates)?;
117 self.screen_names.insert(screen_name, id)?;
118 Ok(())
119 }
120
121 pub fn rebuild_index(&mut self) -> Result<(), Error> {
122 self.screen_names.rebuild(&self.accounts)
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 use std::collections::HashMap;
130
131 #[test]
132 fn insert() {
133 let dir = tempfile::tempdir().unwrap();
134 let db = Database::open(dir).unwrap();
135 db.insert(123, "foo", vec![]).unwrap();
136 db.insert(123, "bar", vec![]).unwrap();
137 db.insert(456, "foo", vec![]).unwrap();
138 db.insert(123, "foo", vec![]).unwrap();
139
140 let mut expected_by_id = HashMap::new();
141 expected_by_id.insert("foo".to_string(), vec![]);
142 expected_by_id.insert("bar".to_string(), vec![]);
143
144 let expected_pairs = vec![
145 (123, "bar".to_string(), vec![]),
146 (123, "foo".to_string(), vec![]),
147 (456, "foo".to_string(), vec![]),
148 ];
149
150 let expected_counts = (
151 accounts::AccountTableCounts {
152 id_count: 2,
153 pair_count: 3,
154 },
155 screen_names::ScreenNameTableCounts {
156 screen_name_count: 2,
157 mapping_count: 3,
158 },
159 );
160
161 assert_eq!(db.lookup_by_screen_name("foo").unwrap(), vec![123, 456]);
162 assert_eq!(db.lookup_by_user_id(123).unwrap(), expected_by_id);
163 assert_eq!(db.get_counts().unwrap(), expected_counts);
164 assert_eq!(
165 db.accounts.pairs().collect::<Result<Vec<_>, _>>().unwrap(),
166 expected_pairs
167 );
168
169 db.accounts.compact_ranges().unwrap();
170
171 assert_eq!(db.lookup_by_screen_name("foo").unwrap(), vec![123, 456]);
172 assert_eq!(db.lookup_by_user_id(123).unwrap(), expected_by_id);
173 assert_eq!(db.get_counts().unwrap(), expected_counts);
174 assert_eq!(
175 db.accounts.pairs().collect::<Result<Vec<_>, _>>().unwrap(),
176 expected_pairs
177 );
178 }
179
180 #[test]
181 fn lookup_by_screen_name_prefix() {
182 let dir = tempfile::tempdir().unwrap();
183 let db = Database::open(dir).unwrap();
184 db.insert(123, "foo", vec![]).unwrap();
185 db.insert(123, "bar", vec![]).unwrap();
186 db.insert(1000, "for", vec![]).unwrap();
187 db.insert(1001, "baz", vec![]).unwrap();
188 db.insert(1002, "follow", vec![]).unwrap();
189 db.insert(1003, "FOR", vec![]).unwrap();
190
191 let expected = vec![
192 ("follow".to_string(), vec![1002]),
193 ("foo".to_string(), vec![123]),
194 ("for".to_string(), vec![1000, 1003]),
195 ];
196
197 assert_eq!(
198 db.lookup_by_screen_name_prefix("fo", 128).unwrap(),
199 expected
200 );
201 }
202}