use crate::common::error::{FlareError, Result};
use crate::transport::connection::Connection;
use crate::transport::events::{ArcObserver, ConnectionEvent};
use async_trait::async_trait;
use std::sync::{Arc, Mutex};
use tokio::sync::Mutex as TokioMutex;
pub struct QUICTransport {
send_stream: Arc<TokioMutex<quinn::SendStream>>,
observers: Arc<Mutex<Vec<ArcObserver>>>,
last_active: Arc<Mutex<std::time::Instant>>,
is_closed: Arc<Mutex<bool>>,
}
impl QUICTransport {
pub fn new(send_stream: quinn::SendStream, recv_stream: quinn::RecvStream) -> Self {
let observers = Arc::new(Mutex::new(Vec::new()));
let last_active = Arc::new(Mutex::new(std::time::Instant::now()));
let is_closed = Arc::new(Mutex::new(false));
let send_stream = Arc::new(TokioMutex::new(send_stream));
let task_recv = Arc::new(TokioMutex::new(recv_stream));
let task_observers = Arc::clone(&observers);
let task_last_active = Arc::clone(&last_active);
let task_is_closed = Arc::clone(&is_closed);
tokio::spawn(Self::receiver_task(
task_recv,
task_observers,
task_last_active,
task_is_closed,
));
Self {
send_stream,
observers,
last_active,
is_closed,
}
}
async fn receiver_task(
recv_stream: Arc<TokioMutex<quinn::RecvStream>>,
observers_arc: Arc<Mutex<Vec<ArcObserver>>>,
last_active: Arc<Mutex<std::time::Instant>>,
is_closed: Arc<Mutex<bool>>,
) {
use tracing::debug;
loop {
if let Ok(closed) = is_closed.lock() {
if *closed {
debug!("[QUIC Transport] Receiver task: connection closed");
break;
}
}
let mut recv = recv_stream.lock().await;
if let Ok(mut active) = last_active.lock() {
*active = std::time::Instant::now();
}
match Self::read_stream(&mut *recv).await {
Ok(data) => {
if !data.is_empty() {
debug!("[QUIC Transport] Received message: {} bytes", data.len());
Self::_notify_observers(&observers_arc, &ConnectionEvent::Message(data));
} else {
debug!("[QUIC Transport] Stream EOF, closing");
Self::_notify_observers(
&observers_arc,
&ConnectionEvent::Disconnected("Stream closed by peer".to_string()),
);
break;
}
}
Err(e) => {
debug!("[QUIC Transport] Read error: {}", e);
Self::_notify_observers(
&observers_arc,
&ConnectionEvent::Error(FlareError::io(e.to_string())),
);
break;
}
}
drop(recv);
}
debug!("[QUIC Transport] Receiver task ended");
}
async fn read_stream(recv: &mut quinn::RecvStream) -> Result<Vec<u8>> {
let mut length_buf = [0u8; 4];
let mut length_bytes_read = 0;
while length_bytes_read < 4 {
match recv.read(&mut length_buf[length_bytes_read..]).await {
Ok(Some(0)) | Ok(None) => {
if length_bytes_read == 0 {
return Ok(Vec::new());
} else {
return Err(FlareError::io("Stream closed while reading length prefix".to_string()));
}
}
Ok(Some(n)) => {
length_bytes_read += n;
}
Err(e) => {
return Err(FlareError::io(e.to_string()));
}
}
}
let length = u32::from_be_bytes(length_buf) as usize;
if length == 0 {
return Ok(Vec::new());
}
if length > 10 * 1024 * 1024 {
return Err(FlareError::io(format!("Message too large: {} bytes", length)));
}
let mut buf = vec![0u8; length];
let mut bytes_read = 0;
while bytes_read < length {
match recv.read(&mut buf[bytes_read..]).await {
Ok(Some(0)) | Ok(None) => {
return Err(FlareError::io(format!("Stream closed while reading message: expected {} bytes, got {}", length, bytes_read)));
}
Ok(Some(n)) => {
bytes_read += n;
}
Err(e) => {
return Err(FlareError::io(e.to_string()));
}
}
}
Ok(buf)
}
fn _notify_observers(
observers_arc: &Arc<Mutex<Vec<ArcObserver>>>,
event: &ConnectionEvent,
) {
if let Ok(observers) = observers_arc.lock() {
for observer in observers.iter() {
observer.on_event(event);
}
}
}
fn notify_observers(&self, event: &ConnectionEvent) {
Self::_notify_observers(&self.observers, event);
}
}
#[async_trait]
impl Connection for QUICTransport {
fn add_observer(&mut self, observer: ArcObserver) {
observer.on_event(&ConnectionEvent::Connected);
if let Ok(mut observers) = self.observers.lock() {
observers.push(observer);
}
}
fn remove_observer(&mut self, observer: ArcObserver) {
if let Ok(mut observers) = self.observers.lock() {
observers.retain(|o| !Arc::ptr_eq(o, &observer));
}
}
async fn send(&mut self, data: &[u8]) -> Result<()> {
if let Ok(mut active) = self.last_active.lock() {
*active = std::time::Instant::now();
}
let mut send = self.send_stream.lock().await;
let length = data.len() as u32;
let length_bytes = length.to_be_bytes();
send.write_all(&length_bytes)
.await
.map_err(|e| FlareError::io(e.to_string()))?;
send.write_all(data)
.await
.map_err(|e| FlareError::io(e.to_string()))?;
Ok(())
}
async fn close(&mut self) -> Result<()> {
if let Ok(mut closed) = self.is_closed.lock() {
*closed = true;
}
let mut send = self.send_stream.lock().await;
send.finish()
.map_err(|e| FlareError::io(e.to_string()))?;
self.notify_observers(&ConnectionEvent::Disconnected("Closed by client".to_string()));
Ok(())
}
fn last_active_time(&self) -> std::time::Instant {
self.last_active
.lock()
.map(|guard| *guard)
.unwrap_or_else(|_| {
std::time::Instant::now() - std::time::Duration::from_secs(3600)
})
}
fn update_active_time(&mut self) {
if let Ok(mut active) = self.last_active.lock() {
*active = std::time::Instant::now();
}
}
}