1use std::num::TryFromIntError;
2use std::{fmt, io};
3
4use radicle::crypto::Signature;
5use sqlite as sql;
6use thiserror::Error;
7
8use crate::node::{Database, NodeId};
9use crate::prelude::{Filter, Timestamp};
10use crate::service::message::{
11 Announcement, AnnouncementMessage, InventoryAnnouncement, NodeAnnouncement, RefsAnnouncement,
12};
13use crate::wire;
14use crate::wire::Decode;
15
16#[derive(Error, Debug)]
17pub enum Error {
18 #[error("internal error: {0}")]
20 Internal(#[from] sql::Error),
21 #[error("unit overflow:: {0}")]
23 UnitOverflow(#[from] TryFromIntError),
24}
25
26pub type AnnouncementId = u64;
28
29pub trait Store {
33 fn prune(&mut self, cutoff: Timestamp) -> Result<usize, Error>;
35
36 fn last(&self) -> Result<Option<Timestamp>, Error>;
38
39 fn announced(
42 &mut self,
43 nid: &NodeId,
44 ann: &Announcement,
45 ) -> Result<Option<AnnouncementId>, Error>;
46
47 fn set_relay(&mut self, id: AnnouncementId, relay: RelayStatus) -> Result<(), Error>;
49
50 fn relays(&mut self, now: Timestamp) -> Result<Vec<(AnnouncementId, Announcement)>, Error>;
52
53 fn filtered<'a>(
61 &'a self,
62 filter: &'a Filter,
63 from: Timestamp,
64 to: Timestamp,
65 ) -> Result<Box<dyn Iterator<Item = Result<Announcement, Error>> + 'a>, Error>;
66}
67
68impl Store for Database {
69 fn prune(&mut self, cutoff: Timestamp) -> Result<usize, Error> {
70 let mut stmt = self
71 .db
72 .prepare("DELETE FROM `announcements` WHERE timestamp < ?1")?;
73
74 stmt.bind((1, &cutoff))?;
75 stmt.next()?;
76
77 Ok(self.db.change_count())
78 }
79
80 fn last(&self) -> Result<Option<Timestamp>, Error> {
81 let stmt = self
82 .db
83 .prepare("SELECT MAX(timestamp) AS latest FROM `announcements`")?;
84
85 if let Some(Ok(row)) = stmt.into_iter().next() {
86 return match row.try_read::<Option<i64>, _>(0)? {
87 Some(i) => Ok(Some(Timestamp::try_from(i)?)),
88 None => Ok(None),
89 };
90 }
91 Ok(None)
92 }
93
94 fn announced(
95 &mut self,
96 nid: &NodeId,
97 ann: &Announcement,
98 ) -> Result<Option<AnnouncementId>, Error> {
99 assert_ne!(
100 ann.timestamp(),
101 Timestamp::MIN,
102 "Timestamp of {ann:?} must not be zero"
103 );
104 let mut stmt = self.db.prepare(
105 "INSERT INTO `announcements` (node, repo, type, message, signature, timestamp)
106 VALUES (?1, ?2, ?3, ?4, ?5, ?6)
107 ON CONFLICT DO UPDATE
108 SET message = ?4, signature = ?5, timestamp = ?6
109 WHERE timestamp < ?6
110 RETURNING rowid",
111 )?;
112 stmt.bind((1, nid))?;
113
114 match &ann.message {
115 AnnouncementMessage::Node(msg) => {
116 stmt.bind((2, sql::Value::String(String::new())))?;
117 stmt.bind((3, &GossipType::Node))?;
118 stmt.bind((4, msg))?;
119 }
120 AnnouncementMessage::Refs(msg) => {
121 stmt.bind((2, &msg.rid))?;
122 stmt.bind((3, &GossipType::Refs))?;
123 stmt.bind((4, msg))?;
124 }
125 AnnouncementMessage::Inventory(msg) => {
126 stmt.bind((2, sql::Value::String(String::new())))?;
127 stmt.bind((3, &GossipType::Inventory))?;
128 stmt.bind((4, msg))?;
129 }
130 }
131 stmt.bind((5, &ann.signature))?;
132 stmt.bind((6, &ann.message.timestamp()))?;
133
134 if let Some(row) = stmt.into_iter().next() {
135 let row = row?;
136 let id = row.read::<i64, _>("rowid");
137
138 Ok(Some(id as AnnouncementId))
139 } else {
140 Ok(None)
141 }
142 }
143
144 fn set_relay(&mut self, id: AnnouncementId, relay: RelayStatus) -> Result<(), Error> {
145 let mut stmt = self.db.prepare(
146 "UPDATE announcements
147 SET relay = ?1
148 WHERE rowid = ?2",
149 )?;
150 stmt.bind((1, relay))?;
151 stmt.bind((2, id as i64))?;
152 stmt.next()?;
153
154 Ok(())
155 }
156
157 fn relays(&mut self, now: Timestamp) -> Result<Vec<(AnnouncementId, Announcement)>, Error> {
158 let mut stmt = self.db.prepare(
159 "UPDATE announcements
160 SET relay = ?1
161 WHERE relay IS ?2
162 RETURNING rowid, node, type, message, signature, timestamp",
163 )?;
164 stmt.bind((1, RelayStatus::RelayedAt(now)))?;
165 stmt.bind((2, RelayStatus::Relay))?;
166
167 let mut rows = stmt
168 .into_iter()
169 .map(|row| {
170 let row = row?;
171 parse::announcement(row)
172 })
173 .collect::<Result<Vec<_>, _>>()?;
174
175 rows.sort_by_key(|(id, _)| *id);
178
179 Ok(rows)
180 }
181
182 fn filtered<'a>(
183 &'a self,
184 filter: &'a Filter,
185 from: Timestamp,
186 to: Timestamp,
187 ) -> Result<Box<dyn Iterator<Item = Result<Announcement, Error>> + 'a>, Error> {
188 let mut stmt = self.db.prepare(
189 "SELECT rowid, node, type, message, signature, timestamp
190 FROM announcements
191 WHERE timestamp >= ?1 and timestamp < ?2
192 ORDER BY timestamp, node, type",
193 )?;
194 assert!(*from <= *to);
195
196 stmt.bind((1, &from))?;
197 stmt.bind((2, &to))?;
198
199 Ok(Box::new(
200 stmt.into_iter()
201 .map(|row| {
202 let row = row?;
203 let (_, ann) = parse::announcement(row)?;
204
205 Ok(ann)
206 })
207 .filter(|ann| match ann {
208 Ok(a) => a.matches(filter),
209 Err(_) => true,
210 }),
211 ))
212 }
213}
214
215impl TryFrom<&sql::Value> for NodeAnnouncement {
216 type Error = sql::Error;
217
218 fn try_from(value: &sql::Value) -> Result<Self, Self::Error> {
219 match value {
220 sql::Value::Binary(bytes) => {
221 let mut reader = io::Cursor::new(bytes);
222 NodeAnnouncement::decode(&mut reader).map_err(wire::Error::into)
223 }
224 _ => Err(sql::Error {
225 code: None,
226 message: Some("sql: invalid type for node announcement".to_owned()),
227 }),
228 }
229 }
230}
231
232impl sql::BindableWithIndex for &NodeAnnouncement {
233 fn bind<I: sql::ParameterIndex>(self, stmt: &mut sql::Statement<'_>, i: I) -> sql::Result<()> {
234 wire::serialize(self).bind(stmt, i)
235 }
236}
237
238impl TryFrom<&sql::Value> for RefsAnnouncement {
239 type Error = sql::Error;
240
241 fn try_from(value: &sql::Value) -> Result<Self, Self::Error> {
242 match value {
243 sql::Value::Binary(bytes) => {
244 let mut reader = io::Cursor::new(bytes);
245 RefsAnnouncement::decode(&mut reader).map_err(wire::Error::into)
246 }
247 _ => Err(sql::Error {
248 code: None,
249 message: Some("sql: invalid type for refs announcement".to_owned()),
250 }),
251 }
252 }
253}
254
255impl sql::BindableWithIndex for &RefsAnnouncement {
256 fn bind<I: sql::ParameterIndex>(self, stmt: &mut sql::Statement<'_>, i: I) -> sql::Result<()> {
257 wire::serialize(self).bind(stmt, i)
258 }
259}
260
261impl TryFrom<&sql::Value> for InventoryAnnouncement {
262 type Error = sql::Error;
263
264 fn try_from(value: &sql::Value) -> Result<Self, Self::Error> {
265 match value {
266 sql::Value::Binary(bytes) => {
267 let mut reader = io::Cursor::new(bytes);
268 InventoryAnnouncement::decode(&mut reader).map_err(wire::Error::into)
269 }
270 _ => Err(sql::Error {
271 code: None,
272 message: Some("sql: invalid type for inventory announcement".to_owned()),
273 }),
274 }
275 }
276}
277
278impl sql::BindableWithIndex for &InventoryAnnouncement {
279 fn bind<I: sql::ParameterIndex>(self, stmt: &mut sql::Statement<'_>, i: I) -> sql::Result<()> {
280 wire::serialize(self).bind(stmt, i)
281 }
282}
283
284impl From<wire::Error> for sql::Error {
285 fn from(other: wire::Error) -> Self {
286 sql::Error {
287 code: None,
288 message: Some(other.to_string()),
289 }
290 }
291}
292
293#[derive(Debug, Clone, Copy, PartialEq, Eq)]
295pub enum RelayStatus {
296 Relay,
297 DontRelay,
298 RelayedAt(Timestamp),
299}
300
301impl sql::BindableWithIndex for RelayStatus {
302 fn bind<I: sql::ParameterIndex>(self, stmt: &mut sql::Statement<'_>, i: I) -> sql::Result<()> {
303 match self {
304 Self::Relay => sql::Value::Null.bind(stmt, i),
305 Self::DontRelay => sql::Value::Integer(-1).bind(stmt, i),
306 Self::RelayedAt(t) => t.bind(stmt, i),
307 }
308 }
309}
310
311#[derive(Debug, Clone, Copy, PartialEq, Eq)]
313enum GossipType {
314 Refs,
315 Node,
316 Inventory,
317}
318
319impl fmt::Display for GossipType {
320 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
321 match self {
322 Self::Refs => write!(f, "refs"),
323 Self::Node => write!(f, "node"),
324 Self::Inventory => write!(f, "inventory"),
325 }
326 }
327}
328
329impl sql::BindableWithIndex for &GossipType {
330 fn bind<I: sql::ParameterIndex>(self, stmt: &mut sql::Statement<'_>, i: I) -> sql::Result<()> {
331 self.to_string().as_str().bind(stmt, i)
332 }
333}
334
335impl TryFrom<&sql::Value> for GossipType {
336 type Error = sql::Error;
337
338 fn try_from(value: &sql::Value) -> Result<Self, Self::Error> {
339 match value {
340 sql::Value::String(s) => match s.as_str() {
341 "refs" => Ok(Self::Refs),
342 "node" => Ok(Self::Node),
343 "inventory" => Ok(Self::Inventory),
344 other => Err(sql::Error {
345 code: None,
346 message: Some(format!("unknown gossip type '{other}'")),
347 }),
348 },
349 _ => Err(sql::Error {
350 code: None,
351 message: Some("sql: invalid type for gossip type".to_owned()),
352 }),
353 }
354 }
355}
356
357mod parse {
358 use super::*;
359
360 pub fn announcement(row: sql::Row) -> Result<(AnnouncementId, Announcement), Error> {
361 let id = row.read::<i64, _>("rowid") as AnnouncementId;
362 let node = row.read::<NodeId, _>("node");
363 let gt = row.read::<GossipType, _>("type");
364 let message = match gt {
365 GossipType::Refs => {
366 let ann = row.try_read::<RefsAnnouncement, _>("message")?;
367 AnnouncementMessage::Refs(ann)
368 }
369 GossipType::Inventory => {
370 let ann = row.try_read::<InventoryAnnouncement, _>("message")?;
371 AnnouncementMessage::Inventory(ann)
372 }
373 GossipType::Node => {
374 let ann = row.try_read::<NodeAnnouncement, _>("message")?;
375 AnnouncementMessage::Node(ann)
376 }
377 };
378 let signature = row.read::<Signature, _>("signature");
379 let timestamp = row.read::<Timestamp, _>("timestamp");
380
381 debug_assert_eq!(timestamp, message.timestamp());
382
383 Ok((
384 id,
385 Announcement {
386 node,
387 message,
388 signature,
389 },
390 ))
391 }
392}
393
394#[cfg(test)]
395#[allow(clippy::unwrap_used)]
396mod test {
397 use super::*;
398 use crate::prelude::{BoundedVec, RepoId};
399 use crate::test::arbitrary;
400 use localtime::LocalTime;
401 use radicle::assert_matches;
402 use radicle_crypto::test::signer::MockSigner;
403
404 #[test]
405 fn test_announced() {
406 let mut db = Database::memory().unwrap();
407 let nid = arbitrary::gen::<NodeId>(1);
408 let rid = arbitrary::gen::<RepoId>(1);
409 let timestamp = LocalTime::now().into();
410 let signer = MockSigner::default();
411 let refs = AnnouncementMessage::Refs(RefsAnnouncement {
412 rid,
413 refs: BoundedVec::new(),
414 timestamp,
415 })
416 .signed(&signer);
417 let inv = AnnouncementMessage::Inventory(InventoryAnnouncement {
418 inventory: BoundedVec::new(),
419 timestamp,
420 })
421 .signed(&signer);
422
423 let id1 = db.announced(&nid, &refs).unwrap().unwrap();
425 assert!(db.announced(&nid, &refs).unwrap().is_none());
426
427 let id2 = db.announced(&nid, &inv).unwrap().unwrap();
428 assert!(db.announced(&nid, &inv).unwrap().is_none());
429
430 assert_eq!(db.relays(LocalTime::now().into()).unwrap().len(), 0);
432
433 db.set_relay(id1, RelayStatus::Relay).unwrap();
435 db.set_relay(id2, RelayStatus::Relay).unwrap();
436
437 assert_matches!(
439 db.relays(LocalTime::now().into()).unwrap().as_slice(),
440 &[(id1_, _), (id2_, _)]
441 if id1_ == id1 && id2_ == id2
442 );
443 assert_matches!(db.relays(LocalTime::now().into()).unwrap().as_slice(), &[]);
445 }
446}