apt_swarm/db/
channel.rs

1use super::{Database, DatabaseClient};
2use crate::db;
3use crate::errors::*;
4use crate::signed::Signed;
5use crate::sync;
6use async_trait::async_trait;
7use sequoia_openpgp::Fingerprint;
8use tokio::sync::mpsc;
9
10pub enum Query {
11    AddRelease(Fingerprint, Signed, mpsc::Sender<String>),
12    IndexFromScan(sync::TreeQuery, mpsc::Sender<(String, usize)>),
13    Spill(Vec<u8>, mpsc::Sender<Vec<(db::Key, db::Value)>>),
14    GetValue(Vec<u8>, mpsc::Sender<db::Value>),
15    // Delete(Vec<u8>, mpsc::Sender<()>),
16    Count(Vec<u8>, mpsc::Sender<u64>),
17}
18
19#[derive(Debug)]
20pub struct DatabaseServer {
21    db: Database,
22    rx: mpsc::Receiver<Query>,
23}
24
25impl DatabaseServer {
26    pub fn new(db: Database) -> (DatabaseServer, DatabaseServerClient) {
27        let (tx, rx) = mpsc::channel(32);
28
29        let server = DatabaseServer { db, rx };
30        let client = DatabaseServerClient { tx };
31
32        (server, client)
33    }
34
35    pub async fn run(&mut self) -> Result<()> {
36        while let Some(msg) = self.rx.recv().await {
37            match msg {
38                Query::AddRelease(fp, signed, tx) => {
39                    let hash = self.db.add_release(&fp, &signed).await?;
40                    tx.send(hash).await.ok();
41                }
42                Query::IndexFromScan(query, tx) => {
43                    let ret = self.db.index_from_scan(&query).await?;
44                    tx.send(ret).await.ok();
45                }
46                Query::Spill(prefix, tx) => {
47                    let ret = self.db.spill(&prefix).await?;
48                    tx.send(ret).await.ok();
49                }
50                Query::GetValue(key, tx) => {
51                    let ret = self.db.get_value(&key).await?;
52                    tx.send(ret).await.ok();
53                }
54                Query::Count(key, tx) => {
55                    let ret = self.db.count(&key).await?;
56                    tx.send(ret).await.ok();
57                }
58            }
59        }
60        Ok(())
61    }
62}
63
64#[derive(Debug, Clone)]
65pub struct DatabaseServerClient {
66    tx: mpsc::Sender<Query>,
67}
68
69impl DatabaseServerClient {
70    async fn request<T>(&self, query: Query, mut rx: mpsc::Receiver<T>) -> Result<T> {
71        self.tx
72            .send(query)
73            .await
74            .map_err(|_| anyhow!("Database server disconnected"))?;
75        let ret = rx.recv().await.context("Database server disconnected")?;
76        Ok(ret)
77    }
78}
79
80#[async_trait]
81impl DatabaseClient for DatabaseServerClient {
82    async fn add_release(&mut self, fp: &Fingerprint, signed: &Signed) -> Result<String> {
83        let (tx, rx) = mpsc::channel(1);
84        let query = Query::AddRelease(fp.clone(), signed.clone(), tx);
85        self.request(query, rx).await
86    }
87
88    async fn index_from_scan(&mut self, query: &sync::TreeQuery) -> Result<(String, usize)> {
89        let (tx, rx) = mpsc::channel(1);
90        let query = Query::IndexFromScan(query.clone(), tx);
91        self.request(query, rx).await
92    }
93
94    async fn spill(&self, prefix: &[u8]) -> Result<Vec<(db::Key, db::Value)>> {
95        let (tx, rx) = mpsc::channel(1);
96        let query = Query::Spill(prefix.to_vec(), tx);
97        self.request(query, rx).await
98    }
99
100    async fn get_value(&self, key: &[u8]) -> Result<db::Value> {
101        let (tx, rx) = mpsc::channel(1);
102        let query = Query::GetValue(key.to_vec(), tx);
103        self.request(query, rx).await
104    }
105
106    async fn count(&mut self, key: &[u8]) -> Result<u64> {
107        let (tx, rx) = mpsc::channel(1);
108        let query = Query::Count(key.to_vec(), tx);
109        self.request(query, rx).await
110    }
111}