Skip to main content

ic_sis/
settings.rs

1use candid::Principal;
2use url::Url;
3
4const DEFAULT_SCHEME: &str = "https";
5const DEFAULT_STATEMENT: &str = "Sign in with Sui";
6const DEFAULT_NETWORK: &str = "mainnet"; // Sui mainnet
7const DEFAULT_SIGN_IN_EXPIRES_IN: u64 = 60 * 5 * 1_000_000_000; // 5 minutes
8const DEFAULT_SESSION_EXPIRES_IN: u64 = 30 * 60 * 1_000_000_000; // 30 minutes
9
10#[derive(Debug, Clone, PartialEq)]
11pub enum RuntimeFeature {
12    IncludeUriInSeed,
13}
14
15#[derive(Default, Debug, Clone)]
16pub struct Settings {
17    pub domain: String,
18
19    pub uri: String,
20
21    pub salt: String,
22
23    pub network: String,
24
25    pub scheme: String,
26
27    pub statement: String,
28
29    pub sign_in_expires_in: u64,
30
31    pub session_expires_in: u64,
32
33    pub targets: Option<Vec<Principal>>,
34
35    pub runtime_features: Option<Vec<RuntimeFeature>>,
36}
37
38pub struct SettingsBuilder {
39    settings: Settings,
40}
41
42impl SettingsBuilder {
43    pub fn new<S: Into<String>, T: Into<String>, U: Into<String>>(
44        domain: S,
45        uri: T,
46        salt: U,
47    ) -> Self {
48        SettingsBuilder {
49            settings: Settings {
50                domain: domain.into(),
51                uri: uri.into(),
52                salt: salt.into(),
53                network: DEFAULT_NETWORK.to_string(),
54                scheme: DEFAULT_SCHEME.to_string(),
55                statement: DEFAULT_STATEMENT.to_string(),
56                sign_in_expires_in: DEFAULT_SIGN_IN_EXPIRES_IN,
57                session_expires_in: DEFAULT_SESSION_EXPIRES_IN,
58                targets: None,
59                runtime_features: None,
60            },
61        }
62    }
63
64    pub fn network<S: Into<String>>(mut self, network: S) -> Self {
65        self.settings.network = network.into();
66        self
67    }
68
69    pub fn scheme<S: Into<String>>(mut self, scheme: S) -> Self {
70        self.settings.scheme = scheme.into();
71        self
72    }
73
74    pub fn statement<S: Into<String>>(mut self, statement: S) -> Self {
75        self.settings.statement = statement.into();
76        self
77    }
78
79    pub fn sign_in_expires_in(mut self, expires_in: u64) -> Self {
80        self.settings.sign_in_expires_in = expires_in;
81        self
82    }
83
84    pub fn session_expires_in(mut self, expires_in: u64) -> Self {
85        self.settings.session_expires_in = expires_in;
86        self
87    }
88
89    pub fn targets(mut self, targets: Vec<Principal>) -> Self {
90        self.settings.targets = Some(targets);
91        self
92    }
93
94    pub fn runtime_features(mut self, features: Vec<RuntimeFeature>) -> Self {
95        self.settings.runtime_features = Some(features);
96        self
97    }
98
99    pub fn build(self) -> Result<Settings, String> {
100        validate_domain(&self.settings.scheme, &self.settings.domain)?;
101        validate_uri(&self.settings.uri)?;
102        validate_salt(&self.settings.salt)?;
103        validate_network(&self.settings.network)?;
104        validate_scheme(&self.settings.scheme)?;
105        validate_statement(&self.settings.statement)?;
106        validate_sign_in_expires_in(self.settings.sign_in_expires_in)?;
107        validate_session_expires_in(self.settings.session_expires_in)?;
108        validate_targets(&self.settings.targets)?;
109
110        Ok(self.settings)
111    }
112}
113
114fn validate_domain(scheme: &str, domain: &str) -> Result<String, String> {
115    let url_str = format!("{}://{}", scheme, domain);
116    let parsed_url = Url::parse(&url_str).map_err(|_| String::from("Invalid domain"))?;
117    if !parsed_url.has_authority() {
118        Err(String::from("Invalid domain"))
119    } else {
120        Ok(parsed_url.host_str().unwrap().to_string())
121    }
122}
123
124fn validate_uri(uri: &str) -> Result<String, String> {
125    let parsed_uri = Url::parse(uri).map_err(|_| String::from("Invalid URI"))?;
126    if !parsed_uri.has_host() {
127        Err(String::from("Invalid URI"))
128    } else {
129        Ok(uri.to_string())
130    }
131}
132
133fn validate_salt(salt: &str) -> Result<String, String> {
134    if salt.is_empty() {
135        return Err(String::from("Salt cannot be empty"));
136    }
137    if salt.chars().any(|c| !c.is_ascii() || !c.is_ascii_graphic()) {
138        return Err(String::from("Invalid salt"));
139    }
140    Ok(salt.to_string())
141}
142
143fn validate_network(network: &str) -> Result<String, String> {
144    if network.is_empty() {
145        return Err(String::from("Network cannot be empty"));
146    }
147    
148    // Validate the network is one of the known Sui networks
149    match network {
150        "mainnet" | "testnet" | "devnet" | "localnet" => Ok(network.to_string()),
151        _ => Err(String::from("Unknown Sui network. Use 'mainnet', 'testnet', 'devnet', or 'localnet'")),
152    }
153}
154
155fn validate_scheme(scheme: &str) -> Result<String, String> {
156    if scheme == "http" || scheme == "https" {
157        return Ok(scheme.to_string());
158    }
159    Err(String::from("Invalid scheme"))
160}
161
162fn validate_statement(statement: &str) -> Result<String, String> {
163    if statement.contains('\n') {
164        return Err(String::from("Invalid statement"));
165    }
166    Ok(statement.to_string())
167}
168
169fn validate_sign_in_expires_in(expires_in: u64) -> Result<u64, String> {
170    if expires_in == 0 {
171        return Err(String::from("Sign in expires in must be greater than 0"));
172    }
173    Ok(expires_in)
174}
175
176fn validate_session_expires_in(expires_in: u64) -> Result<u64, String> {
177    if expires_in == 0 {
178        return Err(String::from("Session expires in must be greater than 0"));
179    }
180    Ok(expires_in)
181}
182
183fn validate_targets(targets: &Option<Vec<Principal>>) -> Result<Option<Vec<Principal>>, String> {
184    if let Some(targets) = targets {
185        if targets.is_empty() {
186            return Err(String::from("Targets cannot be empty"));
187        }
188
189        if targets.len() > 1000 {
190            return Err(String::from("Too many targets"));
191        }
192
193        let mut targets_clone = targets.clone();
194        targets_clone.sort();
195        targets_clone.dedup();
196        if targets_clone.len() != targets.len() {
197            return Err(String::from("Duplicate targets are not allowed"));
198        }
199    }
200    Ok(targets.clone())
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use candid::Principal;
207
208    #[test]
209    fn test_successful_settings_creation_defaults() {
210        let builder = SettingsBuilder::new("example.com", "http://example.com", "some_salt");
211        let settings = builder
212            .build()
213            .expect("Failed to create settings with defaults");
214        assert_eq!(settings.domain, "example.com");
215        assert_eq!(settings.uri, "http://example.com");
216        assert_eq!(settings.salt, "some_salt");
217        assert_eq!(settings.network, DEFAULT_NETWORK);
218        assert_eq!(settings.scheme, DEFAULT_SCHEME);
219        assert_eq!(settings.statement, DEFAULT_STATEMENT);
220        assert_eq!(settings.sign_in_expires_in, DEFAULT_SIGN_IN_EXPIRES_IN);
221        assert_eq!(settings.session_expires_in, DEFAULT_SESSION_EXPIRES_IN);
222        assert!(settings.targets.is_none());
223    }
224
225    #[test]
226    fn test_successful_settings_creation_custom() {
227        let targets = vec![Principal::anonymous()];
228        let builder = SettingsBuilder::new("example.com", "http://example.com", "some_salt")
229            .network("testnet")
230            .scheme("http")
231            .statement("Custom statement")
232            .sign_in_expires_in(10_000_000_000)
233            .session_expires_in(20_000_000_000)
234            .targets(targets.clone());
235        let settings = builder
236            .build()
237            .expect("Failed to create settings with custom values");
238        assert_eq!(settings.network, "testnet");
239        assert_eq!(settings.scheme, "http");
240        assert_eq!(settings.statement, "Custom statement");
241        assert_eq!(settings.sign_in_expires_in, 10_000_000_000);
242        assert_eq!(settings.session_expires_in, 20_000_000_000);
243        assert_eq!(settings.targets, Some(targets));
244    }
245
246    #[test]
247    fn test_invalid_network() {
248        let builder = SettingsBuilder::new("example.com", "http://example.com", "some_salt")
249            .network("invalid_network");
250        assert!(builder.build().is_err());
251    }
252
253    #[test]
254    fn test_empty_salt() {
255        let builder = SettingsBuilder::new("example.com", "http://example.com", "");
256        assert!(builder.build().is_err());
257    }
258
259    #[test]
260    fn test_invalid_scheme() {
261        let builder =
262            SettingsBuilder::new("example.com", "http://example.com", "some_salt").scheme("ftp");
263        assert!(builder.build().is_err());
264    }
265}