1use crate::migration::{maybe_migrate, MigrationResult};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::net::Ipv4Addr;
5use std::path::PathBuf;
6use std::sync::Arc;
7use thiserror::Error;
8use tokio::sync::RwLock;
9use tracing::{info, warn};
10
11#[derive(Error, Debug)]
12pub enum LeaseError {
13 #[error("Lease not found")]
14 NotFound,
15
16 #[error("Client ID mismatch")]
17 ClientMismatch,
18
19 #[error("CSV error: {0}")]
20 CsvError(#[from] csv::Error),
21
22 #[error("IO error: {0}")]
23 IoError(#[from] std::io::Error),
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct Lease {
28 #[serde(with = "ipv4_serde")]
29 pub ip: Ipv4Addr,
30 #[serde(with = "mac_serde")]
31 pub client_id: Vec<u8>,
32 pub leased: bool,
33 pub expires_at: i64,
34 pub network: i64,
35 pub probation: bool,
36}
37
38mod ipv4_serde {
39 use serde::{Deserialize, Deserializer, Serializer};
40 use std::net::Ipv4Addr;
41
42 pub fn serialize<S>(ip: &Ipv4Addr, serializer: S) -> Result<S::Ok, S::Error>
43 where
44 S: Serializer,
45 {
46 serializer.serialize_str(&ip.to_string())
47 }
48
49 pub fn deserialize<'de, D>(deserializer: D) -> Result<Ipv4Addr, D::Error>
50 where
51 D: Deserializer<'de>,
52 {
53 let s = String::deserialize(deserializer)?;
54 s.parse().map_err(serde::de::Error::custom)
55 }
56}
57
58mod mac_serde {
59 use serde::{Deserialize, Deserializer, Serializer};
60
61 pub fn serialize<S>(bytes: &Vec<u8>, serializer: S) -> Result<S::Ok, S::Error>
62 where
63 S: Serializer,
64 {
65 let hex_str: String = bytes
66 .iter()
67 .map(|b| format!("{:02x}", b))
68 .collect::<Vec<_>>()
69 .join(":");
70 serializer.serialize_str(&hex_str)
71 }
72
73 pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
74 where
75 D: Deserializer<'de>,
76 {
77 let s = String::deserialize(deserializer)?;
78 if s.is_empty() {
79 return Ok(Vec::new());
80 }
81 s.split(':')
82 .map(|part| u8::from_str_radix(part, 16).map_err(serde::de::Error::custom))
83 .collect()
84 }
85}
86
87#[derive(Clone)]
88pub struct LeaseStore {
89 leases: Arc<RwLock<HashMap<Ipv4Addr, Lease>>>,
90 file_path: PathBuf,
91}
92
93impl LeaseStore {
94 pub async fn new(file_path: PathBuf) -> Result<Self, LeaseError> {
95 match maybe_migrate(&file_path) {
97 MigrationResult::Migrated(count) => {
98 info!("Migrated {} leases from SQLite to CSV", count);
99 }
100 MigrationResult::Skipped => {
101 info!("Migration skipped: CSV file already exists");
102 }
103 MigrationResult::NoDatabase => {
104 info!("No SQLite database found, starting fresh");
105 }
106 MigrationResult::Failed(err) => {
107 warn!("Migration failed: {}. Starting with empty lease store.", err);
108 }
109 }
110
111 let leases = if file_path.exists() {
112 Self::load_from_csv(&file_path)?
113 } else {
114 HashMap::new()
115 };
116
117 Ok(Self {
118 leases: Arc::new(RwLock::new(leases)),
119 file_path,
120 })
121 }
122
123 fn load_from_csv(path: &PathBuf) -> Result<HashMap<Ipv4Addr, Lease>, LeaseError> {
124 let mut reader = csv::Reader::from_path(path)?;
125 let mut leases = HashMap::new();
126 for result in reader.deserialize() {
127 let lease: Lease = result?;
128 leases.insert(lease.ip, lease);
129 }
130 Ok(leases)
131 }
132
133 async fn flush(&self) -> Result<(), LeaseError> {
134 let leases = self.leases.read().await;
135 let mut writer = csv::Writer::from_path(&self.file_path)?;
136 for lease in leases.values() {
137 writer.serialize(lease)?;
138 }
139 writer.flush()?;
140 Ok(())
141 }
142
143 pub async fn insert_lease(&self, lease: Lease) -> Result<(), LeaseError> {
144 {
145 let mut leases = self.leases.write().await;
146 leases.insert(lease.ip, lease);
147 }
148 self.flush().await
149 }
150
151 pub async fn get_lease_by_ip(&self, ip: &Ipv4Addr) -> Result<Lease, LeaseError> {
152 let leases = self.leases.read().await;
153 leases
154 .get(ip)
155 .filter(|lease| lease.leased)
156 .cloned()
157 .ok_or(LeaseError::NotFound)
158 }
159
160 pub async fn get_ip_from_client_id(&self, client_id: &Vec<u8>) -> Result<Ipv4Addr, LeaseError> {
161 let leases = self.leases.read().await;
162 leases
163 .values()
164 .find(|lease| &lease.client_id == client_id)
165 .map(|lease| lease.ip)
166 .ok_or(LeaseError::NotFound)
167 }
168
169 pub async fn get_all_leases(&self) -> Result<Vec<Lease>, LeaseError> {
170 let leases = self.leases.read().await;
171 Ok(leases.values().cloned().collect())
172 }
173
174 pub async fn get_valid_leases(&self) -> Result<Vec<Lease>, LeaseError> {
175 let leases = self.leases.read().await;
176 let now = std::time::SystemTime::now()
177 .duration_since(std::time::UNIX_EPOCH)
178 .unwrap()
179 .as_secs() as i64;
180
181 Ok(leases
182 .values()
183 .filter(|lease| lease.leased && !lease.probation && lease.expires_at > now)
184 .cloned()
185 .collect())
186 }
187
188 pub async fn is_ip_assigned(&self, ip: Ipv4Addr) -> Result<bool, LeaseError> {
189 let leases = self.leases.read().await;
190 Ok(leases
191 .get(&ip)
192 .map(|lease| lease.leased || lease.probation)
193 .unwrap_or(false))
194 }
195
196 pub async fn update_lease_expiry(&self, ip: Ipv4Addr, expires_at: i64) -> Result<(), LeaseError> {
197 {
198 let mut leases = self.leases.write().await;
199 if let Some(lease) = leases.get_mut(&ip) {
200 lease.expires_at = expires_at;
201 } else {
202 return Err(LeaseError::NotFound);
203 }
204 }
205 self.flush().await
206 }
207
208 pub async fn mark_ip_declined(&self, ip: Ipv4Addr) -> Result<(), LeaseError> {
209 {
210 let mut leases = self.leases.write().await;
211 if let Some(lease) = leases.get_mut(&ip) {
212 lease.probation = true;
213 lease.leased = false;
214 } else {
215 let lease = Lease {
217 ip,
218 client_id: Vec::new(),
219 leased: false,
220 expires_at: 0,
221 network: 0,
222 probation: true,
223 };
224 leases.insert(ip, lease);
225 }
226 }
227 self.flush().await
228 }
229
230 pub async fn release_lease(&self, ip: Ipv4Addr, client_id: &Vec<u8>) -> Result<(), LeaseError> {
231 {
232 let mut leases = self.leases.write().await;
233 if let Some(lease) = leases.get(&ip) {
234 if &lease.client_id == client_id {
235 leases.remove(&ip);
236 } else {
237 return Err(LeaseError::ClientMismatch);
238 }
239 }
240 }
241 self.flush().await
242 }
243}