use crate::errors::{KodeBridgeError, Result};
use bytes::Bytes;
#[cfg(unix)]
use interprocess::os::unix::local_socket::ListenerOptionsExt as _;
#[cfg(windows)]
use interprocess::os::windows::local_socket::ListenerOptionsExt as _;
#[cfg(windows)]
use interprocess::os::windows::security_descriptor::SecurityDescriptor;
use interprocess::{
local_socket::{
tokio::prelude::LocalSocketStream, traits::tokio::Listener as _, GenericFilePath, ListenerOptions, Name,
ToFsName as _,
},
TryClone as _,
};
use parking_lot::RwLock;
use serde_json::Value;
use std::{
collections::HashMap,
fmt,
future::Future,
path::Path,
pin::Pin,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
time::{Duration, Instant},
};
use tokio::{
io::AsyncWriteExt as _,
sync::{broadcast, Semaphore},
time::timeout,
};
use tokio_stream::{Stream, StreamExt as _};
use tracing::{debug, error, info, warn};
#[cfg(windows)]
use widestring::U16CString;
#[derive(Debug, Clone)]
pub struct StreamServerConfig {
pub max_connections: usize,
pub buffer_size: usize,
pub write_timeout: Duration,
pub max_message_size: usize,
pub enable_logging: bool,
pub shutdown_timeout: Duration,
pub broadcast_capacity: usize,
pub keepalive_interval: Duration,
}
impl Default for StreamServerConfig {
fn default() -> Self {
Self {
max_connections: 64,
buffer_size: 65536,
write_timeout: Duration::from_secs(5),
max_message_size: 1024 * 1024, enable_logging: true,
shutdown_timeout: Duration::from_secs(5),
broadcast_capacity: 1000,
keepalive_interval: Duration::from_secs(30),
}
}
}
#[derive(Debug, Clone)]
pub enum StreamMessage {
Json(Value),
Text(String),
Binary(Bytes),
Ping,
Close,
}
impl StreamMessage {
pub fn to_bytes(&self) -> Bytes {
match self {
Self::Json(value) => {
match serde_json::to_vec(value) {
Ok(bytes) => {
let mut output = Vec::with_capacity(bytes.len() + 1);
output.extend_from_slice(&bytes);
output.push(b'\n'); Bytes::from(output)
}
Err(_) => Bytes::from("{}\n"),
}
}
Self::Text(text) => {
let mut output = text.clone();
if !output.ends_with('\n') {
output.push('\n');
}
Bytes::from(output)
}
Self::Binary(bytes) => bytes.clone(),
Self::Ping => Bytes::from("PING\n"),
Self::Close => Bytes::from("CLOSE\n"),
}
}
pub fn json<T: serde::Serialize>(value: &T) -> Result<Self> {
let json_value = serde_json::to_value(value)
.map_err(|e| KodeBridgeError::json_serialize(format!("Failed to serialize: {}", e)))?;
Ok(Self::Json(json_value))
}
pub fn text<T: Into<String>>(text: T) -> Self {
Self::Text(text.into())
}
pub fn binary<T: Into<Bytes>>(data: T) -> Self {
Self::Binary(data.into())
}
}
#[derive(Debug, Clone)]
pub struct StreamClient {
pub client_id: u64,
pub connected_at: Instant,
pub last_activity: Instant,
pub endpoint: String,
pub messages_sent: u64,
pub error_count: u64,
}
impl StreamClient {
fn new(client_id: u64, endpoint: String) -> Self {
let now = Instant::now();
Self {
client_id,
connected_at: now,
last_activity: now,
endpoint,
messages_sent: 0,
error_count: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct StreamServerStats {
pub total_connections: u64,
pub active_connections: u64,
pub total_messages: u64,
pub total_errors: u64,
pub messages_per_second: f64,
pub started_at: Instant,
pub last_update: Instant,
}
impl StreamServerStats {
fn new() -> Self {
let now = Instant::now();
Self {
total_connections: 0,
active_connections: 0,
total_messages: 0,
total_errors: 0,
messages_per_second: 0.0,
started_at: now,
last_update: now,
}
}
fn update_message_rate(&mut self, message_count: u64) {
let now = Instant::now();
let duration = now.duration_since(self.last_update).as_secs_f64();
if duration > 0.0 {
self.messages_per_second = message_count as f64 / duration;
}
self.total_messages += message_count;
self.last_update = now;
}
}
impl fmt::Display for StreamServerStats {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let uptime = self.started_at.elapsed();
write!(
f,
"Stream Server Stats: {} connections, {} active, {} messages ({:.1}/s), {} errors, uptime: {:?}",
self.total_connections,
self.active_connections,
self.total_messages,
self.messages_per_second,
self.total_errors,
uptime
)
}
}
pub trait StreamSource: Send + Sync {
fn next_messages(&mut self) -> Pin<Box<dyn Future<Output = Result<Vec<StreamMessage>>> + Send + '_>>;
fn has_more(&self) -> bool;
fn initialize(&mut self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>>;
fn cleanup(&mut self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>>;
}
pub struct JsonDataSource {
generator: Box<dyn Fn() -> Result<Value> + Send + Sync>,
interval: Duration,
last_generated: Instant,
}
impl JsonDataSource {
pub fn new<F>(generator: F, interval: Duration) -> Self
where
F: Fn() -> Result<Value> + Send + Sync + 'static,
{
Self {
generator: Box::new(generator),
interval,
last_generated: Instant::now(),
}
}
}
impl StreamSource for JsonDataSource {
fn next_messages(&mut self) -> Pin<Box<dyn Future<Output = Result<Vec<StreamMessage>>> + Send + '_>> {
Box::pin(async move {
let now = Instant::now();
if now.duration_since(self.last_generated) >= self.interval {
self.last_generated = now;
match (self.generator)() {
Ok(value) => Ok(vec![StreamMessage::Json(value)]),
Err(e) => Err(e),
}
} else {
Ok(vec![])
}
})
}
fn has_more(&self) -> bool {
true }
fn initialize(&mut self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
Box::pin(async move { Ok(()) })
}
fn cleanup(&mut self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
Box::pin(async move { Ok(()) })
}
}
pub struct IteratorSource<S> {
stream: S,
}
impl<S> IteratorSource<S>
where
S: Stream<Item = StreamMessage> + Send + Unpin,
{
pub const fn new(stream: S) -> Self {
Self { stream }
}
}
impl<S> StreamSource for IteratorSource<S>
where
S: Stream<Item = StreamMessage> + Send + Sync + Unpin,
{
fn next_messages(&mut self) -> Pin<Box<dyn Future<Output = Result<Vec<StreamMessage>>> + Send + '_>> {
Box::pin(async move {
match self.stream.next().await {
Some(message) => Ok(vec![message]),
None => Ok(vec![]),
}
})
}
fn has_more(&self) -> bool {
true }
fn initialize(&mut self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
Box::pin(async move { Ok(()) })
}
fn cleanup(&mut self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
Box::pin(async move { Ok(()) })
}
}
pub struct IpcStreamServer {
name: Name<'static>,
config: StreamServerConfig,
listener_options: ListenerOptions<'static>,
stats: Arc<RwLock<StreamServerStats>>,
connection_semaphore: Arc<Semaphore>,
clients: Arc<RwLock<HashMap<u64, StreamClient>>>,
client_id_counter: Arc<AtomicU64>,
broadcast_tx: Option<broadcast::Sender<StreamMessage>>,
shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
}
impl IpcStreamServer {
pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
let name = path
.as_ref()
.to_fs_name::<GenericFilePath>()
.map_err(|e| KodeBridgeError::configuration(format!("Invalid server path: {}", e)))?
.into_owned();
let config = StreamServerConfig::default();
let connection_semaphore = Arc::new(Semaphore::new(config.max_connections));
let listener_options = ListenerOptions::new();
Ok(Self {
name,
config,
listener_options,
stats: Arc::new(RwLock::new(StreamServerStats::new())),
connection_semaphore,
clients: Arc::new(RwLock::new(HashMap::new())),
client_id_counter: Arc::new(AtomicU64::new(1)),
broadcast_tx: None,
shutdown_tx: None,
})
}
pub fn with_config<P: AsRef<Path>>(path: P, config: StreamServerConfig) -> Result<Self> {
let name = path
.as_ref()
.to_fs_name::<GenericFilePath>()
.map_err(|e| KodeBridgeError::configuration(format!("Invalid server path: {}", e)))?
.into_owned();
let connection_semaphore = Arc::new(Semaphore::new(config.max_connections));
let listener_options = ListenerOptions::new();
Ok(Self {
name,
config,
listener_options,
stats: Arc::new(RwLock::new(StreamServerStats::new())),
connection_semaphore,
clients: Arc::new(RwLock::new(HashMap::new())),
client_id_counter: Arc::new(AtomicU64::new(1)),
broadcast_tx: None,
shutdown_tx: None,
})
}
pub fn with_listener_options(mut self, options: ListenerOptions<'static>) -> Self {
self.listener_options = options;
self
}
#[cfg(unix)]
pub fn with_listener_mode(mut self, mode: libc::mode_t) -> Self {
self.listener_options = self.listener_options.mode(mode);
self
}
#[cfg(windows)]
pub fn with_listener_security_descriptor(mut self, sddl: &str) -> Self {
let sddl = U16CString::from_str(sddl).expect("Invalid SDDL string");
let sd = SecurityDescriptor::deserialize(&sddl).expect("Failed to parse SDDL");
self.listener_options = self.listener_options.security_descriptor(sd);
self
}
pub fn stats(&self) -> StreamServerStats {
self.stats.read().clone()
}
pub fn clients(&self) -> Vec<StreamClient> {
self.clients.read().values().cloned().collect()
}
pub fn broadcast(&self, message: StreamMessage) -> Result<usize> {
if let Some(ref tx) = self.broadcast_tx {
match tx.send(message) {
Ok(_) => Ok(tx.receiver_count()),
Err(_) => Err(KodeBridgeError::connection("No active receivers")),
}
} else {
Err(KodeBridgeError::connection("Server not started"))
}
}
pub async fn serve_with_source<S>(&mut self, mut source: S) -> Result<()>
where
S: StreamSource + 'static,
{
let (broadcast_tx, _) = broadcast::channel(self.config.broadcast_capacity);
self.broadcast_tx = Some(broadcast_tx.clone());
let listener_options = self.listener_options.try_clone()?;
let listener = listener_options
.name(self.name.clone())
.create_tokio()
.map_err(|e| KodeBridgeError::connection(format!("Failed to bind server: {}", e)))?;
info!("🌊 Stream IPC Server listening on {:?}", self.name);
source.initialize().await?;
let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();
self.shutdown_tx = Some(shutdown_tx);
let source_stats = Arc::clone(&self.stats);
let source_broadcast_tx = broadcast_tx.clone();
let (source_shutdown_tx, source_shutdown_rx) = tokio::sync::oneshot::channel();
let source_shutdown = source_shutdown_rx;
let source_task = tokio::spawn(async move {
let mut source = source;
let mut shutdown_rx = source_shutdown;
loop {
tokio::select! {
result = source.next_messages() => {
match result {
Ok(messages) => {
let message_count = messages.len() as u64;
for message in messages {
if source_broadcast_tx.send(message).is_err() {
debug!("No receivers for broadcast message");
}
}
if message_count > 0 {
source_stats.write().update_message_rate(message_count);
}
}
Err(e) => {
error!("Data source error: {}", e);
source_stats.write().total_errors += 1;
}
}
if !source.has_more() {
info!("Data source exhausted");
break;
}
}
_ = &mut shutdown_rx => {
info!("Data source shutdown requested");
break;
}
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
if let Err(e) = source.cleanup().await {
error!("Error cleaning up data source: {}", e);
}
});
loop {
let permit = tokio::select! {
permit_result = Arc::clone(&self.connection_semaphore).acquire_owned() => {
match permit_result {
Ok(permit) => permit,
Err(_) => {
warn!("Connection limiter closed, stopping stream server");
break;
}
}
}
_ = &mut shutdown_rx => {
info!("Stream server shutdown requested");
break;
}
};
tokio::select! {
accept_result = listener.accept() => {
match accept_result {
Ok(stream) => {
let client_id = self.client_id_counter.fetch_add(1, Ordering::SeqCst);
{
let mut stats = self.stats.write();
stats.total_connections += 1;
stats.active_connections += 1;
}
let client = StreamClient::new(client_id, format!("client_{}", client_id));
self.clients.write().insert(client_id, client);
let config = self.config.clone();
let stats = Arc::clone(&self.stats);
let clients = Arc::clone(&self.clients);
let broadcast_rx = broadcast_tx.subscribe();
tokio::spawn(async move {
if let Err(e) = Self::handle_stream_client(
stream,
client_id,
broadcast_rx,
config,
Arc::clone(&stats),
Arc::clone(&clients),
).await {
error!("Stream client {} error: {}", client_id, e);
stats.write().total_errors += 1;
}
clients.write().remove(&client_id);
{
let mut stats = stats.write();
stats.active_connections = stats.active_connections.saturating_sub(1);
}
drop(permit); });
}
Err(e) => {
drop(permit);
error!("Failed to accept connection: {}", e);
}
}
}
_ = &mut shutdown_rx => {
drop(permit);
info!("Stream server shutdown requested");
break;
}
}
}
let _ = source_shutdown_tx.send(());
source_task.abort();
let _ = broadcast_tx.send(StreamMessage::Close);
let start = Instant::now();
while self.stats.read().active_connections > 0 && start.elapsed() < self.config.shutdown_timeout {
tokio::time::sleep(Duration::from_millis(100)).await;
}
let remaining = self.stats.read().active_connections;
if remaining > 0 {
warn!("Shutting down with {} active connections", remaining);
}
info!("Stream IPC Server stopped");
Ok(())
}
pub async fn serve(&mut self) -> Result<()> {
let dummy_source = JsonDataSource::new(
|| Ok(serde_json::json!({})),
Duration::from_secs(3600), );
self.serve_with_source(dummy_source).await
}
pub fn shutdown(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
}
#[allow(clippy::cognitive_complexity)]
async fn handle_stream_client(
mut stream: LocalSocketStream,
client_id: u64,
mut broadcast_rx: broadcast::Receiver<StreamMessage>,
config: StreamServerConfig,
stats: Arc<RwLock<StreamServerStats>>,
clients: Arc<RwLock<HashMap<u64, StreamClient>>>,
) -> Result<()> {
debug!("Handling stream client {}", client_id);
let mut last_keepalive = Instant::now();
loop {
tokio::select! {
msg_result = broadcast_rx.recv() => {
match msg_result {
Ok(message) => {
match message {
StreamMessage::Close => {
debug!("Received close message for client {}", client_id);
break;
}
_ => {
let data = message.to_bytes();
if data.len() > config.max_message_size {
warn!("Message too large for client {}, skipping", client_id);
continue;
}
match timeout(config.write_timeout, stream.write_all(&data)).await {
Ok(Ok(())) => {
if stream.flush().await.is_ok() {
if let Some(client) = clients.write().get_mut(&client_id) {
client.messages_sent += 1;
client.last_activity = Instant::now();
}
}
}
Ok(Err(e)) => {
error!("Failed to write to client {}: {}", client_id, e);
if let Some(client) = clients.write().get_mut(&client_id) {
client.error_count += 1;
}
stats.write().total_errors += 1;
break;
}
Err(_) => {
warn!("Write timeout for client {}", client_id);
if let Some(client) = clients.write().get_mut(&client_id) {
client.error_count += 1;
}
stats.write().total_errors += 1;
break;
}
}
}
}
}
Err(broadcast::error::RecvError::Lagged(skipped)) => {
warn!("Client {} lagged behind, skipped {} messages", client_id, skipped);
if let Some(client) = clients.write().get_mut(&client_id) {
client.error_count += skipped;
}
}
Err(broadcast::error::RecvError::Closed) => {
debug!("Broadcast channel closed for client {}", client_id);
break;
}
}
}
_ = tokio::time::sleep(config.keepalive_interval) => {
let now = Instant::now();
if now.duration_since(last_keepalive) >= config.keepalive_interval {
last_keepalive = now;
let ping_data = StreamMessage::Ping.to_bytes();
if let Err(e) = timeout(config.write_timeout, stream.write_all(&ping_data)).await {
warn!("Failed to send keepalive to client {}: {:?}", client_id, e);
break;
}
if let Err(e) = stream.flush().await {
warn!("Failed to flush keepalive to client {}: {}", client_id, e);
break;
}
}
}
}
}
debug!("Stream client {} finished", client_id);
Ok(())
}
}
impl fmt::Debug for IpcStreamServer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("IpcStreamServer")
.field("name", &self.name)
.field("config", &self.config)
.field("stats", &self.stats)
.finish()
}
}