use anyhow::{anyhow, Context, Result};
use bytes::Bytes;
use quinn::{Connection, Endpoint};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
use tracing::{debug, error, info, trace};
use crate::common::auth::AuthConfig;
use crate::common::message::{Message, MessageType};
use crate::common::stream::{RecvStream, SendStream};
use crate::common::tls::create_client_config;
type MessageCallback = Box<dyn Fn(Message) + Send + Sync + 'static>;
type MessageCallbacks = Arc<RwLock<HashMap<MessageType, Vec<MessageCallback>>>>;
type ResponseSender = mpsc::Sender<Message>;
type ResponseReceivers = Arc<RwLock<HashMap<String, ResponseSender>>>;
pub struct RsubClient {
connection: Connection,
auth: AuthConfig,
send_stream: Arc<SendStream>,
recv_stream: Arc<RecvStream>,
callbacks: MessageCallbacks,
response_receivers: ResponseReceivers,
}
impl RsubClient {
pub async fn connect(addr: SocketAddr, ca_path: &str, auth: AuthConfig) -> Result<Self> {
trace!("Creating new client connection to {}", addr);
let client_config = create_client_config(ca_path).context("Failed to create TLS config")?;
let mut endpoint =
Endpoint::client("[::]:0".parse().unwrap()).context("Failed to create endpoint")?;
endpoint.set_default_client_config(client_config);
trace!("Attempting to establish QUIC connection");
let connection = endpoint
.connect(addr, "localhost")?
.await
.context("Failed to establish connection")?;
trace!("QUIC connection established successfully");
trace!("Opening bi-directional streams");
let (send, recv) = connection
.open_bi()
.await
.context("Failed to open authentication stream")?;
let send_stream = SendStream::new("auth_send".to_string(), send);
let recv_stream = RecvStream::new("auth_recv".to_string(), recv);
trace!("Bi-directional streams opened successfully");
trace!("Creating new RsubClient instance");
let client = Self {
connection,
auth,
send_stream: Arc::new(send_stream),
recv_stream: Arc::new(recv_stream),
callbacks: Arc::new(RwLock::new(HashMap::new())),
response_receivers: Arc::new(RwLock::new(HashMap::new())),
};
trace!("Starting message handling loop");
client.start_message_loop();
trace!("Beginning authentication process");
client.authenticate().await?;
Ok(client)
}
fn start_message_loop(&self) {
let recv_stream = self.recv_stream.clone();
let callbacks = self.callbacks.clone();
let response_receivers = self.response_receivers.clone();
tokio::spawn(async move {
trace!("Message handling loop started");
loop {
match recv_stream.recv_message().await {
Ok(message) => {
trace!("Received message: {:?}", message);
if let Some(handlers) = callbacks.read().await.get(&message.message_type) {
trace!(
"Found {} handlers for message type {:?}",
handlers.len(),
message.message_type
);
for handler in handlers {
trace!("Executing message handler");
handler(message.clone());
}
}
if message.message_type == MessageType::Response {
if let Some(topic) = message.topics.first() {
trace!("Processing response for topic: {}", topic);
if let Some(sender) = response_receivers.write().await.remove(topic)
{
trace!("Sending response to waiting receiver");
let _ = sender.send(message).await;
}
}
}
}
Err(e) => {
error!("Error reading message: {}", e);
trace!("Breaking message loop due to error");
break;
}
}
}
});
}
pub async fn on_message(
&self,
msg_type: MessageType,
callback: impl Fn(Message) + Send + Sync + 'static,
) {
trace!("Registering callback for message type {:?}", msg_type);
self.callbacks
.write()
.await
.entry(msg_type)
.or_default()
.push(Box::new(callback));
}
pub async fn wait_for_response(&self, tag: &str) -> Result<Message> {
trace!("Creating response receiver for tag: {}", tag);
let (tx, mut rx) = mpsc::channel(1);
self.response_receivers
.write()
.await
.insert(tag.to_string(), tx);
trace!("Waiting for response with tag: {}", tag);
Ok(rx.recv().await.context("Failed to receive response")?)
}
async fn authenticate(&self) -> Result<()> {
trace!("Serializing authentication config");
let auth_bytes =
serde_json::to_vec(&self.auth).context("Failed to serialize auth config")?;
trace!("Sending authentication message");
self.send_stream
.write_message(Message::new(MessageType::Auth, Vec::new(), auth_bytes))
.await
.context("Failed to send auth frame")?;
trace!("Waiting for authentication response");
let response = self.wait_for_response("auth").await?;
if response.data.is_empty() {
trace!("Authentication failed: empty response data");
return Err(anyhow!("Failed to authenticate"));
}
if response.data[0] != 200 {
trace!("Authentication failed: status code {}", response.data[0]);
return Err(anyhow!("Failed to authenticate"));
}
info!("Successfully authenticated with server");
trace!("Authentication process completed successfully");
Ok(())
}
pub async fn subscribe(
&self,
topics: Vec<String>,
callback: impl Fn(Message) + Send + Sync + 'static,
) -> Result<()> {
trace!("Registering callback for subscription messages");
self.on_message(MessageType::Message, callback).await;
trace!("Sending subscription request for topics: {:?}", topics);
self.send_stream
.write_message(Message::new(MessageType::Subscribe, topics, Vec::new()))
.await
.context("Failed to send subscription")?;
Ok(())
}
pub async fn publish(&self, topics: Vec<String>, data: Bytes) -> Result<()> {
trace!("Publishing message to topics: {:?}", topics);
self.send_stream
.write_message(Message::new(MessageType::Message, topics, data.to_vec()))
.await
.context("Failed to send message")?;
Ok(())
}
pub async fn close(&self) -> Result<()> {
trace!("Closing client connection");
self.connection
.close(0u32.into(), b"Client closed connection");
Ok(())
}
pub fn connection(&self) -> &Connection {
&self.connection
}
pub async fn is_closed(&self) -> bool {
let is_closed = self.connection.close_reason().is_some();
trace!("Connection closed status: {}", is_closed);
is_closed
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::common::auth::{AuthType, BasicAuth};
use tokio::runtime::Runtime;
const TEST_ADDR: &str = "127.0.0.1:4222";
const TEST_CERT: &str = "cert2.der";
const TEST_TOPIC: &str = "topic1";
#[test]
fn test_client_connection() {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::TRACE)
.init();
let runtime = Runtime::new().unwrap();
runtime.block_on(async {
let client = RsubClient::connect(
TEST_ADDR.parse().unwrap(),
TEST_CERT,
AuthConfig::new(AuthType::Basic(BasicAuth {
username: "admin".to_string(),
password: "secret".to_string(),
})),
)
.await
.expect("Failed to connect");
client
.subscribe(vec![TEST_TOPIC.to_string()], |message| {
debug!("Received message: {:?}", message);
})
.await
.expect("Failed to subscribe");
client
.publish(vec![TEST_TOPIC.to_string()], Bytes::from("Hello RSub"))
.await
.expect("Failed to publish");
tokio::time::sleep(std::time::Duration::from_secs(15)).await;
client.close().await.expect("Failed to close client");
});
}
}