apt_swarm/db/
unix.rs

1use super::proto::{Query, Response, SyncQuery};
2use super::DatabaseClient;
3use crate::db;
4use crate::errors::*;
5use crate::signed::Signed;
6use crate::sync;
7use async_trait::async_trait;
8use bstr::BString;
9use sequoia_openpgp::Fingerprint;
10use std::path::Path;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufStream};
12use tokio::net::UnixStream;
13
14pub struct DatabaseUnixClient {
15    socket: BufStream<UnixStream>,
16}
17
18impl DatabaseUnixClient {
19    pub async fn connect(path: &Path) -> Result<Self> {
20        let socket = UnixStream::connect(path)
21            .await
22            .with_context(|| anyhow!("Failed to connect to socket at {path:?}"))?;
23        debug!("Connected to unix domain socket at {path:?}");
24        let socket = BufStream::new(socket);
25        Ok(Self { socket })
26    }
27
28    pub async fn send_query(&mut self, q: &Query) -> Result<()> {
29        let mut json = serde_json::to_string(q).context("Failed to serialize message as json")?;
30        json.push('\n');
31        self.socket
32            .write_all(json.as_bytes())
33            .await
34            .context("Failed to send to database server")?;
35        self.socket.flush().await?;
36        Ok(())
37    }
38
39    pub async fn recv_response(&mut self) -> Result<Response> {
40        let mut buf = Vec::new();
41        self.socket.read_until(b'\n', &mut buf).await?;
42
43        if buf.is_empty() {
44            bail!("Database has disconnected without sending a response");
45        } else if buf == b"\n" {
46            Ok(Response::Ok)
47        } else {
48            let response = serde_json::from_slice::<Response>(&buf)?;
49            match response {
50                Response::Error(error) => {
51                    bail!("Error from server: {}", error.err);
52                }
53                _ => Ok(response),
54            }
55        }
56    }
57}
58
59#[async_trait]
60impl DatabaseClient for DatabaseUnixClient {
61    async fn add_release(&mut self, fp: &Fingerprint, signed: &Signed) -> Result<String> {
62        self.send_query(&Query::AddRelease(fp.to_string(), signed.clone()))
63            .await?;
64        let inserted = self.recv_response().await?;
65        if let Response::Inserted(hash) = inserted {
66            info!("Added release to database: {hash:?}");
67            Ok(hash)
68        } else {
69            bail!("Unexpected response type from database: {inserted:?}");
70        }
71    }
72
73    async fn index_from_scan(&mut self, query: &sync::TreeQuery) -> Result<(String, usize)> {
74        self.send_query(&Query::IndexFromScan(SyncQuery {
75            fp: query.fp.to_string(),
76            hash_algo: query.hash_algo.clone(),
77            prefix: query.prefix.clone(),
78        }))
79        .await?;
80        let index = self.recv_response().await?;
81        if let Response::Index(index) = index {
82            Ok(index)
83        } else {
84            bail!("Unexpected response type from database: {index:?}");
85        }
86    }
87
88    async fn spill(&self, _prefix: &[u8]) -> Result<Vec<(db::Key, db::Value)>> {
89        todo!("DatabaseUnixClient::spill")
90    }
91
92    async fn get_value(&self, _key: &[u8]) -> Result<db::Value> {
93        todo!("DatabaseUnixClient::get_value")
94    }
95
96    async fn count(&mut self, prefix: &[u8]) -> Result<u64> {
97        self.send_query(&Query::Count(BString::new(prefix.to_vec())))
98            .await?;
99        let count = self.recv_response().await?;
100        if let Response::Num(count) = count {
101            Ok(count)
102        } else {
103            bail!("Unexpected response type from database: {count:?}");
104        }
105    }
106}