use std::time::{Duration, Instant};
use std::{fs, path::Path};
use regex::Regex;
use rumqttc::{Client, Event, Incoming, MqttOptions, Outgoing, Packet, QoS, Transport};
use serde_json::{Map, Value};
use uuid::Uuid;
use crate::errors::{AcpError, AcpResult};
use crate::json_support;
use crate::transport_auth::{
AuthConfig, auth_config_from_value, auth_parameter, ensure_allowed_auth_types,
normalize_auth_config, serialize_auth_config,
};
pub const DEFAULT_MQTT_QOS: u8 = 1;
pub const DEFAULT_MQTT_TOPIC_PREFIX: &str = "acp/agent";
pub type MqttMessageHandler = dyn FnMut(&Map<String, Value>) -> bool + Send;
#[derive(Debug, Clone)]
pub struct MqttTransportClient {
pub broker_url: String,
pub qos: u8,
pub topic_prefix: String,
pub timeout_seconds: u64,
pub keepalive_seconds: u16,
pub auth: Option<AuthConfig>,
}
impl MqttTransportClient {
pub fn new(
broker_url: impl Into<String>,
qos: Option<u8>,
topic_prefix: Option<String>,
timeout_seconds: u64,
keepalive_seconds: u16,
) -> AcpResult<Self> {
Self::new_with_auth(
broker_url,
qos,
topic_prefix,
timeout_seconds,
keepalive_seconds,
None,
)
}
pub fn new_with_auth(
broker_url: impl Into<String>,
qos: Option<u8>,
topic_prefix: Option<String>,
timeout_seconds: u64,
keepalive_seconds: u16,
auth: Option<AuthConfig>,
) -> AcpResult<Self> {
let broker_url = broker_url.into();
if broker_url.trim().is_empty() {
return Err(AcpError::InvalidArgument(
"broker_url must be provided".to_string(),
));
}
let auth = normalize_auth_config(auth)?;
Ok(Self {
broker_url,
qos: coerce_qos(qos.unwrap_or(DEFAULT_MQTT_QOS)),
topic_prefix: topic_prefix
.as_deref()
.map(str::trim)
.filter(|v| !v.is_empty())
.unwrap_or(DEFAULT_MQTT_TOPIC_PREFIX)
.to_string(),
timeout_seconds: timeout_seconds.max(1),
keepalive_seconds: keepalive_seconds.max(5),
auth,
})
}
pub fn agent_identifier_token(agent_id: &str) -> AcpResult<String> {
let regex = Regex::new(r"^agent:(?P<name>[^@]+)(?:@(?P<domain>.+))?$")
.map_err(|e| AcpError::Validation(format!("invalid agent id regex: {e}")))?;
let captures = regex
.captures(agent_id)
.ok_or_else(|| AcpError::Validation(format!("Invalid agent identifier: {agent_id}")))?;
let name = captures
.name("name")
.map(|m| m.as_str().to_string())
.ok_or_else(|| AcpError::Validation("agent id missing name".to_string()))?;
let domain = captures.name("domain").map(|m| m.as_str().to_string());
let base = if let Some(domain) = domain {
format!("{name}.{domain}")
} else {
name
};
let normalized = base
.chars()
.map(|ch| {
if ch.is_ascii_alphanumeric() || matches!(ch, '.' | '_' | '-') {
ch
} else {
'.'
}
})
.collect::<String>()
.split('.')
.filter(|part| !part.is_empty())
.collect::<Vec<_>>()
.join(".")
.to_lowercase();
Ok(if normalized.is_empty() {
"unknown".to_string()
} else {
normalized
})
}
pub fn topic_for_agent(agent_id: &str, topic_prefix: Option<&str>) -> AcpResult<String> {
let prefix = topic_prefix
.unwrap_or(DEFAULT_MQTT_TOPIC_PREFIX)
.trim_end_matches('/');
Ok(format!(
"{prefix}/{}",
Self::agent_identifier_token(agent_id)?
))
}
pub fn build_service_hint(
agent_id: &str,
broker_url: &str,
topic: Option<&str>,
qos: Option<u8>,
topic_prefix: Option<&str>,
) -> AcpResult<Map<String, Value>> {
Self::build_service_hint_with_auth(agent_id, broker_url, topic, qos, topic_prefix, None)
}
pub fn build_service_hint_with_auth(
agent_id: &str,
broker_url: &str,
topic: Option<&str>,
qos: Option<u8>,
topic_prefix: Option<&str>,
auth: Option<AuthConfig>,
) -> AcpResult<Map<String, Value>> {
let mut hint = Map::new();
hint.insert(
"broker_url".to_string(),
Value::String(broker_url.to_string()),
);
hint.insert(
"topic".to_string(),
Value::String(
topic
.map(str::to_string)
.filter(|v| !v.trim().is_empty())
.unwrap_or(Self::topic_for_agent(agent_id, topic_prefix)?),
),
);
hint.insert(
"qos".to_string(),
Value::Number(coerce_qos(qos.unwrap_or(DEFAULT_MQTT_QOS)).into()),
);
if let Some(auth) = serialize_auth_config(normalize_auth_config(auth)?.as_ref()) {
hint.insert("auth".to_string(), auth);
}
Ok(hint)
}
pub fn publish(
&self,
message: &Map<String, Value>,
recipient_agent_id: &str,
mqtt_service: Option<&Map<String, Value>>,
) -> AcpResult<()> {
let broker_url = pick_string(mqtt_service, "broker_url", &self.broker_url);
let auth = normalize_auth_config(
auth_config_from_value(mqtt_service.and_then(|map| map.get("auth")))?
.or_else(|| self.auth.clone()),
)?;
ensure_allowed_auth_types(
auth.as_ref(),
&["none", "username_password", "mtls", "custom"],
"MQTT transport",
)?;
let topic = pick_string(
mqtt_service,
"topic",
&Self::topic_for_agent(recipient_agent_id, Some(&self.topic_prefix))?,
);
let qos = mqtt_service
.and_then(|m| m.get("qos"))
.and_then(value_as_u8)
.map(coerce_qos)
.unwrap_or(self.qos);
let payload = json_support::canonical_json_string(&Value::Object(message.clone()))?;
let (client, mut connection) = self.open_client(&broker_url, auth.as_ref())?;
client
.publish(topic, to_qos(qos), false, payload)
.map_err(|e| AcpError::Transport(format!("mqtt publish failed: {e}")))?;
let deadline = Instant::now() + Duration::from_secs(self.timeout_seconds.max(1));
while Instant::now() < deadline {
match connection.iter().next() {
Some(Ok(Event::Incoming(Incoming::PubAck(_)))) => break,
Some(Ok(Event::Outgoing(Outgoing::Publish(_)))) => break,
Some(Ok(_)) => continue,
Some(Err(e)) => {
return Err(AcpError::Transport(format!("mqtt event loop error: {e}")));
}
None => break,
}
}
Ok(())
}
pub fn consume<F>(
&self,
agent_id: &str,
mut handler: F,
mqtt_service: Option<&Map<String, Value>>,
max_messages: usize,
poll_timeout: Duration,
) -> AcpResult<usize>
where
F: FnMut(&Map<String, Value>) -> bool + Send,
{
let broker_url = pick_string(mqtt_service, "broker_url", &self.broker_url);
let auth = normalize_auth_config(
auth_config_from_value(mqtt_service.and_then(|map| map.get("auth")))?
.or_else(|| self.auth.clone()),
)?;
ensure_allowed_auth_types(
auth.as_ref(),
&["none", "username_password", "mtls", "custom"],
"MQTT transport",
)?;
let topic = pick_string(
mqtt_service,
"topic",
&Self::topic_for_agent(agent_id, Some(&self.topic_prefix))?,
);
let qos = mqtt_service
.and_then(|m| m.get("qos"))
.and_then(value_as_u8)
.map(coerce_qos)
.unwrap_or(self.qos);
let limit = if max_messages == 0 {
usize::MAX
} else {
max_messages
};
let (client, mut connection) = self.open_client(&broker_url, auth.as_ref())?;
client
.subscribe(topic, to_qos(qos))
.map_err(|e| AcpError::Transport(format!("mqtt subscribe failed: {e}")))?;
let mut processed = 0usize;
let start = Instant::now();
while processed < limit {
if start.elapsed() >= poll_timeout {
break;
}
match connection.iter().next() {
Some(Ok(Event::Incoming(Packet::Publish(publish)))) => {
let should_ack = std::str::from_utf8(&publish.payload)
.ok()
.and_then(|body| json_support::map_from_json(body).ok())
.map(|message| handler(&message))
.unwrap_or(false);
let _ = should_ack;
processed += 1;
}
Some(Ok(_)) => continue,
Some(Err(e)) => {
return Err(AcpError::Transport(format!("mqtt consume failed: {e}")));
}
None => break,
}
}
Ok(processed)
}
fn open_client(
&self,
broker_url: &str,
auth: Option<&AuthConfig>,
) -> AcpResult<(Client, rumqttc::Connection)> {
let parsed = url::Url::parse(broker_url)?;
let host = parsed.host_str().ok_or_else(|| {
AcpError::Validation(format!("Invalid MQTT broker_url: {broker_url}"))
})?;
let scheme = parsed.scheme().to_lowercase();
let port = parsed.port().unwrap_or(match scheme.as_str() {
"mqtts" | "ssl" | "wss" => 8883,
_ => 1883,
});
let mut options = MqttOptions::new(format!("acp-{}", Uuid::new_v4().simple()), host, port);
options.set_keep_alive(Duration::from_secs(u64::from(self.keepalive_seconds)));
options.set_request_channel_capacity(16);
if !parsed.username().is_empty() {
options.set_credentials(parsed.username(), parsed.password().unwrap_or_default());
}
if let Some(auth) = auth {
if auth.auth_type == "username_password" {
let username = auth_parameter(auth, "username", "MQTT username_password auth")?;
let password = auth_parameter(auth, "password", "MQTT username_password auth")?;
options.set_credentials(username, password);
} else if auth.auth_type == "custom" {
if let Some(username) = auth
.parameters
.get("username")
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
{
let password = auth
.parameters
.get("password")
.map(|value| value.trim().to_string())
.unwrap_or_default();
options.set_credentials(username, password);
}
}
}
if let Some(transport) = mqtt_tls_transport(auth)? {
options.set_transport(transport);
} else if matches!(scheme.as_str(), "mqtts" | "ssl" | "wss") {
options.set_transport(Transport::tls_with_default_config());
}
Ok(Client::new(options, 10))
}
}
fn mqtt_tls_transport(auth: Option<&AuthConfig>) -> AcpResult<Option<Transport>> {
let Some(auth) = auth else {
return Ok(None);
};
if auth.auth_type != "mtls" && auth.auth_type != "custom" {
return Ok(None);
}
let cert_path = auth
.parameters
.get("cert_path")
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty());
let key_path = auth
.parameters
.get("key_path")
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty());
if auth.auth_type == "mtls" && (cert_path.is_none() || key_path.is_none()) {
return Err(AcpError::Validation(
"MQTT mTLS auth requires auth.parameters.cert_path and auth.parameters.key_path"
.to_string(),
));
}
let ca = auth
.parameters
.get("ca_path")
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
.map(|path| fs::read(Path::new(&path)))
.transpose()
.map_err(|e| AcpError::Validation(format!("unable to read auth.parameters.ca_path: {e}")))?
.unwrap_or_default();
let client_auth = match (cert_path, key_path) {
(Some(cert_path), Some(key_path)) => {
let cert = fs::read(Path::new(&cert_path)).map_err(|e| {
AcpError::Validation(format!("unable to read auth.parameters.cert_path: {e}"))
})?;
let key = fs::read(Path::new(&key_path)).map_err(|e| {
AcpError::Validation(format!("unable to read auth.parameters.key_path: {e}"))
})?;
Some((cert, key))
}
_ => None,
};
if ca.is_empty() && client_auth.is_none() {
return Ok(None);
}
Ok(Some(Transport::tls(ca, client_auth, None)))
}
fn to_qos(value: u8) -> QoS {
match value {
0 => QoS::AtMostOnce,
2 => QoS::ExactlyOnce,
_ => QoS::AtLeastOnce,
}
}
fn pick_string(service: Option<&Map<String, Value>>, key: &str, fallback: &str) -> String {
service
.and_then(|map| map.get(key))
.and_then(Value::as_str)
.map(str::trim)
.filter(|v| !v.is_empty())
.unwrap_or(fallback)
.to_string()
}
fn coerce_qos(value: u8) -> u8 {
if value > 2 { DEFAULT_MQTT_QOS } else { value }
}
fn value_as_u8(value: &Value) -> Option<u8> {
if let Some(number) = value.as_u64() {
return Some(number.min(2) as u8);
}
if let Some(text) = value.as_str() {
return text.parse::<u8>().ok().map(|v| v.min(2));
}
None
}