mini_dhcp/
db.rs

1use crate::migration::{maybe_migrate, MigrationResult};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::net::Ipv4Addr;
5use std::path::PathBuf;
6use std::sync::Arc;
7use thiserror::Error;
8use tokio::sync::RwLock;
9use tracing::{info, warn};
10
11#[derive(Error, Debug)]
12pub enum LeaseError {
13    #[error("Lease not found")]
14    NotFound,
15
16    #[error("Client ID mismatch")]
17    ClientMismatch,
18
19    #[error("CSV error: {0}")]
20    CsvError(#[from] csv::Error),
21
22    #[error("IO error: {0}")]
23    IoError(#[from] std::io::Error),
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct Lease {
28    #[serde(with = "ipv4_serde")]
29    pub ip: Ipv4Addr,
30    #[serde(with = "mac_serde")]
31    pub client_id: Vec<u8>,
32    pub leased: bool,
33    pub expires_at: i64,
34    pub network: i64,
35    pub probation: bool,
36}
37
38mod ipv4_serde {
39    use serde::{Deserialize, Deserializer, Serializer};
40    use std::net::Ipv4Addr;
41
42    pub fn serialize<S>(ip: &Ipv4Addr, serializer: S) -> Result<S::Ok, S::Error>
43    where
44        S: Serializer,
45    {
46        serializer.serialize_str(&ip.to_string())
47    }
48
49    pub fn deserialize<'de, D>(deserializer: D) -> Result<Ipv4Addr, D::Error>
50    where
51        D: Deserializer<'de>,
52    {
53        let s = String::deserialize(deserializer)?;
54        s.parse().map_err(serde::de::Error::custom)
55    }
56}
57
58mod mac_serde {
59    use serde::{Deserialize, Deserializer, Serializer};
60
61    pub fn serialize<S>(bytes: &Vec<u8>, serializer: S) -> Result<S::Ok, S::Error>
62    where
63        S: Serializer,
64    {
65        let hex_str: String = bytes
66            .iter()
67            .map(|b| format!("{:02x}", b))
68            .collect::<Vec<_>>()
69            .join(":");
70        serializer.serialize_str(&hex_str)
71    }
72
73    pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
74    where
75        D: Deserializer<'de>,
76    {
77        let s = String::deserialize(deserializer)?;
78        if s.is_empty() {
79            return Ok(Vec::new());
80        }
81        s.split(':')
82            .map(|part| u8::from_str_radix(part, 16).map_err(serde::de::Error::custom))
83            .collect()
84    }
85}
86
87#[derive(Clone)]
88pub struct LeaseStore {
89    leases: Arc<RwLock<HashMap<Ipv4Addr, Lease>>>,
90    file_path: PathBuf,
91}
92
93impl LeaseStore {
94    pub async fn new(file_path: PathBuf) -> Result<Self, LeaseError> {
95        // Attempt migration from SQLite if needed
96        match maybe_migrate(&file_path) {
97            MigrationResult::Migrated(count) => {
98                info!("Migrated {} leases from SQLite to CSV", count);
99            }
100            MigrationResult::Skipped => {
101                info!("Migration skipped: CSV file already exists");
102            }
103            MigrationResult::NoDatabase => {
104                info!("No SQLite database found, starting fresh");
105            }
106            MigrationResult::Failed(err) => {
107                warn!("Migration failed: {}. Starting with empty lease store.", err);
108            }
109        }
110
111        let leases = if file_path.exists() {
112            Self::load_from_csv(&file_path)?
113        } else {
114            HashMap::new()
115        };
116
117        Ok(Self {
118            leases: Arc::new(RwLock::new(leases)),
119            file_path,
120        })
121    }
122
123    fn load_from_csv(path: &PathBuf) -> Result<HashMap<Ipv4Addr, Lease>, LeaseError> {
124        let mut reader = csv::Reader::from_path(path)?;
125        let mut leases = HashMap::new();
126        for result in reader.deserialize() {
127            let lease: Lease = result?;
128            leases.insert(lease.ip, lease);
129        }
130        Ok(leases)
131    }
132
133    async fn flush(&self) -> Result<(), LeaseError> {
134        let leases = self.leases.read().await;
135        let mut writer = csv::Writer::from_path(&self.file_path)?;
136        for lease in leases.values() {
137            writer.serialize(lease)?;
138        }
139        writer.flush()?;
140        Ok(())
141    }
142
143    pub async fn insert_lease(&self, lease: Lease) -> Result<(), LeaseError> {
144        {
145            let mut leases = self.leases.write().await;
146            leases.insert(lease.ip, lease);
147        }
148        self.flush().await
149    }
150
151    pub async fn get_lease_by_ip(&self, ip: &Ipv4Addr) -> Result<Lease, LeaseError> {
152        let leases = self.leases.read().await;
153        leases
154            .get(ip)
155            .filter(|lease| lease.leased)
156            .cloned()
157            .ok_or(LeaseError::NotFound)
158    }
159
160    pub async fn get_ip_from_client_id(&self, client_id: &Vec<u8>) -> Result<Ipv4Addr, LeaseError> {
161        let leases = self.leases.read().await;
162        leases
163            .values()
164            .find(|lease| &lease.client_id == client_id)
165            .map(|lease| lease.ip)
166            .ok_or(LeaseError::NotFound)
167    }
168
169    pub async fn get_all_leases(&self) -> Result<Vec<Lease>, LeaseError> {
170        let leases = self.leases.read().await;
171        Ok(leases.values().cloned().collect())
172    }
173
174    pub async fn get_valid_leases(&self) -> Result<Vec<Lease>, LeaseError> {
175        let leases = self.leases.read().await;
176        let now = std::time::SystemTime::now()
177            .duration_since(std::time::UNIX_EPOCH)
178            .unwrap()
179            .as_secs() as i64;
180
181        Ok(leases
182            .values()
183            .filter(|lease| lease.leased && !lease.probation && lease.expires_at > now)
184            .cloned()
185            .collect())
186    }
187
188    pub async fn is_ip_assigned(&self, ip: Ipv4Addr) -> Result<bool, LeaseError> {
189        let leases = self.leases.read().await;
190        Ok(leases
191            .get(&ip)
192            .map(|lease| lease.leased || lease.probation)
193            .unwrap_or(false))
194    }
195
196    pub async fn update_lease_expiry(&self, ip: Ipv4Addr, expires_at: i64) -> Result<(), LeaseError> {
197        {
198            let mut leases = self.leases.write().await;
199            if let Some(lease) = leases.get_mut(&ip) {
200                lease.expires_at = expires_at;
201            } else {
202                return Err(LeaseError::NotFound);
203            }
204        }
205        self.flush().await
206    }
207
208    pub async fn mark_ip_declined(&self, ip: Ipv4Addr) -> Result<(), LeaseError> {
209        {
210            let mut leases = self.leases.write().await;
211            if let Some(lease) = leases.get_mut(&ip) {
212                lease.probation = true;
213                lease.leased = false;
214            } else {
215                // Create a new entry for declined IP if it doesn't exist
216                let lease = Lease {
217                    ip,
218                    client_id: Vec::new(),
219                    leased: false,
220                    expires_at: 0,
221                    network: 0,
222                    probation: true,
223                };
224                leases.insert(ip, lease);
225            }
226        }
227        self.flush().await
228    }
229
230    pub async fn release_lease(&self, ip: Ipv4Addr, client_id: &Vec<u8>) -> Result<(), LeaseError> {
231        {
232            let mut leases = self.leases.write().await;
233            if let Some(lease) = leases.get(&ip) {
234                if &lease.client_id == client_id {
235                    leases.remove(&ip);
236                } else {
237                    return Err(LeaseError::ClientMismatch);
238                }
239            }
240        }
241        self.flush().await
242    }
243}