lightning_storage_server/client/
driver.rs

1use crate::client::auth::{Auth, PrivAuth};
2use crate::client::LightningStorageClient;
3use crate::proto::{self, GetRequest, InfoRequest, PingRequest, PutRequest};
4use crate::util::{compute_shared_hmac, prepare_value_for_put, process_value_from_get};
5use crate::Value;
6use log::{debug, error};
7use secp256k1::rand::rngs::OsRng;
8use secp256k1::rand::RngCore;
9use secp256k1::PublicKey;
10use thiserror::Error;
11use tonic::{transport, Request};
12
13#[derive(Debug, Error)]
14pub enum ClientError {
15    #[error("transport error")]
16    Connect(#[from] transport::Error),
17    #[error("API error")]
18    Tonic(#[from] tonic::Status),
19    #[error("invalid response from server")]
20    InvalidResponse,
21    /// client HMAC integrity error, with string
22    #[error("invalid HMAC for key {0} version {1}")]
23    InvalidHmac(String, i64),
24    /// server HMAC integrity error, with string
25    #[error("invalid server HMAC")]
26    InvalidServerHmac(),
27    #[error("Put had conflicts")]
28    PutConflict(Vec<(String, Value)>),
29}
30
31pub struct Client {
32    client: LightningStorageClient<transport::Channel>,
33    auth: Auth,
34}
35
36impl Client {
37    /// Get the server info
38    pub async fn get_info(uri: &str) -> Result<(PublicKey, String), ClientError> {
39        debug!("info");
40        let mut client = connect(uri).await?;
41        let info_request = Request::new(InfoRequest {});
42
43        let response = client.info(info_request).await?.into_inner();
44        debug!("info result {:?}", response);
45        let pubkey =
46            PublicKey::from_slice(&response.server_id).map_err(|_| ClientError::InvalidResponse)?;
47        let version = response.version;
48        Ok((pubkey, version))
49    }
50
51    pub async fn new(uri: &str, auth: Auth) -> Result<Self, ClientError> {
52        let client = connect(uri).await?;
53        Ok(Self { client, auth })
54    }
55
56    pub async fn ping(uri: &str, message: &str) -> Result<String, ClientError> {
57        debug!("ping");
58        let mut client = connect(uri).await?;
59        let ping_request = Request::new(PingRequest { message: message.into() });
60
61        let response = client.ping(ping_request).await?.into_inner();
62        debug!("ping result {:?}", response);
63        Ok(response.message)
64    }
65
66    pub async fn get(
67        &mut self,
68        key_prefix: String,
69        nonce: &[u8],
70    ) -> Result<(Vec<(String, Value)>, Vec<u8>), ClientError> {
71        let get_request = Request::new(GetRequest {
72            auth: self.make_auth_proto(),
73            key_prefix,
74            nonce: nonce.to_vec(),
75        });
76
77        let response = self.client.get(get_request).await?.into_inner();
78        let kvs = kvs_from_proto(response.kvs);
79
80        Ok((kvs, response.hmac))
81    }
82
83    pub async fn put(
84        &mut self,
85        kvs: Vec<(String, Value)>,
86        client_hmac: &[u8],
87    ) -> Result<Vec<u8>, ClientError> {
88        let kvs_proto = kvs
89            .into_iter()
90            .map(|(k, v)| proto::KeyValue {
91                key: k.clone(),
92                value: v.value.clone(),
93                version: v.version,
94            })
95            .collect();
96
97        let put_request = Request::new(PutRequest {
98            auth: self.make_auth_proto(),
99            kvs: kvs_proto,
100            hmac: client_hmac.to_vec(),
101        });
102
103        let response = self.client.put(put_request).await?.into_inner();
104        debug!("put result {:?}", response);
105
106        if response.success {
107            Ok(response.hmac)
108        } else {
109            let conflicts = kvs_from_proto(response.conflicts);
110            Err(ClientError::PutConflict(conflicts))
111        }
112    }
113
114    fn make_auth_proto(&self) -> Option<proto::Auth> {
115        Some(proto::Auth {
116            client_id: self.auth.client_id.serialize().to_vec(),
117            token: self.auth.auth_token(),
118        })
119    }
120}
121
122pub struct PrivClient {
123    client: Client,
124    auth: PrivAuth,
125}
126
127impl PrivClient {
128    /// Get the server info
129    pub async fn get_info(uri: &str) -> Result<(PublicKey, String), ClientError> {
130        Client::get_info(uri).await
131    }
132
133    pub async fn new(uri: &str, auth: PrivAuth) -> Result<Self, ClientError> {
134        let client = Client::new(uri, auth.auth()).await?;
135        Ok(Self { client, auth })
136    }
137
138    pub async fn ping(uri: &str, message: &str) -> Result<String, ClientError> {
139        Client::ping(uri, message).await
140    }
141
142    pub async fn get(
143        &mut self,
144        hmac_secret: &[u8],
145        key_prefix: String,
146    ) -> Result<Vec<(String, Value)>, ClientError> {
147        let mut nonce = Vec::with_capacity(32);
148        nonce.resize(32, 0);
149        let mut rng = OsRng;
150        rng.fill_bytes(&mut nonce);
151
152        debug!("get request '{}'", key_prefix);
153
154        let (mut kvs, received_hmac) = self.client.get(key_prefix, &nonce).await?;
155        let hmac = compute_shared_hmac(&self.auth.shared_secret, &nonce, &kvs);
156        if received_hmac != hmac {
157            error!("get hmac mismatch");
158            return Err(ClientError::InvalidServerHmac());
159        }
160
161        remove_and_check_hmacs(&hmac_secret, &mut kvs)?;
162        debug!("get result {:?}", kvs);
163        Ok(kvs)
164    }
165
166    /// values do not include HMAC
167    pub async fn put(
168        &mut self,
169        hmac_secret: &[u8],
170        mut kvs: Vec<(String, Value)>,
171    ) -> Result<(), ClientError> {
172        debug!("put request {:?}", kvs);
173        kvs.sort_by_key(|(k, _)| k.clone());
174        for (key, value) in kvs.iter_mut() {
175            prepare_value_for_put(hmac_secret, key, value);
176        }
177
178        let client_hmac = compute_shared_hmac(&self.auth.shared_secret, &[0x01], &kvs);
179
180        let server_hmac = compute_shared_hmac(&self.auth.shared_secret, &[0x02], &kvs);
181
182        match self.client.put(kvs, &client_hmac).await {
183            Ok(received_server_hmac) =>
184                if received_server_hmac == server_hmac {
185                    return Ok(());
186                } else {
187                    error!("put hmac mismatch");
188                    return Err(ClientError::InvalidServerHmac());
189                },
190            Err(ClientError::PutConflict(mut conflicts)) => {
191                remove_and_check_hmacs(&hmac_secret, &mut conflicts)?;
192                error!("put conflicts {:?}", conflicts);
193                Err(ClientError::PutConflict(conflicts))
194            }
195            Err(e) => Err(e),
196        }
197    }
198}
199
200async fn connect(uri: &str) -> Result<LightningStorageClient<transport::Channel>, ClientError> {
201    debug!("connect to {}", uri.to_string());
202    let uri_clone = String::from(uri);
203    Ok(LightningStorageClient::connect(uri_clone).await?)
204}
205
206fn kvs_from_proto(conflicts_proto: Vec<proto::KeyValue>) -> Vec<(String, Value)> {
207    conflicts_proto
208        .into_iter()
209        .map(|kv| (kv.key, Value { version: kv.version, value: kv.value }))
210        .collect()
211}
212
213fn remove_and_check_hmacs(
214    hmac_secret: &[u8],
215    kvs: &mut Vec<(String, Value)>,
216) -> Result<(), ClientError> {
217    for (key, value) in kvs.iter_mut() {
218        process_value_from_get(hmac_secret, key, value)
219            .map_err(|()| ClientError::InvalidHmac(key.clone(), value.version))?;
220    }
221    Ok(())
222}