riz/
storage.rs

1use std::{collections::HashMap, env, fs, net::Ipv4Addr, path::Path};
2
3use ipnet::Ipv4Net;
4use log::{error, warn};
5use uuid::Uuid;
6
7use crate::{
8    models::{Light, LightingResponse, Room},
9    Error, Result,
10};
11
12const STORAGE_ENV_KEY: &str = "RIZ_STORAGE_PATH";
13
14/// Reads and syncs with `rooms.json` in `RIZ_STORAGE_PATH` (env var)
15///
16/// Expected to be wrapped by a [std::sync::Mutex], then wrapped
17/// with a [actix_web::web::Data], and cloned to each request
18///
19/// NB: All `&mut` methods update the contents of `rooms.json`
20///
21/// # Examples
22///
23/// ```
24/// use std::sync::Mutex;
25/// use actix_web::web::Data;
26/// use riz::Storage;
27///
28/// let storage = Data::new(Mutex::new(Storage::new()));
29/// ```
30///
31#[derive(Default, Debug)]
32pub struct Storage {
33    rooms: HashMap<Uuid, Room>,
34    file_path: String,
35}
36
37impl Storage {
38    /// Create a new Stoage object (should only do this once)
39    pub fn new() -> Self {
40        let file_path = Self::get_storage_path();
41        let mut rooms = Self::read_json(&file_path);
42
43        for (id, room) in rooms.iter_mut() {
44            room.link(id);
45        }
46
47        Storage { rooms, file_path }
48    }
49
50    fn read_json(file_path: &str) -> HashMap<Uuid, Room> {
51        match fs::read_to_string(file_path) {
52            Ok(content) => {
53                if let Ok(prev) = serde_json::from_str(&content) {
54                    prev
55                } else {
56                    warn!("Failed to decode previous data");
57                    HashMap::new()
58                }
59            }
60            Err(_) => HashMap::new(),
61        }
62    }
63
64    fn get_storage_path() -> String {
65        let path = env::var(STORAGE_ENV_KEY).unwrap_or(".".to_string());
66        if let Some(file_path) = Path::new(&path).join("rooms.json").to_str() {
67            file_path
68        } else {
69            warn!("Invalid storage file path: {}", path);
70            "./rooms.json"
71        }
72        .to_string()
73    }
74
75    /// Write the contents of self.rooms to rooms.json
76    fn write(&self) {
77        if let Ok(contents) = serde_json::to_string(&self.rooms) {
78            if let Err(e) = fs::write(&self.file_path, contents) {
79                error!("Failed to write JSON: {:?}", e);
80            }
81        } else {
82            error!("Failed to dump JSON");
83        }
84    }
85
86    /// Create a new room
87    ///
88    /// # Errors
89    ///   [Error::InvalidIP] if any light in the new room has an invalid IP address
90    ///
91    pub fn new_room(&mut self, room: Room) -> Result<Uuid> {
92        let mut id = Uuid::new_v4();
93        while self.rooms.contains_key(&id) {
94            id = Uuid::new_v4();
95        }
96
97        // ensure any lights ips in the new room are valid (should be empty...)
98        self.validate_room(&room)?;
99
100        let mut room = room;
101        room.link(&id);
102
103        self.rooms.insert(id, room);
104        self.write();
105        Ok(id)
106    }
107
108    /// Create a new light in the room
109    pub fn new_light(&mut self, room: &Uuid, light: Light) -> Result<Uuid> {
110        self.validate_light(&light)?;
111        if let Some(entry) = self.rooms.get_mut(room) {
112            let id = entry.new_light(light)?;
113            self.write();
114            Ok(id)
115        } else {
116            Err(Error::RoomNotFound(*room))
117        }
118    }
119
120    /// Read a room by ID (returns clone)
121    pub fn read(&self, room: &Uuid) -> Option<Room> {
122        self.rooms.get(room).cloned()
123    }
124
125    /// Updates non-light attributes (currently just name)
126    pub fn update_room(&mut self, id: &Uuid, room: &Room) -> Result<()> {
127        if let Some(entry) = self.rooms.get_mut(id) {
128            if entry.update(room) {
129                self.write();
130                Ok(())
131            } else {
132                Err(Error::NoChangeRoom(*id))
133            }
134        } else {
135            Err(Error::RoomNotFound(*id))
136        }
137    }
138
139    /// Update non-lighting attributes of the light in the room (name, ip)
140    pub fn update_light(&mut self, id: &Uuid, light_id: &Uuid, light: &Light) -> Result<()> {
141        if let Some(room) = self.rooms.get_mut(id) {
142            room.update_light(light_id, light)?;
143            self.write();
144            Ok(())
145        } else {
146            Err(Error::light_not_found(id, light_id))
147        }
148    }
149
150    /// Remove a room
151    pub fn delete_room(&mut self, room: &Uuid) -> Result<()> {
152        match self.rooms.remove(room) {
153            Some(_) => {
154                self.write();
155                Ok(())
156            }
157            None => Err(Error::RoomNotFound(*room)),
158        }
159    }
160
161    /// Remove a light in a room
162    pub fn delete_light(&mut self, room: &Uuid, light: &Uuid) -> Result<()> {
163        match self.rooms.get_mut(room) {
164            Some(rm) => {
165                rm.delete_light(light)?;
166                self.write();
167                Ok(())
168            }
169            None => Err(Error::RoomNotFound(*room)),
170        }
171    }
172
173    /// List room IDs
174    pub fn list(&self) -> Result<Vec<&Uuid>> {
175        Ok(self.rooms.keys().collect())
176    }
177
178    /// Process the response of a lighting request
179    pub fn process_reply(&mut self, resp: &LightingResponse) {
180        let mut any_update = false;
181        for room in self.rooms.values_mut() {
182            let room_update = room.process_reply(resp);
183            any_update = any_update || room_update;
184        }
185
186        if any_update {
187            self.write();
188        }
189    }
190
191    /// Check if all lights in the room are valid and unique
192    fn validate_room(&self, room: &Room) -> Result<()> {
193        if let Some(lights) = room.list() {
194            for light_id in lights {
195                if let Some(light) = room.read(light_id) {
196                    self.validate_light(light)?;
197                }
198            }
199        }
200        Ok(())
201    }
202
203    /// Check if the light's ip is valid and unqiue
204    fn validate_light(&self, light: &Light) -> Result<()> {
205        self.validate_ip(&light.ip())
206    }
207
208    /// Check if the IP is valid and unique
209    fn validate_ip(&self, ip: &Ipv4Addr) -> Result<()> {
210        // || ip.is_benchmarking() can be added once stable
211        if ip.is_documentation() {
212            return self.unique_ip(ip);
213        }
214
215        if ip.is_link_local() || ip.is_loopback() {
216            return Err(Error::invalid_ip(ip, "a local ip"));
217        }
218
219        if ip.is_unspecified() {
220            return Err(Error::invalid_ip(ip, "unspecified"));
221        }
222
223        if ip.is_broadcast() {
224            return Err(Error::invalid_ip(ip, "a broadcast address"));
225        }
226
227        if ip.is_multicast() {
228            return Err(Error::invalid_ip(ip, "a multicast address"));
229        }
230
231        // can add when when stable
232        // if ip.is_reserved() {
233        //     return Err(Error::invalid_ip(ip, "a reserved ip"));
234        // }
235
236        if !ip.is_private() {
237            return Err(Error::invalid_ip(ip, "a public ip"));
238        }
239
240        // check if this IP is a subnet broadcast or network address
241        if let Some(net) = classful_network(ip) {
242            // NB: because we are probably behind docker, we can't
243            //     really tell what our local network is, without
244            //     probing around... which we probably shouldn't do.
245            //     otherwise, it would be possible to limit the IPs
246            //     to the actual connected networks. but as we've
247            //     already limited them to private IPs this is fine.
248            //     it won't correctly pick up classless setups though,
249            //     again because docker. ¯\_(ツ)_/¯ oh well
250
251            if *ip == net.network() {
252                return Err(Error::invalid_ip(ip, "the subnet's network address"));
253            }
254
255            if *ip == net.broadcast() {
256                return Err(Error::invalid_ip(ip, "the subnet's broadcast address"));
257            }
258
259            return self.unique_ip(ip);
260        }
261
262        // this can't actually happen...
263        Err(Error::invalid_ip(ip, "unknown"))
264    }
265
266    /// Check if the IP is unique
267    fn unique_ip(&self, ip: &Ipv4Addr) -> Result<()> {
268        for room in self.rooms.values() {
269            if let Some(lights) = room.list() {
270                for light_id in lights {
271                    if let Some(light) = room.read(light_id) {
272                        if *ip == light.ip() {
273                            return Err(Error::invalid_ip(ip, "already known"));
274                        }
275                    }
276                }
277            }
278        }
279        Ok(())
280    }
281}
282
283fn classful_network(ip: &Ipv4Addr) -> Option<Ipv4Net> {
284    match ip.octets()[0] {
285        (1..=126) => Some(Ipv4Net::new(*ip, 8).unwrap()),
286        (128..=191) => Some(Ipv4Net::new(*ip, 16).unwrap()),
287        (192..=223) => Some(Ipv4Net::new(*ip, 24).unwrap()),
288        _ => None,
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use rand::{distributions::Alphanumeric, Rng};
295    use std::{env, panic, str::FromStr, vec};
296
297    use super::*;
298
299    /// Run the closure test with a new temp test storage, and clean up after
300    fn test_storage<T>(test: T) -> ()
301    where
302        T: FnOnce() -> () + panic::UnwindSafe,
303    {
304        let s: String = rand::thread_rng()
305            .sample_iter(&Alphanumeric)
306            .take(12)
307            .map(char::from)
308            .collect();
309
310        let mut base = env::temp_dir();
311        base.push(s);
312        env::set_var(STORAGE_ENV_KEY, base.clone());
313
314        let res = panic::catch_unwind(|| test());
315
316        fs::remove_dir_all(base).unwrap_or_else(|_| error!("failed to clean up tmp storage"));
317
318        assert!(res.is_ok())
319    }
320
321    #[test]
322    fn unique_ips_same_room() {
323        let mut room = Room::new("test");
324        let ip = Ipv4Addr::from_str("192.0.2.3").unwrap();
325        let light = Light::new(ip, Some("bulb"));
326
327        assert!(room.new_light(light.clone()).is_ok());
328        let res = room.new_light(light);
329
330        assert_eq!(res, Err(Error::invalid_ip(&ip, "already known")));
331    }
332
333    #[test]
334    fn unique_ips_different_rooms() {
335        test_storage(|| {
336            let ip = Ipv4Addr::from_str("192.0.2.3").unwrap();
337
338            let mut room = Room::new("test");
339            let light = Light::new(ip, Some("bulb"));
340            room.new_light(light.clone()).unwrap();
341
342            let mut room2 = Room::new("test");
343            room2.new_light(light).unwrap();
344
345            let mut storage = Storage::new();
346            assert!(storage.new_room(room).is_ok());
347
348            let res = storage.new_room(room2);
349            assert_eq!(res, Err(Error::invalid_ip(&ip, "already known")));
350        })
351    }
352
353    #[test]
354    fn new_light_unique_ip() {
355        test_storage(|| {
356            let ip = Ipv4Addr::from_str("192.0.2.3").unwrap();
357
358            let mut room = Room::new("test");
359            let light = Light::new(ip, Some("bulb"));
360            room.new_light(light.clone()).unwrap();
361
362            let mut storage = Storage::new();
363            let room_id = storage.new_room(room).unwrap();
364
365            let res = storage.new_light(&room_id, light);
366            assert_eq!(res, Err(Error::invalid_ip(&ip, "already known")));
367        })
368    }
369
370    #[test]
371    fn invalid_ips_denied() {
372        test_storage(|| {
373            let tests = vec![
374                ("8.8.8.8", "a public ip"),
375                ("127.0.0.1", "a local ip"),
376                ("0.0.0.0", "unspecified"),
377                ("255.255.255.255", "a broadcast address"),
378                ("224.224.224.224", "a multicast address"),
379                // ("240.240.240.240", "a reserved ip"),
380                ("192.168.1.0", "the subnet's network address"),
381                ("172.16.255.255", "the subnet's broadcast address"),
382            ];
383
384            for (ip, reason) in tests {
385                let ip = Ipv4Addr::from_str(ip).unwrap();
386
387                let mut room = Room::new("test");
388                let light = Light::new(ip, None);
389                room.new_light(light).unwrap();
390
391                let mut storage = Storage::new();
392                let res = storage.new_room(room);
393
394                assert_eq!(res, Err(Error::invalid_ip(&ip, reason)));
395            }
396        })
397    }
398
399    #[test]
400    fn valid_ips_allowed() {
401        test_storage(|| {
402            let tests = vec!["10.1.2.3", "192.168.1.25", "172.16.0.17"];
403
404            for ip in tests {
405                let ip = Ipv4Addr::from_str(ip).unwrap();
406
407                let mut room = Room::new("test");
408                let light = Light::new(ip, None);
409                room.new_light(light).unwrap();
410
411                let mut storage = Storage::new();
412                let res = storage.new_room(room);
413
414                assert!(res.is_ok());
415            }
416        })
417    }
418}