use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::{broadcast, mpsc, Mutex};
use super::command_message::{CommandMessage, MessageType, next_transaction_id};
use super::error::IpcError;
use super::handler::{ModuleHandler, ModuleHandlerExt};
use super::message::ModuleRegistration;
use super::transport::SplitTransport;
#[derive(Clone)]
pub struct BroadcastSender {
tx: mpsc::Sender<CommandMessage>,
domain: String,
}
impl BroadcastSender {
pub async fn send(&self, topic: &str, value: serde_json::Value) -> Result<(), IpcError> {
let full_topic = format!("{}.{}", self.domain, topic.to_ascii_lowercase());
let msg = CommandMessage::broadcast(&full_topic, value);
self.tx
.send(msg)
.await
.map_err(|e| IpcError::Channel(e.to_string()))
}
pub async fn send_raw(&self, fqdn_topic: &str, value: serde_json::Value) -> Result<(), IpcError> {
let msg = CommandMessage::broadcast(fqdn_topic, value);
self.tx
.send(msg)
.await
.map_err(|e| IpcError::Channel(e.to_string()))
}
pub async fn send_to(&self, topic: &str, value: serde_json::Value) -> Result<(), IpcError> {
self.send_raw(topic, value).await
}
}
#[derive(Debug, Clone)]
pub struct IpcClientConfig {
pub server_addr: String,
pub heartbeat_interval: Duration,
pub connect_timeout: Duration,
pub reconnect_delay: Duration,
pub max_reconnect_attempts: usize,
pub auto_reconnect: bool,
}
impl Default for IpcClientConfig {
fn default() -> Self {
Self {
server_addr: "127.0.0.1:9100".to_string(),
heartbeat_interval: Duration::from_secs(5),
connect_timeout: Duration::from_secs(10),
reconnect_delay: Duration::from_secs(2),
max_reconnect_attempts: 0, auto_reconnect: true,
}
}
}
impl IpcClientConfig {
pub fn new(server_addr: &str) -> Self {
Self {
server_addr: server_addr.to_string(),
..Default::default()
}
}
}
pub struct IpcClient<H: ModuleHandler> {
config: IpcClientConfig,
handler: Arc<Mutex<H>>,
transport: Option<SplitTransport>,
running: Arc<AtomicBool>,
outbound_tx: mpsc::Sender<CommandMessage>,
outbound_rx: Arc<Mutex<mpsc::Receiver<CommandMessage>>>,
broadcast_tx: broadcast::Sender<CommandMessage>,
pub shm_context: Arc<Mutex<Option<Arc<crate::shm::ShmContext>>>>,
}
impl<H: ModuleHandler + 'static> IpcClient<H> {
pub fn new(config: IpcClientConfig, handler: H) -> Self {
let (outbound_tx, outbound_rx) = mpsc::channel(1024);
let (broadcast_tx, _) = broadcast::channel(256);
Self {
config,
handler: Arc::new(Mutex::new(handler)),
transport: None,
running: Arc::new(AtomicBool::new(false)),
outbound_tx,
outbound_rx: Arc::new(Mutex::new(outbound_rx)),
broadcast_tx,
shm_context: Arc::new(Mutex::new(None)),
}
}
pub async fn connect(addr: &str, handler: H) -> Result<Self, IpcError> {
let config = IpcClientConfig::new(addr);
let mut client = Self::new(config, handler);
client.establish_connection().await?;
Ok(client)
}
pub async fn connect_with_config(config: IpcClientConfig, handler: H) -> Result<Self, IpcError> {
let mut client = Self::new(config, handler);
client.establish_connection().await?;
Ok(client)
}
async fn establish_connection(&mut self) -> Result<(), IpcError> {
log::info!("Connecting to server at {}", self.config.server_addr);
let stream = tokio::time::timeout(
self.config.connect_timeout,
TcpStream::connect(&self.config.server_addr),
)
.await
.map_err(|_| IpcError::Timeout("Connection timeout".to_string()))?
.map_err(|e| IpcError::Connection(e.to_string()))?;
stream.set_nodelay(true).ok();
let transport = SplitTransport::new(stream);
self.transport = Some(transport);
self.register().await?;
log::info!("Connected and registered with server");
Ok(())
}
async fn register(&self) -> Result<(), IpcError> {
let transport = self.transport.as_ref()
.ok_or_else(|| IpcError::Connection("Not connected".to_string()))?;
let handler = self.handler.lock().await;
let registration = ModuleRegistration::new(handler.domain(), handler.version())
.with_capabilities(handler.capabilities());
drop(handler);
let msg = CommandMessage::control(
"register",
serde_json::to_value(®istration).unwrap_or_default(),
);
transport.send(&msg).await?;
let response = tokio::time::timeout(Duration::from_secs(5), transport.recv())
.await
.map_err(|_| IpcError::Timeout("Registration timeout".to_string()))??;
if response.message_type == MessageType::Response
&& response.topic.ends_with("registerack")
&& response.success
{
Ok(())
} else {
Err(IpcError::Connection(format!(
"Registration failed: type={:?}, topic={}, success={}, error={}",
response.message_type,
response.topic,
response.success,
response.error_message
)))
}
}
pub async fn run(mut self) -> Result<(), IpcError> {
self.running.store(true, Ordering::SeqCst);
loop {
if let Err(e) = self.run_message_loop().await {
log::error!("Message loop error: {}", e);
if !self.config.auto_reconnect {
return Err(e);
}
if !self.reconnect().await {
return Err(IpcError::Connection("Failed to reconnect".to_string()));
}
}
if !self.running.load(Ordering::SeqCst) {
break;
}
}
Ok(())
}
async fn run_message_loop(&self) -> Result<(), IpcError> {
let transport = self.transport.as_ref()
.ok_or_else(|| IpcError::Connection("Not connected".to_string()))?
.clone();
let handler = Arc::clone(&self.handler);
let running = Arc::clone(&self.running);
let outbound_rx = Arc::clone(&self.outbound_rx);
let heartbeat_interval = self.config.heartbeat_interval;
let heartbeat_transport = transport.clone();
let heartbeat_running = Arc::clone(&running);
let heartbeat_handler = Arc::clone(&handler);
let heartbeat_task = tokio::spawn(async move {
let mut interval = tokio::time::interval(heartbeat_interval);
while heartbeat_running.load(Ordering::SeqCst) {
interval.tick().await;
if let Err(e) = heartbeat_transport.send(&CommandMessage::heartbeat()).await {
log::warn!("Failed to send heartbeat: {}", e);
break;
}
if let Ok(mut h) = heartbeat_handler.try_lock() {
let _ = h.on_heartbeat().await;
}
}
});
let outbound_transport = transport.clone();
let outbound_running = Arc::clone(&running);
let outbound_task = tokio::spawn(async move {
let mut rx = outbound_rx.lock().await;
while outbound_running.load(Ordering::SeqCst) {
match rx.recv().await {
Some(msg) => {
if let Err(e) = outbound_transport.send(&msg).await {
log::error!("Failed to send outbound message: {}", e);
break;
}
}
None => break,
}
}
});
while running.load(Ordering::SeqCst) {
match transport.recv().await {
Ok(msg) => {
if msg.subtopic() == "configure_shm" {
log::info!("Received SHM configuration command");
#[derive(serde::Deserialize)]
struct ConfigPayload {
shm_id: String,
mappings: Vec<crate::shm::ShmVariableConfig>
}
match serde_json::from_value::<ConfigPayload>(msg.data.clone()) {
Ok(payload) => {
match crate::shm::ShmContext::new(&payload.shm_id, payload.mappings) {
Ok(ctx) => {
log::info!("SHM Connected: ID '{}'", payload.shm_id);
let mut handler_guard = handler.lock().await;
let names = handler_guard.shm_variable_names();
if !names.is_empty() {
let shm_map = ctx.resolve(&names);
log::info!("SHM resolved {}/{} pointers for handler", shm_map.len(), names.len());
if let Err(e) = handler_guard.on_shm_configured(shm_map).await {
log::error!("on_shm_configured error: {}", e);
}
}
drop(handler_guard);
let mut shm_guard = self.shm_context.lock().await;
*shm_guard = Some(Arc::new(ctx));
},
Err(e) => log::error!("Failed to connect to SHM: {}", e),
}
},
Err(e) => log::error!("Invalid configure_shm payload: {}", e),
}
continue;
}
let result = {
let mut handler_guard = handler.lock().await;
handler_guard.process_message(msg).await
};
match result {
Ok((response_msg, should_shutdown)) => {
if response_msg.is_response() {
if let Err(e) = transport.send(&response_msg).await {
log::error!("Failed to send response: {}", e);
break;
}
}
if should_shutdown {
log::info!("Handler requested shutdown, exiting message loop");
running.store(false, Ordering::SeqCst);
break;
}
}
Err(e) => {
log::error!("Handler error: {}", e);
}
}
}
Err(e) => {
log::error!("Receive error: {}", e);
break;
}
}
}
heartbeat_task.abort();
outbound_task.abort();
Ok(())
}
async fn reconnect(&mut self) -> bool {
let mut attempts = 0;
while self.running.load(Ordering::SeqCst) {
attempts += 1;
if self.config.max_reconnect_attempts > 0
&& attempts > self.config.max_reconnect_attempts
{
log::error!("Max reconnection attempts reached");
return false;
}
log::info!(
"Reconnection attempt {} (delay: {:?})",
attempts,
self.config.reconnect_delay
);
tokio::time::sleep(self.config.reconnect_delay).await;
match self.establish_connection().await {
Ok(()) => {
log::info!("Reconnected successfully");
return true;
}
Err(e) => {
log::warn!("Reconnection failed: {}", e);
}
}
}
false
}
pub async fn broadcast(&self, topic: &str, value: serde_json::Value) -> Result<(), IpcError> {
let handler = self.handler.lock().await;
let domain = handler.domain().to_string();
drop(handler);
let full_topic = format!("{}.{}", domain, topic);
let msg = CommandMessage::broadcast(&full_topic, value);
self.outbound_tx
.send(msg)
.await
.map_err(|e| IpcError::Channel(e.to_string()))
}
pub async fn broadcast_sender(&self) -> BroadcastSender {
let handler = self.handler.lock().await;
let domain = handler.domain().to_string();
drop(handler);
BroadcastSender {
tx: self.outbound_tx.clone(),
domain,
}
}
pub async fn request(
&self,
subtopic: &str,
data: serde_json::Value,
timeout: Duration,
) -> Result<CommandMessage, IpcError> {
let transport = self.transport.as_ref()
.ok_or_else(|| IpcError::Connection("Not connected".to_string()))?;
let handler = self.handler.lock().await;
let domain = handler.domain().to_string();
drop(handler);
let full_topic = format!("{}.{}", domain, subtopic);
let transaction_id = next_transaction_id();
let msg = CommandMessage::request(&full_topic, data)
.with_transaction_id(transaction_id);
transport.send(&msg).await?;
let deadline = tokio::time::Instant::now() + timeout;
loop {
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
if remaining.is_zero() {
return Err(IpcError::Timeout(format!(
"Timeout waiting for response to {}",
subtopic
)));
}
match tokio::time::timeout(remaining, transport.recv()).await {
Ok(Ok(response)) => {
if response.transaction_id == transaction_id && response.is_response() {
return Ok(response);
}
}
Ok(Err(e)) => return Err(e),
Err(_) => {
return Err(IpcError::Timeout(format!(
"Timeout waiting for response to {}",
subtopic
)));
}
}
}
}
pub async fn request_to(
&self,
topic: &str,
data: serde_json::Value,
timeout: Duration,
) -> Result<CommandMessage, IpcError> {
let transport = self.transport.as_ref()
.ok_or_else(|| IpcError::Connection("Not connected".to_string()))?;
let transaction_id = next_transaction_id();
let msg = CommandMessage::request(topic, data)
.with_transaction_id(transaction_id);
transport.send(&msg).await?;
let deadline = tokio::time::Instant::now() + timeout;
loop {
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
if remaining.is_zero() {
return Err(IpcError::Timeout(format!(
"Timeout waiting for response to {}",
topic
)));
}
match tokio::time::timeout(remaining, transport.recv()).await {
Ok(Ok(response)) => {
if response.transaction_id == transaction_id && response.is_response() {
return Ok(response);
}
}
Ok(Err(e)) => return Err(e),
Err(_) => {
return Err(IpcError::Timeout(format!(
"Timeout waiting for response to {}",
topic
)));
}
}
}
}
pub async fn send_to(&self, topic: &str, value: serde_json::Value) -> Result<(), IpcError> {
let msg = CommandMessage::broadcast(topic, value);
self.outbound_tx
.send(msg)
.await
.map_err(|e| IpcError::Channel(e.to_string()))
}
pub fn subscribe_broadcasts(&self) -> broadcast::Receiver<CommandMessage> {
self.broadcast_tx.subscribe()
}
pub fn shutdown(&self) {
self.running.store(false, Ordering::SeqCst);
}
pub fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
pub fn get_shm_pointer(&self, name: &str) -> Option<*mut u8> {
if let Ok(shm_guard) = self.shm_context.try_lock() {
if let Some(ref ctx) = *shm_guard {
return unsafe { ctx.get_pointer(name) };
}
}
None
}
}
pub struct IpcClientBuilder<H: ModuleHandler> {
config: IpcClientConfig,
handler: H,
}
impl<H: ModuleHandler + 'static> IpcClientBuilder<H> {
pub fn new(handler: H) -> Self {
Self {
config: IpcClientConfig::default(),
handler,
}
}
pub fn server_addr(mut self, addr: &str) -> Self {
self.config.server_addr = addr.to_string();
self
}
pub fn heartbeat_interval(mut self, interval: Duration) -> Self {
self.config.heartbeat_interval = interval;
self
}
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
self.config.connect_timeout = timeout;
self
}
pub fn reconnect_delay(mut self, delay: Duration) -> Self {
self.config.reconnect_delay = delay;
self
}
pub fn max_reconnect_attempts(mut self, attempts: usize) -> Self {
self.config.max_reconnect_attempts = attempts;
self
}
pub fn auto_reconnect(mut self, enabled: bool) -> Self {
self.config.auto_reconnect = enabled;
self
}
pub async fn connect(self) -> Result<IpcClient<H>, IpcError> {
IpcClient::connect_with_config(self.config, self.handler).await
}
pub fn build(self) -> IpcClient<H> {
IpcClient::new(self.config, self.handler)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ipc::handler::BaseModuleHandler;
#[tokio::test]
async fn test_client_config() {
let config = IpcClientConfig::new("127.0.0.1:9200");
assert_eq!(config.server_addr, "127.0.0.1:9200");
assert_eq!(config.heartbeat_interval, Duration::from_secs(5));
}
#[tokio::test]
async fn test_transaction_id_generation() {
let id1 = next_transaction_id();
let id2 = next_transaction_id();
assert!(id2 > id1);
}
#[test]
fn test_client_builder() {
let handler = BaseModuleHandler::new("TEST");
let client = IpcClientBuilder::new(handler)
.server_addr("127.0.0.1:9300")
.heartbeat_interval(Duration::from_secs(10))
.auto_reconnect(false)
.build();
assert_eq!(client.config.server_addr, "127.0.0.1:9300");
assert_eq!(client.config.heartbeat_interval, Duration::from_secs(10));
assert!(!client.config.auto_reconnect);
}
#[test]
fn test_request_to_does_not_prefix_domain() {
let msg = CommandMessage::request("gm.get_layout", serde_json::json!({}));
assert_eq!(msg.topic, "gm.get_layout");
assert!(!msg.topic.starts_with("testmodule."));
let msg2 = CommandMessage::request("system.status", serde_json::json!({"key": "value"}));
assert_eq!(msg2.topic, "system.status");
let msg3 = CommandMessage::request("other.nested.topic", serde_json::json!(null));
assert_eq!(msg3.topic, "other.nested.topic");
}
}