use anyhow::{anyhow, Result};
use bytes::Bytes;
use quinn::{Endpoint, VarInt};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::vec;
use tokio::sync::RwLock;
use tracing::{debug, error, info, trace};
use super::config::{Config, ServerConfig, TlsConfig};
use super::connection::ClientConnection;
use crate::common::auth::{AuthConfig, AuthType};
use crate::common::message::{Message, MessageType};
use crate::common::stream::{RecvStream, SendStream};
use crate::common::tls::create_server_config;
#[derive(Clone)]
pub struct RsubServer {
endpoint: Endpoint,
connections: Arc<RwLock<HashMap<String, Arc<ClientConnection>>>>,
send_streams: Arc<RwLock<HashMap<String, Arc<SendStream>>>>,
recv_streams: Arc<RwLock<HashMap<String, Arc<RecvStream>>>>,
topics: Arc<RwLock<HashMap<String, Vec<Arc<SendStream>>>>>,
config: Arc<Config>,
}
impl RsubServer {
pub async fn start(config: Arc<Config>) -> Result<()> {
let config = config.clone();
let server_config = &config.server;
let tls_config = &config.tls;
let addr = SocketAddr::new(server_config.host.parse()?, server_config.port);
let (endpoint, _server_cert) = Self::create_endpoint(&server_config, &tls_config)?;
info!("Server listening on {}", addr);
trace!("Server configuration: {:?}", server_config);
let server = Self::new(endpoint, config.clone());
server.accept_connections().await?;
Ok(())
}
fn new(endpoint: Endpoint, config: Arc<Config>) -> Self {
trace!("Creating new RsubServer instance");
RsubServer {
endpoint,
connections: Arc::new(RwLock::new(HashMap::new())),
send_streams: Arc::new(RwLock::new(HashMap::new())),
recv_streams: Arc::new(RwLock::new(HashMap::new())),
topics: Arc::new(RwLock::new(HashMap::new())),
config,
}
}
fn create_endpoint(
server_config: &ServerConfig,
tls_config: &TlsConfig,
) -> Result<(Endpoint, Vec<u8>)> {
trace!("Creating server endpoint with TLS configuration");
let (quinn_server_config, server_cert) = create_server_config(tls_config)?;
let addr = SocketAddr::new(server_config.host.parse()?, server_config.port);
let endpoint = Endpoint::server(quinn_server_config, addr)?;
Ok((endpoint, server_cert))
}
async fn accept_connections(&self) -> Result<()> {
while let Some(conn) = self.endpoint.accept().await {
let connection = conn.await?;
let client_id = connection.remote_address().to_string();
trace!("New connection attempt from {}", client_id);
let client_connection = ClientConnection::new(connection).await?;
self.connections
.write()
.await
.insert(client_id.clone(), Arc::new(client_connection));
let server = self.clone();
let client_id = client_id.clone();
tokio::spawn(async move {
if let Err(e) = server.handle_client(client_id.clone()).await {
error!("Error handling client {}: {}", client_id, e);
}
});
}
Ok(())
}
async fn handle_client(&self, client_id: String) -> Result<()> {
trace!("Starting to handle client {}", client_id);
let connection = self
.connections
.read()
.await
.get(&client_id)
.cloned()
.ok_or_else(|| anyhow!("Client connection not found"))?;
let (send, recv) = connection.connection.accept_bi().await?;
trace!("Opened bi-directional stream for client {}", client_id);
let send_stream = SendStream::new("auth_send".to_string(), send);
let recv_stream = RecvStream::new("auth_recv".to_string(), recv);
let auth_message = recv_stream.recv_message().await?;
trace!("Received auth message from client {}", client_id);
if auth_message.message_type != MessageType::Auth {
return Err(anyhow!("Expected auth message"));
}
let auth_config: AuthConfig = serde_json::from_slice(&auth_message.data)?;
trace!("Parsed auth config from client {}", client_id);
let is_authenticated = match auth_config.auth_type {
AuthType::Basic(basic_auth) => self
.config
.auth
.validate(format!("{}:{}", basic_auth.username, basic_auth.password).as_str()),
AuthType::Bearer(bearer_auth) => self.config.auth.validate(bearer_auth.token.as_str()),
AuthType::None => true,
};
if !is_authenticated {
trace!("Authentication failed for client {}", client_id);
let response = Message::new(MessageType::Response, vec!["auth".to_string()], vec![0]);
send_stream.write_message(response).await?;
connection
.connection
.close(VarInt::from_u32(0), b"Invalid credentials");
return Err(anyhow!("Invalid credentials"));
}
trace!("Client {} successfully authenticated", client_id);
let response = Message::new(MessageType::Response, vec!["auth".to_string()], vec![200]);
send_stream.write_message(response).await?;
while let Ok(message) = recv_stream.recv_message().await {
trace!(
"Received message from client {}: {:?}",
client_id,
message.message_type
);
match message.message_type {
MessageType::Subscribe => {
for topic in message.topics {
self.topics
.write()
.await
.entry(topic.clone())
.or_default()
.push(Arc::new(send_stream.clone()));
debug!("Client {} subscribed to topic {}", client_id, topic);
trace!(
"Added subscription for client {} to topic {}",
client_id,
topic
);
}
}
MessageType::Message => {
for topic in message.topics.clone() {
trace!(
"Processing message for topic {} from client {}",
topic,
client_id
);
let message_clone = message.clone();
if let Some(subscribers) = self.topics.read().await.get(&topic) {
for subscriber in subscribers {
trace!("Forwarding message to subscriber {}", subscriber.id());
subscriber.write_message(message_clone.clone()).await?;
}
}
}
}
_ => {
trace!(
"Ignoring unsupported message type from client {}",
client_id
);
}
}
}
trace!("Client {} connection ended, starting cleanup", client_id);
self.cleanup_client(&client_id).await;
Ok(())
}
async fn cleanup_client(&self, client_id: &str) {
trace!("Starting cleanup for client {}", client_id);
self.connections.write().await.remove(client_id);
self.send_streams
.write()
.await
.retain(|k, _| !k.starts_with(client_id));
self.recv_streams
.write()
.await
.retain(|k, _| !k.starts_with(client_id));
let mut topics = self.topics.write().await;
for streams in topics.values_mut() {
streams.retain(|s| !s.id().starts_with(client_id));
}
topics.retain(|_, streams| !streams.is_empty());
debug!("Cleaned up resources for client {}", client_id);
trace!("Completed cleanup for client {}", client_id);
}
}