apt_swarm/db/
unix.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
use super::proto::{Query, Response, SyncQuery};
use super::DatabaseClient;
use crate::db;
use crate::errors::*;
use crate::signed::Signed;
use crate::sync;
use async_trait::async_trait;
use bstr::BString;
use sequoia_openpgp::Fingerprint;
use std::path::Path;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufStream};
use tokio::net::UnixStream;

pub struct DatabaseUnixClient {
    socket: BufStream<UnixStream>,
}

impl DatabaseUnixClient {
    pub async fn connect(path: &Path) -> Result<Self> {
        let socket = UnixStream::connect(path)
            .await
            .with_context(|| anyhow!("Failed to connect to socket at {path:?}"))?;
        debug!("Connected to unix domain socket at {path:?}");
        let socket = BufStream::new(socket);
        Ok(Self { socket })
    }

    pub async fn send_query(&mut self, q: &Query) -> Result<()> {
        let mut json = serde_json::to_string(q).context("Failed to serialize message as json")?;
        json.push('\n');
        self.socket
            .write_all(json.as_bytes())
            .await
            .context("Failed to send to database server")?;
        self.socket.flush().await?;
        Ok(())
    }

    pub async fn recv_response(&mut self) -> Result<Response> {
        let mut buf = Vec::new();
        self.socket.read_until(b'\n', &mut buf).await?;

        if buf.is_empty() {
            bail!("Database has disconnected without sending a response");
        } else if buf == b"\n" {
            Ok(Response::Ok)
        } else {
            let response = serde_json::from_slice::<Response>(&buf)?;
            match response {
                Response::Error(error) => {
                    bail!("Error from server: {}", error.err);
                }
                _ => Ok(response),
            }
        }
    }
}

#[async_trait]
impl DatabaseClient for DatabaseUnixClient {
    async fn add_release(&mut self, fp: &Fingerprint, signed: &Signed) -> Result<String> {
        self.send_query(&Query::AddRelease(fp.to_string(), signed.clone()))
            .await?;
        let inserted = self.recv_response().await?;
        if let Response::Inserted(hash) = inserted {
            info!("Added release to database: {hash:?}");
            Ok(hash)
        } else {
            bail!("Unexpected response type from database: {inserted:?}");
        }
    }

    async fn index_from_scan(&mut self, query: &sync::Query) -> Result<(String, usize)> {
        self.send_query(&Query::IndexFromScan(SyncQuery {
            fp: query.fp.to_string(),
            hash_algo: query.hash_algo.clone(),
            prefix: query.prefix.clone(),
        }))
        .await?;
        let index = self.recv_response().await?;
        if let Response::Index(index) = index {
            Ok(index)
        } else {
            bail!("Unexpected response type from database: {index:?}");
        }
    }

    async fn spill(&self, _prefix: &[u8]) -> Result<Vec<(db::Key, db::Value)>> {
        todo!("DatabaseUnixClient::spill")
    }

    async fn get_value(&self, _key: &[u8]) -> Result<db::Value> {
        todo!("DatabaseUnixClient::get_value")
    }

    async fn count(&mut self, prefix: &[u8]) -> Result<u64> {
        self.send_query(&Query::Count(BString::new(prefix.to_vec())))
            .await?;
        let count = self.recv_response().await?;
        if let Response::Num(count) = count {
            Ok(count)
        } else {
            bail!("Unexpected response type from database: {count:?}");
        }
    }
}