1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
pub mod driver;

use crate::client::PrivAuth;
use crate::proto::lightning_storage_server::{LightningStorage, LightningStorageServer};
use crate::proto::{
    self, GetReply, GetRequest, InfoReply, InfoRequest, PingReply, PingRequest, PutReply,
    PutRequest,
};
use crate::util::compute_shared_hmac;
use crate::{Database, Error, Value};
use itertools::Itertools;
use log::{debug, error};
use secp256k1::{PublicKey, SecretKey};
use tonic::{Request, Response, Status};

pub struct StorageServer {
    database: Box<dyn Database>,
    public_key: PublicKey,
    secret_key: SecretKey,
}

impl Into<(String, Value)> for proto::KeyValue {
    fn into(self) -> (String, Value) {
        (self.key, Value { version: self.version, value: self.value })
    }
}

// convert a conflict to proto
impl Into<proto::KeyValue> for (String, Option<Value>) {
    fn into(self) -> proto::KeyValue {
        let (key, v) = self;
        let version = v.as_ref().map(|v| v.version).unwrap_or(-1);
        let value = v.as_ref().map(|v| v.value.clone()).unwrap_or_default();
        proto::KeyValue { key, version, value }
    }
}

// convert get result to proto
impl Into<proto::KeyValue> for (String, Value) {
    fn into(self) -> proto::KeyValue {
        let (key, v) = self;
        proto::KeyValue { key, version: v.version, value: v.value }
    }
}

fn into_status(s: Error) -> Status {
    match s {
        Error::Conflict(_) => unimplemented!("unexpected conflict error"),
        e => {
            error!("database error: {:?}", e);
            Status::internal("unexpected error")
        }
    }
}

impl StorageServer {
    fn check_auth(&self, auth_proto: &proto::Auth) -> Result<PrivAuth, Status> {
        let client_id = PublicKey::from_slice(&auth_proto.client_id)
            .map_err(|_| Status::unauthenticated("invalid client id"))?;
        let auth = PrivAuth::new_for_server(&self.secret_key, &client_id);
        if auth_proto.token != auth.auth_token() {
            return Err(Status::invalid_argument("invalid auth token"));
        }
        Ok(auth)
    }
}

#[tonic::async_trait]
impl LightningStorage for StorageServer {
    async fn ping(&self, request: Request<PingRequest>) -> Result<Response<PingReply>, Status> {
        let request = request.into_inner();

        let response = PingReply { message: request.message };
        Ok(Response::new(response))
    }

    async fn info(&self, request: Request<InfoRequest>) -> Result<Response<InfoReply>, Status> {
        let _ = request.into_inner();

        let response = InfoReply {
            version: "0.1".to_string(),
            server_id: self.public_key.serialize().to_vec(),
        };
        Ok(Response::new(response))
    }

    async fn get(&self, request: Request<GetRequest>) -> Result<Response<GetReply>, Status> {
        let request = request.into_inner();
        let auth_proto = request.auth.ok_or_else(|| Status::invalid_argument("missing auth"))?;
        let auth = self.check_auth(&auth_proto)?;
        let client_id = auth_proto.client_id;
        let key_prefix = request.key_prefix;
        debug!("get request({}) {}", hex::encode(&client_id), key_prefix);
        let kvs =
            self.database.get_with_prefix(&client_id, key_prefix).await.map_err(into_status)?;
        debug!("get result {:?}", kvs);
        let hmac = compute_shared_hmac(&auth.shared_secret, &request.nonce, &kvs);
        let kvs_proto = kvs.into_iter().map(|kv| kv.into()).collect();

        let response = GetReply { kvs: kvs_proto, hmac };
        Ok(Response::new(response))
    }

    async fn put(&self, request: Request<PutRequest>) -> Result<Response<PutReply>, Status> {
        let request = request.into_inner();
        let kvs: Vec<_> = request.kvs.into_iter().map(|kv| kv.into()).collect::<Vec<_>>();

        let auth_proto = request.auth.ok_or_else(|| Status::invalid_argument("missing auth"))?;
        let client_id = &auth_proto.client_id;

        debug!("put request({}) {:?}", hex::encode(client_id), kvs);

        for ((k1, _), (k2, _)) in kvs.iter().tuple_windows() {
            if k1 > k2 {
                return Err(Status::invalid_argument("keys are not sorted"));
            }
        }

        let auth = self.check_auth(&auth_proto)?;
        let client_hmac = compute_shared_hmac(&auth.shared_secret, &[0x01], &kvs);

        if client_hmac != request.hmac {
            return Err(Status::invalid_argument("invalid client HMAC"));
        }

        let response = match self.database.put(&client_id, &kvs).await {
            Ok(()) => {
                debug!("put result ok");
                let hmac = compute_shared_hmac(&auth.shared_secret, &[0x02], &kvs);

                PutReply { success: true, hmac, conflicts: vec![] }
            }
            Err(Error::Conflict(conflicts)) => {
                debug!("put result conflict {:?}", conflicts);
                let conflicts = conflicts.into_iter().map(|kv| kv.into()).collect();
                PutReply { success: false, hmac: Default::default(), conflicts }
            }
            Err(e) => {
                error!("database error: {:?}", e);
                return Err(Status::internal("unexpected error"));
            }
        };
        Ok(Response::new(response))
    }
}