1use crate::error::{IoError, IoResult};
12use crate::stream::{SignalStream, StreamConfig};
13use rumqttc::{
14 AsyncClient, Event, EventLoop, Incoming, MqttOptions, QoS, TlsConfiguration, Transport,
15};
16use scirs2_core::ndarray::Array1;
17use serde::{Deserialize, Serialize};
18use std::sync::Arc;
19use std::time::Duration;
20use tokio::sync::Mutex;
21use tracing::{debug, error, info, warn};
22
23#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
25pub enum QosLevel {
26 AtMostOnce = 0,
28 #[default]
30 AtLeastOnce = 1,
31 ExactlyOnce = 2,
33}
34
35impl From<QosLevel> for QoS {
36 fn from(level: QosLevel) -> Self {
37 match level {
38 QosLevel::AtMostOnce => QoS::AtMostOnce,
39 QosLevel::AtLeastOnce => QoS::AtLeastOnce,
40 QosLevel::ExactlyOnce => QoS::ExactlyOnce,
41 }
42 }
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize, Default)]
47pub struct TlsConfig {
48 pub ca_cert_path: Option<String>,
50
51 pub client_cert_path: Option<String>,
53
54 pub client_key_path: Option<String>,
56
57 pub alpn: Option<Vec<String>>,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct MqttConfig {
64 pub host: String,
66
67 pub port: u16,
69
70 pub client_id: String,
72
73 pub topics: Vec<String>,
75
76 #[serde(default)]
78 pub qos: QosLevel,
79
80 #[serde(default = "default_keep_alive")]
82 pub keep_alive_secs: u64,
83
84 #[serde(default)]
86 pub use_tls: bool,
87
88 #[serde(default)]
90 pub tls_config: TlsConfig,
91
92 pub username: Option<String>,
94
95 pub password: Option<String>,
97
98 #[serde(default = "default_true")]
100 pub handle_retained: bool,
101
102 #[serde(default = "default_true")]
104 pub auto_reconnect: bool,
105
106 #[serde(default = "default_reconnect_delay")]
108 pub reconnect_delay_ms: u64,
109
110 #[serde(default = "default_max_reconnect_delay")]
112 pub max_reconnect_delay_ms: u64,
113
114 #[serde(default = "default_batch_size")]
116 pub batch_size: usize,
117
118 #[serde(default = "default_batch_timeout")]
120 pub batch_timeout_ms: u64,
121
122 #[serde(default = "default_true")]
124 pub clean_session: bool,
125}
126
127fn default_keep_alive() -> u64 {
128 30
129}
130
131fn default_true() -> bool {
132 true
133}
134
135fn default_reconnect_delay() -> u64 {
136 1000
137}
138
139fn default_max_reconnect_delay() -> u64 {
140 30000
141}
142
143fn default_batch_size() -> usize {
144 100
145}
146
147fn default_batch_timeout() -> u64 {
148 100
149}
150
151impl Default for MqttConfig {
152 fn default() -> Self {
153 Self {
154 host: "localhost".into(),
155 port: 1883,
156 client_id: "kizzasi-client".into(),
157 topics: vec!["sensors/#".into()],
158 qos: QosLevel::AtLeastOnce,
159 keep_alive_secs: 30,
160 use_tls: false,
161 tls_config: TlsConfig::default(),
162 username: None,
163 password: None,
164 handle_retained: true,
165 auto_reconnect: true,
166 reconnect_delay_ms: 1000,
167 max_reconnect_delay_ms: 30000,
168 batch_size: 100,
169 batch_timeout_ms: 100,
170 clean_session: true,
171 }
172 }
173}
174
175impl MqttConfig {
176 pub fn new(host: &str, port: u16) -> Self {
178 Self {
179 host: host.into(),
180 port,
181 ..Default::default()
182 }
183 }
184
185 pub fn topics(mut self, topics: Vec<String>) -> Self {
187 self.topics = topics;
188 self
189 }
190
191 pub fn topic(mut self, topic: &str) -> Self {
193 self.topics = vec![topic.into()];
194 self
195 }
196
197 pub fn client_id(mut self, id: &str) -> Self {
199 self.client_id = id.into();
200 self
201 }
202
203 pub fn qos(mut self, qos: QosLevel) -> Self {
205 self.qos = qos;
206 self
207 }
208
209 pub fn enable_tls(mut self, tls_config: TlsConfig) -> Self {
211 self.use_tls = true;
212 self.tls_config = tls_config;
213 self
214 }
215
216 pub fn credentials(mut self, username: String, password: String) -> Self {
218 self.username = Some(username);
219 self.password = Some(password);
220 self
221 }
222}
223
224#[derive(Debug, Clone)]
226pub struct MqttMessage {
227 pub topic: String,
229
230 pub payload: Vec<u8>,
232
233 pub qos: QosLevel,
235
236 pub retained: bool,
238}
239
240pub struct MqttClient {
242 config: MqttConfig,
243 stream_config: StreamConfig,
244 buffer: Arc<Mutex<Vec<f32>>>,
245 message_buffer: Arc<Mutex<Vec<MqttMessage>>>,
246 active: Arc<Mutex<bool>>,
247 client: Option<AsyncClient>,
248}
249
250impl MqttClient {
251 pub fn new(mqtt_config: MqttConfig, stream_config: StreamConfig) -> Self {
253 Self {
254 config: mqtt_config,
255 stream_config,
256 buffer: Arc::new(Mutex::new(Vec::new())),
257 message_buffer: Arc::new(Mutex::new(Vec::new())),
258 active: Arc::new(Mutex::new(false)),
259 client: None,
260 }
261 }
262
263 pub async fn connect(&mut self) -> IoResult<()> {
265 let mut options =
266 MqttOptions::new(&self.config.client_id, &self.config.host, self.config.port);
267
268 options.set_keep_alive(Duration::from_secs(self.config.keep_alive_secs));
269 options.set_clean_session(self.config.clean_session);
270
271 if let (Some(username), Some(password)) = (&self.config.username, &self.config.password) {
273 options.set_credentials(username, password);
274 }
275
276 if self.config.use_tls {
278 if let Some(ca_path) = &self.config.tls_config.ca_cert_path {
279 let ca = std::fs::read(ca_path)
280 .map_err(|e| IoError::ConfigError(format!("Failed to read CA cert: {}", e)))?;
281
282 let client_auth = if let (Some(cert_path), Some(key_path)) = (
283 &self.config.tls_config.client_cert_path,
284 &self.config.tls_config.client_key_path,
285 ) {
286 let cert = std::fs::read(cert_path).map_err(|e| {
287 IoError::ConfigError(format!("Failed to read client cert: {}", e))
288 })?;
289 let key = std::fs::read(key_path).map_err(|e| {
290 IoError::ConfigError(format!("Failed to read client key: {}", e))
291 })?;
292 Some((cert, key))
293 } else {
294 None
295 };
296
297 let alpn = self.config.tls_config.alpn.as_ref().map(|protocols| {
298 protocols
299 .iter()
300 .map(|s| s.as_bytes().to_vec())
301 .collect::<Vec<Vec<u8>>>()
302 });
303
304 let tls_config = TlsConfiguration::Simple {
305 ca,
306 alpn,
307 client_auth,
308 };
309
310 options.set_transport(Transport::Tls(tls_config));
311 info!("MQTT TLS/SSL enabled");
312 }
313 }
314
315 let (client, eventloop) = AsyncClient::new(options, 10);
316
317 let qos: QoS = self.config.qos.into();
319 for topic in &self.config.topics {
320 client
321 .subscribe(topic, qos)
322 .await
323 .map_err(|e| IoError::ConnectionFailed(format!("Subscribe failed: {}", e)))?;
324
325 info!(
326 "MQTT subscribed to '{}' with QoS {:?}",
327 topic, self.config.qos
328 );
329 }
330
331 *self.active.lock().await = true;
333
334 self.client = Some(client.clone());
335
336 let buffer = self.buffer.clone();
337 let message_buffer = self.message_buffer.clone();
338 let active = self.active.clone();
339 let config = self.config.clone();
340
341 tokio::spawn(async move {
343 Self::event_loop_task(eventloop, buffer, message_buffer, active, config).await;
344 });
345
346 Ok(())
347 }
348
349 async fn event_loop_task(
351 mut eventloop: EventLoop,
352 buffer: Arc<Mutex<Vec<f32>>>,
353 message_buffer: Arc<Mutex<Vec<MqttMessage>>>,
354 active: Arc<Mutex<bool>>,
355 config: MqttConfig,
356 ) {
357 let mut reconnect_delay = Duration::from_millis(config.reconnect_delay_ms);
358 let max_delay = Duration::from_millis(config.max_reconnect_delay_ms);
359 let batch_timeout = Duration::from_millis(config.batch_timeout_ms);
360 let mut batch: Vec<f32> = Vec::with_capacity(config.batch_size);
361 let mut last_batch_time = tokio::time::Instant::now();
362
363 loop {
364 if !*active.lock().await {
365 break;
366 }
367
368 match eventloop.poll().await {
369 Ok(Event::Incoming(Incoming::Publish(p))) => {
370 debug!("MQTT received on '{}': {} bytes", p.topic, p.payload.len());
371
372 if p.retain && !config.handle_retained {
374 debug!("Skipping retained message");
375 continue;
376 }
377
378 let msg = MqttMessage {
380 topic: p.topic.clone(),
381 payload: p.payload.to_vec(),
382 qos: match p.qos {
383 QoS::AtMostOnce => QosLevel::AtMostOnce,
384 QoS::AtLeastOnce => QosLevel::AtLeastOnce,
385 QoS::ExactlyOnce => QosLevel::ExactlyOnce,
386 },
387 retained: p.retain,
388 };
389
390 message_buffer.lock().await.push(msg);
391
392 if let Ok(values) = serde_json::from_slice::<Vec<f32>>(&p.payload) {
394 batch.extend(values);
395
396 if batch.len() >= config.batch_size
398 || last_batch_time.elapsed() >= batch_timeout
399 {
400 let mut buf = buffer.lock().await;
401 buf.extend(batch.drain(..));
402 last_batch_time = tokio::time::Instant::now();
403 debug!("MQTT batch flushed: {} samples", buf.len());
404 }
405 }
406
407 reconnect_delay = Duration::from_millis(config.reconnect_delay_ms);
409 }
410 Ok(Event::Incoming(Incoming::ConnAck(_))) => {
411 info!("MQTT connection acknowledged");
412 }
413 Ok(Event::Incoming(Incoming::SubAck(_))) => {
414 debug!("MQTT subscription acknowledged");
415 }
416 Ok(Event::Incoming(Incoming::PingResp)) => {
417 debug!("MQTT ping response");
418 }
419 Ok(Event::Outgoing(_)) => {
420 }
422 Err(e) => {
423 error!("MQTT connection error: {}", e);
424
425 if !config.auto_reconnect {
426 *active.lock().await = false;
427 break;
428 }
429
430 warn!("Reconnecting in {:?}...", reconnect_delay);
432 tokio::time::sleep(reconnect_delay).await;
433 reconnect_delay = (reconnect_delay * 2).min(max_delay);
434 }
435 _ => {}
436 }
437 }
438
439 info!("MQTT event loop terminated");
440 }
441
442 pub async fn drain_buffer(&self) -> Vec<f32> {
444 let mut buffer = self.buffer.lock().await;
445 std::mem::take(&mut *buffer)
446 }
447
448 pub async fn drain_messages(&self) -> Vec<MqttMessage> {
450 let mut buffer = self.message_buffer.lock().await;
451 std::mem::take(&mut *buffer)
452 }
453
454 pub async fn publish(
456 &self,
457 topic: &str,
458 payload: Vec<u8>,
459 qos: QosLevel,
460 retain: bool,
461 ) -> IoResult<()> {
462 let client = self
463 .client
464 .as_ref()
465 .ok_or_else(|| IoError::ConnectionFailed("Not connected".into()))?;
466
467 client
468 .publish(topic, qos.into(), retain, payload)
469 .await
470 .map_err(|e| IoError::SendFailed(format!("Publish failed: {}", e)))?;
471
472 debug!("MQTT published to '{}' with QoS {:?}", topic, qos);
473 Ok(())
474 }
475
476 pub async fn is_connected(&self) -> bool {
478 *self.active.lock().await
479 }
480
481 pub async fn disconnect(&mut self) -> IoResult<()> {
483 *self.active.lock().await = false;
484
485 if let Some(client) = &self.client {
486 client
487 .disconnect()
488 .await
489 .map_err(|e| IoError::ConnectionFailed(format!("Disconnect failed: {}", e)))?;
490 }
491
492 self.client = None;
493 info!("MQTT disconnected");
494 Ok(())
495 }
496
497 pub fn stream_config(&self) -> &StreamConfig {
499 &self.stream_config
500 }
501
502 pub fn mqtt_config(&self) -> &MqttConfig {
504 &self.config
505 }
506}
507
508pub struct MqttStream {
510 config: StreamConfig,
511 buffer: Vec<f32>,
512 active: bool,
513}
514
515impl MqttStream {
516 pub fn new(config: StreamConfig) -> Self {
518 Self {
519 config,
520 buffer: Vec::new(),
521 active: true,
522 }
523 }
524
525 pub fn push_data(&mut self, data: Vec<f32>) {
527 self.buffer.extend(data);
528 }
529}
530
531impl SignalStream for MqttStream {
532 fn read(&mut self) -> IoResult<Array1<f32>> {
533 let size = self.config.buffer_size.min(self.buffer.len());
534 if size == 0 {
535 return Ok(Array1::zeros(self.config.buffer_size));
536 }
537
538 let data: Vec<f32> = self.buffer.drain(..size).collect();
539 let mut result = Array1::zeros(self.config.buffer_size);
540 for (i, val) in data.into_iter().enumerate() {
541 result[i] = val;
542 }
543 Ok(result)
544 }
545
546 fn is_active(&self) -> bool {
547 self.active
548 }
549
550 fn config(&self) -> &StreamConfig {
551 &self.config
552 }
553
554 fn close(&mut self) -> IoResult<()> {
555 self.active = false;
556 Ok(())
557 }
558}
559
560#[cfg(test)]
561mod tests {
562 use super::*;
563
564 #[test]
565 fn test_mqtt_config() {
566 let config = MqttConfig::new("broker.example.com", 1883)
567 .topic("sensors/temp")
568 .client_id("test-client")
569 .qos(QosLevel::ExactlyOnce);
570
571 assert_eq!(config.host, "broker.example.com");
572 assert_eq!(config.topics[0], "sensors/temp");
573 assert_eq!(config.qos, QosLevel::ExactlyOnce);
574 }
575
576 #[test]
577 fn test_qos_levels() {
578 assert_eq!(QosLevel::AtMostOnce as u8, 0);
579 assert_eq!(QosLevel::AtLeastOnce as u8, 1);
580 assert_eq!(QosLevel::ExactlyOnce as u8, 2);
581 }
582
583 #[test]
584 fn test_mqtt_config_credentials() {
585 let config = MqttConfig::new("broker.example.com", 1883)
586 .credentials("user".to_string(), "pass".to_string());
587
588 assert_eq!(config.username, Some("user".to_string()));
589 assert_eq!(config.password, Some("pass".to_string()));
590 }
591}