use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{broadcast, mpsc, Mutex, RwLock};
use super::command_message::{CommandMessage, MessageType};
use super::error::IpcError;
use super::message::ModuleRegistration;
use super::transport::SplitTransport;
#[derive(Debug, Clone)]
pub struct IpcServerConfig {
pub bind_addr: String,
pub max_connections: usize,
pub heartbeat_timeout: Duration,
pub registration_timeout: Duration,
}
impl Default for IpcServerConfig {
fn default() -> Self {
Self {
bind_addr: "127.0.0.1:9100".to_string(),
max_connections: 64,
heartbeat_timeout: Duration::from_secs(30),
registration_timeout: Duration::from_secs(10),
}
}
}
impl IpcServerConfig {
pub fn new(bind_addr: &str) -> Self {
Self {
bind_addr: bind_addr.to_string(),
..Default::default()
}
}
}
#[derive(Debug)]
pub struct ModuleConnection {
pub id: usize,
pub registration: ModuleRegistration,
transport: SplitTransport,
active: Arc<AtomicBool>,
outbound_tx: mpsc::Sender<CommandMessage>,
last_heartbeat: Arc<Mutex<std::time::Instant>>,
}
impl ModuleConnection {
fn new(
id: usize,
registration: ModuleRegistration,
transport: SplitTransport,
) -> (Self, mpsc::Receiver<CommandMessage>) {
let (outbound_tx, outbound_rx) = mpsc::channel(256);
let conn = Self {
id,
registration,
transport,
active: Arc::new(AtomicBool::new(true)),
outbound_tx,
last_heartbeat: Arc::new(Mutex::new(std::time::Instant::now())),
};
(conn, outbound_rx)
}
pub fn domain(&self) -> &str {
&self.registration.name
}
pub fn is_active(&self) -> bool {
self.active.load(Ordering::SeqCst)
}
pub fn deactivate(&self) {
self.active.store(false, Ordering::SeqCst);
}
pub async fn send(&self, msg: CommandMessage) -> Result<(), IpcError> {
if !self.is_active() {
return Err(IpcError::Connection("Connection is not active".to_string()));
}
self.outbound_tx
.send(msg)
.await
.map_err(|e| IpcError::Channel(e.to_string()))
}
pub async fn recv(&self) -> Result<CommandMessage, IpcError> {
if !self.is_active() {
return Err(IpcError::Connection("Connection is not active".to_string()));
}
self.transport.recv().await
}
pub async fn update_heartbeat(&self) {
let mut last = self.last_heartbeat.lock().await;
*last = std::time::Instant::now();
}
pub async fn is_timed_out(&self, timeout: Duration) -> bool {
let last = self.last_heartbeat.lock().await;
last.elapsed() > timeout
}
pub fn clone_transport(&self) -> SplitTransport {
self.transport.clone()
}
pub fn outbound_sender(&self) -> mpsc::Sender<CommandMessage> {
self.outbound_tx.clone()
}
}
static CONNECTION_ID_COUNTER: AtomicUsize = AtomicUsize::new(1);
fn next_connection_id() -> usize {
CONNECTION_ID_COUNTER.fetch_add(1, Ordering::SeqCst)
}
pub struct IpcServer {
config: IpcServerConfig,
listener: TcpListener,
connections: Arc<RwLock<HashMap<usize, Arc<ModuleConnection>>>>,
domain_map: Arc<RwLock<HashMap<String, usize>>>,
running: Arc<AtomicBool>,
event_tx: broadcast::Sender<ServerEvent>,
}
#[derive(Debug, Clone)]
pub enum ServerEvent {
ModuleConnected {
connection_id: usize,
domain: String,
},
ModuleDisconnected {
connection_id: usize,
domain: String,
},
MessageReceived {
connection_id: usize,
domain: String,
message: CommandMessage,
},
}
impl IpcServer {
pub async fn bind(config: IpcServerConfig) -> Result<Self, IpcError> {
let listener = TcpListener::bind(&config.bind_addr)
.await
.map_err(|e| IpcError::Connection(format!("Failed to bind: {}", e)))?;
log::info!("IPC server listening on {}", config.bind_addr);
let (event_tx, _) = broadcast::channel(256);
Ok(Self {
config,
listener,
connections: Arc::new(RwLock::new(HashMap::new())),
domain_map: Arc::new(RwLock::new(HashMap::new())),
running: Arc::new(AtomicBool::new(true)),
event_tx,
})
}
pub async fn accept(&self) -> Result<Arc<ModuleConnection>, IpcError> {
loop {
let (stream, addr) = self.listener.accept().await?;
log::info!("New connection from {}", addr);
{
let connections = self.connections.read().await;
if connections.len() >= self.config.max_connections {
log::warn!("Connection limit reached, rejecting {}", addr);
continue;
}
}
stream.set_nodelay(true).ok();
match self.perform_registration(stream).await {
Ok(connection) => {
let connection = Arc::new(connection);
{
let mut connections = self.connections.write().await;
let mut domain_map = self.domain_map.write().await;
connections.insert(connection.id, Arc::clone(&connection));
domain_map.insert(connection.domain().to_string(), connection.id);
}
let _ = self.event_tx.send(ServerEvent::ModuleConnected {
connection_id: connection.id,
domain: connection.domain().to_string(),
});
log::info!(
"Module '{}' registered (connection {})",
connection.domain(),
connection.id
);
return Ok(connection);
}
Err(e) => {
log::warn!("Registration failed for {}: {}", addr, e);
continue;
}
}
}
}
async fn perform_registration(&self, stream: TcpStream) -> Result<ModuleConnection, IpcError> {
let transport = SplitTransport::new(stream);
let registration_msg = tokio::time::timeout(
self.config.registration_timeout,
transport.recv(),
)
.await
.map_err(|_| IpcError::Timeout("Registration timeout".to_string()))??;
if registration_msg.message_type != MessageType::Control
|| !registration_msg.topic.ends_with("register")
{
return Err(IpcError::InvalidMessage(
"Expected registration message".to_string(),
));
}
let registration: ModuleRegistration =
serde_json::from_value(registration_msg.data.clone())
.map_err(|e| IpcError::InvalidMessage(format!("Invalid registration: {}", e)))?;
{
let domain_map = self.domain_map.read().await;
if domain_map.contains_key(®istration.name) {
return Err(IpcError::Connection(format!(
"Module '{}' already registered",
registration.name
)));
}
}
let ack = registration_msg.into_response(serde_json::json!({"status": "registered"}))
.with_topic("control.registerack");
transport.send(&ack).await?;
let conn_id = next_connection_id();
let (connection, outbound_rx) = ModuleConnection::new(conn_id, registration, transport);
let transport_clone = connection.clone_transport();
let active = Arc::clone(&connection.active);
tokio::spawn(async move {
Self::run_outbound_sender(transport_clone, outbound_rx, active).await;
});
Ok(connection)
}
async fn run_outbound_sender(
transport: SplitTransport,
mut rx: mpsc::Receiver<CommandMessage>,
active: Arc<AtomicBool>,
) {
while active.load(Ordering::SeqCst) {
match rx.recv().await {
Some(msg) => {
if let Err(e) = transport.send(&msg).await {
log::error!("Failed to send to module: {}", e);
active.store(false, Ordering::SeqCst);
break;
}
}
None => break,
}
}
}
pub async fn get_connection(&self, id: usize) -> Option<Arc<ModuleConnection>> {
let connections = self.connections.read().await;
connections.get(&id).cloned()
}
pub async fn get_connection_by_domain(&self, domain: &str) -> Option<Arc<ModuleConnection>> {
let domain_map = self.domain_map.read().await;
if let Some(&id) = domain_map.get(domain) {
drop(domain_map);
self.get_connection(id).await
} else {
None
}
}
pub async fn send_to_domain(&self, domain: &str, msg: CommandMessage) -> Result<(), IpcError> {
let connection = self
.get_connection_by_domain(domain)
.await
.ok_or_else(|| IpcError::ModuleNotFound(domain.to_string()))?;
connection.send(msg).await
}
pub async fn broadcast(&self, msg: CommandMessage) -> Vec<Result<(), IpcError>> {
let connections = self.connections.read().await;
let mut results = Vec::new();
for connection in connections.values() {
results.push(connection.send(msg.clone()).await);
}
results
}
pub async fn remove_connection(&self, id: usize) -> Option<Arc<ModuleConnection>> {
let mut connections = self.connections.write().await;
let mut domain_map = self.domain_map.write().await;
if let Some(connection) = connections.remove(&id) {
domain_map.remove(connection.domain());
connection.deactivate();
let _ = self.event_tx.send(ServerEvent::ModuleDisconnected {
connection_id: id,
domain: connection.domain().to_string(),
});
log::info!(
"Module '{}' disconnected (connection {})",
connection.domain(),
id
);
Some(connection)
} else {
None
}
}
pub async fn connected_domains(&self) -> Vec<String> {
let domain_map = self.domain_map.read().await;
domain_map.keys().cloned().collect()
}
pub async fn connection_count(&self) -> usize {
let connections = self.connections.read().await;
connections.len()
}
pub fn subscribe_events(&self) -> broadcast::Receiver<ServerEvent> {
self.event_tx.subscribe()
}
pub fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
pub async fn shutdown(&self) {
self.running.store(false, Ordering::SeqCst);
let mut connections = self.connections.write().await;
for (_, connection) in connections.drain() {
connection.deactivate();
}
let mut domain_map = self.domain_map.write().await;
domain_map.clear();
}
}
pub async fn handle_module_connection<F, Fut>(
connection: Arc<ModuleConnection>,
heartbeat_timeout: Duration,
mut on_message: F,
) where
F: FnMut(CommandMessage) -> Fut + Send,
Fut: std::future::Future<Output = Option<CommandMessage>> + Send,
{
let transport = connection.clone_transport();
while connection.is_active() {
if connection.is_timed_out(heartbeat_timeout).await {
log::warn!(
"Module '{}' heartbeat timeout",
connection.domain()
);
connection.deactivate();
break;
}
match tokio::time::timeout(Duration::from_secs(1), transport.recv()).await {
Ok(Ok(msg)) => {
connection.update_heartbeat().await;
if msg.is_heartbeat() {
continue;
}
if let Some(response) = on_message(msg).await {
if let Err(e) = connection.send(response).await {
log::error!("Failed to send response: {}", e);
connection.deactivate();
break;
}
}
}
Ok(Err(e)) => {
log::error!(
"Error receiving from module '{}': {}",
connection.domain(),
e
);
connection.deactivate();
break;
}
Err(_) => {
continue;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_server_config() {
let config = IpcServerConfig::new("127.0.0.1:9200");
assert_eq!(config.bind_addr, "127.0.0.1:9200");
assert_eq!(config.max_connections, 64);
}
#[tokio::test]
async fn test_connection_id_generation() {
let id1 = next_connection_id();
let id2 = next_connection_id();
assert!(id2 > id1);
}
}