mockforge_mqtt/
broker.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::sync::RwLock;
4use tracing::{info, warn};
5
6use crate::qos::QoSHandler;
7use crate::spec_registry::MqttSpecRegistry;
8use crate::topics::TopicTree;
9
10/// MQTT protocol version
11#[derive(Debug, Clone, Copy, Default)]
12pub enum MqttVersion {
13    V3_1_1,
14    #[default]
15    V5_0,
16}
17
18/// MQTT broker configuration
19#[derive(Debug, Clone)]
20pub struct MqttConfig {
21    pub port: u16,
22    pub host: String,
23    pub max_connections: usize,
24    pub max_packet_size: usize,
25    pub keep_alive_secs: u16,
26    pub version: MqttVersion,
27}
28
29impl Default for MqttConfig {
30    fn default() -> Self {
31        Self {
32            port: 1883,
33            host: "0.0.0.0".to_string(),
34            max_connections: 1000,
35            max_packet_size: 1024 * 1024, // 1MB
36            keep_alive_secs: 60,
37            version: MqttVersion::default(),
38        }
39    }
40}
41
42/// Client session state
43#[derive(Debug, Clone)]
44pub struct ClientSession {
45    pub client_id: String,
46    pub subscriptions: HashMap<String, u8>, // topic_filter -> qos
47    pub clean_session: bool,
48    pub connected_at: u64,
49    pub last_seen: u64,
50}
51
52/// Client state for session management
53#[derive(Debug)]
54pub struct ClientState {
55    pub session: ClientSession,
56    pub pending_messages: Vec<crate::qos::MessageState>, // Messages to send when client reconnects
57}
58
59/// MQTT broker implementation
60pub struct MqttBroker {
61    config: MqttConfig,
62    topics: Arc<RwLock<TopicTree>>,
63    clients: Arc<RwLock<HashMap<String, ClientState>>>,
64    session_store: Arc<RwLock<HashMap<String, ClientSession>>>,
65    qos_handler: QoSHandler,
66    fixture_registry: Arc<RwLock<crate::fixtures::MqttFixtureRegistry>>,
67    next_packet_id: Arc<RwLock<u16>>,
68}
69
70impl MqttBroker {
71    pub fn new(config: MqttConfig, _spec_registry: Arc<MqttSpecRegistry>) -> Self {
72        Self {
73            config,
74            topics: Arc::new(RwLock::new(TopicTree::new())),
75            clients: Arc::new(RwLock::new(HashMap::new())),
76            session_store: Arc::new(RwLock::new(HashMap::new())),
77            qos_handler: QoSHandler::new(),
78            fixture_registry: Arc::new(RwLock::new(crate::fixtures::MqttFixtureRegistry::new())),
79            next_packet_id: Arc::new(RwLock::new(1)),
80        }
81    }
82
83    /// Handle client connection with session management
84    pub async fn client_connect(
85        &self,
86        client_id: &str,
87        clean_session: bool,
88    ) -> Result<(), Box<dyn std::error::Error>> {
89        let now = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?.as_secs();
90
91        let mut clients = self.clients.write().await;
92        let mut sessions = self.session_store.write().await;
93
94        if let Some(_existing_client) = clients.get(client_id) {
95            // Client already connected - this shouldn't happen in normal operation
96            info!("Client {} already connected, updating session", client_id);
97        }
98
99        let session = if clean_session {
100            // Clean session: create new session
101            sessions.remove(client_id); // Remove any existing persistent session
102            ClientSession {
103                client_id: client_id.to_string(),
104                subscriptions: HashMap::new(),
105                clean_session: true,
106                connected_at: now,
107                last_seen: now,
108            }
109        } else {
110            // Persistent session: restore or create
111            if let Some(persistent_session) = sessions.get(client_id) {
112                let mut restored_session = persistent_session.clone();
113                restored_session.connected_at = now;
114                restored_session.last_seen = now;
115                restored_session.clean_session = false;
116                restored_session
117            } else {
118                ClientSession {
119                    client_id: client_id.to_string(),
120                    subscriptions: HashMap::new(),
121                    clean_session: false,
122                    connected_at: now,
123                    last_seen: now,
124                }
125            }
126        };
127
128        let client_state = ClientState {
129            session: session.clone(),
130            pending_messages: Vec::new(),
131        };
132
133        clients.insert(client_id.to_string(), client_state);
134
135        // Record metrics
136        // if let Some(metrics) = &self.metrics_registry {
137        //     metrics.mqtt_connections_active.inc();
138        //     metrics.mqtt_connections_total.inc();
139        // }
140
141        info!("Client {} connected with clean_session: {}", client_id, clean_session);
142        Ok(())
143    }
144
145    /// Handle client disconnection with session persistence
146    pub async fn client_disconnect(
147        &self,
148        client_id: &str,
149    ) -> Result<(), Box<dyn std::error::Error>> {
150        let now = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?.as_secs();
151
152        let mut clients = self.clients.write().await;
153        let mut sessions = self.session_store.write().await;
154
155        if let Some(client_state) = clients.remove(client_id) {
156            let session = client_state.session;
157
158            if !session.clean_session {
159                // Persist session for non-clean sessions
160                let mut persistent_session = session.clone();
161                persistent_session.last_seen = now;
162                sessions.insert(client_id.to_string(), persistent_session);
163
164                info!("Persisted session for client {}", client_id);
165            } else {
166                // Clean up subscriptions for clean sessions
167                let mut topics = self.topics.write().await;
168                for filter in session.subscriptions.keys() {
169                    topics.unsubscribe(filter, client_id);
170                }
171
172                info!("Cleaned up session for client {}", client_id);
173            }
174        }
175
176        // Record metrics
177        // if let Some(metrics) = &self.metrics_registry {
178        //     metrics.mqtt_connections_active.dec();
179        // }
180
181        Ok(())
182    }
183
184    /// Subscribe client to topics with session persistence
185    pub async fn client_subscribe(
186        &self,
187        client_id: &str,
188        topics: Vec<(String, u8)>,
189    ) -> Result<(), Box<dyn std::error::Error>> {
190        let mut clients = self.clients.write().await;
191        let mut broker_topics = self.topics.write().await;
192
193        if let Some(client_state) = clients.get_mut(client_id) {
194            for (filter, qos) in topics {
195                broker_topics.subscribe(&filter, qos, client_id);
196                client_state.session.subscriptions.insert(filter.clone(), qos);
197
198                // Send retained messages for new subscriptions
199                let retained_messages = broker_topics.get_retained_for_filter(&filter);
200                for (topic, message) in retained_messages {
201                    info!("Sending retained message for topic {} to client {}", topic, client_id);
202                    let qos_level = crate::qos::QoS::from_u8(message.qos)
203                        .unwrap_or(crate::qos::QoS::AtMostOnce);
204                    if let Err(e) = self
205                        .route_message_to_client(client_id, topic, &message.payload, qos_level)
206                        .await
207                    {
208                        warn!("Failed to deliver retained message to client {}: {}", client_id, e);
209                    }
210                }
211            }
212
213            // Update persistent session if not clean
214            if !client_state.session.clean_session {
215                let mut sessions = self.session_store.write().await;
216                if let Some(session) = sessions.get_mut(client_id) {
217                    session.subscriptions.clone_from(&client_state.session.subscriptions);
218                }
219            }
220        }
221
222        Ok(())
223    }
224
225    /// Unsubscribe client from topics
226    pub async fn client_unsubscribe(
227        &self,
228        client_id: &str,
229        filters: Vec<String>,
230    ) -> Result<(), Box<dyn std::error::Error>> {
231        let mut clients = self.clients.write().await;
232        let mut broker_topics = self.topics.write().await;
233
234        if let Some(client_state) = clients.get_mut(client_id) {
235            for filter in filters {
236                broker_topics.unsubscribe(&filter, client_id);
237                client_state.session.subscriptions.remove(&filter);
238            }
239
240            // Update persistent session if not clean
241            if !client_state.session.clean_session {
242                let mut sessions = self.session_store.write().await;
243                if let Some(session) = sessions.get_mut(client_id) {
244                    session.subscriptions.clone_from(&client_state.session.subscriptions);
245                }
246            }
247        }
248
249        Ok(())
250    }
251
252    /// Get broker configuration (for testing)
253    pub fn config(&self) -> &MqttConfig {
254        &self.config
255    }
256
257    /// Get list of active topics (subscription filters and retained topics)
258    pub async fn get_active_topics(&self) -> Vec<String> {
259        let topics = self.topics.read().await;
260        let mut all_topics = topics.get_all_topic_filters();
261        all_topics.extend(topics.get_all_retained_topics());
262        all_topics.sort();
263        all_topics.dedup();
264        all_topics
265    }
266
267    /// Get list of connected clients
268    pub async fn get_connected_clients(&self) -> Vec<String> {
269        let clients = self.clients.read().await;
270        clients.keys().cloned().collect()
271    }
272
273    /// Get client information
274    pub async fn get_client_info(&self, client_id: &str) -> Option<ClientSession> {
275        let clients = self.clients.read().await;
276        clients.get(client_id).map(|state| state.session.clone())
277    }
278
279    /// Disconnect a client
280    pub async fn disconnect_client(
281        &self,
282        client_id: &str,
283    ) -> Result<(), Box<dyn std::error::Error>> {
284        self.client_disconnect(client_id).await
285    }
286
287    /// Get topic statistics
288    pub async fn get_topic_stats(&self) -> crate::topics::TopicStats {
289        let topics = self.topics.read().await;
290        topics.stats()
291    }
292
293    /// Generate next packet ID
294    pub async fn next_packet_id(&self) -> u16 {
295        let mut packet_id = self.next_packet_id.write().await;
296        let id = *packet_id;
297        *packet_id = packet_id.wrapping_add(1);
298        if *packet_id == 0 {
299            *packet_id = 1; // Skip 0 as it's reserved
300        }
301        id
302    }
303
304    pub async fn handle_publish(
305        &self,
306        client_id: &str,
307        topic: &str,
308        payload: Vec<u8>,
309        qos: u8,
310        retain: bool,
311    ) -> Result<(), Box<dyn std::error::Error>> {
312        self.handle_publish_internal(client_id, topic, payload, qos, retain, false)
313            .await
314    }
315
316    /// Publish a message with QoS handling but skip fixture lookup (used for fixture responses)
317    pub async fn publish_with_qos(
318        &self,
319        client_id: &str,
320        topic: &str,
321        payload: Vec<u8>,
322        qos: u8,
323        retain: bool,
324    ) -> Result<(), Box<dyn std::error::Error>> {
325        info!("Publishing with QoS to topic: {} with QoS: {}", topic, qos);
326
327        let qos_level = crate::qos::QoS::from_u8(qos).unwrap_or(crate::qos::QoS::AtMostOnce);
328
329        let packet_id = if qos_level != crate::qos::QoS::AtMostOnce {
330            self.next_packet_id().await
331        } else {
332            0 // QoS 0 doesn't use packet IDs
333        };
334
335        let message_state = crate::qos::MessageState {
336            packet_id,
337            topic: topic.to_string(),
338            payload: payload.clone(),
339            qos: qos_level,
340            retained: retain,
341            timestamp: std::time::SystemTime::now()
342                .duration_since(std::time::UNIX_EPOCH)?
343                .as_secs(),
344        };
345
346        // Handle retained messages
347        if retain {
348            let mut topics = self.topics.write().await;
349            topics.retain_message(topic, payload.clone(), qos);
350            info!("Stored retained message for topic: {}", topic);
351        }
352
353        // Handle based on QoS level
354        match qos_level {
355            crate::qos::QoS::AtMostOnce => {
356                self.qos_handler.handle_qo_s0(message_state).await?;
357            }
358            crate::qos::QoS::AtLeastOnce => {
359                self.qos_handler.handle_qo_s1(message_state, client_id).await?;
360            }
361            crate::qos::QoS::ExactlyOnce => {
362                self.qos_handler.handle_qo_s2(message_state, client_id).await?;
363            }
364        }
365
366        Ok(())
367    }
368
369    async fn handle_publish_internal(
370        &self,
371        client_id: &str,
372        topic: &str,
373        payload: Vec<u8>,
374        qos: u8,
375        retain: bool,
376        is_fixture_response: bool,
377    ) -> Result<(), Box<dyn std::error::Error>> {
378        info!("Handling publish to topic: {} with QoS: {}", topic, qos);
379
380        let qos_level = crate::qos::QoS::from_u8(qos).unwrap_or(crate::qos::QoS::AtMostOnce);
381
382        let packet_id = if qos_level != crate::qos::QoS::AtMostOnce {
383            self.next_packet_id().await
384        } else {
385            0 // QoS 0 doesn't use packet IDs
386        };
387
388        let message_state = crate::qos::MessageState {
389            packet_id,
390            topic: topic.to_string(),
391            payload: payload.clone(),
392            qos: qos_level,
393            retained: retain,
394            timestamp: std::time::SystemTime::now()
395                .duration_since(std::time::UNIX_EPOCH)?
396                .as_secs(),
397        };
398
399        // Handle retained messages
400        if retain {
401            let mut topics = self.topics.write().await;
402            topics.retain_message(topic, payload.clone(), qos);
403            info!("Stored retained message for topic: {}", topic);
404        }
405
406        // Handle based on QoS level
407        match qos_level {
408            crate::qos::QoS::AtMostOnce => {
409                self.qos_handler.handle_qo_s0(message_state).await?;
410            }
411            crate::qos::QoS::AtLeastOnce => {
412                self.qos_handler.handle_qo_s1(message_state, client_id).await?;
413            }
414            crate::qos::QoS::ExactlyOnce => {
415                self.qos_handler.handle_qo_s2(message_state, client_id).await?;
416            }
417        }
418
419        // Check if this matches any fixtures (skip if this is already a fixture response to avoid recursion)
420        if !is_fixture_response {
421            if let Some(fixture) = self.fixture_registry.read().await.find_by_topic(topic) {
422                info!("Found matching fixture: {}", fixture.identifier);
423
424                // Generate response using template engine
425                match self.generate_fixture_response(fixture, topic, &payload) {
426                    Ok(response_payload) => {
427                        info!("Generated fixture response with {} bytes", response_payload.len());
428                        // Publish the response to the same topic as the request (skip fixture lookup to avoid recursion)
429                        if let Err(e) = self
430                            .publish_with_qos(
431                                client_id,
432                                topic,
433                                response_payload,
434                                fixture.qos,
435                                fixture.retained,
436                            )
437                            .await
438                        {
439                            warn!("Failed to publish fixture response: {}", e);
440                        }
441                    }
442                    Err(e) => {
443                        warn!("Failed to generate fixture response: {}", e);
444                    }
445                }
446            }
447        }
448
449        // Route to subscribers
450        self.route_to_subscribers(topic, &payload, qos_level).await?;
451
452        // Record metrics
453        // if let Some(metrics) = &self.metrics_registry {
454        //     metrics.mqtt_messages_published_total.inc();
455        // }
456
457        Ok(())
458    }
459
460    /// Route a message to all subscribers of a topic
461    async fn route_to_subscribers(
462        &self,
463        topic: &str,
464        payload: &[u8],
465        qos: crate::qos::QoS,
466    ) -> Result<(), Box<dyn std::error::Error>> {
467        let topics_read = self.topics.read().await;
468        let subscribers = topics_read.match_topic(topic);
469        for subscriber in &subscribers {
470            info!(
471                "Routing to subscriber: {} on topic filter: {}",
472                subscriber.client_id, subscriber.filter
473            );
474            self.route_message_to_client(&subscriber.client_id, topic, payload, qos).await?;
475        }
476        Ok(())
477    }
478
479    /// Generate a response payload from a fixture using template expansion
480    fn generate_fixture_response(
481        &self,
482        fixture: &crate::fixtures::MqttFixture,
483        topic: &str,
484        received_payload: &[u8],
485    ) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
486        use mockforge_core::templating;
487
488        // Create templating context with environment variables
489        let mut env_vars = std::collections::HashMap::new();
490        env_vars.insert("topic".to_string(), topic.to_string());
491
492        // Try to parse received payload as JSON and add it to context
493        if let Ok(received_json) = serde_json::from_slice::<serde_json::Value>(received_payload) {
494            env_vars.insert("payload".to_string(), received_json.to_string());
495        } else {
496            // If not JSON, add as string
497            env_vars.insert(
498                "payload".to_string(),
499                String::from_utf8_lossy(received_payload).to_string(),
500            );
501        }
502
503        let context = templating::TemplatingContext::with_env(env_vars);
504
505        // Use template engine to render payload
506        let template_str = serde_json::to_string(&fixture.response.payload)?;
507        let expanded_payload = templating::expand_str_with_context(&template_str, &context);
508
509        Ok(expanded_payload.into_bytes())
510    }
511
512    /// Route a message to a specific client
513    async fn route_message_to_client(
514        &self,
515        client_id: &str,
516        topic: &str,
517        payload: &[u8],
518        qos: crate::qos::QoS,
519    ) -> Result<(), Box<dyn std::error::Error>> {
520        // Check if client is connected
521        let clients = self.clients.read().await;
522        if let Some(client_state) = clients.get(client_id) {
523            info!("Delivering message to connected client {} on topic {}", client_id, topic);
524
525            // In a real implementation, this would send the actual MQTT PUBLISH packet to the client
526            // For the management layer, we simulate the delivery and record metrics
527
528            // Record metrics
529            // if let Some(metrics) = &self.metrics_registry {
530            //     metrics.mqtt_messages_received_total.inc();
531            // }
532
533            // Add to client's pending messages if QoS requires it
534            if qos != crate::qos::QoS::AtMostOnce {
535                let mut pending_messages = client_state.pending_messages.clone();
536                let message_state = crate::qos::MessageState {
537                    packet_id: 0, // Would be assigned by actual MQTT protocol
538                    topic: topic.to_string(),
539                    payload: payload.to_vec(),
540                    qos,
541                    retained: false,
542                    timestamp: std::time::SystemTime::now()
543                        .duration_since(std::time::UNIX_EPOCH)?
544                        .as_secs(),
545                };
546                pending_messages.push(message_state);
547
548                // Update client state (in real implementation, this would be handled by the MQTT protocol)
549                // For simulation purposes, we just log
550                info!(
551                    "Added QoS {} message to pending delivery queue for client {}",
552                    qos.as_u8(),
553                    client_id
554                );
555            }
556
557            Ok(())
558        } else {
559            warn!("Cannot route message to disconnected client: {}", client_id);
560            Err(format!("Client {} is not connected", client_id).into())
561        }
562    }
563
564    /// Update Prometheus metrics with current broker statistics
565    pub async fn update_metrics(&self) {
566        // if let Some(metrics) = &self.metrics_registry {
567        //     let connected_clients = self.get_connected_clients().await.len() as i64;
568        //     let active_topics = self.get_active_topics().await.len() as i64;
569        //     let topic_stats = self.get_topic_stats().await;
570
571        //     metrics.mqtt_connections_active.set(connected_clients);
572        //     metrics.mqtt_topics_active.set(active_topics);
573        //     metrics.mqtt_subscriptions_active.set(topic_stats.total_subscriptions as i64);
574        //     metrics.mqtt_retained_messages.set(topic_stats.retained_messages as i64);
575        // }
576    }
577}