lssd/
lib.rs

1pub mod database;
2pub mod driver;
3pub use database::{Database, Error};
4pub mod util;
5
6use itertools::Itertools;
7use lightning_storage_server::client::PrivAuth;
8use lightning_storage_server::proto::lightning_storage_server::{
9    LightningStorage, LightningStorageServer,
10};
11use lightning_storage_server::proto::{
12    self, GetReply, GetRequest, InfoReply, InfoRequest, PingReply, PingRequest, PutReply,
13    PutRequest,
14};
15use lightning_storage_server::util::compute_shared_hmac;
16use log::{debug, error};
17use secp256k1::{PublicKey, SecretKey};
18use tonic::{Request, Response, Status};
19
20pub struct StorageServer {
21    database: Box<dyn Database>,
22    public_key: PublicKey,
23    secret_key: SecretKey,
24}
25
26fn into_status(s: Error) -> Status {
27    match s {
28        Error::Conflict(_) => unimplemented!("unexpected conflict error"),
29        e => {
30            error!("database error: {:?}", e);
31            Status::internal("unexpected error")
32        }
33    }
34}
35
36impl StorageServer {
37    fn check_auth(&self, auth_proto: &proto::Auth) -> Result<PrivAuth, Status> {
38        let client_id = PublicKey::from_slice(&auth_proto.client_id)
39            .map_err(|_| Status::unauthenticated("invalid client id"))?;
40        let auth = PrivAuth::new_for_server(&self.secret_key, &client_id);
41        if auth_proto.token != auth.auth_token() {
42            return Err(Status::invalid_argument("invalid auth token"));
43        }
44        Ok(auth)
45    }
46}
47
48#[tonic::async_trait]
49impl LightningStorage for StorageServer {
50    async fn ping(&self, request: Request<PingRequest>) -> Result<Response<PingReply>, Status> {
51        let request = request.into_inner();
52
53        let response = PingReply { message: request.message };
54        Ok(Response::new(response))
55    }
56
57    async fn info(&self, request: Request<InfoRequest>) -> Result<Response<InfoReply>, Status> {
58        let _ = request.into_inner();
59
60        let response = InfoReply {
61            version: "0.1".to_string(),
62            server_id: self.public_key.serialize().to_vec(),
63        };
64        Ok(Response::new(response))
65    }
66
67    async fn get(&self, request: Request<GetRequest>) -> Result<Response<GetReply>, Status> {
68        let request = request.into_inner();
69        let auth_proto = request.auth.ok_or_else(|| Status::invalid_argument("missing auth"))?;
70        let auth = self.check_auth(&auth_proto)?;
71        let client_id = auth_proto.client_id;
72        let key_prefix = request.key_prefix;
73        debug!("get request({}) {}", hex::encode(&client_id), key_prefix);
74        let kvs =
75            self.database.get_with_prefix(&client_id, key_prefix).await.map_err(into_status)?;
76        debug!("get result {:?}", kvs);
77        let hmac = compute_shared_hmac(&auth.shared_secret, &request.nonce, &kvs);
78        let kvs_proto = kvs.into_iter().map(|kv| kv.into()).collect();
79
80        let response = GetReply { kvs: kvs_proto, hmac };
81        Ok(Response::new(response))
82    }
83
84    async fn put(&self, request: Request<PutRequest>) -> Result<Response<PutReply>, Status> {
85        let request = request.into_inner();
86        let kvs: Vec<_> = request.kvs.into_iter().map(|kv| kv.into()).collect::<Vec<_>>();
87
88        let auth_proto = request.auth.ok_or_else(|| Status::invalid_argument("missing auth"))?;
89        let client_id = &auth_proto.client_id;
90
91        debug!("put request({}) {:?}", hex::encode(client_id), kvs);
92
93        for ((k1, _), (k2, _)) in kvs.iter().tuple_windows() {
94            if k1 > k2 {
95                return Err(Status::invalid_argument("keys are not sorted"));
96            }
97        }
98
99        let auth = self.check_auth(&auth_proto)?;
100        let client_hmac = compute_shared_hmac(&auth.shared_secret, &[0x01], &kvs);
101
102        if client_hmac != request.hmac {
103            return Err(Status::invalid_argument("invalid client HMAC"));
104        }
105
106        let response = match self.database.put(&client_id, &kvs).await {
107            Ok(()) => {
108                debug!("put result ok");
109                let hmac = compute_shared_hmac(&auth.shared_secret, &[0x02], &kvs);
110
111                PutReply { success: true, hmac, conflicts: vec![] }
112            }
113            Err(Error::Conflict(conflicts)) => {
114                debug!("put result conflict {:?}", conflicts);
115                let conflicts = conflicts.into_iter().map(|kv| kv.into()).collect();
116                PutReply { success: false, hmac: Default::default(), conflicts }
117            }
118            Err(e) => {
119                error!("database error: {:?}", e);
120                return Err(Status::internal("unexpected error"));
121            }
122        };
123        Ok(Response::new(response))
124    }
125}