1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
use std::sync::Arc;

use crate::{
    channel::Message,
    error::Error,
    hub::Hub,
    permission::ChannelPermission,
    server::{Server, ServerNotification},
};
use crate::{server::client_command, ID};
use crate::{server::HubUpdateType, signing::KeyPair};
use futures_util::{SinkExt, StreamExt};
use pgp::{crypto::HashAlgorithm, types::CompressionAlgorithm, Message as OpenPGPMessage};
use pgp::{packet::LiteralData, types::KeyTrait, SignedPublicKey};
use tokio::sync::Mutex;
use warp::ws::WebSocket;
use xactor::Addr;

use crate::error::Result;
use serde::{Deserialize, Serialize};

pub use warp::ws::Message as WebSocketMessage;

/// Messages that can be sent to the server by the client
#[derive(Serialize, Deserialize, Clone, Debug)]
pub enum ClientMessage {
    SubscribeHub {
        hub_id: ID,
    },
    UnsubscribeHub {
        hub_id: ID,
    },
    SubscribeChannel {
        hub_id: ID,
        channel_id: ID,
    },
    UnsubscribeChannel {
        hub_id: ID,
        channel_id: ID,
    },
    StartTyping {
        hub_id: ID,
        channel_id: ID,
    },
    StopTyping {
        hub_id: ID,
        channel_id: ID,
    },
    SendMessageInit {
        hub_id: ID,
        channel_id: ID,
        content: String,
    },
    SendMessage {
        signed_message: String,
    },
}

/// Messages that the server can send to clients.
#[derive(Serialize, Deserialize, Clone, Debug)]
pub enum ServerMessage {
    Error(String),
    InvalidCommand,
    NotSigned,
    CommandFailed,
    ChatMessage {
        hub_id: ID,
        channel_id: ID,
        message_id: ID,
        armoured_message: String,
    },
    HubUpdated {
        hub_id: ID,
        update_type: HubUpdateType,
    },
    Success,
    UserStartedTyping {
        user_id: String,
        hub_id: ID,
        channel_id: ID,
    },
    UserStoppedTyping {
        user_id: String,
        hub_id: ID,
        channel_id: ID,
    },
    MessageForSigning {
        server_signed_message: String,
    },
}

pub async fn handle_connection(
    websocket: WebSocket,
    public_key: SignedPublicKey,
    server_keys: Arc<KeyPair>,
    addr: Arc<Addr<Server>>,
) -> Result {
    let (mut outgoing, mut incoming) = websocket.split();
    let key = rand::random::<u128>().to_string();
    let message = OpenPGPMessage::Literal(LiteralData::from_str("auth_key", &key)).sign(
        &server_keys.secret_key,
        String::new,
        HashAlgorithm::SHA2_256,
    )?;
    outgoing
        .send(WebSocketMessage::text(message.to_armored_string(None)?))
        .await?;

    if let Some(msg) = incoming.next().await {
        let msg = msg?;
        if let Ok(text) = msg.to_str() {
            let message = crate::signing::verify_message_extract(&public_key, text)?.0;
            if message == key {
                drop((message, key, text));
                drop(msg);
                let out_arc = Arc::new(Mutex::new(outgoing));
                let connection_id: u128;
                {
                    let result = addr
                        .call(client_command::Connect {
                            websocket_writer: out_arc.clone(),
                        })
                        .await
                        .map_err(|_| Error::InternalMessageFailed)?;
                    connection_id = result;
                }
                let user_id = hex::encode_upper(public_key.fingerprint());
                let internal_message_error = Error::InternalMessageFailed.to_string();
                while let Some(msg) = incoming.next().await {
                    let msg = msg?;
                    if let Ok(text) = msg.to_str() {
                        let raw_response = if let Ok((command_text, _)) =
                            crate::signing::verify_message_extract(&public_key, text)
                        {
                            if let Ok(command) = serde_json::from_str(&command_text) {
                                match command {
                                    ClientMessage::SubscribeChannel { hub_id, channel_id } => {
                                        if let Ok(result) = addr
                                            .call(client_command::SubscribeChannel {
                                                user_id: user_id.clone(),
                                                hub_id,
                                                channel_id,
                                                connection_id,
                                            })
                                            .await
                                        {
                                            result.map_or_else(
                                                |err| ServerMessage::Error(err.to_string()),
                                                |_| ServerMessage::Success,
                                            )
                                        } else {
                                            ServerMessage::Error(internal_message_error.clone())
                                        }
                                    }
                                    ClientMessage::UnsubscribeChannel { hub_id, channel_id } => {
                                        if addr
                                            .call(client_command::UnsubscribeChannel {
                                                hub_id,
                                                channel_id,
                                                connection_id,
                                            })
                                            .await
                                            .is_ok()
                                        {
                                            ServerMessage::Success
                                        } else {
                                            ServerMessage::Error(internal_message_error.clone())
                                        }
                                    }
                                    ClientMessage::StartTyping { hub_id, channel_id } => {
                                        if let Ok(result) = addr
                                            .call(client_command::StartTyping {
                                                user_id: user_id.clone(),
                                                hub_id,
                                                channel_id,
                                            })
                                            .await
                                        {
                                            result.map_or_else(
                                                |err| ServerMessage::Error(err.to_string()),
                                                |_| ServerMessage::Success,
                                            )
                                        } else {
                                            ServerMessage::Error(internal_message_error.clone())
                                        }
                                    }
                                    ClientMessage::StopTyping { hub_id, channel_id } => {
                                        if let Ok(result) = addr
                                            .call(client_command::StopTyping {
                                                user_id: user_id.clone(),
                                                hub_id,
                                                channel_id,
                                            })
                                            .await
                                        {
                                            result.map_or_else(
                                                |err| ServerMessage::Error(err.to_string()),
                                                |_| ServerMessage::Success,
                                            )
                                        } else {
                                            ServerMessage::Error(internal_message_error.clone())
                                        }
                                    }
                                    ClientMessage::SubscribeHub { hub_id } => {
                                        if let Ok(result) = addr
                                            .call(client_command::SubscribeHub {
                                                user_id: user_id.clone(),
                                                hub_id,
                                                connection_id,
                                            })
                                            .await
                                        {
                                            result.map_or_else(
                                                |err| ServerMessage::Error(err.to_string()),
                                                |_| ServerMessage::Success,
                                            )
                                        } else {
                                            ServerMessage::Error(internal_message_error.clone())
                                        }
                                    }
                                    ClientMessage::UnsubscribeHub { hub_id } => {
                                        if addr
                                            .call(client_command::UnsubscribeHub {
                                                hub_id,
                                                connection_id,
                                            })
                                            .await
                                            .is_ok()
                                        {
                                            ServerMessage::Success
                                        } else {
                                            ServerMessage::Error(internal_message_error.clone())
                                        }
                                    }
                                    ClientMessage::SendMessageInit {
                                        hub_id,
                                        channel_id,
                                        content,
                                    } => {
                                        let hub = Hub::load(hub_id).await?;
                                        let member = hub.get_member(&user_id)?;
                                        crate::check_permission!(
                                            &member,
                                            channel_id,
                                            ChannelPermission::Write,
                                            &hub
                                        );
                                        ServerMessage::MessageForSigning {
                                            server_signed_message: Message::new(
                                                user_id.clone(),
                                                content,
                                                hub_id,
                                                channel_id,
                                            )
                                            .sign(&server_keys.secret_key, String::new)?
                                            .compress(CompressionAlgorithm::ZIP)?
                                            .to_armored_string(None)?,
                                        }
                                    }
                                    ClientMessage::SendMessage { signed_message } => {
                                        let message = Message::from_double_signed_verify(
                                            &signed_message,
                                            &server_keys.public_key,
                                            &public_key,
                                        )?;
                                        if let Err(err) = crate::channel::Channel::write_message(
                                            message.hub_id,
                                            message.channel_id,
                                            crate::channel::SignedMessage::new(
                                                message.id,
                                                message.created,
                                                signed_message.clone(),
                                            ),
                                        )
                                        .await
                                        {
                                            ServerMessage::Error(err.to_string())
                                        } else if addr
                                            .call(ServerNotification::NewMessage(
                                                message.hub_id,
                                                message.channel_id,
                                                message.id,
                                                signed_message,
                                                message,
                                            ))
                                            .await
                                            .is_ok()
                                        {
                                            ServerMessage::Success
                                        } else {
                                            ServerMessage::Error(internal_message_error.clone())
                                        }
                                    }
                                }
                            } else {
                                ServerMessage::InvalidCommand
                            }
                        } else {
                            ServerMessage::NotSigned
                        };
                        let message = OpenPGPMessage::new_literal(
                            "",
                            serde_json::to_string(&raw_response)?.as_str(),
                        )
                        .sign(
                            &server_keys.secret_key,
                            String::new,
                            HashAlgorithm::SHA2_256,
                        )?
                        .compress(CompressionAlgorithm::ZIP)?;
                        out_arc
                            .lock()
                            .await
                            .send(WebSocketMessage::text(message.to_armored_string(None)?))
                            .await?;
                    }
                }
                return Ok(());
            }
        }
    }
    Err(Error::WsNotAuthenticated)
}