1use std::num::TryFromIntError;
2use std::{fmt, io};
3
4use radicle::crypto::Signature;
5use sqlite as sql;
6use thiserror::Error;
7
8use crate::service::filter::Filter;
9use crate::service::message::{
10 Announcement, AnnouncementMessage, InventoryAnnouncement, NodeAnnouncement, RefsAnnouncement,
11};
12use crate::wire;
13use crate::wire::{Decode as _, Encode as _};
14use radicle::node::Database;
15use radicle::node::NodeId;
16use radicle::prelude::Timestamp;
17
18#[derive(Error, Debug)]
19pub enum Error {
20 #[error("internal error: {0}")]
22 Internal(#[from] sql::Error),
23 #[error("unit overflow:: {0}")]
25 UnitOverflow(#[from] TryFromIntError),
26}
27
28pub type AnnouncementId = u64;
30
31pub trait Store {
35 fn prune(&mut self, cutoff: Timestamp) -> Result<usize, Error>;
37
38 fn last(&self) -> Result<Option<Timestamp>, Error>;
40
41 fn announced(
44 &mut self,
45 nid: &NodeId,
46 ann: &Announcement,
47 ) -> Result<Option<AnnouncementId>, Error>;
48
49 fn set_relay(&mut self, id: AnnouncementId, relay: RelayStatus) -> Result<(), Error>;
51
52 fn relays(&mut self, now: Timestamp) -> Result<Vec<(AnnouncementId, Announcement)>, Error>;
54
55 fn filtered<'a>(
63 &'a self,
64 filter: &'a Filter,
65 from: Timestamp,
66 to: Timestamp,
67 ) -> Result<Box<dyn Iterator<Item = Result<Announcement, Error>> + 'a>, Error>;
68}
69
70impl Store for Database {
71 fn prune(&mut self, cutoff: Timestamp) -> Result<usize, Error> {
72 let mut stmt = self
73 .db
74 .prepare("DELETE FROM `announcements` WHERE timestamp < ?1")?;
75
76 stmt.bind((1, &cutoff))?;
77 stmt.next()?;
78
79 Ok(self.db.change_count())
80 }
81
82 fn last(&self) -> Result<Option<Timestamp>, Error> {
83 let stmt = self
84 .db
85 .prepare("SELECT MAX(timestamp) AS latest FROM `announcements`")?;
86
87 if let Some(Ok(row)) = stmt.into_iter().next() {
88 return match row.try_read::<Option<i64>, _>(0)? {
89 Some(i) => Ok(Some(Timestamp::try_from(i)?)),
90 None => Ok(None),
91 };
92 }
93 Ok(None)
94 }
95
96 fn announced(
97 &mut self,
98 nid: &NodeId,
99 ann: &Announcement,
100 ) -> Result<Option<AnnouncementId>, Error> {
101 assert_ne!(
102 ann.timestamp(),
103 Timestamp::MIN,
104 "Timestamp of {ann:?} must not be zero"
105 );
106 let mut stmt = self.db.prepare(
107 "INSERT INTO `announcements` (node, repo, type, message, signature, timestamp)
108 VALUES (?1, ?2, ?3, ?4, ?5, ?6)
109 ON CONFLICT DO UPDATE
110 SET message = ?4, signature = ?5, timestamp = ?6
111 WHERE timestamp < ?6
112 RETURNING rowid",
113 )?;
114 stmt.bind((1, nid))?;
115
116 match &ann.message {
117 AnnouncementMessage::Node(msg) => {
118 stmt.bind((2, sql::Value::String(String::new())))?;
119 stmt.bind((3, &GossipType::Node))?;
120 stmt.bind((4, &msg.encode_to_vec()[..]))?;
121 }
122 AnnouncementMessage::Refs(msg) => {
123 stmt.bind((2, &msg.rid))?;
124 stmt.bind((3, &GossipType::Refs))?;
125 stmt.bind((4, &msg.encode_to_vec()[..]))?;
126 }
127 AnnouncementMessage::Inventory(msg) => {
128 stmt.bind((2, sql::Value::String(String::new())))?;
129 stmt.bind((3, &GossipType::Inventory))?;
130 stmt.bind((4, &msg.encode_to_vec()[..]))?;
131 }
132 }
133 stmt.bind((5, &ann.signature))?;
134 stmt.bind((6, &ann.message.timestamp()))?;
135
136 if let Some(row) = stmt.into_iter().next() {
137 let row = row?;
138 let id = row.try_read::<i64, _>("rowid")?;
139
140 Ok(Some(id as AnnouncementId))
141 } else {
142 Ok(None)
143 }
144 }
145
146 fn set_relay(&mut self, id: AnnouncementId, relay: RelayStatus) -> Result<(), Error> {
147 let mut stmt = self.db.prepare(
148 "UPDATE announcements
149 SET relay = ?1
150 WHERE rowid = ?2",
151 )?;
152 stmt.bind((1, relay))?;
153 stmt.bind((2, id as i64))?;
154 stmt.next()?;
155
156 Ok(())
157 }
158
159 fn relays(&mut self, now: Timestamp) -> Result<Vec<(AnnouncementId, Announcement)>, Error> {
160 let mut stmt = self.db.prepare(
161 "UPDATE announcements
162 SET relay = ?1
163 WHERE relay IS ?2
164 RETURNING rowid, node, type, message, signature, timestamp",
165 )?;
166 stmt.bind((1, RelayStatus::RelayedAt(now)))?;
167 stmt.bind((2, RelayStatus::Relay))?;
168
169 let mut rows = stmt
170 .into_iter()
171 .map(|row| {
172 let row = row?;
173 parse::announcement(row)
174 })
175 .collect::<Result<Vec<_>, _>>()?;
176
177 rows.sort_by_key(|(id, _)| *id);
180
181 Ok(rows)
182 }
183
184 fn filtered<'a>(
185 &'a self,
186 filter: &'a Filter,
187 from: Timestamp,
188 to: Timestamp,
189 ) -> Result<Box<dyn Iterator<Item = Result<Announcement, Error>> + 'a>, Error> {
190 let mut stmt = self.db.prepare(
191 "SELECT rowid, node, type, message, signature, timestamp
192 FROM announcements
193 WHERE timestamp >= ?1 and timestamp < ?2
194 ORDER BY timestamp, node, type",
195 )?;
196 assert!(*from <= *to);
197
198 stmt.bind((1, &from))?;
199 stmt.bind((2, &to))?;
200
201 Ok(Box::new(
202 stmt.into_iter()
203 .map(|row| {
204 let row = row?;
205 let (_, ann) = parse::announcement(row)?;
206
207 Ok(ann)
208 })
209 .filter(|ann| match ann {
210 Ok(a) => a.matches(filter),
211 Err(_) => true,
212 }),
213 ))
214 }
215}
216
217impl TryFrom<&sql::Value> for NodeAnnouncement {
218 type Error = sql::Error;
219
220 fn try_from(value: &sql::Value) -> Result<Self, Self::Error> {
221 match value {
222 sql::Value::Binary(bytes) => {
223 let mut reader = io::Cursor::new(bytes);
224 NodeAnnouncement::decode(&mut reader).map_err(wire::Error::into)
225 }
226 _ => Err(sql::Error {
227 code: None,
228 message: Some("sql: invalid type for node announcement".to_owned()),
229 }),
230 }
231 }
232}
233
234impl TryFrom<&sql::Value> for RefsAnnouncement {
235 type Error = sql::Error;
236
237 fn try_from(value: &sql::Value) -> Result<Self, Self::Error> {
238 match value {
239 sql::Value::Binary(bytes) => {
240 let mut reader = io::Cursor::new(bytes);
241 RefsAnnouncement::decode(&mut reader).map_err(wire::Error::into)
242 }
243 _ => Err(sql::Error {
244 code: None,
245 message: Some("sql: invalid type for refs announcement".to_owned()),
246 }),
247 }
248 }
249}
250
251impl TryFrom<&sql::Value> for InventoryAnnouncement {
252 type Error = sql::Error;
253
254 fn try_from(value: &sql::Value) -> Result<Self, Self::Error> {
255 match value {
256 sql::Value::Binary(bytes) => {
257 let mut reader = io::Cursor::new(bytes);
258 InventoryAnnouncement::decode(&mut reader).map_err(wire::Error::into)
259 }
260 _ => Err(sql::Error {
261 code: None,
262 message: Some("sql: invalid type for inventory announcement".to_owned()),
263 }),
264 }
265 }
266}
267
268impl From<wire::Error> for sql::Error {
269 fn from(other: wire::Error) -> Self {
270 sql::Error {
271 code: None,
272 message: Some(other.to_string()),
273 }
274 }
275}
276
277#[derive(Debug, Clone, Copy, PartialEq, Eq)]
279pub enum RelayStatus {
280 Relay,
281 DontRelay,
282 RelayedAt(Timestamp),
283}
284
285impl sql::BindableWithIndex for RelayStatus {
286 fn bind<I: sql::ParameterIndex>(self, stmt: &mut sql::Statement<'_>, i: I) -> sql::Result<()> {
287 match self {
288 Self::Relay => sql::Value::Null.bind(stmt, i),
289 Self::DontRelay => sql::Value::Integer(-1).bind(stmt, i),
290 Self::RelayedAt(t) => t.bind(stmt, i),
291 }
292 }
293}
294
295#[derive(Debug, Clone, Copy, PartialEq, Eq)]
297enum GossipType {
298 Refs,
299 Node,
300 Inventory,
301}
302
303impl fmt::Display for GossipType {
304 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
305 match self {
306 Self::Refs => write!(f, "refs"),
307 Self::Node => write!(f, "node"),
308 Self::Inventory => write!(f, "inventory"),
309 }
310 }
311}
312
313impl sql::BindableWithIndex for &GossipType {
314 fn bind<I: sql::ParameterIndex>(self, stmt: &mut sql::Statement<'_>, i: I) -> sql::Result<()> {
315 self.to_string().as_str().bind(stmt, i)
316 }
317}
318
319impl TryFrom<&sql::Value> for GossipType {
320 type Error = sql::Error;
321
322 fn try_from(value: &sql::Value) -> Result<Self, Self::Error> {
323 match value {
324 sql::Value::String(s) => match s.as_str() {
325 "refs" => Ok(Self::Refs),
326 "node" => Ok(Self::Node),
327 "inventory" => Ok(Self::Inventory),
328 other => Err(sql::Error {
329 code: None,
330 message: Some(format!("unknown gossip type '{other}'")),
331 }),
332 },
333 _ => Err(sql::Error {
334 code: None,
335 message: Some("sql: invalid type for gossip type".to_owned()),
336 }),
337 }
338 }
339}
340
341mod parse {
342 use super::*;
343
344 pub fn announcement(row: sql::Row) -> Result<(AnnouncementId, Announcement), Error> {
345 let id = row.try_read::<i64, _>("rowid")? as AnnouncementId;
346 let node = row.try_read::<NodeId, _>("node")?;
347 let gt = row.try_read::<GossipType, _>("type")?;
348 let message = match gt {
349 GossipType::Refs => {
350 let ann = row.try_read::<RefsAnnouncement, _>("message")?;
351 AnnouncementMessage::Refs(ann)
352 }
353 GossipType::Inventory => {
354 let ann = row.try_read::<InventoryAnnouncement, _>("message")?;
355 AnnouncementMessage::Inventory(ann)
356 }
357 GossipType::Node => {
358 let ann = row.try_read::<NodeAnnouncement, _>("message")?;
359 AnnouncementMessage::Node(ann)
360 }
361 };
362 let signature = row.try_read::<Signature, _>("signature")?;
363 let timestamp = row.try_read::<Timestamp, _>("timestamp")?;
364
365 debug_assert_eq!(timestamp, message.timestamp());
366
367 Ok((
368 id,
369 Announcement {
370 node,
371 message,
372 signature,
373 },
374 ))
375 }
376}
377
378#[cfg(test)]
379#[allow(clippy::unwrap_used)]
380mod test {
381 use super::*;
382 use crate::bounded::BoundedVec;
383 use localtime::LocalTime;
384 use radicle::assert_matches;
385 use radicle::identity::RepoId;
386 use radicle::node::device::Device;
387 use radicle::test::arbitrary;
388
389 #[test]
390 fn test_announced() {
391 let mut db = Database::memory().unwrap();
392 let nid = arbitrary::gen::<NodeId>(1);
393 let rid = arbitrary::gen::<RepoId>(1);
394 let timestamp = LocalTime::now().into();
395 let signer = Device::mock();
396 let refs = AnnouncementMessage::Refs(RefsAnnouncement {
397 rid,
398 refs: BoundedVec::new(),
399 timestamp,
400 })
401 .signed(&signer);
402 let inv = AnnouncementMessage::Inventory(InventoryAnnouncement {
403 inventory: BoundedVec::new(),
404 timestamp,
405 })
406 .signed(&signer);
407
408 let id1 = db.announced(&nid, &refs).unwrap().unwrap();
410 assert!(db.announced(&nid, &refs).unwrap().is_none());
411
412 let id2 = db.announced(&nid, &inv).unwrap().unwrap();
413 assert!(db.announced(&nid, &inv).unwrap().is_none());
414
415 assert_eq!(db.relays(LocalTime::now().into()).unwrap().len(), 0);
417
418 db.set_relay(id1, RelayStatus::Relay).unwrap();
420 db.set_relay(id2, RelayStatus::Relay).unwrap();
421
422 assert_matches!(
424 db.relays(LocalTime::now().into()).unwrap().as_slice(),
425 &[(id1_, _), (id2_, _)]
426 if id1_ == id1 && id2_ == id2
427 );
428 assert_matches!(db.relays(LocalTime::now().into()).unwrap().as_slice(), &[]);
430 }
431}