1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::sync::RwLock;
4use tracing::{info, warn};
5
6use crate::qos::QoSHandler;
7use crate::spec_registry::MqttSpecRegistry;
8use crate::topics::TopicTree;
9
10#[derive(Debug, Clone, Copy, Default)]
12pub enum MqttVersion {
13 V3_1_1,
14 #[default]
15 V5_0,
16}
17
18#[derive(Debug, Clone)]
20pub struct MqttConfig {
21 pub port: u16,
22 pub host: String,
23 pub max_connections: usize,
24 pub max_packet_size: usize,
25 pub keep_alive_secs: u16,
26 pub version: MqttVersion,
27}
28
29impl Default for MqttConfig {
30 fn default() -> Self {
31 Self {
32 port: 1883,
33 host: "0.0.0.0".to_string(),
34 max_connections: 1000,
35 max_packet_size: 1024 * 1024, keep_alive_secs: 60,
37 version: MqttVersion::default(),
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
44pub struct ClientSession {
45 pub client_id: String,
46 pub subscriptions: HashMap<String, u8>, pub clean_session: bool,
48 pub connected_at: u64,
49 pub last_seen: u64,
50}
51
52#[derive(Debug)]
54pub struct ClientState {
55 pub session: ClientSession,
56 pub pending_messages: Vec<crate::qos::MessageState>, }
58
59pub struct MqttBroker {
61 config: MqttConfig,
62 topics: Arc<RwLock<TopicTree>>,
63 clients: Arc<RwLock<HashMap<String, ClientState>>>,
64 session_store: Arc<RwLock<HashMap<String, ClientSession>>>,
65 qos_handler: QoSHandler,
66 fixture_registry: Arc<RwLock<crate::fixtures::MqttFixtureRegistry>>,
67 next_packet_id: Arc<RwLock<u16>>,
68}
69
70impl MqttBroker {
71 pub fn new(config: MqttConfig, _spec_registry: Arc<MqttSpecRegistry>) -> Self {
72 Self {
73 config,
74 topics: Arc::new(RwLock::new(TopicTree::new())),
75 clients: Arc::new(RwLock::new(HashMap::new())),
76 session_store: Arc::new(RwLock::new(HashMap::new())),
77 qos_handler: QoSHandler::new(),
78 fixture_registry: Arc::new(RwLock::new(crate::fixtures::MqttFixtureRegistry::new())),
79 next_packet_id: Arc::new(RwLock::new(1)),
80 }
81 }
82
83 pub async fn client_connect(
85 &self,
86 client_id: &str,
87 clean_session: bool,
88 ) -> Result<(), Box<dyn std::error::Error>> {
89 let now = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?.as_secs();
90
91 let mut clients = self.clients.write().await;
92 let mut sessions = self.session_store.write().await;
93
94 if let Some(_existing_client) = clients.get(client_id) {
95 info!("Client {} already connected, updating session", client_id);
97 }
98
99 let session = if clean_session {
100 sessions.remove(client_id); ClientSession {
103 client_id: client_id.to_string(),
104 subscriptions: HashMap::new(),
105 clean_session: true,
106 connected_at: now,
107 last_seen: now,
108 }
109 } else {
110 if let Some(persistent_session) = sessions.get(client_id) {
112 let mut restored_session = persistent_session.clone();
113 restored_session.connected_at = now;
114 restored_session.last_seen = now;
115 restored_session.clean_session = false;
116 restored_session
117 } else {
118 ClientSession {
119 client_id: client_id.to_string(),
120 subscriptions: HashMap::new(),
121 clean_session: false,
122 connected_at: now,
123 last_seen: now,
124 }
125 }
126 };
127
128 let client_state = ClientState {
129 session: session.clone(),
130 pending_messages: Vec::new(),
131 };
132
133 clients.insert(client_id.to_string(), client_state);
134
135 info!("Client {} connected with clean_session: {}", client_id, clean_session);
142 Ok(())
143 }
144
145 pub async fn client_disconnect(
147 &self,
148 client_id: &str,
149 ) -> Result<(), Box<dyn std::error::Error>> {
150 let now = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?.as_secs();
151
152 let mut clients = self.clients.write().await;
153 let mut sessions = self.session_store.write().await;
154
155 if let Some(client_state) = clients.remove(client_id) {
156 let session = client_state.session;
157
158 if !session.clean_session {
159 let mut persistent_session = session.clone();
161 persistent_session.last_seen = now;
162 sessions.insert(client_id.to_string(), persistent_session);
163
164 info!("Persisted session for client {}", client_id);
165 } else {
166 let mut topics = self.topics.write().await;
168 for filter in session.subscriptions.keys() {
169 topics.unsubscribe(filter, client_id);
170 }
171
172 info!("Cleaned up session for client {}", client_id);
173 }
174 }
175
176 Ok(())
182 }
183
184 pub async fn client_subscribe(
186 &self,
187 client_id: &str,
188 topics: Vec<(String, u8)>,
189 ) -> Result<(), Box<dyn std::error::Error>> {
190 let mut clients = self.clients.write().await;
191 let mut broker_topics = self.topics.write().await;
192
193 if let Some(client_state) = clients.get_mut(client_id) {
194 for (filter, qos) in topics {
195 broker_topics.subscribe(&filter, qos, client_id);
196 client_state.session.subscriptions.insert(filter.clone(), qos);
197
198 let retained_messages = broker_topics.get_retained_for_filter(&filter);
200 for (topic, message) in retained_messages {
201 info!("Sending retained message for topic {} to client {}", topic, client_id);
202 let qos_level = crate::qos::QoS::from_u8(message.qos)
203 .unwrap_or(crate::qos::QoS::AtMostOnce);
204 if let Err(e) = self
205 .route_message_to_client(client_id, topic, &message.payload, qos_level)
206 .await
207 {
208 warn!("Failed to deliver retained message to client {}: {}", client_id, e);
209 }
210 }
211 }
212
213 if !client_state.session.clean_session {
215 let mut sessions = self.session_store.write().await;
216 if let Some(session) = sessions.get_mut(client_id) {
217 session.subscriptions.clone_from(&client_state.session.subscriptions);
218 }
219 }
220 }
221
222 Ok(())
223 }
224
225 pub async fn client_unsubscribe(
227 &self,
228 client_id: &str,
229 filters: Vec<String>,
230 ) -> Result<(), Box<dyn std::error::Error>> {
231 let mut clients = self.clients.write().await;
232 let mut broker_topics = self.topics.write().await;
233
234 if let Some(client_state) = clients.get_mut(client_id) {
235 for filter in filters {
236 broker_topics.unsubscribe(&filter, client_id);
237 client_state.session.subscriptions.remove(&filter);
238 }
239
240 if !client_state.session.clean_session {
242 let mut sessions = self.session_store.write().await;
243 if let Some(session) = sessions.get_mut(client_id) {
244 session.subscriptions.clone_from(&client_state.session.subscriptions);
245 }
246 }
247 }
248
249 Ok(())
250 }
251
252 pub fn config(&self) -> &MqttConfig {
254 &self.config
255 }
256
257 pub async fn get_active_topics(&self) -> Vec<String> {
259 let topics = self.topics.read().await;
260 let mut all_topics = topics.get_all_topic_filters();
261 all_topics.extend(topics.get_all_retained_topics());
262 all_topics.sort();
263 all_topics.dedup();
264 all_topics
265 }
266
267 pub async fn get_connected_clients(&self) -> Vec<String> {
269 let clients = self.clients.read().await;
270 clients.keys().cloned().collect()
271 }
272
273 pub async fn get_client_info(&self, client_id: &str) -> Option<ClientSession> {
275 let clients = self.clients.read().await;
276 clients.get(client_id).map(|state| state.session.clone())
277 }
278
279 pub async fn disconnect_client(
281 &self,
282 client_id: &str,
283 ) -> Result<(), Box<dyn std::error::Error>> {
284 self.client_disconnect(client_id).await
285 }
286
287 pub async fn get_topic_stats(&self) -> crate::topics::TopicStats {
289 let topics = self.topics.read().await;
290 topics.stats()
291 }
292
293 pub async fn next_packet_id(&self) -> u16 {
295 let mut packet_id = self.next_packet_id.write().await;
296 let id = *packet_id;
297 *packet_id = packet_id.wrapping_add(1);
298 if *packet_id == 0 {
299 *packet_id = 1; }
301 id
302 }
303
304 pub async fn handle_publish(
305 &self,
306 client_id: &str,
307 topic: &str,
308 payload: Vec<u8>,
309 qos: u8,
310 retain: bool,
311 ) -> Result<(), Box<dyn std::error::Error>> {
312 self.handle_publish_internal(client_id, topic, payload, qos, retain, false)
313 .await
314 }
315
316 pub async fn publish_with_qos(
318 &self,
319 client_id: &str,
320 topic: &str,
321 payload: Vec<u8>,
322 qos: u8,
323 retain: bool,
324 ) -> Result<(), Box<dyn std::error::Error>> {
325 info!("Publishing with QoS to topic: {} with QoS: {}", topic, qos);
326
327 let qos_level = crate::qos::QoS::from_u8(qos).unwrap_or(crate::qos::QoS::AtMostOnce);
328
329 let packet_id = if qos_level != crate::qos::QoS::AtMostOnce {
330 self.next_packet_id().await
331 } else {
332 0 };
334
335 let message_state = crate::qos::MessageState {
336 packet_id,
337 topic: topic.to_string(),
338 payload: payload.clone(),
339 qos: qos_level,
340 retained: retain,
341 timestamp: std::time::SystemTime::now()
342 .duration_since(std::time::UNIX_EPOCH)?
343 .as_secs(),
344 };
345
346 if retain {
348 let mut topics = self.topics.write().await;
349 topics.retain_message(topic, payload.clone(), qos);
350 info!("Stored retained message for topic: {}", topic);
351 }
352
353 match qos_level {
355 crate::qos::QoS::AtMostOnce => {
356 self.qos_handler.handle_qo_s0(message_state).await?;
357 }
358 crate::qos::QoS::AtLeastOnce => {
359 self.qos_handler.handle_qo_s1(message_state, client_id).await?;
360 }
361 crate::qos::QoS::ExactlyOnce => {
362 self.qos_handler.handle_qo_s2(message_state, client_id).await?;
363 }
364 }
365
366 Ok(())
367 }
368
369 async fn handle_publish_internal(
370 &self,
371 client_id: &str,
372 topic: &str,
373 payload: Vec<u8>,
374 qos: u8,
375 retain: bool,
376 is_fixture_response: bool,
377 ) -> Result<(), Box<dyn std::error::Error>> {
378 info!("Handling publish to topic: {} with QoS: {}", topic, qos);
379
380 let qos_level = crate::qos::QoS::from_u8(qos).unwrap_or(crate::qos::QoS::AtMostOnce);
381
382 let packet_id = if qos_level != crate::qos::QoS::AtMostOnce {
383 self.next_packet_id().await
384 } else {
385 0 };
387
388 let message_state = crate::qos::MessageState {
389 packet_id,
390 topic: topic.to_string(),
391 payload: payload.clone(),
392 qos: qos_level,
393 retained: retain,
394 timestamp: std::time::SystemTime::now()
395 .duration_since(std::time::UNIX_EPOCH)?
396 .as_secs(),
397 };
398
399 if retain {
401 let mut topics = self.topics.write().await;
402 topics.retain_message(topic, payload.clone(), qos);
403 info!("Stored retained message for topic: {}", topic);
404 }
405
406 match qos_level {
408 crate::qos::QoS::AtMostOnce => {
409 self.qos_handler.handle_qo_s0(message_state).await?;
410 }
411 crate::qos::QoS::AtLeastOnce => {
412 self.qos_handler.handle_qo_s1(message_state, client_id).await?;
413 }
414 crate::qos::QoS::ExactlyOnce => {
415 self.qos_handler.handle_qo_s2(message_state, client_id).await?;
416 }
417 }
418
419 if !is_fixture_response {
421 if let Some(fixture) = self.fixture_registry.read().await.find_by_topic(topic) {
422 info!("Found matching fixture: {}", fixture.identifier);
423
424 match self.generate_fixture_response(fixture, topic, &payload) {
426 Ok(response_payload) => {
427 info!("Generated fixture response with {} bytes", response_payload.len());
428 if let Err(e) = self
430 .publish_with_qos(
431 client_id,
432 topic,
433 response_payload,
434 fixture.qos,
435 fixture.retained,
436 )
437 .await
438 {
439 warn!("Failed to publish fixture response: {}", e);
440 }
441 }
442 Err(e) => {
443 warn!("Failed to generate fixture response: {}", e);
444 }
445 }
446 }
447 }
448
449 self.route_to_subscribers(topic, &payload, qos_level).await?;
451
452 Ok(())
458 }
459
460 async fn route_to_subscribers(
462 &self,
463 topic: &str,
464 payload: &[u8],
465 qos: crate::qos::QoS,
466 ) -> Result<(), Box<dyn std::error::Error>> {
467 let topics_read = self.topics.read().await;
468 let subscribers = topics_read.match_topic(topic);
469 for subscriber in &subscribers {
470 info!(
471 "Routing to subscriber: {} on topic filter: {}",
472 subscriber.client_id, subscriber.filter
473 );
474 self.route_message_to_client(&subscriber.client_id, topic, payload, qos).await?;
475 }
476 Ok(())
477 }
478
479 fn generate_fixture_response(
481 &self,
482 fixture: &crate::fixtures::MqttFixture,
483 topic: &str,
484 received_payload: &[u8],
485 ) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
486 use mockforge_core::templating;
487
488 let mut env_vars = std::collections::HashMap::new();
490 env_vars.insert("topic".to_string(), topic.to_string());
491
492 if let Ok(received_json) = serde_json::from_slice::<serde_json::Value>(received_payload) {
494 env_vars.insert("payload".to_string(), received_json.to_string());
495 } else {
496 env_vars.insert(
498 "payload".to_string(),
499 String::from_utf8_lossy(received_payload).to_string(),
500 );
501 }
502
503 let context = templating::TemplatingContext::with_env(env_vars);
504
505 let template_str = serde_json::to_string(&fixture.response.payload)?;
507 let expanded_payload = templating::expand_str_with_context(&template_str, &context);
508
509 Ok(expanded_payload.into_bytes())
510 }
511
512 async fn route_message_to_client(
514 &self,
515 client_id: &str,
516 topic: &str,
517 payload: &[u8],
518 qos: crate::qos::QoS,
519 ) -> Result<(), Box<dyn std::error::Error>> {
520 let clients = self.clients.read().await;
522 if let Some(client_state) = clients.get(client_id) {
523 info!("Delivering message to connected client {} on topic {}", client_id, topic);
524
525 if qos != crate::qos::QoS::AtMostOnce {
535 let mut pending_messages = client_state.pending_messages.clone();
536 let message_state = crate::qos::MessageState {
537 packet_id: 0, topic: topic.to_string(),
539 payload: payload.to_vec(),
540 qos,
541 retained: false,
542 timestamp: std::time::SystemTime::now()
543 .duration_since(std::time::UNIX_EPOCH)?
544 .as_secs(),
545 };
546 pending_messages.push(message_state);
547
548 info!(
551 "Added QoS {} message to pending delivery queue for client {}",
552 qos.as_u8(),
553 client_id
554 );
555 }
556
557 Ok(())
558 } else {
559 warn!("Cannot route message to disconnected client: {}", client_id);
560 Err(format!("Client {} is not connected", client_id).into())
561 }
562 }
563
564 pub async fn update_metrics(&self) {
566 }
577}