Skip to main content

agentlink_native/
mqtt.rs

1//! Native MQTT Client Implementation
2//!
3//! Uses rumqttc for MQTT connections.
4
5use std::sync::Arc;
6use std::time::Duration;
7
8use async_trait::async_trait;
9use agentlink_core::mqtt::{
10    MqttClient, MqttConfig, MqttConnectionState, MqttEvent, MqttMessage, MqttQoS,
11};
12use agentlink_core::error::SdkResult;
13use rumqttc::{AsyncClient, Event as MqttEventLoopEvent, EventLoop, Incoming, MqttOptions, QoS, Transport};
14use tokio::sync::Mutex;
15use tokio::task::JoinHandle;
16
17/// Native MQTT client using rumqttc
18pub struct NativeMqttClient {
19    client: Arc<Mutex<Option<AsyncClient>>>,
20    state: Arc<Mutex<MqttConnectionState>>,
21    event_callback: Arc<Mutex<Option<Box<dyn Fn(MqttEvent) + Send + Sync>>>>,
22    event_loop_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
23}
24
25impl NativeMqttClient {
26    pub fn new() -> Self {
27        Self {
28            client: Arc::new(Mutex::new(None)),
29            state: Arc::new(Mutex::new(MqttConnectionState::Disconnected)),
30            event_callback: Arc::new(Mutex::new(None)),
31            event_loop_handle: Arc::new(Mutex::new(None)),
32        }
33    }
34
35    fn parse_broker_url(url: &str) -> SdkResult<(String, u16, Option<Transport>)> {
36        let is_tls = url.starts_with("mqtts://");
37        let is_tcp = url.starts_with("mqtt://");
38
39        if !is_tls && !is_tcp {
40            return Err(agentlink_core::error::SdkError::Config(
41                format!("Unsupported MQTT protocol: {}", url)
42            ));
43        }
44
45        let url_part = if is_tls {
46            &url[8..]
47        } else {
48            &url[7..]
49        };
50
51        let parts: Vec<&str> = url_part.split('/').next().unwrap().split(':').collect();
52        let host = parts[0];
53
54        let default_port = if is_tls { 8883 } else { 1883 };
55        let port = parts.get(1)
56            .and_then(|p| p.parse::<u16>().ok())
57            .unwrap_or(default_port);
58
59        let transport = if is_tls {
60            Some(Self::create_tls_transport()?)
61        } else {
62            None
63        };
64
65        Ok((host.to_string(), port, transport))
66    }
67
68    fn create_tls_transport() -> SdkResult<Transport> {
69        use rumqttc::tokio_rustls::rustls::client::danger::{
70            HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier,
71        };
72        use rumqttc::tokio_rustls::rustls::pki_types::{CertificateDer, ServerName, UnixTime};
73        use rumqttc::tokio_rustls::rustls::{DigitallySignedStruct, Error, SignatureScheme};
74
75        #[derive(Debug)]
76        struct NoVerification;
77
78        impl ServerCertVerifier for NoVerification {
79            fn verify_server_cert(
80                &self, _end_entity: &CertificateDer<'_>, _intermediates: &[CertificateDer<'_>],
81                _server_name: &ServerName<'_>, _ocsp_response: &[u8], _now: UnixTime,
82            ) -> Result<ServerCertVerified, Error> {
83                Ok(ServerCertVerified::assertion())
84            }
85
86            fn verify_tls12_signature(
87                &self, _message: &[u8], _cert: &CertificateDer<'_>, _dss: &DigitallySignedStruct,
88            ) -> Result<HandshakeSignatureValid, Error> {
89                Ok(HandshakeSignatureValid::assertion())
90            }
91
92            fn verify_tls13_signature(
93                &self, _message: &[u8], _cert: &CertificateDer<'_>, _dss: &DigitallySignedStruct,
94            ) -> Result<HandshakeSignatureValid, Error> {
95                Ok(HandshakeSignatureValid::assertion())
96            }
97
98            fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
99                vec![
100                    SignatureScheme::RSA_PKCS1_SHA256,
101                    SignatureScheme::ECDSA_NISTP256_SHA256,
102                    SignatureScheme::ECDSA_NISTP384_SHA384,
103                    SignatureScheme::ED25519,
104                    SignatureScheme::RSA_PSS_SHA256,
105                ]
106            }
107        }
108
109        let config = rumqttc::tokio_rustls::rustls::ClientConfig::builder()
110            .dangerous()
111            .with_custom_certificate_verifier(Arc::new(NoVerification))
112            .with_no_client_auth();
113
114        Ok(Transport::Tls(rumqttc::TlsConfiguration::Rustls(Arc::new(config))))
115    }
116
117    async fn run_event_loop(
118        mut eventloop: EventLoop,
119        state: Arc<Mutex<MqttConnectionState>>,
120        callback: Arc<Mutex<Option<Box<dyn Fn(MqttEvent) + Send + Sync>>>>,
121    ) {
122        loop {
123            match eventloop.poll().await {
124                Ok(notification) => {
125                    match notification {
126                        MqttEventLoopEvent::Incoming(incoming) => {
127                            match incoming {
128                                Incoming::Publish(packet) => {
129                                    let msg = MqttMessage {
130                                        topic: packet.topic,
131                                        payload: packet.payload.to_vec(),
132                                        qos: match packet.qos {
133                                            QoS::AtMostOnce => MqttQoS::AtMostOnce,
134                                            QoS::AtLeastOnce => MqttQoS::AtLeastOnce,
135                                            QoS::ExactlyOnce => MqttQoS::ExactlyOnce,
136                                        },
137                                    };
138                                    if let Some(cb) = callback.lock().await.as_ref() {
139                                        cb(MqttEvent::MessageReceived(msg));
140                                    }
141                                }
142                                Incoming::ConnAck(_) => {
143                                    *state.lock().await = MqttConnectionState::Connected;
144                                    if let Some(cb) = callback.lock().await.as_ref() {
145                                        cb(MqttEvent::Connected);
146                                    }
147                                }
148                                Incoming::Disconnect => {
149                                    *state.lock().await = MqttConnectionState::Disconnected;
150                                    if let Some(cb) = callback.lock().await.as_ref() {
151                                        cb(MqttEvent::Disconnected);
152                                    }
153                                }
154                                _ => {}
155                            }
156                        }
157                        _ => {}
158                    }
159                }
160                Err(e) => {
161                    if let Some(cb) = callback.lock().await.as_ref() {
162                        cb(MqttEvent::Error { error: e.to_string() });
163                    }
164                }
165            }
166        }
167    }
168}
169
170#[async_trait]
171impl MqttClient for NativeMqttClient {
172    async fn connect(&self, config: MqttConfig) -> SdkResult<()> {
173        let (host, port, transport) = Self::parse_broker_url(&config.broker_url)?;
174
175        let mut mqtt_options = MqttOptions::new(&config.client_id, &host, port);
176        mqtt_options.set_keep_alive(Duration::from_secs(config.keep_alive_secs));
177        mqtt_options.set_clean_session(config.clean_session);
178
179        if let Some(transport) = transport {
180            mqtt_options.set_transport(transport);
181        }
182
183        if let Some(username) = config.username {
184            let password = config.password.unwrap_or_default();
185            mqtt_options.set_credentials(username, password);
186        }
187
188        let (client, eventloop) = AsyncClient::new(mqtt_options, 10);
189
190        *self.client.lock().await = Some(client);
191        *self.state.lock().await = MqttConnectionState::Connecting;
192
193        // Start event loop
194        let state = self.state.clone();
195        let callback = self.event_callback.clone();
196        let handle = tokio::spawn(Self::run_event_loop(eventloop, state, callback));
197        *self.event_loop_handle.lock().await = Some(handle);
198
199        Ok(())
200    }
201
202    async fn disconnect(&self) -> SdkResult<()> {
203        if let Some(handle) = self.event_loop_handle.lock().await.take() {
204            handle.abort();
205        }
206
207        if let Some(client) = self.client.lock().await.take() {
208            client.disconnect().await.map_err(|e| {
209                agentlink_core::error::SdkError::Mqtt(e.to_string())
210            })?;
211        }
212
213        *self.state.lock().await = MqttConnectionState::Disconnected;
214
215        Ok(())
216    }
217
218    async fn subscribe(&self, topic: &str, qos: MqttQoS) -> SdkResult<()> {
219        let client = self.client.lock().await;
220        if let Some(ref c) = *client {
221            let rumqtt_qos = match qos {
222                MqttQoS::AtMostOnce => QoS::AtMostOnce,
223                MqttQoS::AtLeastOnce => QoS::AtLeastOnce,
224                MqttQoS::ExactlyOnce => QoS::ExactlyOnce,
225            };
226            c.subscribe(topic, rumqtt_qos).await.map_err(|e| {
227                agentlink_core::error::SdkError::Mqtt(e.to_string())
228            })?;
229            Ok(())
230        } else {
231            Err(agentlink_core::error::SdkError::NotConnected)
232        }
233    }
234
235    async fn unsubscribe(&self, topic: &str) -> SdkResult<()> {
236        let client = self.client.lock().await;
237        if let Some(ref c) = *client {
238            c.unsubscribe(topic).await.map_err(|e| {
239                agentlink_core::error::SdkError::Mqtt(e.to_string())
240            })?;
241            Ok(())
242        } else {
243            Err(agentlink_core::error::SdkError::NotConnected)
244        }
245    }
246
247    async fn publish(&self, message: MqttMessage) -> SdkResult<()> {
248        let client = self.client.lock().await;
249        if let Some(ref c) = *client {
250            let qos = match message.qos {
251                MqttQoS::AtMostOnce => QoS::AtMostOnce,
252                MqttQoS::AtLeastOnce => QoS::AtLeastOnce,
253                MqttQoS::ExactlyOnce => QoS::ExactlyOnce,
254            };
255            c.publish(&message.topic, qos, false, message.payload).await.map_err(|e| {
256                agentlink_core::error::SdkError::Mqtt(e.to_string())
257            })?;
258            Ok(())
259        } else {
260            Err(agentlink_core::error::SdkError::NotConnected)
261        }
262    }
263
264    fn connection_state(&self) -> MqttConnectionState {
265        // This is synchronous, so we need to use try_lock or block
266        // For simplicity, we'll return Disconnected if we can't get the lock
267        // In practice, this should be called from an async context
268        MqttConnectionState::Disconnected
269    }
270}
271
272/// Extension trait for MQTT clients with event support
273pub trait MqttClientExt: MqttClient {
274    fn set_event_callback<F>(&self, callback: F)
275    where
276        F: Fn(MqttEvent) + Send + Sync + 'static;
277}
278
279impl MqttClientExt for NativeMqttClient {
280    fn set_event_callback<F>(&self, callback: F)
281    where
282        F: Fn(MqttEvent) + Send + Sync + 'static,
283    {
284        // This is a bit hacky since we can't easily store the callback
285        // In a real implementation, we'd need to use a channel or similar
286        // For now, we'll just ignore this in the synchronous context
287        let _ = callback;
288    }
289}