1use std::sync::Arc;
6use std::time::Duration;
7
8use async_trait::async_trait;
9use agentlink_core::mqtt::{
10 MqttClient, MqttConfig, MqttConnectionState, MqttEvent, MqttMessage, MqttQoS,
11};
12use agentlink_core::error::SdkResult;
13use rumqttc::{AsyncClient, Event as MqttEventLoopEvent, EventLoop, Incoming, MqttOptions, QoS, Transport};
14use tokio::sync::Mutex;
15use tokio::task::JoinHandle;
16
17pub struct NativeMqttClient {
19 client: Arc<Mutex<Option<AsyncClient>>>,
20 state: Arc<Mutex<MqttConnectionState>>,
21 event_callback: Arc<Mutex<Option<Box<dyn Fn(MqttEvent) + Send + Sync>>>>,
22 event_loop_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
23}
24
25impl NativeMqttClient {
26 pub fn new() -> Self {
27 Self {
28 client: Arc::new(Mutex::new(None)),
29 state: Arc::new(Mutex::new(MqttConnectionState::Disconnected)),
30 event_callback: Arc::new(Mutex::new(None)),
31 event_loop_handle: Arc::new(Mutex::new(None)),
32 }
33 }
34
35 fn parse_broker_url(url: &str) -> SdkResult<(String, u16, Option<Transport>)> {
36 let is_tls = url.starts_with("mqtts://");
37 let is_tcp = url.starts_with("mqtt://");
38
39 if !is_tls && !is_tcp {
40 return Err(agentlink_core::error::SdkError::Config(
41 format!("Unsupported MQTT protocol: {}", url)
42 ));
43 }
44
45 let url_part = if is_tls {
46 &url[8..]
47 } else {
48 &url[7..]
49 };
50
51 let parts: Vec<&str> = url_part.split('/').next().unwrap().split(':').collect();
52 let host = parts[0];
53
54 let default_port = if is_tls { 8883 } else { 1883 };
55 let port = parts.get(1)
56 .and_then(|p| p.parse::<u16>().ok())
57 .unwrap_or(default_port);
58
59 let transport = if is_tls {
60 Some(Self::create_tls_transport()?)
61 } else {
62 None
63 };
64
65 Ok((host.to_string(), port, transport))
66 }
67
68 fn create_tls_transport() -> SdkResult<Transport> {
69 use rumqttc::tokio_rustls::rustls::client::danger::{
70 HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier,
71 };
72 use rumqttc::tokio_rustls::rustls::pki_types::{CertificateDer, ServerName, UnixTime};
73 use rumqttc::tokio_rustls::rustls::{DigitallySignedStruct, Error, SignatureScheme};
74
75 #[derive(Debug)]
76 struct NoVerification;
77
78 impl ServerCertVerifier for NoVerification {
79 fn verify_server_cert(
80 &self, _end_entity: &CertificateDer<'_>, _intermediates: &[CertificateDer<'_>],
81 _server_name: &ServerName<'_>, _ocsp_response: &[u8], _now: UnixTime,
82 ) -> Result<ServerCertVerified, Error> {
83 Ok(ServerCertVerified::assertion())
84 }
85
86 fn verify_tls12_signature(
87 &self, _message: &[u8], _cert: &CertificateDer<'_>, _dss: &DigitallySignedStruct,
88 ) -> Result<HandshakeSignatureValid, Error> {
89 Ok(HandshakeSignatureValid::assertion())
90 }
91
92 fn verify_tls13_signature(
93 &self, _message: &[u8], _cert: &CertificateDer<'_>, _dss: &DigitallySignedStruct,
94 ) -> Result<HandshakeSignatureValid, Error> {
95 Ok(HandshakeSignatureValid::assertion())
96 }
97
98 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
99 vec![
100 SignatureScheme::RSA_PKCS1_SHA256,
101 SignatureScheme::ECDSA_NISTP256_SHA256,
102 SignatureScheme::ECDSA_NISTP384_SHA384,
103 SignatureScheme::ED25519,
104 SignatureScheme::RSA_PSS_SHA256,
105 ]
106 }
107 }
108
109 let config = rumqttc::tokio_rustls::rustls::ClientConfig::builder()
110 .dangerous()
111 .with_custom_certificate_verifier(Arc::new(NoVerification))
112 .with_no_client_auth();
113
114 Ok(Transport::Tls(rumqttc::TlsConfiguration::Rustls(Arc::new(config))))
115 }
116
117 async fn run_event_loop(
118 mut eventloop: EventLoop,
119 state: Arc<Mutex<MqttConnectionState>>,
120 callback: Arc<Mutex<Option<Box<dyn Fn(MqttEvent) + Send + Sync>>>>,
121 ) {
122 loop {
123 match eventloop.poll().await {
124 Ok(notification) => {
125 match notification {
126 MqttEventLoopEvent::Incoming(incoming) => {
127 match incoming {
128 Incoming::Publish(packet) => {
129 let msg = MqttMessage {
130 topic: packet.topic,
131 payload: packet.payload.to_vec(),
132 qos: match packet.qos {
133 QoS::AtMostOnce => MqttQoS::AtMostOnce,
134 QoS::AtLeastOnce => MqttQoS::AtLeastOnce,
135 QoS::ExactlyOnce => MqttQoS::ExactlyOnce,
136 },
137 };
138 if let Some(cb) = callback.lock().await.as_ref() {
139 cb(MqttEvent::MessageReceived(msg));
140 }
141 }
142 Incoming::ConnAck(_) => {
143 *state.lock().await = MqttConnectionState::Connected;
144 if let Some(cb) = callback.lock().await.as_ref() {
145 cb(MqttEvent::Connected);
146 }
147 }
148 Incoming::Disconnect => {
149 *state.lock().await = MqttConnectionState::Disconnected;
150 if let Some(cb) = callback.lock().await.as_ref() {
151 cb(MqttEvent::Disconnected);
152 }
153 }
154 _ => {}
155 }
156 }
157 _ => {}
158 }
159 }
160 Err(e) => {
161 if let Some(cb) = callback.lock().await.as_ref() {
162 cb(MqttEvent::Error { error: e.to_string() });
163 }
164 }
165 }
166 }
167 }
168}
169
170#[async_trait]
171impl MqttClient for NativeMqttClient {
172 async fn connect(&self, config: MqttConfig) -> SdkResult<()> {
173 let (host, port, transport) = Self::parse_broker_url(&config.broker_url)?;
174
175 let mut mqtt_options = MqttOptions::new(&config.client_id, &host, port);
176 mqtt_options.set_keep_alive(Duration::from_secs(config.keep_alive_secs));
177 mqtt_options.set_clean_session(config.clean_session);
178
179 if let Some(transport) = transport {
180 mqtt_options.set_transport(transport);
181 }
182
183 if let Some(username) = config.username {
184 let password = config.password.unwrap_or_default();
185 mqtt_options.set_credentials(username, password);
186 }
187
188 let (client, eventloop) = AsyncClient::new(mqtt_options, 10);
189
190 *self.client.lock().await = Some(client);
191 *self.state.lock().await = MqttConnectionState::Connecting;
192
193 let state = self.state.clone();
195 let callback = self.event_callback.clone();
196 let handle = tokio::spawn(Self::run_event_loop(eventloop, state, callback));
197 *self.event_loop_handle.lock().await = Some(handle);
198
199 Ok(())
200 }
201
202 async fn disconnect(&self) -> SdkResult<()> {
203 if let Some(handle) = self.event_loop_handle.lock().await.take() {
204 handle.abort();
205 }
206
207 if let Some(client) = self.client.lock().await.take() {
208 client.disconnect().await.map_err(|e| {
209 agentlink_core::error::SdkError::Mqtt(e.to_string())
210 })?;
211 }
212
213 *self.state.lock().await = MqttConnectionState::Disconnected;
214
215 Ok(())
216 }
217
218 async fn subscribe(&self, topic: &str, qos: MqttQoS) -> SdkResult<()> {
219 let client = self.client.lock().await;
220 if let Some(ref c) = *client {
221 let rumqtt_qos = match qos {
222 MqttQoS::AtMostOnce => QoS::AtMostOnce,
223 MqttQoS::AtLeastOnce => QoS::AtLeastOnce,
224 MqttQoS::ExactlyOnce => QoS::ExactlyOnce,
225 };
226 c.subscribe(topic, rumqtt_qos).await.map_err(|e| {
227 agentlink_core::error::SdkError::Mqtt(e.to_string())
228 })?;
229 Ok(())
230 } else {
231 Err(agentlink_core::error::SdkError::NotConnected)
232 }
233 }
234
235 async fn unsubscribe(&self, topic: &str) -> SdkResult<()> {
236 let client = self.client.lock().await;
237 if let Some(ref c) = *client {
238 c.unsubscribe(topic).await.map_err(|e| {
239 agentlink_core::error::SdkError::Mqtt(e.to_string())
240 })?;
241 Ok(())
242 } else {
243 Err(agentlink_core::error::SdkError::NotConnected)
244 }
245 }
246
247 async fn publish(&self, message: MqttMessage) -> SdkResult<()> {
248 let client = self.client.lock().await;
249 if let Some(ref c) = *client {
250 let qos = match message.qos {
251 MqttQoS::AtMostOnce => QoS::AtMostOnce,
252 MqttQoS::AtLeastOnce => QoS::AtLeastOnce,
253 MqttQoS::ExactlyOnce => QoS::ExactlyOnce,
254 };
255 c.publish(&message.topic, qos, false, message.payload).await.map_err(|e| {
256 agentlink_core::error::SdkError::Mqtt(e.to_string())
257 })?;
258 Ok(())
259 } else {
260 Err(agentlink_core::error::SdkError::NotConnected)
261 }
262 }
263
264 fn connection_state(&self) -> MqttConnectionState {
265 MqttConnectionState::Disconnected
269 }
270}
271
272pub trait MqttClientExt: MqttClient {
274 fn set_event_callback<F>(&self, callback: F)
275 where
276 F: Fn(MqttEvent) + Send + Sync + 'static;
277}
278
279impl MqttClientExt for NativeMqttClient {
280 fn set_event_callback<F>(&self, callback: F)
281 where
282 F: Fn(MqttEvent) + Send + Sync + 'static,
283 {
284 let _ = callback;
288 }
289}