1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
//! Server configuration.
use mpc_relay_protocol::{decode_keypair, hex, snow::Keypair};
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use tokio::fs;

use crate::{Error, Result};

/// Configuration for the web server.
#[derive(Default, Serialize, Deserialize)]
#[serde(default)]
pub struct ServerConfig {
    /// Path to the server key.
    pub key: PathBuf,

    /// Settings for session management.
    pub session: SessionConfig,

    /// Configuration for TLS encryption.
    pub tls: Option<TlsConfig>,

    /// Allow access to clients with these
    /// public keys.
    pub allow: Option<Vec<AccessKey>>,

    /// Deny access to clients with these
    /// public keys.
    pub deny: Option<Vec<AccessKey>>,
}

impl ServerConfig {
    /// Determine if a public key is allowed access.
    pub fn is_allowed_access(&self, key: impl AsRef<[u8]>) -> bool {
        //let restricted = self.allow.is_some() || self.deny.is_some();

        if let Some(deny) = &self.deny {
            if deny.iter().any(|k| k.public_key == key.as_ref()) {
                return false;
            }
        }

        if let Some(allow) = &self.allow {
            if allow.iter().any(|k| k.public_key == key.as_ref()) {
                return true;
            }
            false
        } else {
            true
        }
    }
}

#[derive(Default, Serialize, Deserialize)]
pub struct AccessKey {
    #[serde(with = "hex::serde")]
    public_key: Vec<u8>,
}

/// Certificate and key for TLS.
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct TlsConfig {
    /// Path to the certificate.
    pub cert: PathBuf,
    /// Path to the certificate key file.
    pub key: PathBuf,
}

/// Configuration for server sessions.
#[derive(Debug, Serialize, Deserialize)]
pub struct SessionConfig {
    /// Timeout for sessions in seconds.
    pub timeout: u64,

    /// Interval in seconds to reap expired sessions.
    ///
    /// Default is every 30 minutes.
    pub reap_interval: u64,
}

impl Default for SessionConfig {
    fn default() -> Self {
        Self {
            timeout: 300,
            reap_interval: 1800,
        }
    }
}

impl ServerConfig {
    /// Load a server config from a file path.
    pub async fn load<P: AsRef<Path>>(
        path: P,
    ) -> Result<(Self, Keypair)> {
        if !fs::try_exists(path.as_ref()).await? {
            return Err(Error::NotFile(path.as_ref().to_path_buf()));
        }

        let contents = fs::read_to_string(path.as_ref()).await?;
        let mut config: ServerConfig = toml::from_str(&contents)?;

        if config.key == PathBuf::default() {
            return Err(Error::KeyFileRequired);
        }

        let dir = Self::directory(path.as_ref())?;

        if config.key.is_relative() {
            config.key = dir.join(&config.key).canonicalize()?;
        }

        if !fs::try_exists(&config.key).await? {
            return Err(Error::KeyNotFound(config.key.clone()));
        }

        let contents = fs::read_to_string(&config.key).await?;
        let keypair = decode_keypair(contents)?;

        if let Some(tls) = config.tls.as_mut() {
            if tls.cert.is_relative() {
                tls.cert = dir.join(&tls.cert).canonicalize()?;
            }
            if tls.key.is_relative() {
                tls.key = dir.join(&tls.key).canonicalize()?;
            }
        }

        Ok((config, keypair))
    }

    /// Parent directory of the configuration file.
    fn directory(file: impl AsRef<Path>) -> Result<PathBuf> {
        file.as_ref()
            .parent()
            .map(|p| p.to_path_buf())
            .ok_or_else(|| Error::NoParentDir)
    }
}