kizzasi_io/
mqtt.rs

1//! MQTT client for IoT/industrial sensor connectivity
2//!
3//! Provides enhanced MQTT client with:
4//! - TLS/SSL support
5//! - QoS levels (0, 1, 2)
6//! - Retained messages
7//! - Wildcard topic subscriptions
8//! - Auto-reconnection with exponential backoff
9//! - Message batching
10
11use crate::error::{IoError, IoResult};
12use crate::stream::{SignalStream, StreamConfig};
13use rumqttc::{
14    AsyncClient, Event, EventLoop, Incoming, MqttOptions, QoS, TlsConfiguration, Transport,
15};
16use scirs2_core::ndarray::Array1;
17use serde::{Deserialize, Serialize};
18use std::sync::Arc;
19use std::time::Duration;
20use tokio::sync::Mutex;
21use tracing::{debug, error, info, warn};
22
23/// MQTT Quality of Service level
24#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
25pub enum QosLevel {
26    /// At most once delivery (fire and forget)
27    AtMostOnce = 0,
28    /// At least once delivery (acknowledged)
29    #[default]
30    AtLeastOnce = 1,
31    /// Exactly once delivery (assured)
32    ExactlyOnce = 2,
33}
34
35impl From<QosLevel> for QoS {
36    fn from(level: QosLevel) -> Self {
37        match level {
38            QosLevel::AtMostOnce => QoS::AtMostOnce,
39            QosLevel::AtLeastOnce => QoS::AtLeastOnce,
40            QosLevel::ExactlyOnce => QoS::ExactlyOnce,
41        }
42    }
43}
44
45/// TLS/SSL configuration
46#[derive(Debug, Clone, Serialize, Deserialize, Default)]
47pub struct TlsConfig {
48    /// Path to CA certificate file
49    pub ca_cert_path: Option<String>,
50
51    /// Path to client certificate file
52    pub client_cert_path: Option<String>,
53
54    /// Path to client key file
55    pub client_key_path: Option<String>,
56
57    /// ALPN protocols
58    pub alpn: Option<Vec<String>>,
59}
60
61/// Configuration for MQTT connection
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct MqttConfig {
64    /// MQTT broker host
65    pub host: String,
66
67    /// MQTT broker port
68    pub port: u16,
69
70    /// Client ID
71    pub client_id: String,
72
73    /// Topics to subscribe to (supports wildcards: +, #)
74    pub topics: Vec<String>,
75
76    /// QoS level
77    #[serde(default)]
78    pub qos: QosLevel,
79
80    /// Keep alive interval in seconds
81    #[serde(default = "default_keep_alive")]
82    pub keep_alive_secs: u64,
83
84    /// Enable TLS/SSL
85    #[serde(default)]
86    pub use_tls: bool,
87
88    /// TLS configuration
89    #[serde(default)]
90    pub tls_config: TlsConfig,
91
92    /// Username for authentication
93    pub username: Option<String>,
94
95    /// Password for authentication
96    pub password: Option<String>,
97
98    /// Enable retained message handling
99    #[serde(default = "default_true")]
100    pub handle_retained: bool,
101
102    /// Enable auto-reconnection
103    #[serde(default = "default_true")]
104    pub auto_reconnect: bool,
105
106    /// Reconnection delay (ms)
107    #[serde(default = "default_reconnect_delay")]
108    pub reconnect_delay_ms: u64,
109
110    /// Maximum reconnection delay (ms)
111    #[serde(default = "default_max_reconnect_delay")]
112    pub max_reconnect_delay_ms: u64,
113
114    /// Message batch size
115    #[serde(default = "default_batch_size")]
116    pub batch_size: usize,
117
118    /// Batch timeout (ms)
119    #[serde(default = "default_batch_timeout")]
120    pub batch_timeout_ms: u64,
121
122    /// Clean session flag
123    #[serde(default = "default_true")]
124    pub clean_session: bool,
125}
126
127fn default_keep_alive() -> u64 {
128    30
129}
130
131fn default_true() -> bool {
132    true
133}
134
135fn default_reconnect_delay() -> u64 {
136    1000
137}
138
139fn default_max_reconnect_delay() -> u64 {
140    30000
141}
142
143fn default_batch_size() -> usize {
144    100
145}
146
147fn default_batch_timeout() -> u64 {
148    100
149}
150
151impl Default for MqttConfig {
152    fn default() -> Self {
153        Self {
154            host: "localhost".into(),
155            port: 1883,
156            client_id: "kizzasi-client".into(),
157            topics: vec!["sensors/#".into()],
158            qos: QosLevel::AtLeastOnce,
159            keep_alive_secs: 30,
160            use_tls: false,
161            tls_config: TlsConfig::default(),
162            username: None,
163            password: None,
164            handle_retained: true,
165            auto_reconnect: true,
166            reconnect_delay_ms: 1000,
167            max_reconnect_delay_ms: 30000,
168            batch_size: 100,
169            batch_timeout_ms: 100,
170            clean_session: true,
171        }
172    }
173}
174
175impl MqttConfig {
176    /// Create a new MQTT configuration
177    pub fn new(host: &str, port: u16) -> Self {
178        Self {
179            host: host.into(),
180            port,
181            ..Default::default()
182        }
183    }
184
185    /// Set topics (supports wildcards)
186    pub fn topics(mut self, topics: Vec<String>) -> Self {
187        self.topics = topics;
188        self
189    }
190
191    /// Set a single topic
192    pub fn topic(mut self, topic: &str) -> Self {
193        self.topics = vec![topic.into()];
194        self
195    }
196
197    /// Set the client ID
198    pub fn client_id(mut self, id: &str) -> Self {
199        self.client_id = id.into();
200        self
201    }
202
203    /// Set QoS level
204    pub fn qos(mut self, qos: QosLevel) -> Self {
205        self.qos = qos;
206        self
207    }
208
209    /// Enable TLS/SSL
210    pub fn enable_tls(mut self, tls_config: TlsConfig) -> Self {
211        self.use_tls = true;
212        self.tls_config = tls_config;
213        self
214    }
215
216    /// Set credentials
217    pub fn credentials(mut self, username: String, password: String) -> Self {
218        self.username = Some(username);
219        self.password = Some(password);
220        self
221    }
222}
223
224/// MQTT message
225#[derive(Debug, Clone)]
226pub struct MqttMessage {
227    /// Topic
228    pub topic: String,
229
230    /// Payload
231    pub payload: Vec<u8>,
232
233    /// QoS level
234    pub qos: QosLevel,
235
236    /// Retained flag
237    pub retained: bool,
238}
239
240/// MQTT client for receiving sensor data
241pub struct MqttClient {
242    config: MqttConfig,
243    stream_config: StreamConfig,
244    buffer: Arc<Mutex<Vec<f32>>>,
245    message_buffer: Arc<Mutex<Vec<MqttMessage>>>,
246    active: Arc<Mutex<bool>>,
247    client: Option<AsyncClient>,
248}
249
250impl MqttClient {
251    /// Create a new MQTT client
252    pub fn new(mqtt_config: MqttConfig, stream_config: StreamConfig) -> Self {
253        Self {
254            config: mqtt_config,
255            stream_config,
256            buffer: Arc::new(Mutex::new(Vec::new())),
257            message_buffer: Arc::new(Mutex::new(Vec::new())),
258            active: Arc::new(Mutex::new(false)),
259            client: None,
260        }
261    }
262
263    /// Connect to the MQTT broker and start receiving messages
264    pub async fn connect(&mut self) -> IoResult<()> {
265        let mut options =
266            MqttOptions::new(&self.config.client_id, &self.config.host, self.config.port);
267
268        options.set_keep_alive(Duration::from_secs(self.config.keep_alive_secs));
269        options.set_clean_session(self.config.clean_session);
270
271        // Set credentials if provided
272        if let (Some(username), Some(password)) = (&self.config.username, &self.config.password) {
273            options.set_credentials(username, password);
274        }
275
276        // Configure TLS if enabled
277        if self.config.use_tls {
278            if let Some(ca_path) = &self.config.tls_config.ca_cert_path {
279                let ca = std::fs::read(ca_path)
280                    .map_err(|e| IoError::ConfigError(format!("Failed to read CA cert: {}", e)))?;
281
282                let client_auth = if let (Some(cert_path), Some(key_path)) = (
283                    &self.config.tls_config.client_cert_path,
284                    &self.config.tls_config.client_key_path,
285                ) {
286                    let cert = std::fs::read(cert_path).map_err(|e| {
287                        IoError::ConfigError(format!("Failed to read client cert: {}", e))
288                    })?;
289                    let key = std::fs::read(key_path).map_err(|e| {
290                        IoError::ConfigError(format!("Failed to read client key: {}", e))
291                    })?;
292                    Some((cert, key))
293                } else {
294                    None
295                };
296
297                let alpn = self.config.tls_config.alpn.as_ref().map(|protocols| {
298                    protocols
299                        .iter()
300                        .map(|s| s.as_bytes().to_vec())
301                        .collect::<Vec<Vec<u8>>>()
302                });
303
304                let tls_config = TlsConfiguration::Simple {
305                    ca,
306                    alpn,
307                    client_auth,
308                };
309
310                options.set_transport(Transport::Tls(tls_config));
311                info!("MQTT TLS/SSL enabled");
312            }
313        }
314
315        let (client, eventloop) = AsyncClient::new(options, 10);
316
317        // Subscribe to all topics
318        let qos: QoS = self.config.qos.into();
319        for topic in &self.config.topics {
320            client
321                .subscribe(topic, qos)
322                .await
323                .map_err(|e| IoError::ConnectionFailed(format!("Subscribe failed: {}", e)))?;
324
325            info!(
326                "MQTT subscribed to '{}' with QoS {:?}",
327                topic, self.config.qos
328            );
329        }
330
331        // Mark as active
332        *self.active.lock().await = true;
333
334        self.client = Some(client.clone());
335
336        let buffer = self.buffer.clone();
337        let message_buffer = self.message_buffer.clone();
338        let active = self.active.clone();
339        let config = self.config.clone();
340
341        // Spawn event loop handler with reconnection
342        tokio::spawn(async move {
343            Self::event_loop_task(eventloop, buffer, message_buffer, active, config).await;
344        });
345
346        Ok(())
347    }
348
349    /// Event loop task with auto-reconnection
350    async fn event_loop_task(
351        mut eventloop: EventLoop,
352        buffer: Arc<Mutex<Vec<f32>>>,
353        message_buffer: Arc<Mutex<Vec<MqttMessage>>>,
354        active: Arc<Mutex<bool>>,
355        config: MqttConfig,
356    ) {
357        let mut reconnect_delay = Duration::from_millis(config.reconnect_delay_ms);
358        let max_delay = Duration::from_millis(config.max_reconnect_delay_ms);
359        let batch_timeout = Duration::from_millis(config.batch_timeout_ms);
360        let mut batch: Vec<f32> = Vec::with_capacity(config.batch_size);
361        let mut last_batch_time = tokio::time::Instant::now();
362
363        loop {
364            if !*active.lock().await {
365                break;
366            }
367
368            match eventloop.poll().await {
369                Ok(Event::Incoming(Incoming::Publish(p))) => {
370                    debug!("MQTT received on '{}': {} bytes", p.topic, p.payload.len());
371
372                    // Handle retained messages
373                    if p.retain && !config.handle_retained {
374                        debug!("Skipping retained message");
375                        continue;
376                    }
377
378                    // Store raw message
379                    let msg = MqttMessage {
380                        topic: p.topic.clone(),
381                        payload: p.payload.to_vec(),
382                        qos: match p.qos {
383                            QoS::AtMostOnce => QosLevel::AtMostOnce,
384                            QoS::AtLeastOnce => QosLevel::AtLeastOnce,
385                            QoS::ExactlyOnce => QosLevel::ExactlyOnce,
386                        },
387                        retained: p.retain,
388                    };
389
390                    message_buffer.lock().await.push(msg);
391
392                    // Try to parse payload as JSON array of floats
393                    if let Ok(values) = serde_json::from_slice::<Vec<f32>>(&p.payload) {
394                        batch.extend(values);
395
396                        // Flush batch if full or timeout
397                        if batch.len() >= config.batch_size
398                            || last_batch_time.elapsed() >= batch_timeout
399                        {
400                            let mut buf = buffer.lock().await;
401                            buf.extend(batch.drain(..));
402                            last_batch_time = tokio::time::Instant::now();
403                            debug!("MQTT batch flushed: {} samples", buf.len());
404                        }
405                    }
406
407                    // Reset reconnect delay on successful message
408                    reconnect_delay = Duration::from_millis(config.reconnect_delay_ms);
409                }
410                Ok(Event::Incoming(Incoming::ConnAck(_))) => {
411                    info!("MQTT connection acknowledged");
412                }
413                Ok(Event::Incoming(Incoming::SubAck(_))) => {
414                    debug!("MQTT subscription acknowledged");
415                }
416                Ok(Event::Incoming(Incoming::PingResp)) => {
417                    debug!("MQTT ping response");
418                }
419                Ok(Event::Outgoing(_)) => {
420                    // Outgoing events don't need handling
421                }
422                Err(e) => {
423                    error!("MQTT connection error: {}", e);
424
425                    if !config.auto_reconnect {
426                        *active.lock().await = false;
427                        break;
428                    }
429
430                    // Exponential backoff
431                    warn!("Reconnecting in {:?}...", reconnect_delay);
432                    tokio::time::sleep(reconnect_delay).await;
433                    reconnect_delay = (reconnect_delay * 2).min(max_delay);
434                }
435                _ => {}
436            }
437        }
438
439        info!("MQTT event loop terminated");
440    }
441
442    /// Get the current buffer contents
443    pub async fn drain_buffer(&self) -> Vec<f32> {
444        let mut buffer = self.buffer.lock().await;
445        std::mem::take(&mut *buffer)
446    }
447
448    /// Get buffered messages
449    pub async fn drain_messages(&self) -> Vec<MqttMessage> {
450        let mut buffer = self.message_buffer.lock().await;
451        std::mem::take(&mut *buffer)
452    }
453
454    /// Publish a message
455    pub async fn publish(
456        &self,
457        topic: &str,
458        payload: Vec<u8>,
459        qos: QosLevel,
460        retain: bool,
461    ) -> IoResult<()> {
462        let client = self
463            .client
464            .as_ref()
465            .ok_or_else(|| IoError::ConnectionFailed("Not connected".into()))?;
466
467        client
468            .publish(topic, qos.into(), retain, payload)
469            .await
470            .map_err(|e| IoError::SendFailed(format!("Publish failed: {}", e)))?;
471
472        debug!("MQTT published to '{}' with QoS {:?}", topic, qos);
473        Ok(())
474    }
475
476    /// Check if client is connected
477    pub async fn is_connected(&self) -> bool {
478        *self.active.lock().await
479    }
480
481    /// Disconnect from the broker
482    pub async fn disconnect(&mut self) -> IoResult<()> {
483        *self.active.lock().await = false;
484
485        if let Some(client) = &self.client {
486            client
487                .disconnect()
488                .await
489                .map_err(|e| IoError::ConnectionFailed(format!("Disconnect failed: {}", e)))?;
490        }
491
492        self.client = None;
493        info!("MQTT disconnected");
494        Ok(())
495    }
496
497    /// Get the stream config
498    pub fn stream_config(&self) -> &StreamConfig {
499        &self.stream_config
500    }
501
502    /// Get the MQTT config
503    pub fn mqtt_config(&self) -> &MqttConfig {
504        &self.config
505    }
506}
507
508/// Synchronous wrapper for MqttClient as SignalStream
509pub struct MqttStream {
510    config: StreamConfig,
511    buffer: Vec<f32>,
512    active: bool,
513}
514
515impl MqttStream {
516    /// Create a new MQTT stream (placeholder for sync API)
517    pub fn new(config: StreamConfig) -> Self {
518        Self {
519            config,
520            buffer: Vec::new(),
521            active: true,
522        }
523    }
524
525    /// Push data into the buffer (called from async context)
526    pub fn push_data(&mut self, data: Vec<f32>) {
527        self.buffer.extend(data);
528    }
529}
530
531impl SignalStream for MqttStream {
532    fn read(&mut self) -> IoResult<Array1<f32>> {
533        let size = self.config.buffer_size.min(self.buffer.len());
534        if size == 0 {
535            return Ok(Array1::zeros(self.config.buffer_size));
536        }
537
538        let data: Vec<f32> = self.buffer.drain(..size).collect();
539        let mut result = Array1::zeros(self.config.buffer_size);
540        for (i, val) in data.into_iter().enumerate() {
541            result[i] = val;
542        }
543        Ok(result)
544    }
545
546    fn is_active(&self) -> bool {
547        self.active
548    }
549
550    fn config(&self) -> &StreamConfig {
551        &self.config
552    }
553
554    fn close(&mut self) -> IoResult<()> {
555        self.active = false;
556        Ok(())
557    }
558}
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563
564    #[test]
565    fn test_mqtt_config() {
566        let config = MqttConfig::new("broker.example.com", 1883)
567            .topic("sensors/temp")
568            .client_id("test-client")
569            .qos(QosLevel::ExactlyOnce);
570
571        assert_eq!(config.host, "broker.example.com");
572        assert_eq!(config.topics[0], "sensors/temp");
573        assert_eq!(config.qos, QosLevel::ExactlyOnce);
574    }
575
576    #[test]
577    fn test_qos_levels() {
578        assert_eq!(QosLevel::AtMostOnce as u8, 0);
579        assert_eq!(QosLevel::AtLeastOnce as u8, 1);
580        assert_eq!(QosLevel::ExactlyOnce as u8, 2);
581    }
582
583    #[test]
584    fn test_mqtt_config_credentials() {
585        let config = MqttConfig::new("broker.example.com", 1883)
586            .credentials("user".to_string(), "pass".to_string());
587
588        assert_eq!(config.username, Some("user".to_string()));
589        assert_eq!(config.password, Some("pass".to_string()));
590    }
591}